Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[patch]: use tool_calls in request #20272

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 40 additions & 8 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
import logging
import os
import sys
Expand Down Expand Up @@ -50,8 +51,10 @@
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
)
Expand Down Expand Up @@ -169,20 +172,25 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None

tool_call_supported_props = {"id", "type", "function"}
message_dict["tool_calls"] = [
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
for tool_call in message_dict["tool_calls"]
]
else:
pass
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or None
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
Expand Down Expand Up @@ -1067,3 +1075,27 @@ class AnswerWithJustification(BaseModel):

def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)


def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}


def _lc_invalid_tool_call_to_openai_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.outputs import (
Expand Down Expand Up @@ -519,6 +520,49 @@ def test_tool_use() -> None:
llm_with_tool.invoke(msgs)


def test_manual_tool_call_msg() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content="",
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="foo",
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
assert output.content
# Should not have called the tool again.
assert not output.tool_calls and not output.invalid_tool_calls

# OpenAI should error when tool call id doesn't match across AIMessage and
# ToolMessage
msgs = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content="",
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="bar",
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
with pytest.raises(Exception):
llm_with_tool.invoke(msgs)


def test_openai_structured_output() -> None:
class MyModel(BaseModel):
"""A Person"""
Expand Down
13 changes: 9 additions & 4 deletions libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
},
"type": "function",
Expand All @@ -126,7 +126,7 @@ def test__convert_dict_to_message_tool_call() -> None:
assert _convert_message_to_dict(expected_output) == message

# Test malformed tool call
raw_tool_calls = [
raw_tool_calls: list = [
{
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
Expand All @@ -138,12 +138,13 @@ def test__convert_dict_to_message_tool_call() -> None:
{
"id": "call_abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
},
"type": "function",
},
]
raw_tool_calls = list(sorted(raw_tool_calls, key=lambda x: x["id"]))
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
Expand All @@ -166,7 +167,11 @@ def test__convert_dict_to_message_tool_call() -> None:
],
)
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
reverted_message_dict = _convert_message_to_dict(expected_output)
reverted_message_dict["tool_calls"] = list(
sorted(reverted_message_dict["tool_calls"], key=lambda x: x["id"])
)
assert reverted_message_dict == message


@pytest.fixture
Expand Down