Skip to content

Commit

Permalink
[langchain] fix OpenAIAssistantRunnable.create_assistant (langchain-a…
Browse files Browse the repository at this point in the history
…i#19081)

- **Description:** OpenAI assistants support some pre-built tools (e.g.,
`"retrieval"` and `"code_interpreter"`) and expect these as `{"type":
"code_interpreter"}`. This may have been upset by
langchain-ai#18935
- **Issue:** langchain-ai#19057
  • Loading branch information
ccurme authored and gkorland committed Mar 30, 2024
1 parent e103303 commit f9e8c4b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 4 deletions.
45 changes: 41 additions & 4 deletions libs/langchain/langchain/agents/openai_assistant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,23 @@
import json
from json import JSONDecodeError
from time import sleep
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import CallbackManager
from langchain_core.load import dumpd
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
Expand Down Expand Up @@ -76,6 +87,32 @@ def _get_openai_async_client() -> openai.AsyncOpenAI:
) from e


def _is_assistants_builtin_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> bool:
"""Determine if tool corresponds to OpenAI Assistants built-in."""
assistants_builtin_tools = ("code_interpreter", "retrieval")
return (
isinstance(tool, dict)
and ("type" in tool)
and (tool["type"] in assistants_builtin_tools)
)


def _get_assistants_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool.
Note that OpenAI assistants supports several built-in tools,
such as "code_interpreter" and "retrieval."
"""
if _is_assistants_builtin_tool(tool):
return tool # type: ignore
else:
return convert_to_openai_tool(tool)


OutputType = Union[
List[OpenAIAssistantAction],
OpenAIAssistantFinish,
Expand Down Expand Up @@ -210,7 +247,7 @@ def create_assistant(
assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
tools=[convert_to_openai_tool(tool) for tool in tools], # type: ignore
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
model=model,
file_ids=kwargs.get("file_ids"),
)
Expand Down Expand Up @@ -328,7 +365,7 @@ async def acreate_assistant(
AsyncOpenAIAssistantRunnable configured to run using the created assistant.
"""
async_client = async_client or _get_openai_async_client()
openai_tools = [convert_to_openai_tool(tool) for tool in tools]
openai_tools = [_get_assistants_tool(tool) for tool in tools]
assistant = await async_client.beta.assistants.create(
name=name,
instructions=instructions,
Expand Down
43 changes: 43 additions & 0 deletions libs/langchain/tests/unit_tests/agents/test_openai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from functools import partial
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from langchain.agents.openai_assistant import OpenAIAssistantRunnable


def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
client = AsyncMock() if use_async else MagicMock()
mock_assistant = MagicMock()
mock_assistant.id = "abc123"
client.beta.assistants.create.return_value = mock_assistant # type: ignore
return client


@pytest.mark.requires("openai")
def test_user_supplied_client() -> None:
import openai
Expand All @@ -19,3 +31,34 @@ def test_user_supplied_client() -> None:
)

assert assistant.client == client


@pytest.mark.requires("openai")
@patch(
"langchain.agents.openai_assistant.base._get_openai_client",
new=partial(_create_mock_client, use_async=False),
)
def test_create_assistant() -> None:
assistant = OpenAIAssistantRunnable.create_assistant(
name="name",
instructions="instructions",
tools=[{"type": "code_interpreter"}],
model="",
)
assert isinstance(assistant, OpenAIAssistantRunnable)


@pytest.mark.requires("openai")
@patch(
"langchain.agents.openai_assistant.base._get_openai_async_client",
new=partial(_create_mock_client, use_async=True),
)
async def test_acreate_assistant() -> None:
assistant = await OpenAIAssistantRunnable.acreate_assistant(
name="name",
instructions="instructions",
tools=[{"type": "code_interpreter"}],
model="",
client=_create_mock_client(),
)
assert isinstance(assistant, OpenAIAssistantRunnable)

0 comments on commit f9e8c4b

Please sign in to comment.