Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update agents to use tool call messages #20074

Merged
merged 101 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 99 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
d68a71c
rfc: tool calls msg
baskaryan Mar 11, 2024
2319f42
fmt
baskaryan Mar 11, 2024
f6048cc
fmt
baskaryan Mar 12, 2024
67e3d1e
fmt
baskaryan Mar 12, 2024
35189ad
merge master
ccurme Mar 29, 2024
0c14c48
ToolMessage --> ToolCallsMessage in openai
ccurme Mar 29, 2024
684e107
lint
ccurme Mar 29, 2024
b6dc899
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 1, 2024
2b373cd
factor out parse_partial_json and parse_tool_calls
ccurme Apr 1, 2024
28bc99e
add ToolCallsMessageChunk
ccurme Apr 1, 2024
ea74d44
add to openai
ccurme Apr 1, 2024
6dd3ff7
format
ccurme Apr 1, 2024
2692239
lint
ccurme Apr 1, 2024
1849f93
accept None
ccurme Apr 1, 2024
10b2c7b
update test_imports
ccurme Apr 1, 2024
1f637b9
ToolMessage --> ToolOutputMessage in tests
ccurme Apr 1, 2024
3605678
update snapshots
ccurme Apr 1, 2024
9540adf
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 1, 2024
16bdda4
add ToolCall to module init
ccurme Apr 1, 2024
8698a86
add to integration test
ccurme Apr 1, 2024
c7d00cc
lint
ccurme Apr 1, 2024
808cf82
update serializable mapping
ccurme Apr 1, 2024
0d6118a
update mistral
ccurme Apr 1, 2024
6c0e3ea
update test_imports
ccurme Apr 1, 2024
508e3bc
ToolOutputMessage --> ToolMessage
ccurme Apr 1, 2024
189600f
update snapshots
ccurme Apr 1, 2024
360f9e2
lint
ccurme Apr 1, 2024
ca59616
clean up ToolMessage
ccurme Apr 2, 2024
bda88a6
update JsonOutputToolsParser
ccurme Apr 2, 2024
ac4a0da
lint
ccurme Apr 2, 2024
0b515c8
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 2, 2024
6d81051
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 3, 2024
fec6db2
cr
ccurme Apr 3, 2024
c5ec7fd
update ToolCallsMessageChunk.__add__
ccurme Apr 3, 2024
33127d3
move tool calls msgs to ai
ccurme Apr 3, 2024
e1fb61b
update AIMessageChunk.__add__
ccurme Apr 3, 2024
b1e9235
move parse_tool_calls
ccurme Apr 3, 2024
75d11dc
fix bug
ccurme Apr 3, 2024
9d0176f
lint
ccurme Apr 3, 2024
e225577
rename ToolCallsMessage
ccurme Apr 3, 2024
c27c9e8
rename ToolCallsMessageChunk
ccurme Apr 3, 2024
5765f48
lint
ccurme Apr 3, 2024
4eb4c34
merge tool calls
ccurme Apr 3, 2024
9089029
fmt (#19968)
baskaryan Apr 3, 2024
a8b2733
fix bug
ccurme Apr 3, 2024
a29c6b1
catch KeyError
ccurme Apr 3, 2024
3ffcf0a
update mistral streaming
ccurme Apr 3, 2024
1b53ef6
cr
ccurme Apr 3, 2024
ed71599
update mistral test
ccurme Apr 3, 2024
48d9355
remove check
ccurme Apr 3, 2024
d140449
update docstring
ccurme Apr 4, 2024
5ea8bb4
add tests
ccurme Apr 4, 2024
2060f37
add tests
ccurme Apr 4, 2024
ff07346
update fireworks
ccurme Apr 4, 2024
71764a9
update cohere and add test
ccurme Apr 4, 2024
cbf66ec
add to openai tests
ccurme Apr 4, 2024
f45caa0
fix bug
ccurme Apr 4, 2024
a49f23e
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 4, 2024
9a05cba
update groq
ccurme Apr 4, 2024
bc85987
update agents
ccurme Apr 4, 2024
2ea5d66
Merge branch 'master' into bagatur/tool_calls_msg
baskaryan Apr 4, 2024
6103cd4
Merge branch 'bagatur/tool_calls_msg' of github.com:langchain-ai/lang…
ccurme Apr 4, 2024
9ff7ae9
Revert "update agents"
ccurme Apr 4, 2024
e4ca284
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 4, 2024
e012414
update anthropic
ccurme Apr 4, 2024
53138b9
cr
ccurme Apr 4, 2024
a79a980
use tool call msgs (#20051)
baskaryan Apr 5, 2024
bfe8fe3
update docstrings
ccurme Apr 5, 2024
ce684dd
spell check
ccurme Apr 5, 2024
13222fa
export json output parsers in langchain
ccurme Apr 5, 2024
e845536
undo stray import changes
ccurme Apr 5, 2024
1f902d2
tweak docstring
ccurme Apr 5, 2024
62c8f6e
update agent tools parser
ccurme Apr 5, 2024
46c9be9
update agents
ccurme Apr 5, 2024
7fcff8b
fix test
ccurme Apr 5, 2024
b490c57
move tool calls to AIMessage (#20090)
ccurme Apr 5, 2024
34e629d
merge (delete cohere)
ccurme Apr 5, 2024
94bdf1c
update
ccurme Apr 5, 2024
a1feb7d
merge
ccurme Apr 6, 2024
f8b3b82
Merge branch 'master' into bagatur/tool_calls_msg
ccurme Apr 6, 2024
52b7531
best effort parsing + handle parsing errors (#20111)
ccurme Apr 6, 2024
bb37dd7
fix anthropic
ccurme Apr 6, 2024
bcffbc4
Merge branch 'bagatur/tool_calls_msg' into cc/tool_calls_agents
ccurme Apr 7, 2024
d5f4f53
update agents
ccurme Apr 7, 2024
a809ef4
add create_tools_agent
ccurme Apr 7, 2024
d3caec1
add type hint
ccurme Apr 7, 2024
893e9a9
add type hint
ccurme Apr 7, 2024
f330a22
Merge branch 'cc/tool_calls_agents' of github.com:langchain-ai/langch…
ccurme Apr 7, 2024
12955b2
type hint
ccurme Apr 7, 2024
225f519
cr
ccurme Apr 8, 2024
fa3a8c0
update tool calls to typeddict (#20208)
ccurme Apr 9, 2024
3d243d4
fix docstring
ccurme Apr 9, 2024
0cc3142
update (#20215)
ccurme Apr 9, 2024
c3f2805
fix bug
ccurme Apr 9, 2024
bbe9c0f
Merge branch 'bagatur/tool_calls_msg' into cc/tool_calls_agents
ccurme Apr 9, 2024
55affdf
update
ccurme Apr 9, 2024
1d937b7
fmt
baskaryan Apr 10, 2024
aa8cfa0
fmt
baskaryan Apr 10, 2024
1e0aff8
fmt
baskaryan Apr 10, 2024
e621293
tools_agent --> tool_calling_agent
ccurme Apr 10, 2024
3f4015b
Merge branch 'cc/tool_calls_agents' of github.com:langchain-ai/langch…
ccurme Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 10 additions & 3 deletions libs/core/langchain_core/load/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@
"agents",
"AgentActionMessageLog",
),
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
("langchain", "schema", "agent", "ToolAgentAction"): (
"langchain",
"agents",
"output_parsers",
"openai_tools",
"OpenAIToolAgentAction",
"tools",
"ToolAgentAction",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think it matters too much in this particular case since actions aren't really serialized but still technically breaking

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea probably shouldn't change namespace

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or should support both

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

),
("langchain", "prompts", "chat", "BaseMessagePromptTemplate"): (
"langchain_core",
Expand Down Expand Up @@ -528,6 +528,13 @@
"image",
"ImagePromptTemplate",
),
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
"langchain",
"agents",
"output_parsers",
"openai_tools",
"OpenAIToolAgentAction",
),
}

# Needed for backwards compatibility for a few versions where we serialized
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
create_structured_chat_agent,
)
from langchain.agents.tools import Tool, tool
from langchain.agents.tools_agent.base import create_tools_agent
from langchain.agents.xml.base import XMLAgent, create_xml_agent

DEPRECATED_CODE = [
Expand Down Expand Up @@ -154,4 +155,5 @@ def __getattr__(name: str) -> Any:
"create_self_ask_with_search_agent",
"create_json_chat_agent",
"create_structured_chat_agent",
"create_tools_agent",
]
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/format_scratchpad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
format_to_openai_function_messages,
format_to_openai_functions,
)
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
from langchain.agents.format_scratchpad.xml import format_xml

__all__ = [
"format_xml",
"format_to_openai_function_messages",
"format_to_openai_functions",
"format_to_tool_messages",
"format_log_to_str",
"format_log_to_messages",
]
60 changes: 3 additions & 57 deletions libs/langchain/langchain/agents/format_scratchpad/openai_tools.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,5 @@
import json
from typing import List, Sequence, Tuple

from langchain_core.agents import AgentAction
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolMessage,
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages as format_to_openai_tool_messages,
)

