diff --git a/.github/workflows/test-integrations-data-processing.yml b/.github/workflows/test-integrations-data-processing.yml index ebcd89efea..1f618bd93d 100644 --- a/.github/workflows/test-integrations-data-processing.yml +++ b/.github/workflows/test-integrations-data-processing.yml @@ -58,6 +58,10 @@ jobs: run: | set -x # print commands that are executed ./scripts/runtox.sh "py${{ matrix.python-version }}-huey-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch + - name: Test langchain latest + run: | + set -x # print commands that are executed + ./scripts/runtox.sh "py${{ matrix.python-version }}-langchain-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch - name: Test openai latest run: | set -x # print commands that are executed @@ -114,6 +118,10 @@ jobs: run: | set -x # print commands that are executed ./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-huey" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch + - name: Test langchain pinned + run: | + set -x # print commands that are executed + ./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-langchain" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch - name: Test openai pinned run: | set -x # print commands that are executed diff --git a/mypy.ini b/mypy.ini index c1444d61e5..844e140de2 100644 --- a/mypy.ini +++ b/mypy.ini @@ -48,6 +48,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-asgiref.*] ignore_missing_imports = True +[mypy-langchain_core.*] +ignore_missing_imports = True [mypy-executing.*] ignore_missing_imports = True [mypy-asttokens.*] diff --git a/scripts/split-tox-gh-actions/split-tox-gh-actions.py b/scripts/split-tox-gh-actions/split-tox-gh-actions.py index 6b456c5544..288725d2c5 100755 --- a/scripts/split-tox-gh-actions/split-tox-gh-actions.py +++ b/scripts/split-tox-gh-actions/split-tox-gh-actions.py @@ -70,6 +70,7 @@ "beam", "celery", "huey", + "langchain", "openai", "rq", ], diff --git a/sentry_sdk/ai/__init__.py b/sentry_sdk/ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sentry_sdk/ai/monitoring.py b/sentry_sdk/ai/monitoring.py new file mode 100644 index 0000000000..f5f9cd7aad --- /dev/null +++ b/sentry_sdk/ai/monitoring.py @@ -0,0 +1,77 @@ +from functools import wraps + +import sentry_sdk.utils +from sentry_sdk import start_span +from sentry_sdk.tracing import Span +from sentry_sdk.utils import ContextVar +from sentry_sdk._types import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional, Callable, Any + +_ai_pipeline_name = ContextVar("ai_pipeline_name", default=None) + + +def set_ai_pipeline_name(name): + # type: (Optional[str]) -> None + _ai_pipeline_name.set(name) + + +def get_ai_pipeline_name(): + # type: () -> Optional[str] + return _ai_pipeline_name.get() + + +def ai_track(description, **span_kwargs): + # type: (str, Any) -> Callable[..., Any] + def decorator(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + @wraps(f) + def wrapped(*args, **kwargs): + # type: (Any, Any) -> Any + curr_pipeline = _ai_pipeline_name.get() + op = span_kwargs.get("op", "ai.run" if curr_pipeline else "ai.pipeline") + with start_span(description=description, op=op, **span_kwargs) as span: + if curr_pipeline: + span.set_data("ai.pipeline.name", curr_pipeline) + return f(*args, **kwargs) + else: + _ai_pipeline_name.set(description) + try: + res = f(*args, **kwargs) + except Exception as e: + event, hint = sentry_sdk.utils.event_from_exception( + e, + client_options=sentry_sdk.get_client().options, + mechanism={"type": "ai_monitoring", "handled": False}, + ) + sentry_sdk.capture_event(event, hint=hint) + raise e from None + finally: + _ai_pipeline_name.set(None) + return res + + return wrapped + + return decorator + + +def record_token_usage( + span, prompt_tokens=None, completion_tokens=None, total_tokens=None +): + # type: (Span, Optional[int], Optional[int], Optional[int]) -> None + ai_pipeline_name = get_ai_pipeline_name() + if ai_pipeline_name: + span.set_data("ai.pipeline.name", ai_pipeline_name) + if prompt_tokens is not None: + span.set_measurement("ai_prompt_tokens_used", value=prompt_tokens) + if completion_tokens is not None: + span.set_measurement("ai_completion_tokens_used", value=completion_tokens) + if ( + total_tokens is None + and prompt_tokens is not None + and completion_tokens is not None + ): + total_tokens = prompt_tokens + completion_tokens + if total_tokens is not None: + span.set_measurement("ai_total_tokens_used", total_tokens) diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py new file mode 100644 index 0000000000..42d46304e4 --- /dev/null +++ b/sentry_sdk/ai/utils.py @@ -0,0 +1,32 @@ +from sentry_sdk._types import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + +from sentry_sdk.tracing import Span +from sentry_sdk.utils import logger + + +def _normalize_data(data): + # type: (Any) -> Any + + # convert pydantic data (e.g. OpenAI v1+) to json compatible format + if hasattr(data, "model_dump"): + try: + return data.model_dump() + except Exception as e: + logger.warning("Could not convert pydantic data to JSON: %s", e) + return data + if isinstance(data, list): + if len(data) == 1: + return _normalize_data(data[0]) # remove empty dimensions + return list(_normalize_data(x) for x in data) + if isinstance(data, dict): + return {k: _normalize_data(v) for (k, v) in data.items()} + return data + + +def set_data_normalized(span, key, value): + # type: (Span, str, Any) -> None + normalized = _normalize_data(value) + span.set_data(key, normalized) diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index b72701daed..19595ed7fa 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -91,6 +91,85 @@ class SPANDATA: See: https://develop.sentry.dev/sdk/performance/span-data-conventions/ """ + AI_INPUT_MESSAGES = "ai.input_messages" + """ + The input messages to an LLM call. + Example: [{"role": "user", "message": "hello"}] + """ + + AI_MODEL_ID = "ai.model_id" + """ + The unique descriptor of the model being execugted + Example: gpt-4 + """ + + AI_METADATA = "ai.metadata" + """ + Extra metadata passed to an AI pipeline step. + Example: {"executed_function": "add_integers"} + """ + + AI_TAGS = "ai.tags" + """ + Tags that describe an AI pipeline step. + Example: {"executed_function": "add_integers"} + """ + + AI_STREAMING = "ai.streaming" + """ + Whether or not the AI model call's repsonse was streamed back asynchronously + Example: true + """ + + AI_TEMPERATURE = "ai.temperature" + """ + For an AI model call, the temperature parameter. Temperature essentially means how random the output will be. + Example: 0.5 + """ + + AI_TOP_P = "ai.top_p" + """ + For an AI model call, the top_p parameter. Top_p essentially controls how random the output will be. + Example: 0.5 + """ + + AI_TOP_K = "ai.top_k" + """ + For an AI model call, the top_k parameter. Top_k essentially controls how random the output will be. + Example: 35 + """ + + AI_FUNCTION_CALL = "ai.function_call" + """ + For an AI model call, the function that was called. This is deprecated for OpenAI, and replaced by tool_calls + """ + + AI_TOOL_CALLS = "ai.tool_calls" + """ + For an AI model call, the function that was called. This is deprecated for OpenAI, and replaced by tool_calls + """ + + AI_TOOLS = "ai.tools" + """ + For an AI model call, the functions that are available + """ + + AI_RESPONSE_FORMAT = "ai.response_format" + """ + For an AI model call, the format of the response + """ + + AI_LOGIT_BIAS = "ai.response_format" + """ + For an AI model call, the logit bias + """ + + AI_RESPONSES = "ai.responses" + """ + The responses to an AI model call. Always as a list. + Example: ["hello", "world"] + """ + DB_NAME = "db.name" """ The name of the database being accessed. For commands that switch the database, this should be set to the target database (even if the command fails). @@ -245,6 +324,11 @@ class OP: MIDDLEWARE_STARLITE_SEND = "middleware.starlite.send" OPENAI_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.openai" OPENAI_EMBEDDINGS_CREATE = "ai.embeddings.create.openai" + LANGCHAIN_PIPELINE = "ai.pipeline.langchain" + LANGCHAIN_RUN = "ai.run.langchain" + LANGCHAIN_TOOL = "ai.tool.langchain" + LANGCHAIN_AGENT = "ai.agent.langchain" + LANGCHAIN_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.langchain" QUEUE_SUBMIT_ARQ = "queue.submit.arq" QUEUE_TASK_ARQ = "queue.task.arq" QUEUE_SUBMIT_CELERY = "queue.submit.celery" diff --git a/sentry_sdk/integrations/__init__.py b/sentry_sdk/integrations/__init__.py index b0ec5e2d3e..f692e88294 100644 --- a/sentry_sdk/integrations/__init__.py +++ b/sentry_sdk/integrations/__init__.py @@ -85,6 +85,7 @@ def iter_default_integrations(with_auto_enabling_integrations): "sentry_sdk.integrations.graphene.GrapheneIntegration", "sentry_sdk.integrations.httpx.HttpxIntegration", "sentry_sdk.integrations.huey.HueyIntegration", + "sentry_sdk.integrations.langchain.LangchainIntegration", "sentry_sdk.integrations.loguru.LoguruIntegration", "sentry_sdk.integrations.openai.OpenAIIntegration", "sentry_sdk.integrations.pymongo.PyMongoIntegration", diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py new file mode 100644 index 0000000000..35e955b958 --- /dev/null +++ b/sentry_sdk/integrations/langchain.py @@ -0,0 +1,457 @@ +from collections import OrderedDict +from functools import wraps + +import sentry_sdk +from sentry_sdk._types import TYPE_CHECKING +from sentry_sdk.ai.monitoring import set_ai_pipeline_name, record_token_usage +from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.ai.utils import set_data_normalized +from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.tracing import Span + +if TYPE_CHECKING: + from typing import Any, List, Callable, Dict, Union, Optional + from uuid import UUID +from sentry_sdk.integrations import DidNotEnable, Integration +from sentry_sdk.utils import logger, capture_internal_exceptions + +try: + from langchain_core.messages import BaseMessage + from langchain_core.outputs import LLMResult + from langchain_core.callbacks import ( + manager, + BaseCallbackHandler, + ) + from langchain_core.agents import AgentAction, AgentFinish +except ImportError: + raise DidNotEnable("langchain not installed") + + +try: + import tiktoken # type: ignore + + enc = tiktoken.get_encoding("cl100k_base") + + def count_tokens(s): + # type: (str) -> int + return len(enc.encode_ordinary(s)) + + logger.debug("[langchain] using tiktoken to count tokens") +except ImportError: + logger.info( + "The Sentry Python SDK requires 'tiktoken' in order to measure token usage from streaming langchain calls." + "Please install 'tiktoken' if you aren't receiving accurate token usage in Sentry." + "See https://docs.sentry.io/platforms/python/integrations/langchain/ for more information." + ) + + def count_tokens(s): + # type: (str) -> int + return 1 + + +DATA_FIELDS = { + "temperature": SPANDATA.AI_TEMPERATURE, + "top_p": SPANDATA.AI_TOP_P, + "top_k": SPANDATA.AI_TOP_K, + "function_call": SPANDATA.AI_FUNCTION_CALL, + "tool_calls": SPANDATA.AI_TOOL_CALLS, + "tools": SPANDATA.AI_TOOLS, + "response_format": SPANDATA.AI_RESPONSE_FORMAT, + "logit_bias": SPANDATA.AI_LOGIT_BIAS, + "tags": SPANDATA.AI_TAGS, +} + +# To avoid double collecting tokens, we do *not* measure +# token counts for models for which we have an explicit integration +NO_COLLECT_TOKEN_MODELS = ["openai-chat"] + + +class LangchainIntegration(Integration): + identifier = "langchain" + + # The most number of spans (e.g., LLM calls) that can be processed at the same time. + max_spans = 1024 + + def __init__(self, include_prompts=True, max_spans=1024): + # type: (LangchainIntegration, bool, int) -> None + self.include_prompts = include_prompts + self.max_spans = max_spans + + @staticmethod + def setup_once(): + # type: () -> None + manager._configure = _wrap_configure(manager._configure) + + +class WatchedSpan: + span = None # type: Span + num_completion_tokens = 0 # type: int + num_prompt_tokens = 0 # type: int + no_collect_tokens = False # type: bool + children = [] # type: List[WatchedSpan] + is_pipeline = False # type: bool + + def __init__(self, span): + # type: (Span) -> None + self.span = span + + +class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc] + """Base callback handler that can be used to handle callbacks from langchain.""" + + span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan] + + max_span_map_size = 0 + + def __init__(self, max_span_map_size, include_prompts): + # type: (int, bool) -> None + self.max_span_map_size = max_span_map_size + self.include_prompts = include_prompts + + def gc_span_map(self): + # type: () -> None + + while len(self.span_map) > self.max_span_map_size: + run_id, watched_span = self.span_map.popitem(last=False) + self._exit_span(watched_span, run_id) + + def _handle_error(self, run_id, error): + # type: (UUID, Any) -> None + if not run_id or run_id not in self.span_map: + return + + span_data = self.span_map[run_id] + if not span_data: + return + sentry_sdk.capture_exception(error, span_data.span.scope) + span_data.span.__exit__(None, None, None) + del self.span_map[run_id] + + def _normalize_langchain_message(self, message): + # type: (BaseMessage) -> Any + parsed = {"content": message.content, "role": message.type} + parsed.update(message.additional_kwargs) + return parsed + + def _create_span(self, run_id, parent_id, **kwargs): + # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan + + watched_span = None # type: Optional[WatchedSpan] + if parent_id: + parent_span = self.span_map[parent_id] # type: Optional[WatchedSpan] + if parent_span: + watched_span = WatchedSpan(parent_span.span.start_child(**kwargs)) + parent_span.children.append(watched_span) + if watched_span is None: + watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs)) + + if kwargs.get("op", "").startswith("ai.pipeline."): + if kwargs.get("description"): + set_ai_pipeline_name(kwargs.get("description")) + watched_span.is_pipeline = True + + watched_span.span.__enter__() + self.span_map[run_id] = watched_span + self.gc_span_map() + return watched_span + + def _exit_span(self, span_data, run_id): + # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None + + if span_data.is_pipeline: + set_ai_pipeline_name(None) + + span_data.span.__exit__(None, None, None) + del self.span_map[run_id] + + def on_llm_start( + self, + serialized, + prompts, + *, + run_id, + tags=None, + parent_run_id=None, + metadata=None, + **kwargs, + ): + # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any + """Run when LLM starts running.""" + with capture_internal_exceptions(): + if not run_id: + return + all_params = kwargs.get("invocation_params", {}) + all_params.update(serialized.get("kwargs", {})) + watched_span = self._create_span( + run_id, + kwargs.get("parent_run_id"), + op=OP.LANGCHAIN_RUN, + description=kwargs.get("name") or "Langchain LLM call", + ) + span = watched_span.span + if should_send_default_pii() and self.include_prompts: + set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts) + for k, v in DATA_FIELDS.items(): + if k in all_params: + set_data_normalized(span, v, all_params[k]) + + def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): + # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Any) -> Any + """Run when Chat Model starts running.""" + with capture_internal_exceptions(): + if not run_id: + return + all_params = kwargs.get("invocation_params", {}) + all_params.update(serialized.get("kwargs", {})) + watched_span = self._create_span( + run_id, + kwargs.get("parent_run_id"), + op=OP.LANGCHAIN_CHAT_COMPLETIONS_CREATE, + description=kwargs.get("name") or "Langchain Chat Model", + ) + span = watched_span.span + model = all_params.get( + "model", all_params.get("model_name", all_params.get("model_id")) + ) + watched_span.no_collect_tokens = any( + x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS + ) + if not model and "anthropic" in all_params.get("_type"): + model = "claude-2" + if model: + span.set_data(SPANDATA.AI_MODEL_ID, model) + if should_send_default_pii() and self.include_prompts: + set_data_normalized( + span, + SPANDATA.AI_INPUT_MESSAGES, + [ + [self._normalize_langchain_message(x) for x in list_] + for list_ in messages + ], + ) + for k, v in DATA_FIELDS.items(): + if k in all_params: + set_data_normalized(span, v, all_params[k]) + if not watched_span.no_collect_tokens: + for list_ in messages: + for message in list_: + self.span_map[run_id].num_prompt_tokens += count_tokens( + message.content + ) + count_tokens(message.type) + + def on_llm_new_token(self, token, *, run_id, **kwargs): + # type: (SentryLangchainCallback, str, UUID, Any) -> Any + """Run on new LLM token. Only available when streaming is enabled.""" + with capture_internal_exceptions(): + if not run_id or run_id not in self.span_map: + return + span_data = self.span_map[run_id] + if not span_data or span_data.no_collect_tokens: + return + span_data.num_completion_tokens += count_tokens(token) + + def on_llm_end(self, response, *, run_id, **kwargs): + # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any + """Run when LLM ends running.""" + with capture_internal_exceptions(): + if not run_id: + return + + token_usage = ( + response.llm_output.get("token_usage") if response.llm_output else None + ) + + span_data = self.span_map[run_id] + if not span_data: + return + + if should_send_default_pii() and self.include_prompts: + set_data_normalized( + span_data.span, + SPANDATA.AI_RESPONSES, + [[x.text for x in list_] for list_ in response.generations], + ) + + if not span_data.no_collect_tokens: + if token_usage: + record_token_usage( + span_data.span, + token_usage.get("prompt_tokens"), + token_usage.get("completion_tokens"), + token_usage.get("total_tokens"), + ) + else: + record_token_usage( + span_data.span, + span_data.num_prompt_tokens, + span_data.num_completion_tokens, + ) + + self._exit_span(span_data, run_id) + + def on_llm_error(self, error, *, run_id, **kwargs): + # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any + """Run when LLM errors.""" + with capture_internal_exceptions(): + self._handle_error(run_id, error) + + def on_chain_start(self, serialized, inputs, *, run_id, **kwargs): + # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any + """Run when chain starts running.""" + with capture_internal_exceptions(): + if not run_id: + return + watched_span = self._create_span( + run_id, + kwargs.get("parent_run_id"), + op=( + OP.LANGCHAIN_RUN + if kwargs.get("parent_run_id") is not None + else OP.LANGCHAIN_PIPELINE + ), + description=kwargs.get("name") or "Chain execution", + ) + metadata = kwargs.get("metadata") + if metadata: + set_data_normalized(watched_span.span, SPANDATA.AI_METADATA, metadata) + + def on_chain_end(self, outputs, *, run_id, **kwargs): + # type: (SentryLangchainCallback, Dict[str, Any], UUID, Any) -> Any + """Run when chain ends running.""" + with capture_internal_exceptions(): + if not run_id or run_id not in self.span_map: + return + + span_data = self.span_map[run_id] + if not span_data: + return + self._exit_span(span_data, run_id) + + def on_chain_error(self, error, *, run_id, **kwargs): + # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any + """Run when chain errors.""" + self._handle_error(run_id, error) + + def on_agent_action(self, action, *, run_id, **kwargs): + # type: (SentryLangchainCallback, AgentAction, UUID, Any) -> Any + with capture_internal_exceptions(): + if not run_id: + return + watched_span = self._create_span( + run_id, + kwargs.get("parent_run_id"), + op=OP.LANGCHAIN_AGENT, + description=action.tool or "AI tool usage", + ) + if action.tool_input and should_send_default_pii() and self.include_prompts: + set_data_normalized( + watched_span.span, SPANDATA.AI_INPUT_MESSAGES, action.tool_input + ) + + def on_agent_finish(self, finish, *, run_id, **kwargs): + # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any + with capture_internal_exceptions(): + if not run_id: + return + + span_data = self.span_map[run_id] + if not span_data: + return + if should_send_default_pii() and self.include_prompts: + set_data_normalized( + span_data.span, SPANDATA.AI_RESPONSES, finish.return_values.items() + ) + self._exit_span(span_data, run_id) + + def on_tool_start(self, serialized, input_str, *, run_id, **kwargs): + # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Any) -> Any + """Run when tool starts running.""" + with capture_internal_exceptions(): + if not run_id: + return + watched_span = self._create_span( + run_id, + kwargs.get("parent_run_id"), + op=OP.LANGCHAIN_TOOL, + description=serialized.get("name") + or kwargs.get("name") + or "AI tool usage", + ) + if should_send_default_pii() and self.include_prompts: + set_data_normalized( + watched_span.span, + SPANDATA.AI_INPUT_MESSAGES, + kwargs.get("inputs", [input_str]), + ) + if kwargs.get("metadata"): + set_data_normalized( + watched_span.span, SPANDATA.AI_METADATA, kwargs.get("metadata") + ) + + def on_tool_end(self, output, *, run_id, **kwargs): + # type: (SentryLangchainCallback, str, UUID, Any) -> Any + """Run when tool ends running.""" + with capture_internal_exceptions(): + if not run_id or run_id not in self.span_map: + return + + span_data = self.span_map[run_id] + if not span_data: + return + if should_send_default_pii() and self.include_prompts: + set_data_normalized(span_data.span, SPANDATA.AI_RESPONSES, output) + self._exit_span(span_data, run_id) + + def on_tool_error(self, error, *args, run_id, **kwargs): + # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any + """Run when tool errors.""" + self._handle_error(run_id, error) + + +def _wrap_configure(f): + # type: (Callable[..., Any]) -> Callable[..., Any] + + @wraps(f) + def new_configure(*args, **kwargs): + # type: (Any, Any) -> Any + + integration = sentry_sdk.get_client().get_integration(LangchainIntegration) + + with capture_internal_exceptions(): + new_callbacks = [] # type: List[BaseCallbackHandler] + if "local_callbacks" in kwargs: + existing_callbacks = kwargs["local_callbacks"] + kwargs["local_callbacks"] = new_callbacks + elif len(args) > 2: + existing_callbacks = args[2] + args = ( + args[0], + args[1], + new_callbacks, + ) + args[3:] + else: + existing_callbacks = [] + + if existing_callbacks: + if isinstance(existing_callbacks, list): + for cb in existing_callbacks: + new_callbacks.append(cb) + elif isinstance(existing_callbacks, BaseCallbackHandler): + new_callbacks.append(existing_callbacks) + else: + logger.warn("Unknown callback type: %s", existing_callbacks) + + already_added = False + for callback in new_callbacks: + if isinstance(callback, SentryLangchainCallback): + already_added = True + + if not already_added: + new_callbacks.append( + SentryLangchainCallback( + integration.max_spans, integration.include_prompts + ) + ) + return f(*args, **kwargs) + + return new_configure diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index 0d77a27ec0..20147b342f 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -2,6 +2,9 @@ from sentry_sdk import consts from sentry_sdk._types import TYPE_CHECKING +from sentry_sdk.ai.monitoring import record_token_usage +from sentry_sdk.consts import SPANDATA +from sentry_sdk.ai.utils import set_data_normalized if TYPE_CHECKING: from typing import Any, Iterable, List, Optional, Callable, Iterator @@ -48,11 +51,6 @@ def count_tokens(s): return 0 -COMPLETION_TOKENS_USED = "ai.completion_tоkens.used" -PROMPT_TOKENS_USED = "ai.prompt_tоkens.used" -TOTAL_TOKENS_USED = "ai.total_tоkens.used" - - class OpenAIIntegration(Integration): identifier = "openai" @@ -77,35 +75,13 @@ def _capture_exception(exc): sentry_sdk.capture_event(event, hint=hint) -def _normalize_data(data): - # type: (Any) -> Any - - # convert pydantic data (e.g. OpenAI v1+) to json compatible format - if hasattr(data, "model_dump"): - try: - return data.model_dump() - except Exception as e: - logger.warning("Could not convert pydantic data to JSON: %s", e) - return data - if isinstance(data, list): - return list(_normalize_data(x) for x in data) - if isinstance(data, dict): - return {k: _normalize_data(v) for (k, v) in data.items()} - return data - - -def set_data_normalized(span, key, value): - # type: (Span, str, Any) -> None - span.set_data(key, _normalize_data(value)) - - def _calculate_chat_completion_usage( messages, response, span, streaming_message_responses=None ): # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]]) -> None - completion_tokens = 0 - prompt_tokens = 0 - total_tokens = 0 + completion_tokens = 0 # type: Optional[int] + prompt_tokens = 0 # type: Optional[int] + total_tokens = 0 # type: Optional[int] if hasattr(response, "usage"): if hasattr(response.usage, "completion_tokens") and isinstance( response.usage.completion_tokens, int @@ -134,15 +110,13 @@ def _calculate_chat_completion_usage( if hasattr(choice, "message"): completion_tokens += count_tokens(choice.message) + if prompt_tokens == 0: + prompt_tokens = None + if completion_tokens == 0: + completion_tokens = None if total_tokens == 0: - total_tokens = prompt_tokens + completion_tokens - - if completion_tokens != 0: - set_data_normalized(span, COMPLETION_TOKENS_USED, completion_tokens) - if prompt_tokens != 0: - set_data_normalized(span, PROMPT_TOKENS_USED, prompt_tokens) - if total_tokens != 0: - set_data_normalized(span, TOTAL_TOKENS_USED, total_tokens) + total_tokens = None + record_token_usage(span, prompt_tokens, completion_tokens, total_tokens) def _wrap_chat_completion_create(f): @@ -167,7 +141,8 @@ def new_chat_completion(*args, **kwargs): streaming = kwargs.get("stream") span = sentry_sdk.start_span( - op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE, description="Chat Completion" + op=consts.OP.OPENAI_CHAT_COMPLETIONS_CREATE, + description="Chat Completion", ) span.__enter__() try: @@ -181,10 +156,10 @@ def new_chat_completion(*args, **kwargs): with capture_internal_exceptions(): if should_send_default_pii() and integration.include_prompts: - set_data_normalized(span, "ai.input_messages", messages) + set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, messages) - set_data_normalized(span, "ai.model_id", model) - set_data_normalized(span, "ai.streaming", streaming) + set_data_normalized(span, SPANDATA.AI_MODEL_ID, model) + set_data_normalized(span, SPANDATA.AI_STREAMING, streaming) if hasattr(res, "choices"): if should_send_default_pii() and integration.include_prompts: @@ -224,7 +199,9 @@ def new_iterator(): should_send_default_pii() and integration.include_prompts ): - set_data_normalized(span, "ai.responses", all_responses) + set_data_normalized( + span, SPANDATA.AI_RESPONSES, all_responses + ) _calculate_chat_completion_usage( messages, res, span, all_responses ) @@ -285,11 +262,7 @@ def new_embeddings_create(*args, **kwargs): if prompt_tokens == 0: prompt_tokens = count_tokens(kwargs["input"] or "") - if total_tokens == 0: - total_tokens = prompt_tokens - - set_data_normalized(span, PROMPT_TOKENS_USED, prompt_tokens) - set_data_normalized(span, TOTAL_TOKENS_USED, total_tokens) + record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens) return response diff --git a/setup.py b/setup.py index 037a621ddf..bef9842119 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ def get_file_text(file_name): "grpcio": ["grpcio>=1.21.1"], "httpx": ["httpx>=0.16.0"], "huey": ["huey>=2"], + "langchain": ["langchain>=0.0.210"], "loguru": ["loguru>=0.5"], "openai": ["openai>=1.0.0", "tiktoken>=0.3.0"], "opentelemetry": ["opentelemetry-distro>=0.35b0"], diff --git a/tests/integrations/langchain/__init__.py b/tests/integrations/langchain/__init__.py new file mode 100644 index 0000000000..a286454a56 --- /dev/null +++ b/tests/integrations/langchain/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("langchain_core") diff --git a/tests/integrations/langchain/test_langchain.py b/tests/integrations/langchain/test_langchain.py new file mode 100644 index 0000000000..6498cefbaf --- /dev/null +++ b/tests/integrations/langchain/test_langchain.py @@ -0,0 +1,223 @@ +from typing import List, Optional, Any, Iterator +from unittest.mock import Mock + +import pytest +from langchain_community.chat_models import ChatOpenAI +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk + +from sentry_sdk import start_transaction +from sentry_sdk.integrations.langchain import LangchainIntegration +from langchain.agents import tool, AgentExecutor, create_openai_tools_agent +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + + +@tool +def get_word_length(word: str) -> int: + """Returns the length of a word.""" + return len(word) + + +global stream_result_mock # type: Mock +global llm_type # type: str + + +class MockOpenAI(ChatOpenAI): + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + for x in stream_result_mock(): + yield x + + @property + def _llm_type(self) -> str: + return llm_type + + +@pytest.mark.parametrize( + "send_default_pii, include_prompts, use_unknown_llm_type", + [ + (True, True, False), + (True, False, False), + (False, True, False), + (False, False, True), + ], +) +def test_langchain_agent( + sentry_init, capture_events, send_default_pii, include_prompts, use_unknown_llm_type +): + global llm_type + llm_type = "acme-llm" if use_unknown_llm_type else "openai-chat" + + sentry_init( + integrations=[LangchainIntegration(include_prompts=include_prompts)], + traces_sample_rate=1.0, + send_default_pii=send_default_pii, + ) + events = capture_events() + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are very powerful assistant, but don't know current events", + ), + ("user", "{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + global stream_result_mock + stream_result_mock = Mock( + side_effect=[ + [ + ChatGenerationChunk( + type="ChatGenerationChunk", + message=AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": "call_BbeyNhCKa6kYLYzrD40NGm3b", + "function": { + "arguments": "", + "name": "get_word_length", + }, + "type": "function", + } + ] + }, + ), + ), + ChatGenerationChunk( + type="ChatGenerationChunk", + message=AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": { + "arguments": '{"word": "eudca"}', + "name": None, + }, + "type": None, + } + ] + }, + ), + ), + ChatGenerationChunk( + type="ChatGenerationChunk", + message=AIMessageChunk(content="5"), + generation_info={"finish_reason": "function_call"}, + ), + ], + [ + ChatGenerationChunk( + text="The word eudca has 5 letters.", + type="ChatGenerationChunk", + message=AIMessageChunk(content="The word eudca has 5 letters."), + ), + ChatGenerationChunk( + type="ChatGenerationChunk", + generation_info={"finish_reason": "stop"}, + message=AIMessageChunk(content=""), + ), + ], + ] + ) + llm = MockOpenAI( + model_name="gpt-3.5-turbo", + temperature=0, + openai_api_key="badkey", + ) + agent = create_openai_tools_agent(llm, [get_word_length], prompt) + + agent_executor = AgentExecutor(agent=agent, tools=[get_word_length], verbose=True) + + with start_transaction(): + list(agent_executor.stream({"input": "How many letters in the word eudca"})) + + tx = events[0] + assert tx["type"] == "transaction" + chat_spans = list( + x for x in tx["spans"] if x["op"] == "ai.chat_completions.create.langchain" + ) + tool_exec_span = next(x for x in tx["spans"] if x["op"] == "ai.tool.langchain") + + assert len(chat_spans) == 2 + + # We can't guarantee anything about the "shape" of the langchain execution graph + assert len(list(x for x in tx["spans"] if x["op"] == "ai.run.langchain")) > 0 + + if use_unknown_llm_type: + assert "ai_prompt_tokens_used" in chat_spans[0]["measurements"] + assert "ai_total_tokens_used" in chat_spans[0]["measurements"] + else: + # important: to avoid double counting, we do *not* measure + # tokens used if we have an explicit integration (e.g. OpenAI) + assert "measurements" not in chat_spans[0] + + if send_default_pii and include_prompts: + assert ( + "You are very powerful" + in chat_spans[0]["data"]["ai.input_messages"][0]["content"] + ) + assert "5" in chat_spans[0]["data"]["ai.responses"] + assert "word" in tool_exec_span["data"]["ai.input_messages"] + assert 5 == int(tool_exec_span["data"]["ai.responses"]) + assert ( + "You are very powerful" + in chat_spans[1]["data"]["ai.input_messages"][0]["content"] + ) + assert "5" in chat_spans[1]["data"]["ai.responses"] + else: + assert "ai.input_messages" not in chat_spans[0].get("data", {}) + assert "ai.responses" not in chat_spans[0].get("data", {}) + assert "ai.input_messages" not in chat_spans[1].get("data", {}) + assert "ai.responses" not in chat_spans[1].get("data", {}) + assert "ai.input_messages" not in tool_exec_span.get("data", {}) + assert "ai.responses" not in tool_exec_span.get("data", {}) + + +def test_langchain_error(sentry_init, capture_events): + sentry_init( + integrations=[LangchainIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are very powerful assistant, but don't know current events", + ), + ("user", "{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + global stream_result_mock + stream_result_mock = Mock(side_effect=Exception("API rate limit error")) + llm = MockOpenAI( + model_name="gpt-3.5-turbo", + temperature=0, + openai_api_key="badkey", + ) + agent = create_openai_tools_agent(llm, [get_word_length], prompt) + + agent_executor = AgentExecutor(agent=agent, tools=[get_word_length], verbose=True) + + with start_transaction(), pytest.raises(Exception): + list(agent_executor.stream({"input": "How many letters in the word eudca"})) + + error = events[0] + assert error["level"] == "error" diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py index 074d859274..f14ae82333 100644 --- a/tests/integrations/openai/test_openai.py +++ b/tests/integrations/openai/test_openai.py @@ -7,12 +7,7 @@ from openai.types.create_embedding_response import Usage as EmbeddingTokenUsage from sentry_sdk import start_transaction -from sentry_sdk.integrations.openai import ( - OpenAIIntegration, - COMPLETION_TOKENS_USED, - PROMPT_TOKENS_USED, - TOTAL_TOKENS_USED, -) +from sentry_sdk.integrations.openai import OpenAIIntegration from unittest import mock # python 3.3 and above @@ -72,15 +67,15 @@ def test_nonstreaming_chat_completion( assert span["op"] == "ai.chat_completions.create.openai" if send_default_pii and include_prompts: - assert "hello" in span["data"]["ai.input_messages"][0]["content"] - assert "the model response" in span["data"]["ai.responses"][0]["content"] + assert "hello" in span["data"]["ai.input_messages"]["content"] + assert "the model response" in span["data"]["ai.responses"]["content"] else: assert "ai.input_messages" not in span["data"] assert "ai.responses" not in span["data"] - assert span["data"][COMPLETION_TOKENS_USED] == 10 - assert span["data"][PROMPT_TOKENS_USED] == 20 - assert span["data"][TOTAL_TOKENS_USED] == 30 + assert span["measurements"]["ai_completion_tokens_used"]["value"] == 10 + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 20 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 # noinspection PyTypeChecker @@ -151,8 +146,8 @@ def test_streaming_chat_completion( assert span["op"] == "ai.chat_completions.create.openai" if send_default_pii and include_prompts: - assert "hello" in span["data"]["ai.input_messages"][0]["content"] - assert "hello world" in span["data"]["ai.responses"][0] + assert "hello" in span["data"]["ai.input_messages"]["content"] + assert "hello world" in span["data"]["ai.responses"] else: assert "ai.input_messages" not in span["data"] assert "ai.responses" not in span["data"] @@ -160,9 +155,9 @@ def test_streaming_chat_completion( try: import tiktoken # type: ignore # noqa # pylint: disable=unused-import - assert span["data"][COMPLETION_TOKENS_USED] == 2 - assert span["data"][PROMPT_TOKENS_USED] == 1 - assert span["data"][TOTAL_TOKENS_USED] == 3 + assert span["measurements"]["ai_completion_tokens_used"]["value"] == 2 + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 1 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 3 except ImportError: pass # if tiktoken is not installed, we can't guarantee token usage will be calculated properly @@ -223,9 +218,9 @@ def test_embeddings_create( span = tx["spans"][0] assert span["op"] == "ai.embeddings.create.openai" if send_default_pii and include_prompts: - assert "hello" in span["data"]["ai.input_messages"][0] + assert "hello" in span["data"]["ai.input_messages"] else: assert "ai.input_messages" not in span["data"] - assert span["data"][PROMPT_TOKENS_USED] == 20 - assert span["data"][TOTAL_TOKENS_USED] == 30 + assert span["measurements"]["ai_prompt_tokens_used"]["value"] == 20 + assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 diff --git a/tox.ini b/tox.ini index 409e8d70b0..e373589736 100644 --- a/tox.ini +++ b/tox.ini @@ -140,6 +140,11 @@ envlist = {py3.6,py3.11,py3.12}-huey-v{2.0} {py3.6,py3.11,py3.12}-huey-latest + # Langchain + {py3.9,py3.11,py3.12}-langchain-0.1 + {py3.9,py3.11,py3.12}-langchain-latest + {py3.9,py3.11,py3.12}-langchain-notiktoken + # Loguru {py3.6,py3.11,py3.12}-loguru-v{0.5} {py3.6,py3.11,py3.12}-loguru-latest @@ -149,11 +154,6 @@ envlist = {py3.9,py3.11,py3.12}-openai-latest {py3.9,py3.11,py3.12}-openai-notiktoken - # OpenAI - {py3.9,py3.11,py3.12}-openai-v1 - {py3.9,py3.11,py3.12}-openai-latest - {py3.9,py3.11,py3.12}-openai-notiktoken - # OpenTelemetry (OTel) {py3.7,py3.9,py3.11,py3.12}-opentelemetry @@ -437,6 +437,14 @@ deps = huey-v2.0: huey~=2.0.0 huey-latest: huey + # Langchain + langchain: openai~=1.0.0 + langchain-0.1: langchain~=0.1.11 + langchain-0.1: tiktoken~=0.6.0 + langchain-latest: langchain + langchain-latest: tiktoken~=0.6.0 + langchain-notiktoken: langchain + # Loguru loguru-v0.5: loguru~=0.5.0 loguru-latest: loguru @@ -604,6 +612,7 @@ setenv = graphene: TESTPATH=tests/integrations/graphene httpx: TESTPATH=tests/integrations/httpx huey: TESTPATH=tests/integrations/huey + langchain: TESTPATH=tests/integrations/langchain loguru: TESTPATH=tests/integrations/loguru openai: TESTPATH=tests/integrations/openai opentelemetry: TESTPATH=tests/integrations/opentelemetry