Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: include tool_calls in ai msg chunk serialization #20291

Merged
merged 5 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 18 additions & 2 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, List, Literal
from typing import Any, Dict, List, Literal

from langchain_core.messages.base import (
BaseMessage,
Expand Down Expand Up @@ -40,7 +40,15 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@root_validator
@property
def lc_attributes(self) -> Dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}

@root_validator()
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
Expand Down Expand Up @@ -88,6 +96,14 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@property
def lc_attributes(self) -> Dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are tool_call_chunks serialized properly without being in here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default behavior is to serialize (only) attributes that are passed in as init args

"invalid_tool_calls": self.invalid_tool_calls,
}

@root_validator()
def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
Expand Down
67 changes: 67 additions & 0 deletions libs/core/tests/unit_tests/messages/test_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)


def test_serdes_message() -> None:
msg = AIMessage(
[{"text": "blah", "type": "text"}],
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
invalid_tool_calls=[
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
],
)
expected = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
],
},
}
actual = dumpd(msg)
assert actual == expected
assert load(actual) == msg


def test_serdes_message_chunk() -> None:
chunk = AIMessageChunk(
[{"text": "blah", "type": "text"}],
tool_call_chunks=[
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
],
)
expected = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
"kwargs": {
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
{
"name": "foobad",
"args": "blah",
"id": "booz",
"error": "Malformed args.",
}
],
"tool_call_chunks": [
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
],
},
}
actual = dumpd(chunk)
assert actual == expected
assert load(actual) == chunk