From 4411f926b5603932ce81629cd27ea5c9f9e8b3c4 Mon Sep 17 00:00:00 2001 From: Tony Xiao Date: Wed, 7 Dec 2022 10:55:19 -0500 Subject: [PATCH] feat(profiling): Set active thread id for ASGI frameworks When running in ASGI sync views, the transaction gets started in the main thread then the request is dispatched to a handler thread. We want to set the handler thread as the active thread id to ensure that profiles will show it on first render. --- sentry_sdk/client.py | 4 +- sentry_sdk/integrations/asgi.py | 3 +- sentry_sdk/integrations/django/asgi.py | 13 ++-- sentry_sdk/integrations/django/views.py | 16 +++-- sentry_sdk/integrations/fastapi.py | 19 ++++++ sentry_sdk/integrations/quart.py | 68 ++++++++++++++++--- sentry_sdk/integrations/starlette.py | 6 ++ sentry_sdk/profiler.py | 26 ++++--- sentry_sdk/scope.py | 30 ++++---- tests/integrations/django/asgi/test_asgi.py | 37 ++++++++++ tests/integrations/django/myapp/urls.py | 6 ++ tests/integrations/django/myapp/views.py | 23 +++++++ tests/integrations/fastapi/test_fastapi.py | 46 +++++++++++++ tests/integrations/quart/test_quart.py | 44 ++++++++++++ .../integrations/starlette/test_starlette.py | 48 +++++++++++++ tests/integrations/wsgi/test_wsgi.py | 2 +- 16 files changed, 345 insertions(+), 46 deletions(-) diff --git a/sentry_sdk/client.py b/sentry_sdk/client.py index d32d014d96..8af7003156 100644 --- a/sentry_sdk/client.py +++ b/sentry_sdk/client.py @@ -433,9 +433,7 @@ def capture_event( if is_transaction: if profile is not None: - envelope.add_profile( - profile.to_json(event_opt, self.options, scope) - ) + envelope.add_profile(profile.to_json(event_opt, self.options)) envelope.add_transaction(event_opt) else: envelope.add_event(event_opt) diff --git a/sentry_sdk/integrations/asgi.py b/sentry_sdk/integrations/asgi.py index cfeaf4d298..f34f10dc85 100644 --- a/sentry_sdk/integrations/asgi.py +++ b/sentry_sdk/integrations/asgi.py @@ -14,6 +14,7 @@ from sentry_sdk.hub import Hub, _should_send_default_pii from sentry_sdk.integrations._wsgi_common import _filter_headers from sentry_sdk.integrations.modules import _get_installed_modules +from sentry_sdk.profiler import start_profiling from sentry_sdk.sessions import auto_session_tracking from sentry_sdk.tracing import ( SOURCE_FOR_STYLE, @@ -175,7 +176,7 @@ async def _run_app(self, scope, callback): with hub.start_transaction( transaction, custom_sampling_context={"asgi_scope": scope} - ): + ), start_profiling(transaction, hub): # XXX: Would be cool to have correct span status, but we # would have to wrap send(). That is a bit hard to do with # the current abstraction over ASGI 2/3. diff --git a/sentry_sdk/integrations/django/asgi.py b/sentry_sdk/integrations/django/asgi.py index 5803a7e29b..955d8d19e8 100644 --- a/sentry_sdk/integrations/django/asgi.py +++ b/sentry_sdk/integrations/django/asgi.py @@ -7,6 +7,7 @@ """ import asyncio +import threading from sentry_sdk import Hub, _functools from sentry_sdk._types import MYPY @@ -89,10 +90,14 @@ def wrap_async_view(hub, callback): async def sentry_wrapped_callback(request, *args, **kwargs): # type: (Any, *Any, **Any) -> Any - with hub.start_span( - op=OP.VIEW_RENDER, description=request.resolver_match.view_name - ): - return await callback(request, *args, **kwargs) + with hub.configure_scope() as sentry_scope: + if sentry_scope.profile is not None: + sentry_scope.profile.active_thread_id = threading.current_thread().ident + + with hub.start_span( + op=OP.VIEW_RENDER, description=request.resolver_match.view_name + ): + return await callback(request, *args, **kwargs) return sentry_wrapped_callback diff --git a/sentry_sdk/integrations/django/views.py b/sentry_sdk/integrations/django/views.py index 33ddce24d6..735822aa72 100644 --- a/sentry_sdk/integrations/django/views.py +++ b/sentry_sdk/integrations/django/views.py @@ -1,3 +1,5 @@ +import threading + from sentry_sdk.consts import OP from sentry_sdk.hub import Hub from sentry_sdk._types import MYPY @@ -73,9 +75,15 @@ def _wrap_sync_view(hub, callback): @_functools.wraps(callback) def sentry_wrapped_callback(request, *args, **kwargs): # type: (Any, *Any, **Any) -> Any - with hub.start_span( - op=OP.VIEW_RENDER, description=request.resolver_match.view_name - ): - return callback(request, *args, **kwargs) + with hub.configure_scope() as sentry_scope: + # set the active thread id to the handler thread for sync views + # this isn't necessary for async views since that runs on main + if sentry_scope.profile is not None: + sentry_scope.profile.active_thread_id = threading.current_thread().ident + + with hub.start_span( + op=OP.VIEW_RENDER, description=request.resolver_match.view_name + ): + return callback(request, *args, **kwargs) return sentry_wrapped_callback diff --git a/sentry_sdk/integrations/fastapi.py b/sentry_sdk/integrations/fastapi.py index d38e978fbf..077e2bc23f 100644 --- a/sentry_sdk/integrations/fastapi.py +++ b/sentry_sdk/integrations/fastapi.py @@ -1,3 +1,6 @@ +import asyncio +import threading + from sentry_sdk._types import MYPY from sentry_sdk.hub import Hub, _should_send_default_pii from sentry_sdk.integrations import DidNotEnable @@ -62,6 +65,22 @@ def patch_get_request_handler(): def _sentry_get_request_handler(*args, **kwargs): # type: (*Any, **Any) -> Any + dependant = kwargs.get("dependant") + if dependant and not asyncio.iscoroutinefunction(dependant.call): + old_call = dependant.call + + def _sentry_call(*args, **kwargs): + # type: (*Any, **Any) -> Any + hub = Hub.current + with hub.configure_scope() as sentry_scope: + if sentry_scope.profile is not None: + sentry_scope.profile.active_thread_id = ( + threading.current_thread().ident + ) + return old_call(*args, **kwargs) + + dependant.call = _sentry_call + old_app = old_get_request_handler(*args, **kwargs) async def _sentry_app(*args, **kwargs): diff --git a/sentry_sdk/integrations/quart.py b/sentry_sdk/integrations/quart.py index e1d4228651..7e126f205a 100644 --- a/sentry_sdk/integrations/quart.py +++ b/sentry_sdk/integrations/quart.py @@ -1,5 +1,8 @@ from __future__ import absolute_import +import inspect +import threading + from sentry_sdk.hub import _should_send_default_pii, Hub from sentry_sdk.integrations import DidNotEnable, Integration from sentry_sdk.integrations._wsgi_common import _filter_headers @@ -11,6 +14,7 @@ event_from_exception, ) +from sentry_sdk._functools import wraps from sentry_sdk._types import MYPY if MYPY: @@ -34,6 +38,7 @@ request, websocket, ) + from quart.scaffold import Scaffold # type: ignore from quart.signals import ( # type: ignore got_background_exception, got_request_exception, @@ -41,6 +46,7 @@ request_started, websocket_started, ) + from quart.utils import is_coroutine_function # type: ignore except ImportError: raise DidNotEnable("Quart is not installed") @@ -71,18 +77,62 @@ def setup_once(): got_request_exception.connect(_capture_exception) got_websocket_exception.connect(_capture_exception) - old_app = Quart.__call__ + patch_asgi_app() + patch_scaffold_route() + + +def patch_asgi_app(): + # type: () -> None + old_app = Quart.__call__ + + async def sentry_patched_asgi_app(self, scope, receive, send): + # type: (Any, Any, Any, Any) -> Any + if Hub.current.get_integration(QuartIntegration) is None: + return await old_app(self, scope, receive, send) + + middleware = SentryAsgiMiddleware(lambda *a, **kw: old_app(self, *a, **kw)) + middleware.__call__ = middleware._run_asgi3 + return await middleware(scope, receive, send) + + Quart.__call__ = sentry_patched_asgi_app + + +def patch_scaffold_route(): + # type: () -> None + old_route = Scaffold.route + + def _sentry_route(*args, **kwargs): + # type: (*Any, **Any) -> Any + old_decorator = old_route(*args, **kwargs) + + def decorator(old_func): + # type: (Any) -> Any + + if inspect.isfunction(old_func) and not is_coroutine_function(old_func): + + @wraps(old_func) + def _sentry_func(*args, **kwargs): + # type: (*Any, **Any) -> Any + hub = Hub.current + integration = hub.get_integration(QuartIntegration) + if integration is None: + return old_func(*args, **kwargs) + + with hub.configure_scope() as sentry_scope: + if sentry_scope.profile is not None: + sentry_scope.profile.active_thread_id = ( + threading.current_thread().ident + ) + + return old_func(*args, **kwargs) + + return old_decorator(_sentry_func) - async def sentry_patched_asgi_app(self, scope, receive, send): - # type: (Any, Any, Any, Any) -> Any - if Hub.current.get_integration(QuartIntegration) is None: - return await old_app(self, scope, receive, send) + return old_decorator(old_func) - middleware = SentryAsgiMiddleware(lambda *a, **kw: old_app(self, *a, **kw)) - middleware.__call__ = middleware._run_asgi3 - return await middleware(scope, receive, send) + return decorator - Quart.__call__ = sentry_patched_asgi_app + Scaffold.route = _sentry_route def _set_transaction_name_and_source(scope, transaction_style, request): diff --git a/sentry_sdk/integrations/starlette.py b/sentry_sdk/integrations/starlette.py index 155c840461..b35e1c9fac 100644 --- a/sentry_sdk/integrations/starlette.py +++ b/sentry_sdk/integrations/starlette.py @@ -2,6 +2,7 @@ import asyncio import functools +import threading from sentry_sdk._compat import iteritems from sentry_sdk._types import MYPY @@ -403,6 +404,11 @@ def _sentry_sync_func(*args, **kwargs): return old_func(*args, **kwargs) with hub.configure_scope() as sentry_scope: + if sentry_scope.profile is not None: + sentry_scope.profile.active_thread_id = ( + threading.current_thread().ident + ) + request = args[0] _set_transaction_name_and_source( diff --git a/sentry_sdk/profiler.py b/sentry_sdk/profiler.py index 43bedcf383..5c47dcd020 100644 --- a/sentry_sdk/profiler.py +++ b/sentry_sdk/profiler.py @@ -47,7 +47,6 @@ from typing import Sequence from typing import Tuple from typing_extensions import TypedDict - import sentry_sdk.scope import sentry_sdk.tracing StackId = int @@ -305,6 +304,7 @@ def __init__( self.scheduler = scheduler self.transaction = transaction self.hub = hub + self.active_thread_id = None # type: Optional[int] self._start_ns = None # type: Optional[int] self._stop_ns = None # type: Optional[int] @@ -312,6 +312,14 @@ def __init__( def __enter__(self): # type: () -> None + hub = self.hub or sentry_sdk.Hub.current + + _, scope = hub._stack[-1] + old_profile = scope.profile + scope.profile = self + + self._context_manager_state = (hub, scope, old_profile) + self._start_ns = nanosecond_time() self.scheduler.start_profiling() @@ -320,8 +328,13 @@ def __exit__(self, ty, value, tb): self.scheduler.stop_profiling() self._stop_ns = nanosecond_time() - def to_json(self, event_opt, options, scope): - # type: (Any, Dict[str, Any], Optional[sentry_sdk.scope.Scope]) -> Dict[str, Any] + _, scope, old_profile = self._context_manager_state + del self._context_manager_state + + scope.profile = old_profile + + def to_json(self, event_opt, options): + # type: (Any, Dict[str, Any]) -> Dict[str, Any] assert self._start_ns is not None assert self._stop_ns is not None @@ -333,9 +346,6 @@ def to_json(self, event_opt, options, scope): profile["frames"], options["in_app_exclude"], options["in_app_include"] ) - # the active thread id from the scope always take priorty if it exists - active_thread_id = None if scope is None else scope.active_thread_id - return { "environment": event_opt.get("environment"), "event_id": uuid.uuid4().hex, @@ -369,8 +379,8 @@ def to_json(self, event_opt, options, scope): "trace_id": self.transaction.trace_id, "active_thread_id": str( self.transaction._active_thread_id - if active_thread_id is None - else active_thread_id + if self.active_thread_id is None + else self.active_thread_id ), } ], diff --git a/sentry_sdk/scope.py b/sentry_sdk/scope.py index f5ac270914..7d9b4f5177 100644 --- a/sentry_sdk/scope.py +++ b/sentry_sdk/scope.py @@ -27,6 +27,7 @@ Type, ) + from sentry_sdk.profiler import Profile from sentry_sdk.tracing import Span from sentry_sdk.session import Session @@ -94,10 +95,7 @@ class Scope(object): "_session", "_attachments", "_force_auto_session_tracking", - # The thread that is handling the bulk of the work. This can just - # be the main thread, but that's not always true. For web frameworks, - # this would be the thread handling the request. - "_active_thread_id", + "_profile", ) def __init__(self): @@ -129,7 +127,7 @@ def clear(self): self._session = None # type: Optional[Session] self._force_auto_session_tracking = None # type: Optional[bool] - self._active_thread_id = None # type: Optional[int] + self._profile = None # type: Optional[Profile] @_attr_setter def level(self, value): @@ -235,15 +233,15 @@ def span(self, span): self._transaction = transaction.name @property - def active_thread_id(self): - # type: () -> Optional[int] - """Get/set the current active thread id.""" - return self._active_thread_id + def profile(self): + # type: () -> Optional[Profile] + return self._profile - def set_active_thread_id(self, active_thread_id): - # type: (Optional[int]) -> None - """Set the current active thread id.""" - self._active_thread_id = active_thread_id + @profile.setter + def profile(self, profile): + # type: (Optional[Profile]) -> None + + self._profile = profile def set_tag( self, @@ -464,8 +462,8 @@ def update_from_scope(self, scope): self._span = scope._span if scope._attachments: self._attachments.extend(scope._attachments) - if scope._active_thread_id is not None: - self._active_thread_id = scope._active_thread_id + if scope._profile: + self._profile = scope._profile def update_from_kwargs( self, @@ -515,7 +513,7 @@ def __copy__(self): rv._force_auto_session_tracking = self._force_auto_session_tracking rv._attachments = list(self._attachments) - rv._active_thread_id = self._active_thread_id + rv._profile = self._profile return rv diff --git a/tests/integrations/django/asgi/test_asgi.py b/tests/integrations/django/asgi/test_asgi.py index 70fd416188..0652a5fdcb 100644 --- a/tests/integrations/django/asgi/test_asgi.py +++ b/tests/integrations/django/asgi/test_asgi.py @@ -1,3 +1,5 @@ +import json + import django import pytest from channels.testing import HttpCommunicator @@ -70,6 +72,41 @@ async def test_async_views(sentry_init, capture_events, application): } +@pytest.mark.parametrize("application", APPS) +@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"]) +@pytest.mark.asyncio +@pytest.mark.skipif( + django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1" +) +async def test_active_thread_id(sentry_init, capture_envelopes, endpoint, application): + sentry_init( + integrations=[DjangoIntegration()], + traces_sample_rate=1.0, + _experiments={"profiles_sample_rate": 1.0}, + ) + + envelopes = capture_envelopes() + + comm = HttpCommunicator(application, "GET", endpoint) + response = await comm.get_response() + assert response["status"] == 200, response["body"] + + await comm.wait() + + data = json.loads(response["body"]) + + envelopes = [envelope for envelope in envelopes] + assert len(envelopes) == 1 + + profiles = [item for item in envelopes[0].items if item.type == "profile"] + assert len(profiles) == 1 + + for profile in profiles: + transactions = profile.payload.json["transactions"] + assert len(transactions) == 1 + assert str(data["active"]) == transactions[0]["active_thread_id"] + + @pytest.mark.asyncio @pytest.mark.skipif( django.VERSION < (3, 1), reason="async views have been introduced in Django 3.1" diff --git a/tests/integrations/django/myapp/urls.py b/tests/integrations/django/myapp/urls.py index 376261abcf..ee357c843b 100644 --- a/tests/integrations/django/myapp/urls.py +++ b/tests/integrations/django/myapp/urls.py @@ -58,6 +58,7 @@ def path(path, *args, **kwargs): views.csrf_hello_not_exempt, name="csrf_hello_not_exempt", ), + path("sync/thread_ids", views.thread_ids_sync, name="thread_ids_sync"), ] # async views @@ -67,6 +68,11 @@ def path(path, *args, **kwargs): if views.my_async_view is not None: urlpatterns.append(path("my_async_view", views.my_async_view, name="my_async_view")) +if views.thread_ids_async is not None: + urlpatterns.append( + path("async/thread_ids", views.thread_ids_async, name="thread_ids_async") + ) + # rest framework try: urlpatterns.append( diff --git a/tests/integrations/django/myapp/views.py b/tests/integrations/django/myapp/views.py index bee5e656d3..dbf266e1ab 100644 --- a/tests/integrations/django/myapp/views.py +++ b/tests/integrations/django/myapp/views.py @@ -1,3 +1,6 @@ +import json +import threading + from django import VERSION from django.contrib.auth import login from django.contrib.auth.models import User @@ -159,6 +162,16 @@ def csrf_hello_not_exempt(*args, **kwargs): return HttpResponse("ok") +def thread_ids_sync(*args, **kwargs): + response = json.dumps( + { + "main": threading.main_thread().ident, + "active": threading.current_thread().ident, + } + ) + return HttpResponse(response) + + if VERSION >= (3, 1): # Use exec to produce valid Python 2 exec( @@ -173,6 +186,16 @@ def csrf_hello_not_exempt(*args, **kwargs): await asyncio.sleep(1) return HttpResponse('Hello World')""" ) + + exec( + """async def thread_ids_async(request): + response = json.dumps({ + "main": threading.main_thread().ident, + "active": threading.current_thread().ident, + }) + return HttpResponse(response)""" + ) else: async_message = None my_async_view = None + thread_ids_async = None diff --git a/tests/integrations/fastapi/test_fastapi.py b/tests/integrations/fastapi/test_fastapi.py index bc61cfc263..9c24ce2e44 100644 --- a/tests/integrations/fastapi/test_fastapi.py +++ b/tests/integrations/fastapi/test_fastapi.py @@ -1,3 +1,6 @@ +import json +import threading + import pytest from sentry_sdk.integrations.fastapi import FastApiIntegration @@ -23,6 +26,20 @@ async def _message_with_id(message_id): capture_message("Hi") return {"message": "Hi"} + @app.get("/sync/thread_ids") + def _thread_ids_sync(): + return { + "main": str(threading.main_thread().ident), + "active": str(threading.current_thread().ident), + } + + @app.get("/async/thread_ids") + async def _thread_ids_async(): + return { + "main": str(threading.main_thread().ident), + "active": str(threading.current_thread().ident), + } + return app @@ -135,3 +152,32 @@ def test_legacy_setup( (event,) = events assert event["transaction"] == "/message/{message_id}" + + +@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"]) +def test_active_thread_id(sentry_init, capture_envelopes, endpoint): + sentry_init( + traces_sample_rate=1.0, + _experiments={"profiles_sample_rate": 1.0}, + ) + app = fastapi_app_factory() + asgi_app = SentryAsgiMiddleware(app) + + envelopes = capture_envelopes() + + client = TestClient(asgi_app) + response = client.get(endpoint) + assert response.status_code == 200 + + data = json.loads(response.content) + + envelopes = [envelope for envelope in envelopes] + assert len(envelopes) == 1 + + profiles = [item for item in envelopes[0].items if item.type == "profile"] + assert len(profiles) == 1 + + for profile in profiles: + transactions = profile.payload.json["transactions"] + assert len(transactions) == 1 + assert str(data["active"]) == transactions[0]["active_thread_id"] diff --git a/tests/integrations/quart/test_quart.py b/tests/integrations/quart/test_quart.py index 6d2c590a53..bda2c1013e 100644 --- a/tests/integrations/quart/test_quart.py +++ b/tests/integrations/quart/test_quart.py @@ -1,3 +1,6 @@ +import json +import threading + import pytest import pytest_asyncio @@ -41,6 +44,20 @@ async def hi_with_id(message_id): capture_message("hi with id") return "ok with id" + @app.get("/sync/thread_ids") + def _thread_ids_sync(): + return { + "main": str(threading.main_thread().ident), + "active": str(threading.current_thread().ident), + } + + @app.get("/async/thread_ids") + async def _thread_ids_async(): + return { + "main": str(threading.main_thread().ident), + "active": str(threading.current_thread().ident), + } + return app @@ -523,3 +540,30 @@ async def dispatch_request(self): assert event["message"] == "hi" assert event["transaction"] == "hello_class" + + +@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"]) +async def test_active_thread_id(sentry_init, capture_envelopes, endpoint, app): + sentry_init( + traces_sample_rate=1.0, + _experiments={"profiles_sample_rate": 1.0}, + ) + + envelopes = capture_envelopes() + + async with app.test_client() as client: + response = await client.get(endpoint) + assert response.status_code == 200 + + data = json.loads(response.content) + + envelopes = [envelope for envelope in envelopes] + assert len(envelopes) == 1 + + profiles = [item for item in envelopes[0].items if item.type == "profile"] + assert len(profiles) == 1 + + for profile in profiles: + transactions = profile.payload.json["transactions"] + assert len(transactions) == 1 + assert str(data["active"]) == transactions[0]["active_thread_id"] diff --git a/tests/integrations/starlette/test_starlette.py b/tests/integrations/starlette/test_starlette.py index e41e6d5d19..a279142995 100644 --- a/tests/integrations/starlette/test_starlette.py +++ b/tests/integrations/starlette/test_starlette.py @@ -3,6 +3,7 @@ import functools import json import os +import threading import pytest @@ -108,6 +109,22 @@ async def _message_with_id(request): capture_message("hi") return starlette.responses.JSONResponse({"status": "ok"}) + def _thread_ids_sync(request): + return starlette.responses.JSONResponse( + { + "main": threading.main_thread().ident, + "active": threading.current_thread().ident, + } + ) + + async def _thread_ids_async(request): + return starlette.responses.JSONResponse( + { + "main": threading.main_thread().ident, + "active": threading.current_thread().ident, + } + ) + app = starlette.applications.Starlette( debug=debug, routes=[ @@ -115,6 +132,8 @@ async def _message_with_id(request): starlette.routing.Route("/custom_error", _custom_error), starlette.routing.Route("/message", _message), starlette.routing.Route("/message/{message_id}", _message_with_id), + starlette.routing.Route("/sync/thread_ids", _thread_ids_sync), + starlette.routing.Route("/async/thread_ids", _thread_ids_async), ], middleware=middleware, ) @@ -824,3 +843,32 @@ def test_legacy_setup( (event,) = events assert event["transaction"] == "/message/{message_id}" + + +@pytest.mark.parametrize("endpoint", ["/sync/thread_ids", "/async/thread_ids"]) +def test_active_thread_id(sentry_init, capture_envelopes, endpoint): + sentry_init( + traces_sample_rate=1.0, + _experiments={"profiles_sample_rate": 1.0}, + ) + app = starlette_app_factory() + asgi_app = SentryAsgiMiddleware(app) + + envelopes = capture_envelopes() + + client = TestClient(asgi_app) + response = client.get(endpoint) + assert response.status_code == 200 + + data = json.loads(response.content) + + envelopes = [envelope for envelope in envelopes] + assert len(envelopes) == 1 + + profiles = [item for item in envelopes[0].items if item.type == "profile"] + assert len(profiles) == 1 + + for profile in profiles: + transactions = profile.payload.json["transactions"] + assert len(transactions) == 1 + assert str(data["active"]) == transactions[0]["active_thread_id"] diff --git a/tests/integrations/wsgi/test_wsgi.py b/tests/integrations/wsgi/test_wsgi.py index 9eba712616..3ca9c5e9e7 100644 --- a/tests/integrations/wsgi/test_wsgi.py +++ b/tests/integrations/wsgi/test_wsgi.py @@ -297,8 +297,8 @@ def sample_app(environ, start_response): ], ) def test_profile_sent( - capture_envelopes, sentry_init, + capture_envelopes, teardown_profiling, profiles_sample_rate, profile_count,