Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix middleware being patched multiple times when using FastAPI #1841

Merged
merged 8 commits into from Jan 19, 2023
77 changes: 42 additions & 35 deletions sentry_sdk/integrations/starlette.py
Expand Up @@ -168,44 +168,47 @@ def patch_exception_middleware(middleware_class):
"""
old_middleware_init = middleware_class.__init__

def _sentry_middleware_init(self, *args, **kwargs):
# type: (Any, Any, Any) -> None
old_middleware_init(self, *args, **kwargs)

# Patch existing exception handlers
old_handlers = self._exception_handlers.copy()
not_yet_patched = "_sentry_middleware_init" not in str(old_middleware_init)

async def _sentry_patched_exception_handler(self, *args, **kwargs):
if not_yet_patched:
def _sentry_middleware_init(self, *args, **kwargs):
# type: (Any, Any, Any) -> None
exp = args[0]
old_middleware_init(self, *args, **kwargs)

is_http_server_error = (
hasattr(exp, "status_code") and exp.status_code >= 500
)
if is_http_server_error:
_capture_exception(exp, handled=True)

# Find a matching handler
old_handler = None
for cls in type(exp).__mro__:
if cls in old_handlers:
old_handler = old_handlers[cls]
break
# Patch existing exception handlers
old_handlers = self._exception_handlers.copy()

async def _sentry_patched_exception_handler(self, *args, **kwargs):
# type: (Any, Any, Any) -> None
exp = args[0]

if old_handler is None:
return
is_http_server_error = (
hasattr(exp, "status_code") and exp.status_code >= 500
)
if is_http_server_error:
_capture_exception(exp, handled=True)

if _is_async_callable(old_handler):
return await old_handler(self, *args, **kwargs)
else:
return old_handler(self, *args, **kwargs)
# Find a matching handler
old_handler = None
for cls in type(exp).__mro__:
if cls in old_handlers:
old_handler = old_handlers[cls]
break

for key in self._exception_handlers.keys():
self._exception_handlers[key] = _sentry_patched_exception_handler
if old_handler is None:
return

middleware_class.__init__ = _sentry_middleware_init
if _is_async_callable(old_handler):
return await old_handler(self, *args, **kwargs)
else:
return old_handler(self, *args, **kwargs)

old_call = middleware_class.__call__
for key in self._exception_handlers.keys():
self._exception_handlers[key] = _sentry_patched_exception_handler

middleware_class.__init__ = _sentry_middleware_init

old_call = middleware_class.__call__

async def _sentry_exceptionmiddleware_call(self, scope, receive, send):
# type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None
Expand Down Expand Up @@ -268,12 +271,16 @@ def patch_authentication_middleware(middleware_class):
"""
old_call = middleware_class.__call__

async def _sentry_authenticationmiddleware_call(self, scope, receive, send):
# type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None
await old_call(self, scope, receive, send)
_add_user_to_sentry_scope(scope)
not_yet_patched = "_sentry_authenticationmiddleware_call" not in str(old_call)

if not_yet_patched:

async def _sentry_authenticationmiddleware_call(self, scope, receive, send):
# type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None
await old_call(self, scope, receive, send)
_add_user_to_sentry_scope(scope)

middleware_class.__call__ = _sentry_authenticationmiddleware_call
middleware_class.__call__ = _sentry_authenticationmiddleware_call


def patch_middlewares():
Expand Down