Skip to content

Commit

Permalink
Revert #15326
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Jan 4, 2024
1 parent 7a93356 commit 51bf0fc
Show file tree
Hide file tree
Showing 21 changed files with 1,032 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tracers that record execution of LangChain runs."""

from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.stdout import (
ConsoleCallbackHandler,
FunctionCallbackHandler,
Expand All @@ -12,5 +13,6 @@
"ConsoleCallbackHandler",
"FunctionCallbackHandler",
"LangChainTracer",
"LangChainTracerV1",
"WandbTracer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiohttp import ClientSession
from langchain_core.callbacks.manager import atrace_as_chain_group, trace_as_chain_group
from langchain_core.prompts import PromptTemplate
from langchain_core.tracers.context import tracing_v2_enabled
from langchain_core.tracers.context import tracing_enabled, tracing_v2_enabled

from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms import OpenAI
Expand Down Expand Up @@ -76,6 +76,63 @@ async def test_tracing_concurrent() -> None:
await aiosession.close()


async def test_tracing_concurrent_bw_compat_environ() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

os.environ["LANGCHAIN_HANDLER"] = "langchain"
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
aiosession = ClientSession()
llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
tasks = [agent.arun(q) for q in questions[:3]]
await asyncio.gather(*tasks)
await aiosession.close()
if "LANGCHAIN_HANDLER" in os.environ:
del os.environ["LANGCHAIN_HANDLER"]


def test_tracing_context_manager() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

llm = OpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
with tracing_enabled() as session:
assert session
agent.run(questions[0]) # this should be traced

agent.run(questions[0]) # this should not be traced


async def test_tracing_context_manager_async() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]

# start a background task
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
with tracing_enabled() as session:
assert session
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
await asyncio.gather(*tasks)

await task


async def test_tracing_v2_environment_variable() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
def test_public_api() -> None:
"""Test for changes in the public API."""
expected_all = [
"BaseRun",
"ChainRun",
"LLMRun",
"Run",
"RunTypeEnum",
"ToolRun",
"TracerSession",
"TracerSessionBase",
"TracerSessionV1",
"TracerSessionV1Base",
"TracerSessionV1Create",
]

assert sorted(schemas_all) == expected_all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pytest
from freezegun import freeze_time

from langchain_core.callbacks import CallbackManager
from langchain_core.messages import HumanMessage
from langchain_core.outputs import LLMResult
Expand Down Expand Up @@ -36,12 +35,12 @@ def _compare_run_with_error(run: Run, expected_run: Run) -> None:
assert len(expected_run.child_runs) == len(run.child_runs)
for received, expected in zip(run.child_runs, expected_run.child_runs):
_compare_run_with_error(received, expected)
received_dict = run.dict(exclude={"child_runs"})
received_err = received_dict.pop("error")
expected_dict = expected_run.dict(exclude={"child_runs"})
expected_err = expected_dict.pop("error")
received = run.dict(exclude={"child_runs"})
received_err = received.pop("error")
expected = expected_run.dict(exclude={"child_runs"})
expected_err = expected.pop("error")

assert received_dict == expected_dict
assert received == expected
if expected_err is not None:
assert received_err is not None
assert expected_err in received_err
Expand Down Expand Up @@ -406,7 +405,6 @@ def _on_llm_error(self, run: Run) -> None:
tracer = FakeTracerWithLlmErrorCallback()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.error_run is not None
_compare_run_with_error(tracer.error_run, compare_run)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from uuid import UUID

import pytest
from langsmith import Client

from langchain_core.outputs import LLMResult
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.schemas import Run
from langsmith import Client


def test_example_id_assignment_threadsafe() -> None:
Expand Down

0 comments on commit 51bf0fc

Please sign in to comment.