from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction


def _create_tool_message(
agent_action: OpenAIToolAgentAction, observation: str
) -> ToolMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return ToolMessage(
tool_call_id=agent_action.tool_call_id,
content=content,
additional_kwargs={"name": agent_action.tool},
)


def format_to_openai_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages.

Args:
intermediate_steps: Steps the LLM has taken to date, along with observations

Returns:
list of messages to send to the LLM for the next prediction

"""
messages = []
for agent_action, observation in intermediate_steps:
if isinstance(agent_action, OpenAIToolAgentAction):
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
]
messages.extend([new for new in new_messages if new not in messages])
else:
messages.append(AIMessage(content=agent_action.log))
return messages
__all__ = ["format_to_openai_tool_messages"]
59 changes: 59 additions & 0 deletions libs/langchain/langchain/agents/format_scratchpad/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
from typing import List, Sequence, Tuple

from langchain_core.agents import AgentAction
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolMessage,
)

from langchain.agents.output_parsers.tools import ToolAgentAction


def _create_tool_message(
agent_action: ToolAgentAction, observation: str
) -> ToolMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return ToolMessage(
tool_call_id=agent_action.tool_call_id,
content=content,
additional_kwargs={"name": agent_action.tool},
)


def format_to_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages.

Args:
intermediate_steps: Steps the LLM has taken to date, along with observations

Returns:
list of messages to send to the LLM for the next prediction

"""
messages = []
for agent_action, observation in intermediate_steps:
if isinstance(agent_action, ToolAgentAction):
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
]
messages.extend([new for new in new_messages if new not in messages])
else:
messages.append(AIMessage(content=agent_action.log))
return messages
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/output_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
ReActSingleInputOutputParser,
)
from langchain.agents.output_parsers.self_ask import SelfAskOutputParser
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain.agents.output_parsers.xml import XMLAgentOutputParser

