Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mistral: add IDs to tool calls #20299

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore


class ToolCall(TypedDict):
"""A call to a tool.
"""Represents a request to call a tool.

Attributes:
name: (str) the name of the tool to be called
Expand Down
9 changes: 5 additions & 4 deletions libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import ToolCall
from langchain_core.output_parsers import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
Expand All @@ -17,7 +18,7 @@ def parse_tool_call(
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[Dict[str, Any]]:
) -> Optional[ToolCall]:
"""Parse a single tool call."""
if "function" not in raw_tool_call:
return None
Expand All @@ -44,7 +45,7 @@ def parse_tool_call(
"args": function_args or {},
}
if return_id:
parsed["id"] = raw_tool_call["id"]
parsed["id"] = raw_tool_call.get("id")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is breaking if one expects a KeyError when using JsonOutputToolsParser(return_id=True) and there are no IDs.

return parsed


Expand All @@ -67,9 +68,9 @@ def parse_tool_calls(
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> List[dict]:
) -> List[ToolCall]:
"""Parse a list of tool calls."""
final_tools = []
final_tools: List[ToolCall] = []
exceptions = []
for tool_call in raw_tool_calls:
try:
Expand Down
52 changes: 32 additions & 20 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import uuid
from operator import itemgetter
from typing import (
Any,
Expand Down Expand Up @@ -91,14 +92,18 @@ def _convert_mistral_chat_message_to_message(
for raw_tool_call in raw_tool_calls:
try:
parsed: dict = cast(
dict, parse_tool_call(raw_tool_call, return_id=False)
)
tool_calls.append(
{
**parsed,
**{"id": None},
},
dict, parse_tool_call(raw_tool_call, return_id=True)
)
if not parsed["id"]:
tool_call_id = uuid.uuid4().hex[:]
tool_calls.append(
{
**parsed,
**{"id": tool_call_id},
},
)
else:
tool_calls.append(parsed)
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
Expand Down Expand Up @@ -160,15 +165,20 @@ def _convert_delta_to_message_chunk(
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in raw_tool_calls
]
tool_call_chunks = []
for raw_tool_call in raw_tool_calls:
if not raw_tool_call.get("index") and not raw_tool_call.get("id"):
tool_call_id = uuid.uuid4().hex[:]
else:
tool_call_id = raw_tool_call.get("id")
tool_call_chunks.append(
{
"name": raw_tool_call["function"].get("name"),
"args": raw_tool_call["function"].get("arguments"),
"id": tool_call_id,
"index": raw_tool_call.get("index"),
}
)
except KeyError:
pass
else:
Expand All @@ -195,15 +205,17 @@ def _convert_message_to_mistral_chat_message(
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs:
tool_calls = [
{
tool_calls = []
for tc in message.additional_kwargs["tool_calls"]:
chunk = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
for tc in message.additional_kwargs["tool_calls"]
]
if _id := tc.get("id"):
chunk["id"] = _id
tool_calls.append(chunk)
else:
tool_calls = []
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
AIMessage,
AIMessageChunk,
HumanMessage,
ToolCall,
ToolCallChunk,
)
from langchain_core.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -168,9 +166,10 @@ class Person(BaseModel):

result = tool_llm.invoke("Erick, 27 years old")
assert isinstance(result, AIMessage)
assert result.tool_calls == [
ToolCall(name="Person", args={"name": "Erick", "age": 27}, id=None)
]
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call["name"] == "Person"
assert tool_call["args"] == {"name": "Erick", "age": 27}


def test_streaming_tool_call() -> None:
Expand Down Expand Up @@ -201,11 +200,10 @@ class Person(BaseModel):
}

assert isinstance(chunk, AIMessageChunk)
assert chunk.tool_call_chunks == [
ToolCallChunk(
name="Person", args='{"name": "Erick", "age": 27}', id=None, index=None
)
]
assert len(chunk.tool_call_chunks) == 1
tool_call_chunk = chunk.tool_call_chunks[0]
assert tool_call_chunk["name"] == "Person"
assert tool_call_chunk["args"] == '{"name": "Erick", "age": 27}'

# where it doesn't call the tool
strm = tool_llm.stream("What is 2+2?")
Expand Down
9 changes: 6 additions & 3 deletions libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def test_astream_with_callback() -> None:

def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
Expand All @@ -142,7 +143,7 @@ def test__convert_dict_to_message_tool_call() -> None:
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
id="abc123",
)
],
)
Expand All @@ -152,12 +153,14 @@ def test__convert_dict_to_message_tool_call() -> None:
# Test malformed tool call
raw_tool_calls = [
{
"id": "abc123",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
},
{
"id": "def456",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
Expand All @@ -174,14 +177,14 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id=None,
id="abc123",
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
id="def456",
),
],
)
Expand Down