-
Notifications
You must be signed in to change notification settings - Fork 13.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update agents to use tool call messages #20074
Changes from 99 commits
d68a71c
2319f42
f6048cc
67e3d1e
35189ad
0c14c48
684e107
b6dc899
2b373cd
28bc99e
ea74d44
6dd3ff7
2692239
1849f93
10b2c7b
1f637b9
3605678
9540adf
16bdda4
8698a86
c7d00cc
808cf82
0d6118a
6c0e3ea
508e3bc
189600f
360f9e2
ca59616
bda88a6
ac4a0da
0b515c8
6d81051
fec6db2
c5ec7fd
33127d3
e1fb61b
b1e9235
75d11dc
9d0176f
e225577
c27c9e8
5765f48
4eb4c34
9089029
a8b2733
a29c6b1
3ffcf0a
1b53ef6
ed71599
48d9355
d140449
5ea8bb4
2060f37
ff07346
71764a9
cbf66ec
f45caa0
a49f23e
9a05cba
bc85987
2ea5d66
6103cd4
9ff7ae9
e4ca284
e012414
53138b9
a79a980
bfe8fe3
ce684dd
13222fa
e845536
1f902d2
62c8f6e
46c9be9
7fcff8b
b490c57
34e629d
94bdf1c
a1feb7d
f8b3b82
52b7531
bb37dd7
bcffbc4
d5f4f53
a809ef4
d3caec1
893e9a9
f330a22
12955b2
225f519
fa3a8c0
3d243d4
0cc3142
c3f2805
bbe9c0f
55affdf
1d937b7
aa8cfa0
1e0aff8
e621293
3f4015b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,5 @@ | ||
import json | ||
from typing import List, Sequence, Tuple | ||
|
||
from langchain_core.agents import AgentAction | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
ToolMessage, | ||
from langchain.agents.format_scratchpad.tools import ( | ||
format_to_tool_messages as format_to_openai_tool_messages, | ||
) | ||
|
||
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction | ||
|
||
|
||
def _create_tool_message( | ||
agent_action: OpenAIToolAgentAction, observation: str | ||
) -> ToolMessage: | ||
"""Convert agent action and observation into a function message. | ||
Args: | ||
agent_action: the tool invocation request from the agent | ||
observation: the result of the tool invocation | ||
Returns: | ||
FunctionMessage that corresponds to the original tool invocation | ||
""" | ||
if not isinstance(observation, str): | ||
try: | ||
content = json.dumps(observation, ensure_ascii=False) | ||
except Exception: | ||
content = str(observation) | ||
else: | ||
content = observation | ||
return ToolMessage( | ||
tool_call_id=agent_action.tool_call_id, | ||
content=content, | ||
additional_kwargs={"name": agent_action.tool}, | ||
) | ||
|
||
|
||
def format_to_openai_tool_messages( | ||
intermediate_steps: Sequence[Tuple[AgentAction, str]], | ||
) -> List[BaseMessage]: | ||
"""Convert (AgentAction, tool output) tuples into FunctionMessages. | ||
|
||
Args: | ||
intermediate_steps: Steps the LLM has taken to date, along with observations | ||
|
||
Returns: | ||
list of messages to send to the LLM for the next prediction | ||
|
||
""" | ||
messages = [] | ||
for agent_action, observation in intermediate_steps: | ||
if isinstance(agent_action, OpenAIToolAgentAction): | ||
new_messages = list(agent_action.message_log) + [ | ||
_create_tool_message(agent_action, observation) | ||
] | ||
messages.extend([new for new in new_messages if new not in messages]) | ||
else: | ||
messages.append(AIMessage(content=agent_action.log)) | ||
return messages | ||
__all__ = ["format_to_openai_tool_messages"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import json | ||
from typing import List, Sequence, Tuple | ||
|
||
from langchain_core.agents import AgentAction | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
ToolMessage, | ||
) | ||
|
||
from langchain.agents.output_parsers.tools import ToolAgentAction | ||
|
||
|
||
def _create_tool_message( | ||
agent_action: ToolAgentAction, observation: str | ||
) -> ToolMessage: | ||
"""Convert agent action and observation into a function message. | ||
Args: | ||
agent_action: the tool invocation request from the agent | ||
observation: the result of the tool invocation | ||
Returns: | ||
FunctionMessage that corresponds to the original tool invocation | ||
""" | ||
if not isinstance(observation, str): | ||
try: | ||
content = json.dumps(observation, ensure_ascii=False) | ||
except Exception: | ||
content = str(observation) | ||
else: | ||
content = observation | ||
return ToolMessage( | ||
tool_call_id=agent_action.tool_call_id, | ||
content=content, | ||
additional_kwargs={"name": agent_action.tool}, | ||
) | ||
|
||
|
||
def format_to_tool_messages( | ||
intermediate_steps: Sequence[Tuple[AgentAction, str]], | ||
) -> List[BaseMessage]: | ||
"""Convert (AgentAction, tool output) tuples into FunctionMessages. | ||
|
||
Args: | ||
intermediate_steps: Steps the LLM has taken to date, along with observations | ||
|
||
Returns: | ||
list of messages to send to the LLM for the next prediction | ||
|
||
""" | ||
messages = [] | ||
for agent_action, observation in intermediate_steps: | ||
if isinstance(agent_action, ToolAgentAction): | ||
new_messages = list(agent_action.message_log) + [ | ||
_create_tool_message(agent_action, observation) | ||
] | ||
messages.extend([new for new in new_messages if new not in messages]) | ||
else: | ||
messages.append(AIMessage(content=agent_action.log)) | ||
return messages |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we rename things to not be openai-specific? ToolCallAgentAction, parse_tool_call_message_to_agent_action, ToolCallAgentOutputParser There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could perhaps even move to core |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import json | ||
from json import JSONDecodeError | ||
from typing import List, Union | ||
|
||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish | ||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
ToolCall, | ||
) | ||
from langchain_core.outputs import ChatGeneration, Generation | ||
|
||
from langchain.agents.agent import MultiActionAgentOutputParser | ||
|
||
|
||
class ToolAgentAction(AgentActionMessageLog): | ||
tool_call_id: str | ||
"""Tool call that this message is responding to.""" | ||
|
||
|
||
def parse_ai_message_to_tool_action( | ||
message: BaseMessage, | ||
) -> Union[List[AgentAction], AgentFinish]: | ||
"""Parse an AI message potentially containing tool_calls.""" | ||
if not isinstance(message, AIMessage): | ||
raise TypeError(f"Expected an AI message got {type(message)}") | ||
|
||
actions: List = [] | ||
if message.tool_calls: | ||
tool_calls = message.tool_calls | ||
else: | ||
if not message.additional_kwargs.get("tool_calls"): | ||
return AgentFinish( | ||
return_values={"output": message.content}, log=str(message.content) | ||
) | ||
# Best-effort parsing | ||
tool_calls = [] | ||
for tool_call in message.additional_kwargs["tool_calls"]: | ||
function = tool_call["function"] | ||
function_name = function["name"] | ||
try: | ||
args = json.loads(function["arguments"] or "{}") | ||
tool_calls.append( | ||
ToolCall(name=function_name, args=args, id=tool_call["id"]) | ||
) | ||
except JSONDecodeError: | ||
raise OutputParserException( | ||
f"Could not parse tool input: {function} because " | ||
f"the `arguments` is not valid JSON." | ||
) | ||
for tool_call in tool_calls: | ||
# HACK HACK HACK: | ||
# The code that encodes tool input into Open AI uses a special variable | ||
# name called `__arg1` to handle old style tools that do not expose a | ||
# schema and expect a single string argument as an input. | ||
# We unpack the argument here if it exists. | ||
# Open AI does not support passing in a JSON array as an argument. | ||
function_name = tool_call["name"] | ||
_tool_input = tool_call["args"] | ||
if "__arg1" in _tool_input: | ||
tool_input = _tool_input["__arg1"] | ||
else: | ||
tool_input = _tool_input | ||
|
||
content_msg = f"responded: {message.content}\n" if message.content else "\n" | ||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" | ||
actions.append( | ||
ToolAgentAction( | ||
tool=function_name, | ||
tool_input=tool_input, | ||
log=log, | ||
message_log=[message], | ||
tool_call_id=tool_call["id"], | ||
) | ||
) | ||
return actions | ||
|
||
|
||
class ToolsAgentOutputParser(MultiActionAgentOutputParser): | ||
"""Parses a message into agent actions/finish. | ||
|
||
If a tool_calls parameter is passed, then that is used to get | ||
the tool names and tool inputs. | ||
|
||
If one is not passed, then the AIMessage is assumed to be the final output. | ||
""" | ||
|
||
@property | ||
def _type(self) -> str: | ||
return "tools-agent-output-parser" | ||
|
||
def parse_result( | ||
self, result: List[Generation], *, partial: bool = False | ||
) -> Union[List[AgentAction], AgentFinish]: | ||
if not isinstance(result[0], ChatGeneration): | ||
raise ValueError("This output parser only works on ChatGeneration output") | ||
message = result[0].message | ||
return parse_ai_message_to_tool_action(message) | ||
|
||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: | ||
raise ValueError("Can only parse messages") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think it matters too much in this particular case since actions aren't really serialized but still technically breaking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea probably shouldn't change namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or should support both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated