Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core]: add tool calls message #18947

Merged
merged 85 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
d68a71c
rfc: tool calls msg
baskaryan Mar 11, 2024
2319f42
fmt
baskaryan Mar 11, 2024
f6048cc
fmt
baskaryan Mar 12, 2024
67e3d1e
fmt
baskaryan Mar 12, 2024
35189ad
merge master
ccurme Mar 29, 2024
0c14c48
ToolMessage --> ToolCallsMessage in openai
ccurme Mar 29, 2024
684e107
lint
ccurme Mar 29, 2024
b6dc899
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 1, 2024
2b373cd
factor out parse_partial_json and parse_tool_calls
ccurme Apr 1, 2024
28bc99e
add ToolCallsMessageChunk
ccurme Apr 1, 2024
ea74d44
add to openai
ccurme Apr 1, 2024
6dd3ff7
format
ccurme Apr 1, 2024
2692239
lint
ccurme Apr 1, 2024
1849f93
accept None
ccurme Apr 1, 2024
10b2c7b
update test_imports
ccurme Apr 1, 2024
1f637b9
ToolMessage --> ToolOutputMessage in tests
ccurme Apr 1, 2024
3605678
update snapshots
ccurme Apr 1, 2024
9540adf
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 1, 2024
16bdda4
add ToolCall to module init
ccurme Apr 1, 2024
8698a86
add to integration test
ccurme Apr 1, 2024
c7d00cc
lint
ccurme Apr 1, 2024
808cf82
update serializable mapping
ccurme Apr 1, 2024
0d6118a
update mistral
ccurme Apr 1, 2024
6c0e3ea
update test_imports
ccurme Apr 1, 2024
508e3bc
ToolOutputMessage --> ToolMessage
ccurme Apr 1, 2024
189600f
update snapshots
ccurme Apr 1, 2024
360f9e2
lint
ccurme Apr 1, 2024
ca59616
clean up ToolMessage
ccurme Apr 2, 2024
bda88a6
update JsonOutputToolsParser
ccurme Apr 2, 2024
ac4a0da
lint
ccurme Apr 2, 2024
0b515c8
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 2, 2024
6d81051
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 3, 2024
fec6db2
cr
ccurme Apr 3, 2024
c5ec7fd
update ToolCallsMessageChunk.__add__
ccurme Apr 3, 2024
33127d3
move tool calls msgs to ai
ccurme Apr 3, 2024
e1fb61b
update AIMessageChunk.__add__
ccurme Apr 3, 2024
b1e9235
move parse_tool_calls
ccurme Apr 3, 2024
75d11dc
fix bug
ccurme Apr 3, 2024
9d0176f
lint
ccurme Apr 3, 2024
e225577
rename ToolCallsMessage
ccurme Apr 3, 2024
c27c9e8
rename ToolCallsMessageChunk
ccurme Apr 3, 2024
5765f48
lint
ccurme Apr 3, 2024
4eb4c34
merge tool calls
ccurme Apr 3, 2024
9089029
fmt (#19968)
baskaryan Apr 3, 2024
a8b2733
fix bug
ccurme Apr 3, 2024
a29c6b1
catch KeyError
ccurme Apr 3, 2024
3ffcf0a
update mistral streaming
ccurme Apr 3, 2024
1b53ef6
cr
ccurme Apr 3, 2024
ed71599
update mistral test
ccurme Apr 3, 2024
48d9355
remove check
ccurme Apr 3, 2024
d140449
update docstring
ccurme Apr 4, 2024
5ea8bb4
add tests
ccurme Apr 4, 2024
2060f37
add tests
ccurme Apr 4, 2024
ff07346
update fireworks
ccurme Apr 4, 2024
71764a9
update cohere and add test
ccurme Apr 4, 2024
cbf66ec
add to openai tests
ccurme Apr 4, 2024
f45caa0
fix bug
ccurme Apr 4, 2024
a49f23e
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 4, 2024
9a05cba
update groq
ccurme Apr 4, 2024
bc85987
update agents
ccurme Apr 4, 2024
2ea5d66
Merge branch 'master' into bagatur/tool_calls_msg
baskaryan Apr 4, 2024
6103cd4
Merge branch 'bagatur/tool_calls_msg' of github.com:langchain-ai/lang…
ccurme Apr 4, 2024
9ff7ae9
Revert "update agents"
ccurme Apr 4, 2024
e4ca284
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 4, 2024
e012414
update anthropic
ccurme Apr 4, 2024
53138b9
cr
ccurme Apr 4, 2024
a79a980
use tool call msgs (#20051)
baskaryan Apr 5, 2024
bfe8fe3
update docstrings
ccurme Apr 5, 2024
ce684dd
spell check
ccurme Apr 5, 2024
13222fa
export json output parsers in langchain
ccurme Apr 5, 2024
e845536
undo stray import changes
ccurme Apr 5, 2024
1f902d2
tweak docstring
ccurme Apr 5, 2024
b490c57
move tool calls to AIMessage (#20090)
ccurme Apr 5, 2024
34e629d
merge (delete cohere)
ccurme Apr 5, 2024
94bdf1c
update
ccurme Apr 5, 2024
a1feb7d
merge
ccurme Apr 6, 2024
f8b3b82
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 6, 2024
52b7531
best effort parsing + handle parsing errors (#20111)
ccurme Apr 6, 2024
bb37dd7
fix anthropic
ccurme Apr 6, 2024
225f519
cr
ccurme Apr 8, 2024
fa3a8c0
update tool calls to typeddict (#20208)
ccurme Apr 9, 2024
3d243d4
fix docstring
ccurme Apr 9, 2024
0cc3142
update (#20215)
ccurme Apr 9, 2024
c3f2805
fix bug
ccurme Apr 9, 2024
f6d864c
fix bug
ccurme Apr 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
423 changes: 423 additions & 0 deletions cookbook/tool_call_messages.ipynb

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion libs/core/langchain_core/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

""" # noqa: E501

from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
Expand Down Expand Up @@ -50,9 +56,12 @@
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"InvalidToolCall",
"MessageLikeRepresentation",
"SystemMessage",
"SystemMessageChunk",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
"_message_from_dict",
Expand Down
124 changes: 116 additions & 8 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import warnings
from typing import Any, List, Literal

from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.messages.tool import (
InvalidToolCall,
ToolCall,
ToolCallChunk,
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import (
parse_partial_json,
)


class AIMessage(BaseMessage):
Expand All @@ -16,13 +28,46 @@ class AIMessage(BaseMessage):
conversation.
"""

tool_calls: List[ToolCall] = []
"""If provided, tool calls associated with the message."""
invalid_tool_calls: List[InvalidToolCall] = []
"""If provided, tool calls with parsing errors associated with the message."""

type: Literal["ai"] = "ai"

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@root_validator
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming here is confusing since it's shadowed below

values.get("tool_calls")
or values.get("invalid_tool_calls")
or values.get("tool_call_chunks")
)
if raw_tool_calls and not tool_calls:
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How often will this get logged? Potentially multiple times per agent run?

Not a bad thing necessarily

"New langchain packages are available that more efficiently handle "
"tool calling. Please upgrade your packages to versions that set "
"message tool calls. e.g., `pip install --upgrade langchain-anthropic"
"`, pip install--upgrade langchain-openai`, etc."
Copy link
Contributor

@jacoblee93 jacoblee93 Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing space in string

)
try:
if issubclass(cls, AIMessageChunk): # type: ignore
values["tool_call_chunks"] = default_tool_chunk_parser(
raw_tool_calls
)
else:
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
except Exception:
pass
return values


AIMessage.update_forward_refs()

Expand All @@ -35,27 +80,90 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501

tool_call_chunks: List[ToolCallChunk] = []
"""If provided, tool call chunks associated with the message."""

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@root_validator()
def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
ccurme marked this conversation as resolved.
Show resolved Hide resolved
values["tool_calls"] = []
values["invalid_tool_calls"] = []
return values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit - what if values["tool_calls"] is explicitly passed?

tool_calls = []
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
try:
args_ = parse_partial_json(chunk["args"])
if isinstance(args_, dict):
tool_calls.append(
ToolCall(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need index?

name=chunk["name"] or "",
args=args_,
id=chunk["id"],
)
)
else:
raise ValueError("Malformed args.")
except Exception:
invalid_tool_calls.append(
InvalidToolCall(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
error="Malformed args.",
)
)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
return values

def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)

content = merge_content(self.content, other.content)
additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
)

# Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks:
raw_tool_calls = merge_lists(
self.tool_call_chunks,
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
else:
tool_call_chunks = []

return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
id=self.id,
)

Expand Down
114 changes: 113 additions & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, List, Literal
import json
from typing import Any, Dict, List, Literal, Optional, Tuple

from typing_extensions import TypedDict

from langchain_core.messages.base import (
BaseMessage,
Expand Down Expand Up @@ -61,3 +64,112 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
)

return super().__add__(other)


class ToolCall(TypedDict):
"""A call to a tool.

Attributes:
name: (str) the name of the tool to be called
args: (dict) the arguments to the tool call
id: (str) if provided, an identifier associated with the tool call
"""

name: str
args: Dict[str, Any]
id: Optional[str]


class ToolCallChunk(TypedDict):
"""A chunk of a tool call (e.g., as part of a stream).

When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.

Example:

.. code-block:: python
ccurme marked this conversation as resolved.
Show resolved Hide resolved

left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
(
AIMessageChunk(content="", tool_call_chunks=left_chunks)
+ AIMessageChunk(content="", tool_call_chunks=right_chunks)
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]

Attributes:
name: (str) if provided, a substring of the name of the tool to be called
args: (str) if provided, a JSON substring of the arguments to the tool call
id: (str) if provided, a substring of an identifier for the tool call
index: (int) if provided, the index of the tool call in a sequence
"""

name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]


class InvalidToolCall(TypedDict):
"""Allowance for errors made by LLM.

Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
"""

name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]


def default_tool_parser(
raw_tool_calls: List[dict],
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
"""Best-effort parsing of tools."""
tool_calls = []
invalid_tool_calls = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
continue
else:
function_name = tool_call["function"]["name"]
try:
function_args = json.loads(tool_call["function"]["arguments"])
parsed = ToolCall(
name=function_name or "",
args=function_args or {},
id=tool_call.get("id"),
)
tool_calls.append(parsed)
except json.JSONDecodeError:
invalid_tool_calls.append(
InvalidToolCall(
name=function_name,
args=tool_call["function"]["arguments"],
id=tool_call.get("id"),
error="Malformed args.",
)
)
return tool_calls, invalid_tool_calls


def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]:
"""Best-effort parsing of tool chunks."""
tool_call_chunks = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
function_args = None
function_name = None
else:
function_args = tool_call["function"]["arguments"]
function_name = tool_call["function"]["name"]
parsed = ToolCallChunk(
name=function_name,
args=function_args,
id=tool_call.get("id"),
index=tool_call.get("index"),
)
tool_call_chunks.append(parsed)
return tool_call_chunks
10 changes: 8 additions & 2 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
AIMessage,
AIMessageChunk,
)
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
Expand Down Expand Up @@ -119,8 +122,11 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
if not isinstance(chunk, BaseMessageChunk):
return chunk
# chunk classes always have the equivalent non-chunk class as their first parent
ignore_keys = ["type"]
if isinstance(chunk, AIMessageChunk):
ignore_keys.append("tool_call_chunks")
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
**{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys}
)


Expand Down