Skip to content

Commit

Permalink
mistral[patch]: add IDs to tool calls (#20299)
Browse files Browse the repository at this point in the history
Mistral gives us one ID per response, no individual IDs for tool calls.

```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI


prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant"),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)

@tool
def magic_function(input: int) -> int:
    """Applies a magic function to an input."""
    return input + 2

tools = [magic_function]

agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
  • Loading branch information
ccurme and eyurtsev committed Apr 11, 2024
1 parent 22fd844 commit 795c728
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 37 deletions.
2 changes: 1 addition & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore


class ToolCall(TypedDict):
"""A call to a tool.
"""Represents a request to call a tool.
Attributes:
name: (str) the name of the tool to be called
Expand Down
6 changes: 3 additions & 3 deletions libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def parse_tool_call(
"args": function_args or {},
}
if return_id:
parsed["id"] = raw_tool_call["id"]
parsed["id"] = raw_tool_call.get("id")
return parsed


Expand All @@ -67,9 +67,9 @@ def parse_tool_calls(
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> List[dict]:
) -> List[Dict[str, Any]]:
"""Parse a list of tool calls."""
final_tools = []
final_tools: List[Dict[str, Any]] = []
exceptions = []
for tool_call in raw_tool_calls:
try:
Expand Down
52 changes: 32 additions & 20 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import uuid
from operator import itemgetter
from typing import (
Any,
Expand Down Expand Up @@ -91,14 +92,18 @@ def _convert_mistral_chat_message_to_message(
for raw_tool_call in raw_tool_calls:
try:
parsed: dict = cast(
dict, parse_tool_call(raw_tool_call, return_id=False)
)
tool_calls.append(
{
**parsed,
**{"id": None},
},
dict, parse_tool_call(raw_tool_call, return_id=True)
)
if not parsed["id"]:
tool_call_id = uuid.uuid4().hex[:]
tool_calls.append(
{
**parsed,
**{"id": tool_call_id},
},
)
else:
tool_calls.append(parsed)
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
Expand Down Expand Up @@ -160,15 +165,20 @@ def _convert_delta_to_message_chunk(
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in raw_tool_calls
]
tool_call_chunks = []
for raw_tool_call in raw_tool_calls:
if not raw_tool_call.get("index") and not raw_tool_call.get("id"):
tool_call_id = uuid.uuid4().hex[:]
else:
tool_call_id = raw_tool_call.get("id")
tool_call_chunks.append(
{
"name": raw_tool_call["function"].get("name"),
"args": raw_tool_call["function"].get("arguments"),
"id": tool_call_id,
"index": raw_tool_call.get("index"),
}
)
except KeyError:
pass
else:
Expand All @@ -195,15 +205,17 @@ def _convert_message_to_mistral_chat_message(
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs:
tool_calls = [
{
tool_calls = []
for tc in message.additional_kwargs["tool_calls"]:
chunk = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
for tc in message.additional_kwargs["tool_calls"]
]
if _id := tc.get("id"):
chunk["id"] = _id
tool_calls.append(chunk)
else:
tool_calls = []
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
AIMessage,
AIMessageChunk,
HumanMessage,
ToolCall,
ToolCallChunk,
)
from langchain_core.pydantic_v1 import BaseModel

Expand Down Expand Up @@ -168,9 +166,10 @@ class Person(BaseModel):

result = tool_llm.invoke("Erick, 27 years old")
assert isinstance(result, AIMessage)
assert result.tool_calls == [
ToolCall(name="Person", args={"name": "Erick", "age": 27}, id=None)
]
assert len(result.tool_calls) == 1
tool_call = result.tool_calls[0]
assert tool_call["name"] == "Person"
assert tool_call["args"] == {"name": "Erick", "age": 27}


def test_streaming_tool_call() -> None:
Expand Down Expand Up @@ -201,11 +200,10 @@ class Person(BaseModel):
}

assert isinstance(chunk, AIMessageChunk)
assert chunk.tool_call_chunks == [
ToolCallChunk(
name="Person", args='{"name": "Erick", "age": 27}', id=None, index=None
)
]
assert len(chunk.tool_call_chunks) == 1
tool_call_chunk = chunk.tool_call_chunks[0]
assert tool_call_chunk["name"] == "Person"
assert tool_call_chunk["args"] == '{"name": "Erick", "age": 27}'

# where it doesn't call the tool
strm = tool_llm.stream("What is 2+2?")
Expand Down
9 changes: 6 additions & 3 deletions libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def test_astream_with_callback() -> None:

def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
Expand All @@ -142,7 +143,7 @@ def test__convert_dict_to_message_tool_call() -> None:
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
id="abc123",
)
],
)
Expand All @@ -152,12 +153,14 @@ def test__convert_dict_to_message_tool_call() -> None:
# Test malformed tool call
raw_tool_calls = [
{
"id": "abc123",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
},
{
"id": "def456",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
Expand All @@ -174,14 +177,14 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id=None,
id="abc123",
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
id="def456",
),
],
)
Expand Down

0 comments on commit 795c728

Please sign in to comment.