Skip to content

Commit

Permalink
core[patch]: include tool_calls in ai msg chunk serialization (#20291)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored and hinthornw committed Apr 26, 2024
1 parent 9b9d64c commit e819774
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 6 deletions.
20 changes: 18 additions & 2 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, List, Literal
from typing import Any, Dict, List, Literal

from langchain_core.messages.base import (
BaseMessage,
Expand Down Expand Up @@ -40,7 +40,15 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@root_validator
@property
def lc_attributes(self) -> Dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}

@root_validator()
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
Expand Down Expand Up @@ -88,6 +96,14 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@property
def lc_attributes(self) -> Dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}

@root_validator()
def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
Expand Down
67 changes: 67 additions & 0 deletions libs/core/tests/unit_tests/messages/test_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)


def test_serdes_message() -> None:
msg = AIMessage(
content=[{"text": "blah", "type": "text"}],
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
invalid_tool_calls=[
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
],
)
expected = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
],
},
}
actual = dumpd(msg)
assert actual == expected
assert load(actual) == msg


def test_serdes_message_chunk() -> None:
chunk = AIMessageChunk(
content=[{"text": "blah", "type": "text"}],
tool_call_chunks=[
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
],
)
expected = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
"kwargs": {
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
{
"name": "foobad",
"args": "blah",
"id": "booz",
"error": "Malformed args.",
}
],
"tool_call_chunks": [
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
],
},
}
actual = dumpd(chunk)
assert actual == expected
assert load(actual) == chunk

0 comments on commit e819774

Please sign in to comment.