Skip to content

Commit

Permalink
update agents to use tool call messages (langchain-ai#20074)
Browse files Browse the repository at this point in the history
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant"),
        MessagesPlaceholder("chat_history", optional=True),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)
model = ChatAnthropic(model="claude-3-opus-20240229")

@tool
def magic_function(input: int) -> int:
    """Applies a magic function to an input."""
    return input + 2

tools = [magic_function]

agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
```
> Entering new AgentExecutor chain...

Invoking: `magic_function` with `{'input': 3}`
responded: [{'text': '<thinking>\nThe user has asked for the value of magic_function applied to the input 3. Looking at the available tools, magic_function is the relevant one to use here, as it takes an integer input and returns an integer output.\n\nThe magic_function has one required parameter:\n- input (integer)\n\nThe user has directly provided the value 3 for the input parameter. Since the required parameter is present, we can proceed with calling the function.\n</thinking>', 'type': 'text'}, {'id': 'toolu_01HsTheJPA5mcipuFDBbJ1CW', 'input': {'input': 3}, 'name': 'magic_function', 'type': 'tool_use'}]

5
Therefore, the value of magic_function(3) is 5.

> Finished chain.
{'input': 'what is the value of magic_function(3)?',
 'output': 'Therefore, the value of magic_function(3) is 5.'}
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
  • Loading branch information
3 people authored and junkeon committed Apr 16, 2024
1 parent 8a14b23 commit f6949df
Show file tree
Hide file tree
Showing 14 changed files with 541 additions and 310 deletions.
13 changes: 10 additions & 3 deletions libs/core/langchain_core/load/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@
"agents",
"AgentActionMessageLog",
),
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
("langchain", "schema", "agent", "ToolAgentAction"): (
"langchain",
"agents",
"output_parsers",
"openai_tools",
"OpenAIToolAgentAction",
"tools",
"ToolAgentAction",
),
("langchain", "prompts", "chat", "BaseMessagePromptTemplate"): (
"langchain_core",
Expand Down Expand Up @@ -528,6 +528,13 @@
"image",
"ImagePromptTemplate",
),
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
"langchain",
"agents",
"output_parsers",
"openai_tools",
"OpenAIToolAgentAction",
),
}

# Needed for backwards compatibility for a few versions where we serialized
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
StructuredChatAgent,
create_structured_chat_agent,
)
from langchain.agents.tool_calling_agent.base import create_tool_calling_agent
from langchain.agents.tools import Tool, tool
from langchain.agents.xml.base import XMLAgent, create_xml_agent

Expand Down Expand Up @@ -154,4 +155,5 @@ def __getattr__(name: str) -> Any:
"create_self_ask_with_search_agent",
"create_json_chat_agent",
"create_structured_chat_agent",
"create_tool_calling_agent",
]
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/format_scratchpad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
format_to_openai_function_messages,
format_to_openai_functions,
)
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
from langchain.agents.format_scratchpad.xml import format_xml

__all__ = [
"format_xml",
"format_to_openai_function_messages",
"format_to_openai_functions",
"format_to_tool_messages",
"format_log_to_str",
"format_log_to_messages",
]
60 changes: 3 additions & 57 deletions libs/langchain/langchain/agents/format_scratchpad/openai_tools.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,5 @@
import json
from typing import List, Sequence, Tuple

from langchain_core.agents import AgentAction
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolMessage,
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages as format_to_openai_tool_messages,
)

from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction


def _create_tool_message(
agent_action: OpenAIToolAgentAction, observation: str
) -> ToolMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return ToolMessage(
tool_call_id=agent_action.tool_call_id,
content=content,
additional_kwargs={"name": agent_action.tool},
)


def format_to_openai_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for agent_action, observation in intermediate_steps:
if isinstance(agent_action, OpenAIToolAgentAction):
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
]
messages.extend([new for new in new_messages if new not in messages])
else:
messages.append(AIMessage(content=agent_action.log))
return messages
__all__ = ["format_to_openai_tool_messages"]
59 changes: 59 additions & 0 deletions libs/langchain/langchain/agents/format_scratchpad/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
from typing import List, Sequence, Tuple

from langchain_core.agents import AgentAction
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolMessage,
)

from langchain.agents.output_parsers.tools import ToolAgentAction


def _create_tool_message(
agent_action: ToolAgentAction, observation: str
) -> ToolMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return ToolMessage(
tool_call_id=agent_action.tool_call_id,
content=content,
additional_kwargs={"name": agent_action.tool},
)


def format_to_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for agent_action, observation in intermediate_steps:
if isinstance(agent_action, ToolAgentAction):
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
]
messages.extend([new for new in new_messages if new not in messages])
else:
messages.append(AIMessage(content=agent_action.log))
return messages
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/output_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
ReActSingleInputOutputParser,
)
from langchain.agents.output_parsers.self_ask import SelfAskOutputParser
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain.agents.output_parsers.xml import XMLAgentOutputParser

__all__ = [
"ReActSingleInputOutputParser",
"SelfAskOutputParser",
"ToolsAgentOutputParser",
"ReActJsonSingleInputOutputParser",
"OpenAIFunctionsAgentOutputParser",
"XMLAgentOutputParser",
Expand Down
76 changes: 23 additions & 53 deletions libs/langchain/langchain/agents/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,40 @@
import json
from json import JSONDecodeError
from typing import List, Union

from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import (
AIMessage,
BaseMessage,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, Generation

from langchain.agents.agent import MultiActionAgentOutputParser
from langchain.agents.output_parsers.tools import (
ToolAgentAction,
parse_ai_message_to_tool_action,
)


class OpenAIToolAgentAction(AgentActionMessageLog):
tool_call_id: str
"""Tool call that this message is responding to."""
OpenAIToolAgentAction = ToolAgentAction


def parse_ai_message_to_openai_tool_action(
message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")

if not message.additional_kwargs.get("tool_calls"):
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)

actions: List = []
for tool_call in message.additional_kwargs["tool_calls"]:
function = tool_call["function"]
function_name = function["name"]
try:
_tool_input = json.loads(function["arguments"] or "{}")
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tool input: {function} because "
f"the `arguments` is not valid JSON."
tool_actions = parse_ai_message_to_tool_action(message)
if isinstance(tool_actions, AgentFinish):
return tool_actions
final_actions: List[AgentAction] = []
for action in tool_actions:
if isinstance(action, ToolAgentAction):
final_actions.append(
OpenAIToolAgentAction(
tool=action.tool,
tool_input=action.tool_input,
log=action.log,
message_log=action.message_log,
tool_call_id=action.tool_call_id,
)
)

# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input

content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
actions.append(
OpenAIToolAgentAction(
tool=function_name,
tool_input=tool_input,
log=log,
message_log=[message],
tool_call_id=tool_call["id"],
)
)
return actions
final_actions.append(action)
return final_actions


class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
Expand Down
102 changes: 102 additions & 0 deletions libs/langchain/langchain/agents/output_parsers/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
from json import JSONDecodeError
from typing import List, Union

from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolCall,
)
from langchain_core.outputs import ChatGeneration, Generation

from langchain.agents.agent import MultiActionAgentOutputParser


class ToolAgentAction(AgentActionMessageLog):
tool_call_id: str
"""Tool call that this message is responding to."""


def parse_ai_message_to_tool_action(
message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")

actions: List = []
if message.tool_calls:
tool_calls = message.tool_calls
else:
if not message.additional_kwargs.get("tool_calls"):
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)
# Best-effort parsing
tool_calls = []
for tool_call in message.additional_kwargs["tool_calls"]:
function = tool_call["function"]
function_name = function["name"]
try:
args = json.loads(function["arguments"] or "{}")
tool_calls.append(
ToolCall(name=function_name, args=args, id=tool_call["id"])
)
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tool input: {function} because "
f"the `arguments` is not valid JSON."
)
for tool_call in tool_calls:
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
function_name = tool_call["name"]
_tool_input = tool_call["args"]
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input

content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
actions.append(
ToolAgentAction(
tool=function_name,
tool_input=tool_input,
log=log,
message_log=[message],
tool_call_id=tool_call["id"],
)
)
return actions


class ToolsAgentOutputParser(MultiActionAgentOutputParser):
"""Parses a message into agent actions/finish.
If a tool_calls parameter is passed, then that is used to get
the tool names and tool inputs.
If one is not passed, then the AIMessage is assumed to be the final output.
"""

@property
def _type(self) -> str:
return "tools-agent-output-parser"

def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return parse_ai_message_to_tool_action(message)

def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")
Empty file.

0 comments on commit f6949df

Please sign in to comment.