From 0714d9f6d38c65d87fc4523e9d9b471d535dcc8a Mon Sep 17 00:00:00 2001 From: Johnny Deuss Date: Thu, 19 Jan 2023 12:50:56 +0000 Subject: [PATCH] Fix middleware being patched multiple times when using FastAPI (#1841) * Fix middleware being patched multiple times when using FastAPI --- sentry_sdk/integrations/starlette.py | 118 ++++++++++++++------------- 1 file changed, 63 insertions(+), 55 deletions(-) diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index b35e1c9fac..aec194a779 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -168,62 +168,66 @@ def patch_exception_middleware(middleware_class): """ old_middleware_init = middleware_class.__init__ - def _sentry_middleware_init(self, *args, **kwargs): - # type: (Any, Any, Any) -> None - old_middleware_init(self, *args, **kwargs) + not_yet_patched = "_sentry_middleware_init" not in str(old_middleware_init) - # Patch existing exception handlers - old_handlers = self._exception_handlers.copy() + if not_yet_patched: - async def _sentry_patched_exception_handler(self, *args, **kwargs): + def _sentry_middleware_init(self, *args, **kwargs): # type: (Any, Any, Any) -> None - exp = args[0] - - is_http_server_error = ( - hasattr(exp, "status_code") and exp.status_code >= 500 - ) - if is_http_server_error: - _capture_exception(exp, handled=True) - - # Find a matching handler - old_handler = None - for cls in type(exp).__mro__: - if cls in old_handlers: - old_handler = old_handlers[cls] - break - - if old_handler is None: - return - - if _is_async_callable(old_handler): - return await old_handler(self, *args, **kwargs) - else: - return old_handler(self, *args, **kwargs) + old_middleware_init(self, *args, **kwargs) - for key in self._exception_handlers.keys(): - self._exception_handlers[key] = _sentry_patched_exception_handler + # Patch existing exception handlers + old_handlers = self._exception_handlers.copy() - middleware_class.__init__ = _sentry_middleware_init + async def _sentry_patched_exception_handler(self, *args, **kwargs): + # type: (Any, Any, Any) -> None + exp = args[0] - old_call = middleware_class.__call__ - - async def _sentry_exceptionmiddleware_call(self, scope, receive, send): - # type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None - # Also add the user (that was eventually set by be Authentication middle - # that was called before this middleware). This is done because the authentication - # middleware sets the user in the scope and then (in the same function) - # calls this exception middelware. In case there is no exception (or no handler - # for the type of exception occuring) then the exception bubbles up and setting the - # user information into the sentry scope is done in auth middleware and the - # ASGI middleware will then send everything to Sentry and this is fine. - # But if there is an exception happening that the exception middleware here - # has a handler for, it will send the exception directly to Sentry, so we need - # the user information right now. - # This is why we do it here. - _add_user_to_sentry_scope(scope) - await old_call(self, scope, receive, send) - - middleware_class.__call__ = _sentry_exceptionmiddleware_call + is_http_server_error = ( + hasattr(exp, "status_code") and exp.status_code >= 500 + ) + if is_http_server_error: + _capture_exception(exp, handled=True) + + # Find a matching handler + old_handler = None + for cls in type(exp).__mro__: + if cls in old_handlers: + old_handler = old_handlers[cls] + break + + if old_handler is None: + return + + if _is_async_callable(old_handler): + return await old_handler(self, *args, **kwargs) + else: + return old_handler(self, *args, **kwargs) + + for key in self._exception_handlers.keys(): + self._exception_handlers[key] = _sentry_patched_exception_handler + + middleware_class.__init__ = _sentry_middleware_init + + old_call = middleware_class.__call__ + + async def _sentry_exceptionmiddleware_call(self, scope, receive, send): + # type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None + # Also add the user (that was eventually set by be Authentication middle + # that was called before this middleware). This is done because the authentication + # middleware sets the user in the scope and then (in the same function) + # calls this exception middelware. In case there is no exception (or no handler + # for the type of exception occuring) then the exception bubbles up and setting the + # user information into the sentry scope is done in auth middleware and the + # ASGI middleware will then send everything to Sentry and this is fine. + # But if there is an exception happening that the exception middleware here + # has a handler for, it will send the exception directly to Sentry, so we need + # the user information right now. + # This is why we do it here. + _add_user_to_sentry_scope(scope) + await old_call(self, scope, receive, send) + + middleware_class.__call__ = _sentry_exceptionmiddleware_call def _add_user_to_sentry_scope(scope): @@ -268,12 +272,16 @@ def patch_authentication_middleware(middleware_class): """ old_call = middleware_class.__call__ - async def _sentry_authenticationmiddleware_call(self, scope, receive, send): - # type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None - await old_call(self, scope, receive, send) - _add_user_to_sentry_scope(scope) + not_yet_patched = "_sentry_authenticationmiddleware_call" not in str(old_call) + + if not_yet_patched: + + async def _sentry_authenticationmiddleware_call(self, scope, receive, send): + # type: (Dict[str, Any], Dict[str, Any], Callable[[], Awaitable[Dict[str, Any]]], Callable[[Dict[str, Any]], Awaitable[None]]) -> None + await old_call(self, scope, receive, send) + _add_user_to_sentry_scope(scope) - middleware_class.__call__ = _sentry_authenticationmiddleware_call + middleware_class.__call__ = _sentry_authenticationmiddleware_call def patch_middlewares():