Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert #15326 #15546

Merged
merged 4 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tracers that record execution of LangChain runs."""

from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.stdout import (
ConsoleCallbackHandler,
FunctionCallbackHandler,
Expand All @@ -12,5 +13,6 @@
"ConsoleCallbackHandler",
"FunctionCallbackHandler",
"LangChainTracer",
"LangChainTracerV1",
"WandbTracer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiohttp import ClientSession
from langchain_core.callbacks.manager import atrace_as_chain_group, trace_as_chain_group
from langchain_core.prompts import PromptTemplate
from langchain_core.tracers.context import tracing_v2_enabled
from langchain_core.tracers.context import tracing_enabled, tracing_v2_enabled

from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms import OpenAI
Expand Down Expand Up @@ -76,6 +76,63 @@ async def test_tracing_concurrent() -> None:
await aiosession.close()


async def test_tracing_concurrent_bw_compat_environ() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

os.environ["LANGCHAIN_HANDLER"] = "langchain"
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
aiosession = ClientSession()
llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
tasks = [agent.arun(q) for q in questions[:3]]
await asyncio.gather(*tasks)
await aiosession.close()
if "LANGCHAIN_HANDLER" in os.environ:
del os.environ["LANGCHAIN_HANDLER"]


def test_tracing_context_manager() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

llm = OpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
with tracing_enabled() as session:
assert session
agent.run(questions[0]) # this should be traced

agent.run(questions[0]) # this should not be traced


async def test_tracing_context_manager_async() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]

# start a background task
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
with tracing_enabled() as session:
assert session
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
await asyncio.gather(*tasks)

await task


async def test_tracing_v2_environment_variable() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

Expand Down
22 changes: 20 additions & 2 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,7 @@ def _configure(
_configure_hooks,
_get_tracer_project,
_tracing_v2_is_enabled,
tracing_callback_var,
tracing_v2_callback_var,
)

Expand Down Expand Up @@ -1885,20 +1886,27 @@ def _configure(
callback_manager.add_metadata(inheritable_metadata or {})
callback_manager.add_metadata(local_metadata or {}, False)

tracer = tracing_callback_var.get()
tracing_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING")
or tracer is not None
or env_var_is_set("LANGCHAIN_HANDLER")
)

tracer_v2 = tracing_v2_callback_var.get()
tracing_v2_enabled_ = _tracing_v2_is_enabled()
tracer_project = _get_tracer_project()
debug = _get_debug()
if verbose or debug or tracing_v2_enabled_:
if verbose or debug or tracing_enabled_ or tracing_v2_enabled_:
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.stdout import ConsoleCallbackHandler

if verbose and not any(
isinstance(handler, StdOutCallbackHandler)
for handler in callback_manager.handlers
):
if debug:
# We will use ConsoleCallbackHandler instead of StdOutCallbackHandler
pass
else:
callback_manager.add_handler(StdOutCallbackHandler(), False)
Expand All @@ -1907,6 +1915,16 @@ def _configure(
for handler in callback_manager.handlers
):
callback_manager.add_handler(ConsoleCallbackHandler(), True)
if tracing_enabled_ and not any(
isinstance(handler, LangChainTracerV1)
for handler in callback_manager.handlers
):
if tracer:
callback_manager.add_handler(tracer, True)
else:
handler = LangChainTracerV1()
handler.load_session(tracer_project)
callback_manager.add_handler(handler, True)
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
Expand Down
33 changes: 32 additions & 1 deletion libs/core/langchain_core/tracers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from langsmith.run_helpers import get_run_tree_context

from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from langchain_core.tracers.schemas import TracerSessionV1
from langchain_core.utils.env import env_var_is_set

if TYPE_CHECKING:
Expand All @@ -28,6 +30,10 @@
from langchain_core.callbacks.base import BaseCallbackHandler, Callbacks
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager

tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
"tracing_callback", default=None
)

tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None
)
Expand All @@ -36,6 +42,32 @@
)


@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Get the Deprecated LangChainTracer in a context manager.

Args:
session_name (str, optional): The name of the session.
Defaults to "default".

Returns:
TracerSessionV1: The LangChainTracer session.

Example:
>>> with tracing_enabled() as session:
... # Use the LangChainTracer session
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
try:
tracing_callback_var.set(cb)
yield session
finally:
tracing_callback_var.set(None)


@contextmanager
def tracing_v2_enabled(
project_name: Optional[str] = None,
Expand Down Expand Up @@ -136,7 +168,6 @@ def _tracing_v2_is_enabled() -> bool:
env_var_is_set("LANGCHAIN_TRACING_V2")
or tracing_v2_callback_var.get() is not None
or get_run_tree_context() is not None
or env_var_is_set("LANGCHAIN_TRACING")
)


Expand Down
185 changes: 185 additions & 0 deletions libs/core/langchain_core/tracers/langchain_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from __future__ import annotations

import logging
import os
from typing import Any, Dict, Optional, Union

import requests

from langchain_core.messages import get_buffer_string
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import (
ChainRun,
LLMRun,
Run,
ToolRun,
TracerSession,
TracerSessionV1,
TracerSessionV1Base,
)
from langchain_core.utils import raise_for_status_with_text

logger = logging.getLogger(__name__)


def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
return headers


def _get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")


class LangChainTracerV1(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""

def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self.session: Optional[TracerSessionV1] = None
self._endpoint = _get_endpoint()
self._headers = get_headers()

def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
session = self.session or self.load_default_session()
if not isinstance(session, TracerSessionV1):
raise ValueError(
"LangChainTracerV1 is not compatible with"
f" session of type {type(session)}"
)

if run.run_type == "llm":
if "prompts" in run.inputs:
prompts = run.inputs["prompts"]
elif "messages" in run.inputs:
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
else:
raise ValueError("No prompts found in LLM run inputs")
return LLMRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
error=run.error,
prompts=prompts,
response=run.outputs if run.outputs else None,
)
if run.run_type == "chain":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ChainRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
inputs=run.inputs,
outputs=run.outputs,
error=run.error,
extra=run.extra,
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
)
if run.run_type == "tool":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ToolRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
action=str(run.serialized),
tool_input=run.inputs.get("input", ""),
output=None if run.outputs is None else run.outputs.get("output"),
error=run.error,
extra=run.extra,
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
)
raise ValueError(f"Unknown run type: {run.run_type}")

def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
v1_run = self._convert_to_v1_run(run)
else:
v1_run = run
if isinstance(v1_run, LLMRun):
endpoint = f"{self._endpoint}/llm-runs"
elif isinstance(v1_run, ChainRun):
endpoint = f"{self._endpoint}/chain-runs"
else:
endpoint = f"{self._endpoint}/tool-runs"

try:
response = requests.post(
endpoint,
data=v1_run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logger.warning(f"Failed to persist run: {e}")

def _persist_session(
self, session_create: TracerSessionV1Base
) -> Union[TracerSessionV1, TracerSession]:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
except Exception as e:
logger.warning(f"Failed to create session, using default session: {e}")
session = TracerSessionV1(id=1, **session_create.dict())
return session

def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
"""Load a session from the tracer."""
try:
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)

tracer_session = TracerSessionV1(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logger.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV1(id=1)

self.session = tracer_session
return tracer_session

def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)

def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")