-
Notifications
You must be signed in to change notification settings - Fork 13.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core]: add tool calls message #18947
Changes from all commits
d68a71c
2319f42
f6048cc
67e3d1e
35189ad
0c14c48
684e107
b6dc899
2b373cd
28bc99e
ea74d44
6dd3ff7
2692239
1849f93
10b2c7b
1f637b9
3605678
9540adf
16bdda4
8698a86
c7d00cc
808cf82
0d6118a
6c0e3ea
508e3bc
189600f
360f9e2
ca59616
bda88a6
ac4a0da
0b515c8
6d81051
fec6db2
c5ec7fd
33127d3
e1fb61b
b1e9235
75d11dc
9d0176f
e225577
c27c9e8
5765f48
4eb4c34
9089029
a8b2733
a29c6b1
3ffcf0a
1b53ef6
ed71599
48d9355
d140449
5ea8bb4
2060f37
ff07346
71764a9
cbf66ec
f45caa0
a49f23e
9a05cba
bc85987
2ea5d66
6103cd4
9ff7ae9
e4ca284
e012414
53138b9
a79a980
bfe8fe3
ce684dd
13222fa
e845536
1f902d2
b490c57
34e629d
94bdf1c
a1feb7d
f8b3b82
52b7531
bb37dd7
225f519
fa3a8c0
3d243d4
0cc3142
c3f2805
f6d864c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,23 @@ | ||
import warnings | ||
from typing import Any, List, Literal | ||
|
||
from langchain_core.messages.base import ( | ||
BaseMessage, | ||
BaseMessageChunk, | ||
merge_content, | ||
) | ||
from langchain_core.utils._merge import merge_dicts | ||
from langchain_core.messages.tool import ( | ||
InvalidToolCall, | ||
ToolCall, | ||
ToolCallChunk, | ||
default_tool_chunk_parser, | ||
default_tool_parser, | ||
) | ||
from langchain_core.pydantic_v1 import root_validator | ||
from langchain_core.utils._merge import merge_dicts, merge_lists | ||
from langchain_core.utils.json import ( | ||
parse_partial_json, | ||
) | ||
|
||
|
||
class AIMessage(BaseMessage): | ||
|
@@ -16,13 +28,46 @@ class AIMessage(BaseMessage): | |
conversation. | ||
""" | ||
|
||
tool_calls: List[ToolCall] = [] | ||
"""If provided, tool calls associated with the message.""" | ||
invalid_tool_calls: List[InvalidToolCall] = [] | ||
"""If provided, tool calls with parsing errors associated with the message.""" | ||
|
||
type: Literal["ai"] = "ai" | ||
|
||
@classmethod | ||
def get_lc_namespace(cls) -> List[str]: | ||
"""Get the namespace of the langchain object.""" | ||
return ["langchain", "schema", "messages"] | ||
|
||
@root_validator | ||
def _backwards_compat_tool_calls(cls, values: dict) -> dict: | ||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls") | ||
tool_calls = ( | ||
values.get("tool_calls") | ||
or values.get("invalid_tool_calls") | ||
or values.get("tool_call_chunks") | ||
) | ||
if raw_tool_calls and not tool_calls: | ||
warnings.warn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How often will this get logged? Potentially multiple times per agent run? Not a bad thing necessarily |
||
"New langchain packages are available that more efficiently handle " | ||
"tool calling. Please upgrade your packages to versions that set " | ||
"message tool calls. e.g., `pip install --upgrade langchain-anthropic" | ||
"`, pip install--upgrade langchain-openai`, etc." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing space in string |
||
) | ||
try: | ||
if issubclass(cls, AIMessageChunk): # type: ignore | ||
values["tool_call_chunks"] = default_tool_chunk_parser( | ||
raw_tool_calls | ||
) | ||
else: | ||
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls) | ||
values["tool_calls"] = tool_calls | ||
values["invalid_tool_calls"] = invalid_tool_calls | ||
except Exception: | ||
pass | ||
return values | ||
|
||
|
||
AIMessage.update_forward_refs() | ||
|
||
|
@@ -35,27 +80,90 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): | |
# non-chunk variant. | ||
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501 | ||
|
||
tool_call_chunks: List[ToolCallChunk] = [] | ||
"""If provided, tool call chunks associated with the message.""" | ||
|
||
@classmethod | ||
def get_lc_namespace(cls) -> List[str]: | ||
"""Get the namespace of the langchain object.""" | ||
return ["langchain", "schema", "messages"] | ||
|
||
@root_validator() | ||
def init_tool_calls(cls, values: dict) -> dict: | ||
if not values["tool_call_chunks"]: | ||
ccurme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
values["tool_calls"] = [] | ||
values["invalid_tool_calls"] = [] | ||
return values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small nit - what if |
||
tool_calls = [] | ||
invalid_tool_calls = [] | ||
for chunk in values["tool_call_chunks"]: | ||
try: | ||
args_ = parse_partial_json(chunk["args"]) | ||
if isinstance(args_, dict): | ||
tool_calls.append( | ||
ToolCall( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you need |
||
name=chunk["name"] or "", | ||
args=args_, | ||
id=chunk["id"], | ||
) | ||
) | ||
else: | ||
raise ValueError("Malformed args.") | ||
except Exception: | ||
invalid_tool_calls.append( | ||
InvalidToolCall( | ||
name=chunk["name"], | ||
args=chunk["args"], | ||
id=chunk["id"], | ||
error="Malformed args.", | ||
) | ||
) | ||
values["tool_calls"] = tool_calls | ||
values["invalid_tool_calls"] = invalid_tool_calls | ||
return values | ||
|
||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore | ||
if isinstance(other, AIMessageChunk): | ||
if self.example != other.example: | ||
raise ValueError( | ||
"Cannot concatenate AIMessageChunks with different example values." | ||
) | ||
|
||
content = merge_content(self.content, other.content) | ||
additional_kwargs = merge_dicts( | ||
self.additional_kwargs, other.additional_kwargs | ||
) | ||
response_metadata = merge_dicts( | ||
self.response_metadata, other.response_metadata | ||
) | ||
|
||
# Merge tool call chunks | ||
if self.tool_call_chunks or other.tool_call_chunks: | ||
raw_tool_calls = merge_lists( | ||
self.tool_call_chunks, | ||
other.tool_call_chunks, | ||
) | ||
if raw_tool_calls: | ||
tool_call_chunks = [ | ||
ToolCallChunk( | ||
name=rtc.get("name"), | ||
args=rtc.get("args"), | ||
index=rtc.get("index"), | ||
id=rtc.get("id"), | ||
) | ||
for rtc in raw_tool_calls | ||
] | ||
else: | ||
tool_call_chunks = [] | ||
else: | ||
tool_call_chunks = [] | ||
|
||
return self.__class__( | ||
example=self.example, | ||
content=merge_content(self.content, other.content), | ||
additional_kwargs=merge_dicts( | ||
self.additional_kwargs, other.additional_kwargs | ||
), | ||
response_metadata=merge_dicts( | ||
self.response_metadata, other.response_metadata | ||
), | ||
content=content, | ||
additional_kwargs=additional_kwargs, | ||
tool_call_chunks=tool_call_chunks, | ||
response_metadata=response_metadata, | ||
id=self.id, | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming here is confusing since it's shadowed below