-
Notifications
You must be signed in to change notification settings - Fork 13.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core]: add tool calls message #18947
Changes from 80 commits
d68a71c
2319f42
f6048cc
67e3d1e
35189ad
0c14c48
684e107
b6dc899
2b373cd
28bc99e
ea74d44
6dd3ff7
2692239
1849f93
10b2c7b
1f637b9
3605678
9540adf
16bdda4
8698a86
c7d00cc
808cf82
0d6118a
6c0e3ea
508e3bc
189600f
360f9e2
ca59616
bda88a6
ac4a0da
0b515c8
6d81051
fec6db2
c5ec7fd
33127d3
e1fb61b
b1e9235
75d11dc
9d0176f
e225577
c27c9e8
5765f48
4eb4c34
9089029
a8b2733
a29c6b1
3ffcf0a
1b53ef6
ed71599
48d9355
d140449
5ea8bb4
2060f37
ff07346
71764a9
cbf66ec
f45caa0
a49f23e
9a05cba
bc85987
2ea5d66
6103cd4
9ff7ae9
e4ca284
e012414
53138b9
a79a980
bfe8fe3
ce684dd
13222fa
e845536
1f902d2
b490c57
34e629d
94bdf1c
a1feb7d
f8b3b82
52b7531
bb37dd7
225f519
fa3a8c0
3d243d4
0cc3142
c3f2805
f6d864c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,24 @@ | ||
from typing import Any, List, Literal | ||
import warnings | ||
from json import JSONDecodeError | ||
from typing import Any, List, Literal, Optional, Union | ||
|
||
from langchain_core.messages.base import ( | ||
BaseMessage, | ||
BaseMessageChunk, | ||
merge_content, | ||
) | ||
from langchain_core.utils._merge import merge_dicts | ||
from langchain_core.messages.tool import ( | ||
InvalidToolCall, | ||
ToolCall, | ||
ToolCallChunk, | ||
default_tool_chunk_parser, | ||
default_tool_parser, | ||
) | ||
from langchain_core.pydantic_v1 import root_validator | ||
from langchain_core.utils._merge import merge_dicts, merge_lists | ||
from langchain_core.utils.json import ( | ||
parse_partial_json, | ||
) | ||
|
||
|
||
class AIMessage(BaseMessage): | ||
|
@@ -16,13 +29,36 @@ class AIMessage(BaseMessage): | |
conversation. | ||
""" | ||
|
||
tool_calls: Optional[List[Union[ToolCall, InvalidToolCall]]] = None | ||
"""If provided, tool calls associated with the message.""" | ||
|
||
type: Literal["ai"] = "ai" | ||
|
||
@classmethod | ||
def get_lc_namespace(cls) -> List[str]: | ||
"""Get the namespace of the langchain object.""" | ||
return ["langchain", "schema", "messages"] | ||
|
||
@root_validator | ||
def _backwards_compat_tool_calls(cls, values: dict) -> dict: | ||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls") | ||
tool_calls = values.get("tool_calls") or values.get("tool_call_chunks") | ||
if raw_tool_calls and not tool_calls: | ||
warnings.warn( | ||
"New langchain packages are available that more efficiently handle " | ||
"tool calling. Please upgrade your packages to versions that set " | ||
"message tool calls. e.g., `pip install --upgrade langchain-anthropic" | ||
"`, pip install--upgrade langchain-openai`, etc." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing space in string |
||
try: | ||
if issubclass(cls, AIMessageChunk): # type: ignore | ||
values["tool_call_chunks"] = default_tool_chunk_parser(raw_tool_calls) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ooc why is this needed on AIMessage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we don't special-case AIMessageChunk, this method attempts to set tool_calls, and then AIMessageChunk's init_tool_calls raises the ValueError on line 85. |
||
else: | ||
values["tool_calls"] = default_tool_parser(raw_tool_calls) | ||
except Exception: | ||
ccurme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pass | ||
return values | ||
|
||
|
||
AIMessage.update_forward_refs() | ||
|
||
|
@@ -35,27 +71,73 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): | |
# non-chunk variant. | ||
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501 | ||
|
||
tool_call_chunks: Optional[List[ToolCallChunk]] = None | ||
"""If provided, tool call chunks associated with the message.""" | ||
|
||
@classmethod | ||
def get_lc_namespace(cls) -> List[str]: | ||
"""Get the namespace of the langchain object.""" | ||
return ["langchain", "schema", "messages"] | ||
|
||
@root_validator() | ||
def init_tool_calls(cls, values: dict) -> dict: | ||
if values["tool_calls"] is not None: | ||
raise ValueError( | ||
"tool_calls cannot be set on AIMessageChunk, it is derived " | ||
"from tool_call_chunks." | ||
) | ||
if not values["tool_call_chunks"]: | ||
values["tool_calls"] = values["tool_call_chunks"] | ||
ccurme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return values | ||
tool_calls = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small nit - what if |
||
for chunk in values["tool_call_chunks"]: | ||
try: | ||
args_ = parse_partial_json(chunk.args) | ||
args_ = args_ if isinstance(args_, dict) else {} | ||
except (JSONDecodeError, TypeError): # None args raise TypeError | ||
args_ = {} | ||
tool_calls.append( | ||
ToolCall( | ||
name=chunk.name or "", args=args_, index=chunk.index, id=chunk.id | ||
) | ||
) | ||
values["tool_calls"] = tool_calls | ||
return values | ||
|
||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore | ||
if isinstance(other, AIMessageChunk): | ||
if self.example != other.example: | ||
raise ValueError( | ||
"Cannot concatenate AIMessageChunks with different example values." | ||
) | ||
|
||
content = merge_content(self.content, other.content) | ||
additional_kwargs = merge_dicts( | ||
self.additional_kwargs, other.additional_kwargs | ||
) | ||
response_metadata = merge_dicts( | ||
self.response_metadata, other.response_metadata | ||
) | ||
|
||
# Merge tool call chunks | ||
if self.tool_call_chunks or other.tool_call_chunks: | ||
raw_tool_calls = merge_lists( | ||
[tc.dict() for tc in self.tool_call_chunks or []], | ||
[tc.dict() for tc in other.tool_call_chunks or []], | ||
) | ||
if raw_tool_calls: | ||
tool_call_chunks = [ToolCallChunk(**rtc) for rtc in raw_tool_calls] | ||
else: | ||
tool_call_chunks = None | ||
else: | ||
tool_call_chunks = None | ||
|
||
return self.__class__( | ||
example=self.example, | ||
content=merge_content(self.content, other.content), | ||
additional_kwargs=merge_dicts( | ||
self.additional_kwargs, other.additional_kwargs | ||
), | ||
response_metadata=merge_dicts( | ||
self.response_metadata, other.response_metadata | ||
), | ||
content=content, | ||
additional_kwargs=additional_kwargs, | ||
tool_call_chunks=tool_call_chunks, | ||
response_metadata=response_metadata, | ||
id=self.id, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from typing import Any, List, Literal | ||
import json | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
from langchain_core.load import Serializable | ||
from langchain_core.messages.base import ( | ||
BaseMessage, | ||
BaseMessageChunk, | ||
|
@@ -61,3 +63,105 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore | |
) | ||
|
||
return super().__add__(other) | ||
|
||
|
||
class ToolCall(Serializable): | ||
"""A call to a tool. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How will this get serialized? |
||
|
||
Attributes: | ||
name: (str) the name of the tool to be called | ||
args: (dict) the arguments to the tool call | ||
id: (str) if provided, an identifier associated with the tool call | ||
index: (int) if provided, the index of the tool call in a sequence | ||
of content | ||
""" | ||
|
||
name: str | ||
args: Dict[str, Any] | ||
id: Optional[str] = None | ||
index: Optional[int] = None | ||
|
||
|
||
class ToolCallChunk(Serializable): | ||
"""A chunk of a tool call (e.g., as part of a stream). | ||
|
||
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__), | ||
all string attributes are concatenated. Chunks are only merged if their | ||
values of `index` are equal and not None. | ||
|
||
Example: | ||
|
||
.. code-block:: python | ||
|
||
ccurme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)] | ||
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)] | ||
( | ||
AIMessageChunk(content="", tool_call_chunks=left_chunks) | ||
+ AIMessageChunk(content="", tool_call_chunks=right_chunks) | ||
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)] | ||
|
||
Attributes: | ||
name: (str) if provided, a substring of the name of the tool to be called | ||
args: (str) if provided, a JSON substring of the arguments to the tool call | ||
id: (str) if provided, a substring of an identifier for the tool call | ||
index: (int) if provided, the index of the tool call in a sequence | ||
""" | ||
|
||
name: Optional[str] = None | ||
args: Optional[str] = None | ||
id: Optional[str] = None | ||
index: Optional[int] = None | ||
|
||
|
||
class InvalidToolCall(Serializable): | ||
"""Allowance for errors made by LLM. | ||
|
||
Here we add an `error` key to surface errors made during generation | ||
(e.g., invalid JSON arguments.) | ||
""" | ||
|
||
name: Optional[str] = None | ||
args: Optional[str] = None | ||
id: Optional[str] = None | ||
index: Optional[int] = None | ||
error: Optional[str] = None | ||
|
||
|
||
def default_tool_parser(raw_tool_calls: List[dict]) -> List[ToolCall]: | ||
"""Best-effort parsing of tools.""" | ||
tool_calls = [] | ||
for tool_call in raw_tool_calls: | ||
if "function" not in tool_call: | ||
function_args = None | ||
function_name = None | ||
else: | ||
function_args = json.loads(tool_call["function"]["arguments"]) | ||
function_name = tool_call["function"]["name"] | ||
parsed = ToolCall( | ||
name=function_name or "", | ||
args=function_args or {}, | ||
id=tool_call.get("id"), | ||
index=tool_call.get("index"), | ||
) | ||
tool_calls.append(parsed) | ||
return tool_calls | ||
|
||
|
||
def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]: | ||
"""Best-effort parsing of tool chunks.""" | ||
tool_call_chunks = [] | ||
for tool_call in raw_tool_calls: | ||
if "function" not in tool_call: | ||
function_args = None | ||
function_name = None | ||
else: | ||
function_args = tool_call["function"]["arguments"] | ||
function_name = tool_call["function"]["name"] | ||
parsed = ToolCallChunk( | ||
name=function_name, | ||
args=function_args, | ||
id=tool_call.get("id"), | ||
index=tool_call.get("index"), | ||
) | ||
tool_call_chunks.append(parsed) | ||
return tool_call_chunks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How often will this get logged? Potentially multiple times per agent run?
Not a bad thing necessarily