-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update agents to use tool call messages (#20074)
```python from langchain.agents import AgentExecutor, create_tool_calling_agent, tool from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder prompt = ChatPromptTemplate.from_messages( [ ("system", "You are a helpful assistant"), MessagesPlaceholder("chat_history", optional=True), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad"), ] ) model = ChatAnthropic(model="claude-3-opus-20240229") @tool def magic_function(input: int) -> int: """Applies a magic function to an input.""" return input + 2 tools = [magic_function] agent = create_tool_calling_agent(model, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) agent_executor.invoke({"input": "what is the value of magic_function(3)?"}) ``` ``` > Entering new AgentExecutor chain... Invoking: `magic_function` with `{'input': 3}` responded: [{'text': '<thinking>\nThe user has asked for the value of magic_function applied to the input 3. Looking at the available tools, magic_function is the relevant one to use here, as it takes an integer input and returns an integer output.\n\nThe magic_function has one required parameter:\n- input (integer)\n\nThe user has directly provided the value 3 for the input parameter. Since the required parameter is present, we can proceed with calling the function.\n</thinking>', 'type': 'text'}, {'id': 'toolu_01HsTheJPA5mcipuFDBbJ1CW', 'input': {'input': 3}, 'name': 'magic_function', 'type': 'tool_use'}] 5 Therefore, the value of magic_function(3) is 5. > Finished chain. {'input': 'what is the value of magic_function(3)?', 'output': 'Therefore, the value of magic_function(3) is 5.'} ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
- Loading branch information
1 parent
9eb6f53
commit 21c1ce0
Showing
14 changed files
with
541 additions
and
310 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 3 additions & 57 deletions
60
libs/langchain/langchain/agents/format_scratchpad/openai_tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,5 @@ | ||
import json | ||
from typing import List, Sequence, Tuple | ||
|
||
from langchain_core.agents import AgentAction | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
ToolMessage, | ||
from langchain.agents.format_scratchpad.tools import ( | ||
format_to_tool_messages as format_to_openai_tool_messages, | ||
) | ||
|
||
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction | ||
|
||
|
||
def _create_tool_message( | ||
agent_action: OpenAIToolAgentAction, observation: str | ||
) -> ToolMessage: | ||
"""Convert agent action and observation into a function message. | ||
Args: | ||
agent_action: the tool invocation request from the agent | ||
observation: the result of the tool invocation | ||
Returns: | ||
FunctionMessage that corresponds to the original tool invocation | ||
""" | ||
if not isinstance(observation, str): | ||
try: | ||
content = json.dumps(observation, ensure_ascii=False) | ||
except Exception: | ||
content = str(observation) | ||
else: | ||
content = observation | ||
return ToolMessage( | ||
tool_call_id=agent_action.tool_call_id, | ||
content=content, | ||
additional_kwargs={"name": agent_action.tool}, | ||
) | ||
|
||
|
||
def format_to_openai_tool_messages( | ||
intermediate_steps: Sequence[Tuple[AgentAction, str]], | ||
) -> List[BaseMessage]: | ||
"""Convert (AgentAction, tool output) tuples into FunctionMessages. | ||
Args: | ||
intermediate_steps: Steps the LLM has taken to date, along with observations | ||
Returns: | ||
list of messages to send to the LLM for the next prediction | ||
""" | ||
messages = [] | ||
for agent_action, observation in intermediate_steps: | ||
if isinstance(agent_action, OpenAIToolAgentAction): | ||
new_messages = list(agent_action.message_log) + [ | ||
_create_tool_message(agent_action, observation) | ||
] | ||
messages.extend([new for new in new_messages if new not in messages]) | ||
else: | ||
messages.append(AIMessage(content=agent_action.log)) | ||
return messages | ||
__all__ = ["format_to_openai_tool_messages"] |
59 changes: 59 additions & 0 deletions
59
libs/langchain/langchain/agents/format_scratchpad/tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import json | ||
from typing import List, Sequence, Tuple | ||
|
||
from langchain_core.agents import AgentAction | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
ToolMessage, | ||
) | ||
|
||
from langchain.agents.output_parsers.tools import ToolAgentAction | ||
|
||
|
||
def _create_tool_message( | ||
agent_action: ToolAgentAction, observation: str | ||
) -> ToolMessage: | ||
"""Convert agent action and observation into a function message. | ||
Args: | ||
agent_action: the tool invocation request from the agent | ||
observation: the result of the tool invocation | ||
Returns: | ||
FunctionMessage that corresponds to the original tool invocation | ||
""" | ||
if not isinstance(observation, str): | ||
try: | ||
content = json.dumps(observation, ensure_ascii=False) | ||
except Exception: | ||
content = str(observation) | ||
else: | ||
content = observation | ||
return ToolMessage( | ||
tool_call_id=agent_action.tool_call_id, | ||
content=content, | ||
additional_kwargs={"name": agent_action.tool}, | ||
) | ||
|
||
|
||
def format_to_tool_messages( | ||
intermediate_steps: Sequence[Tuple[AgentAction, str]], | ||
) -> List[BaseMessage]: | ||
"""Convert (AgentAction, tool output) tuples into FunctionMessages. | ||
Args: | ||
intermediate_steps: Steps the LLM has taken to date, along with observations | ||
Returns: | ||
list of messages to send to the LLM for the next prediction | ||
""" | ||
messages = [] | ||
for agent_action, observation in intermediate_steps: | ||
if isinstance(agent_action, ToolAgentAction): | ||
new_messages = list(agent_action.message_log) + [ | ||
_create_tool_message(agent_action, observation) | ||
] | ||
messages.extend([new for new in new_messages if new not in messages]) | ||
else: | ||
messages.append(AIMessage(content=agent_action.log)) | ||
return messages |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 23 additions & 53 deletions
76
libs/langchain/langchain/agents/output_parsers/openai_tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
libs/langchain/langchain/agents/output_parsers/tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import json | ||
from json import JSONDecodeError | ||
from typing import List, Union | ||
|
||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish | ||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
ToolCall, | ||
) | ||
from langchain_core.outputs import ChatGeneration, Generation | ||
|
||
from langchain.agents.agent import MultiActionAgentOutputParser | ||
|
||
|
||
class ToolAgentAction(AgentActionMessageLog): | ||
tool_call_id: str | ||
"""Tool call that this message is responding to.""" | ||
|
||
|
||
def parse_ai_message_to_tool_action( | ||
message: BaseMessage, | ||
) -> Union[List[AgentAction], AgentFinish]: | ||
"""Parse an AI message potentially containing tool_calls.""" | ||
if not isinstance(message, AIMessage): | ||
raise TypeError(f"Expected an AI message got {type(message)}") | ||
|
||
actions: List = [] | ||
if message.tool_calls: | ||
tool_calls = message.tool_calls | ||
else: | ||
if not message.additional_kwargs.get("tool_calls"): | ||
return AgentFinish( | ||
return_values={"output": message.content}, log=str(message.content) | ||
) | ||
# Best-effort parsing | ||
tool_calls = [] | ||
for tool_call in message.additional_kwargs["tool_calls"]: | ||
function = tool_call["function"] | ||
function_name = function["name"] | ||
try: | ||
args = json.loads(function["arguments"] or "{}") | ||
tool_calls.append( | ||
ToolCall(name=function_name, args=args, id=tool_call["id"]) | ||
) | ||
except JSONDecodeError: | ||
raise OutputParserException( | ||
f"Could not parse tool input: {function} because " | ||
f"the `arguments` is not valid JSON." | ||
) | ||
for tool_call in tool_calls: | ||
# HACK HACK HACK: | ||
# The code that encodes tool input into Open AI uses a special variable | ||
# name called `__arg1` to handle old style tools that do not expose a | ||
# schema and expect a single string argument as an input. | ||
# We unpack the argument here if it exists. | ||
# Open AI does not support passing in a JSON array as an argument. | ||
function_name = tool_call["name"] | ||
_tool_input = tool_call["args"] | ||
if "__arg1" in _tool_input: | ||
tool_input = _tool_input["__arg1"] | ||
else: | ||
tool_input = _tool_input | ||
|
||
content_msg = f"responded: {message.content}\n" if message.content else "\n" | ||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" | ||
actions.append( | ||
ToolAgentAction( | ||
tool=function_name, | ||
tool_input=tool_input, | ||
log=log, | ||
message_log=[message], | ||
tool_call_id=tool_call["id"], | ||
) | ||
) | ||
return actions | ||
|
||
|
||
class ToolsAgentOutputParser(MultiActionAgentOutputParser): | ||
"""Parses a message into agent actions/finish. | ||
If a tool_calls parameter is passed, then that is used to get | ||
the tool names and tool inputs. | ||
If one is not passed, then the AIMessage is assumed to be the final output. | ||
""" | ||
|
||
@property | ||
def _type(self) -> str: | ||
return "tools-agent-output-parser" | ||
|
||
def parse_result( | ||
self, result: List[Generation], *, partial: bool = False | ||
) -> Union[List[AgentAction], AgentFinish]: | ||
if not isinstance(result[0], ChatGeneration): | ||
raise ValueError("This output parser only works on ChatGeneration output") | ||
message = result[0].message | ||
return parse_ai_message_to_tool_action(message) | ||
|
||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: | ||
raise ValueError("Can only parse messages") |
Empty file.
Oops, something went wrong.