__all__ = [
"ReActSingleInputOutputParser",
"SelfAskOutputParser",
"ToolsAgentOutputParser",
"ReActJsonSingleInputOutputParser",
"OpenAIFunctionsAgentOutputParser",
"XMLAgentOutputParser",
Expand Down
76 changes: 23 additions & 53 deletions libs/langchain/langchain/agents/output_parsers/openai_tools.py
Copy link
Collaborator

@baskaryan baskaryan Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename things to not be openai-specific? ToolCallAgentAction, parse_tool_call_message_to_agent_action, ToolCallAgentOutputParser

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could perhaps even move to core

Original file line number Diff line number Diff line change
@@ -1,70 +1,40 @@
import json
from json import JSONDecodeError
from typing import List, Union

from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import (
AIMessage,
BaseMessage,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, Generation

from langchain.agents.agent import MultiActionAgentOutputParser
from langchain.agents.output_parsers.tools import (
ToolAgentAction,
parse_ai_message_to_tool_action,
)


class OpenAIToolAgentAction(AgentActionMessageLog):
tool_call_id: str
"""Tool call that this message is responding to."""
OpenAIToolAgentAction = ToolAgentAction


def parse_ai_message_to_openai_tool_action(
message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")

if not message.additional_kwargs.get("tool_calls"):
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)

actions: List = []
for tool_call in message.additional_kwargs["tool_calls"]:
function = tool_call["function"]
function_name = function["name"]
try:
_tool_input = json.loads(function["arguments"] or "{}")
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tool input: {function} because "
f"the `arguments` is not valid JSON."
tool_actions = parse_ai_message_to_tool_action(message)
if isinstance(tool_actions, AgentFinish):
return tool_actions
final_actions: List[AgentAction] = []
for action in tool_actions:
if isinstance(action, ToolAgentAction):
final_actions.append(
OpenAIToolAgentAction(
tool=action.tool,
tool_input=action.tool_input,
log=action.log,
message_log=action.message_log,
tool_call_id=action.tool_call_id,
)
)

# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input

content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
actions.append(
OpenAIToolAgentAction(
tool=function_name,
tool_input=tool_input,
log=log,
message_log=[message],
tool_call_id=tool_call["id"],
)
)
return actions
final_actions.append(action)
return final_actions


class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
Expand Down
102 changes: 102 additions & 0 deletions libs/langchain/langchain/agents/output_parsers/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
from json import JSONDecodeError
from typing import List, Union

from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolCall,
)
from langchain_core.outputs import ChatGeneration, Generation

from langchain.agents.agent import MultiActionAgentOutputParser


class ToolAgentAction(AgentActionMessageLog):
tool_call_id: str
"""Tool call that this message is responding to."""


def parse_ai_message_to_tool_action(
message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")

actions: List = []
if message.tool_calls:
tool_calls = message.tool_calls
else:
if not message.additional_kwargs.get("tool_calls"):
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)
# Best-effort parsing
tool_calls = []
for tool_call in message.additional_kwargs["tool_calls"]:
function = tool_call["function"]
function_name = function["name"]
try:
args = json.loads(function["arguments"] or "{}")
tool_calls.append(
ToolCall(name=function_name, args=args, id=tool_call["id"])
)
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tool input: {function} because "
f"the `arguments` is not valid JSON."
)
for tool_call in tool_calls:
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
function_name = tool_call["name"]
_tool_input = tool_call["args"]
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input

content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
actions.append(
ToolAgentAction(
tool=function_name,
tool_input=tool_input,
log=log,
message_log=[message],
tool_call_id=tool_call["id"],
)
)
return actions


class ToolsAgentOutputParser(MultiActionAgentOutputParser):
"""Parses a message into agent actions/finish.

If a tool_calls parameter is passed, then that is used to get
the tool names and tool inputs.

If one is not passed, then the AIMessage is assumed to be the final output.
"""

@property
def _type(self) -> str:
return "tools-agent-output-parser"

def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return parse_ai_message_to_tool_action(message)

def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")
Empty file.