Skip to content

Commit

Permalink
core[minor], ...: add tool calls message (#18947)
Browse files Browse the repository at this point in the history
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]

```python
class ToolCall(TypedDict):
    name: str
    args: Dict[str, Any]
    id: Optional[str]

class InvalidToolCall(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    error: Optional[str]

class ToolCallChunk(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    index: Optional[int]


class AIMessage(BaseMessage):
    ...
    tool_calls: List[ToolCall] = []
    invalid_tool_calls: List[InvalidToolCall] = []
    ...


class AIMessageChunk(AIMessage, BaseMessageChunk):
    ...
    tool_call_chunks: Optional[List[ToolCallChunk]] = None
    ...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
langchain-ai/langchainplus#3561
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
  - additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).

Packages outside of `partners/`:
- langchain-ai/langchain-cohere#7
- https://github.com/langchain-ai/langchain-google/pull/123/files

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
  • Loading branch information
2 people authored and hinthornw committed Apr 26, 2024
1 parent b8e6641 commit ada9f40
Show file tree
Hide file tree
Showing 31 changed files with 2,347 additions and 389 deletions.
423 changes: 423 additions & 0 deletions cookbook/tool_call_messages.ipynb

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion libs/core/langchain_core/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
""" # noqa: E501

from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
Expand Down Expand Up @@ -50,9 +56,12 @@
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"InvalidToolCall",
"MessageLikeRepresentation",
"SystemMessage",
"SystemMessageChunk",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
"_message_from_dict",
Expand Down
124 changes: 116 additions & 8 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import warnings
from typing import Any, List, Literal

from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.messages.tool import (
InvalidToolCall,
ToolCall,
ToolCallChunk,
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import (
parse_partial_json,
)


class AIMessage(BaseMessage):
Expand All @@ -16,13 +28,46 @@ class AIMessage(BaseMessage):
conversation.
"""

tool_calls: List[ToolCall] = []
"""If provided, tool calls associated with the message."""
invalid_tool_calls: List[InvalidToolCall] = []
"""If provided, tool calls with parsing errors associated with the message."""

type: Literal["ai"] = "ai"

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@root_validator
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
values.get("tool_calls")
or values.get("invalid_tool_calls")
or values.get("tool_call_chunks")
)
if raw_tool_calls and not tool_calls:
warnings.warn(
"New langchain packages are available that more efficiently handle "
"tool calling. Please upgrade your packages to versions that set "
"message tool calls. e.g., `pip install --upgrade langchain-anthropic"
"`, pip install--upgrade langchain-openai`, etc."
)
try:
if issubclass(cls, AIMessageChunk): # type: ignore
values["tool_call_chunks"] = default_tool_chunk_parser(
raw_tool_calls
)
else:
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
except Exception:
pass
return values


AIMessage.update_forward_refs()

Expand All @@ -35,27 +80,90 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501

tool_call_chunks: List[ToolCallChunk] = []
"""If provided, tool call chunks associated with the message."""

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@root_validator()
def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
values["tool_calls"] = []
values["invalid_tool_calls"] = []
return values
tool_calls = []
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
try:
args_ = parse_partial_json(chunk["args"])
if isinstance(args_, dict):
tool_calls.append(
ToolCall(
name=chunk["name"] or "",
args=args_,
id=chunk["id"],
)
)
else:
raise ValueError("Malformed args.")
except Exception:
invalid_tool_calls.append(
InvalidToolCall(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
error="Malformed args.",
)
)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
return values

def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)

content = merge_content(self.content, other.content)
additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
)

# Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks:
raw_tool_calls = merge_lists(
self.tool_call_chunks,
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
else:
tool_call_chunks = []

return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
id=self.id,
)

Expand Down
114 changes: 113 additions & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, List, Literal
import json
from typing import Any, Dict, List, Literal, Optional, Tuple

from typing_extensions import TypedDict

from langchain_core.messages.base import (
BaseMessage,
Expand Down Expand Up @@ -61,3 +64,112 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
)

return super().__add__(other)


class ToolCall(TypedDict):
"""A call to a tool.
Attributes:
name: (str) the name of the tool to be called
args: (dict) the arguments to the tool call
id: (str) if provided, an identifier associated with the tool call
"""

name: str
args: Dict[str, Any]
id: Optional[str]


class ToolCallChunk(TypedDict):
"""A chunk of a tool call (e.g., as part of a stream).
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.
Example:
.. code-block:: python
left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
(
AIMessageChunk(content="", tool_call_chunks=left_chunks)
+ AIMessageChunk(content="", tool_call_chunks=right_chunks)
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]
Attributes:
name: (str) if provided, a substring of the name of the tool to be called
args: (str) if provided, a JSON substring of the arguments to the tool call
id: (str) if provided, a substring of an identifier for the tool call
index: (int) if provided, the index of the tool call in a sequence
"""

name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]


class InvalidToolCall(TypedDict):
"""Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
"""

name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]


def default_tool_parser(
raw_tool_calls: List[dict],
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
"""Best-effort parsing of tools."""
tool_calls = []
invalid_tool_calls = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
continue
else:
function_name = tool_call["function"]["name"]
try:
function_args = json.loads(tool_call["function"]["arguments"])
parsed = ToolCall(
name=function_name or "",
args=function_args or {},
id=tool_call.get("id"),
)
tool_calls.append(parsed)
except json.JSONDecodeError:
invalid_tool_calls.append(
InvalidToolCall(
name=function_name,
args=tool_call["function"]["arguments"],
id=tool_call.get("id"),
error="Malformed args.",
)
)
return tool_calls, invalid_tool_calls


def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]:
"""Best-effort parsing of tool chunks."""
tool_call_chunks = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
function_args = None
function_name = None
else:
function_args = tool_call["function"]["arguments"]
function_name = tool_call["function"]["name"]
parsed = ToolCallChunk(
name=function_name,
args=function_args,
id=tool_call.get("id"),
index=tool_call.get("index"),
)
tool_call_chunks.append(parsed)
return tool_call_chunks
10 changes: 8 additions & 2 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
AIMessage,
AIMessageChunk,
)
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
Expand Down Expand Up @@ -119,8 +122,11 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
if not isinstance(chunk, BaseMessageChunk):
return chunk
# chunk classes always have the equivalent non-chunk class as their first parent
ignore_keys = ["type"]
if isinstance(chunk, AIMessageChunk):
ignore_keys.append("tool_call_chunks")
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
**{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys}
)


Expand Down

0 comments on commit ada9f40

Please sign in to comment.