Skip to content

Commit

Permalink
openai[patch]: use tool_calls in request (langchain-ai#20272)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored and junkeon committed Apr 16, 2024
1 parent b55451a commit 5bb2cda
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 12 deletions.
48 changes: 40 additions & 8 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
import logging
import os
import sys
Expand Down Expand Up @@ -50,8 +51,10 @@
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
)
Expand Down Expand Up @@ -169,20 +172,25 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None

tool_call_supported_props = {"id", "type", "function"}
message_dict["tool_calls"] = [
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
for tool_call in message_dict["tool_calls"]
]
else:
pass
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or None
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
Expand Down Expand Up @@ -1067,3 +1075,27 @@ class AnswerWithJustification(BaseModel):

def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)


def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}


def _lc_invalid_tool_call_to_openai_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.outputs import (
Expand Down Expand Up @@ -519,6 +520,49 @@ def test_tool_use() -> None:
llm_with_tool.invoke(msgs)


def test_manual_tool_call_msg() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content="",
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="foo",
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
assert output.content
# Should not have called the tool again.
assert not output.tool_calls and not output.invalid_tool_calls

# OpenAI should error when tool call id doesn't match across AIMessage and
# ToolMessage
msgs = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content="",
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="bar",
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
with pytest.raises(Exception):
llm_with_tool.invoke(msgs)


def test_openai_structured_output() -> None:
class MyModel(BaseModel):
"""A Person"""
Expand Down
13 changes: 9 additions & 4 deletions libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
},
"type": "function",
Expand All @@ -126,7 +126,7 @@ def test__convert_dict_to_message_tool_call() -> None:
assert _convert_message_to_dict(expected_output) == message

# Test malformed tool call
raw_tool_calls = [
raw_tool_calls: list = [
{
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
Expand All @@ -138,12 +138,13 @@ def test__convert_dict_to_message_tool_call() -> None:
{
"id": "call_abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
},
"type": "function",
},
]
raw_tool_calls = list(sorted(raw_tool_calls, key=lambda x: x["id"]))
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
Expand All @@ -166,7 +167,11 @@ def test__convert_dict_to_message_tool_call() -> None:
],
)
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
reverted_message_dict = _convert_message_to_dict(expected_output)
reverted_message_dict["tool_calls"] = list(
sorted(reverted_message_dict["tool_calls"], key=lambda x: x["id"])
)
assert reverted_message_dict == message


@pytest.fixture
Expand Down

0 comments on commit 5bb2cda

Please sign in to comment.