Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: update tool calling cookbook #20290

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
270 changes: 23 additions & 247 deletions cookbook/tool_call_messages.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,10 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "c48812ed-35bd-4fbe-9a2c-6c7335e5645e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `bind_tools` is in beta. It is actively being worked on, so the API may change.\n",
" warn_beta(\n"
]
}
],
"outputs": [],
"source": [
"from langchain_anthropic import ChatAnthropic\n",
"from langchain_core.runnables import ConfigurableField\n",
Expand Down Expand Up @@ -49,221 +40,6 @@
")"
]
},
{
"cell_type": "markdown",
"id": "4719ebdb-ad50-468e-9b30-fb5fb086e140",
"metadata": {},
"source": [
"# AgentExecutor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b98feaa5-8c2d-4125-9519-67114a1fef31",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Tuple, Union\n",
"\n",
"from langchain.agents import AgentExecutor\n",
"from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction\n",
"from langchain_core.agents import AgentFinish, _convert_agent_action_to_messages\n",
"from langchain_core.messages import (\n",
" AIMessage,\n",
" BaseMessage,\n",
" ToolMessage,\n",
")\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"\n",
"def actions_observations_to_messages(\n",
" steps: List[Tuple[OpenAIToolAgentAction, str]],\n",
") -> List[BaseMessage]:\n",
" messages = []\n",
" for action, observation in steps:\n",
" messages.extend([m for m in action.message_log if m not in messages])\n",
" messages.append(ToolMessage(observation, tool_call_id=action.tool_call_id))\n",
" return messages\n",
"\n",
"\n",
"def messages_to_action(\n",
" msg: AIMessage,\n",
") -> Union[List[OpenAIToolAgentAction], AgentFinish]:\n",
" if isinstance(msg, AIMessage) and msg.tool_calls is not None:\n",
" actions = []\n",
" for tool_call in msg.tool_calls:\n",
" actions.append(\n",
" OpenAIToolAgentAction(\n",
" tool=tool_call.name,\n",
" tool_input=tool_call.args,\n",
" tool_call_id=tool_call.id,\n",
" message_log=[msg],\n",
" log=\"\",\n",
" )\n",
" )\n",
" return actions\n",
" else:\n",
" return AgentFinish(return_values={\"output\": msg.content}, log=\"\")\n",
"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You're a helpful assistant with access to tools\"),\n",
" (\"human\", \"{input}\"),\n",
" (\"placeholder\", \"{agent_scratchpad}\"),\n",
" ]\n",
")\n",
"\n",
"agent = (\n",
" RunnablePassthrough.assign(\n",
" agent_scratchpad=lambda x: actions_observations_to_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" )\n",
" | prompt\n",
" | llm_with_tools\n",
" | messages_to_action\n",
")\n",
"\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b4c0fc7a-80bb-4bb8-a87b-7388291ae8b6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m300.03770462067547\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n",
" 'output': 'The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.'}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\n",
" {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "41a3a3c8-185d-4861-b6f0-7592668feb62",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m82.65606421491815\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m85.65606421491815\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n",
" 'output': 'Therefore, 17.24 - 918.1241 = -900.8841'}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor = AgentExecutor(\n",
" agent=agent.with_config(configurable={\"llm\": \"claude3\"}), tools=tools, verbose=True\n",
")\n",
"agent_executor.invoke(\n",
" {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9c186263-1b98-4cb2-b6d1-71f65eb0d811",
Expand All @@ -274,15 +50,15 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "28fc2c60-7dbc-428a-8983-1a6a15ea30d2",
"metadata": {},
"outputs": [],
"source": [
"import operator\n",
"from typing import Annotated, Sequence, TypedDict\n",
"\n",
"from langchain_core.messages import AIMessage, BaseMessage, HumanMessage\n",
"from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage\n",
"from langchain_core.runnables import RunnableLambda\n",
"from langgraph.graph import END, StateGraph\n",
"\n",
Expand All @@ -292,16 +68,16 @@
"\n",
"\n",
"def should_continue(state):\n",
" return \"continue\" if state[\"messages\"][-1].tool_calls is not None else \"end\"\n",
" return \"continue\" if state[\"messages\"][-1].tool_calls else \"end\"\n",
"\n",
"\n",
"def call_model(state, config):\n",
" return {\"messages\": [llm_with_tools.invoke(state[\"messages\"], config=config)]}\n",
"\n",
"\n",
"def _invoke_tool(tool_call):\n",
" tool = {tool.name: tool for tool in tools}[tool_call.name]\n",
" return ToolMessage(tool.invoke(tool_call.args), tool_call_id=tool_call.id)\n",
" tool = {tool.name: tool for tool in tools}[tool_call[\"name\"]]\n",
" return ToolMessage(tool.invoke(tool_call[\"args\"]), tool_call_id=tool_call[\"id\"])\n",
"\n",
"\n",
"tool_executor = RunnableLambda(_invoke_tool)\n",
Expand Down Expand Up @@ -330,21 +106,21 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "24463798-74e6-4c39-8092-7a1524d83225",
"execution_count": 4,
"id": "3710e724-2595-4625-ba3a-effb81e66e4a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_kbBUUeqK75fZZqDTvu8aif7Z', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_pBD8daSyXidXnrIyG4vG5C9O', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8e1d9687-611c-4c8e-9fcd-ef6e48bd22a6-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 8, 'y': 2.743}, id='call_kbBUUeqK75fZZqDTvu8aif7Z'), ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='call_pBD8daSyXidXnrIyG4vG5C9O')]),\n",
" ToolMessage(content='300.03770462067547', tool_call_id='call_kbBUUeqK75fZZqDTvu8aif7Z'),\n",
" ToolMessage(content='-900.8841', tool_call_id='call_pBD8daSyXidXnrIyG4vG5C9O'),\n",
" AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-47fe5cbc-3f25-44c3-85b2-6540c3054a77-0')]}"
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_6yMU2WsS4Bqgi1WxFHxtfJRc', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_GAL3dQiKFF9XEV0RrRLPTvVp', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-528302fc-7acf-4c11-82c4-119ccf40c573-0', tool_calls=[{'name': 'exponentiate', 'args': {'x': 8, 'y': 2.743}, 'id': 'call_6yMU2WsS4Bqgi1WxFHxtfJRc'}, {'name': 'add', 'args': {'x': 17.24, 'y': -918.1241}, 'id': 'call_GAL3dQiKFF9XEV0RrRLPTvVp'}]),\n",
" ToolMessage(content='300.03770462067547', tool_call_id='call_6yMU2WsS4Bqgi1WxFHxtfJRc'),\n",
" ToolMessage(content='-900.8841', tool_call_id='call_GAL3dQiKFF9XEV0RrRLPTvVp'),\n",
" AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-d1161669-ed09-4b18-94bd-6d8530df5aa8-0')]}"
]
},
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -363,24 +139,24 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "073c074e-d722-42e0-85ec-c62c079207e4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
" AIMessage(content=[{'text': \"Okay, let's break this down into steps:\", 'type': 'text'}, {'id': 'toolu_01DJkSDpB8ztmJx2DLNbc3eW', 'input': {'x': 5, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_01KuVNohyJr24cPhJkY3XVtt', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 84}}, id='run-336cdfb6-0fe4-4d7a-9946-9f01c2eb41ae-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 5, 'y': 2.743}, id='toolu_01DJkSDpB8ztmJx2DLNbc3eW', index=1)]),\n",
" ToolMessage(content='82.65606421491815', tool_call_id='toolu_01DJkSDpB8ztmJx2DLNbc3eW'),\n",
" AIMessage(content=[{'text': 'To get 5 raised to the 2.743 power.', 'type': 'text'}, {'id': 'toolu_01MKQqnDw5CtyuKjQP8YG1FX', 'input': {'x': 3, 'y': 82.65606421491815}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01UBsKkvA4StUR4NEvoFFFep', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 552, 'output_tokens': 91}}, id='run-9d25b7bd-58aa-47dd-933f-15459b24b2c2-0', tool_calls=[ToolCall(name='add', args={'x': 3, 'y': 82.65606421491815}, id='toolu_01MKQqnDw5CtyuKjQP8YG1FX', index=1)]),\n",
" ToolMessage(content='85.65606421491815', tool_call_id='toolu_01MKQqnDw5CtyuKjQP8YG1FX'),\n",
" AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 85.656.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_019Wb2zPouCR3dw2bSKvCRUL', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01Y2H2L8FWcDtVkCtuosie2P', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 661, 'output_tokens': 105}}, id='run-e553c1e3-24ba-4e1b-93ba-6f1985932db4-0', tool_calls=[ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='toolu_019Wb2zPouCR3dw2bSKvCRUL', index=1)]),\n",
" ToolMessage(content='-900.8841', tool_call_id='toolu_019Wb2zPouCR3dw2bSKvCRUL'),\n",
" AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01Q14dqvaCD2eA4zwrUvxTcF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 782, 'output_tokens': 24}}, id='run-f6b6e525-2df6-4617-9bb3-b39d5cc963a9-0')]}"
" AIMessage(content=[{'text': \"Okay, let's break this down into two parts:\", 'type': 'text'}, {'id': 'toolu_01DEhqcXkXTtzJAiZ7uMBeDC', 'input': {'x': 3, 'y': 5}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01AkLGH8sxMHaH15yewmjwkF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 81}}, id='run-f35bfae8-8ded-4f8a-831b-0940d6ad16b6-0', tool_calls=[{'name': 'add', 'args': {'x': 3, 'y': 5}, 'id': 'toolu_01DEhqcXkXTtzJAiZ7uMBeDC'}]),\n",
" ToolMessage(content='8.0', tool_call_id='toolu_01DEhqcXkXTtzJAiZ7uMBeDC'),\n",
" AIMessage(content=[{'id': 'toolu_013DyMLrvnrto33peAKMGMr1', 'input': {'x': 8.0, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_015Fmp8aztwYcce2JDAFfce3', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 545, 'output_tokens': 75}}, id='run-48aaeeeb-a1e5-48fd-a57a-6c3da2907b47-0', tool_calls=[{'name': 'exponentiate', 'args': {'x': 8.0, 'y': 2.743}, 'id': 'toolu_013DyMLrvnrto33peAKMGMr1'}]),\n",
" ToolMessage(content='300.03770462067547', tool_call_id='toolu_013DyMLrvnrto33peAKMGMr1'),\n",
" AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 300.04.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_01UTmMrGTmLpPrPCF1rShN46', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_015TkhfRBENPib2RWAxkieH6', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 638, 'output_tokens': 105}}, id='run-45fb62e3-d102-4159-881d-241c5dbadeed-0', tool_calls=[{'name': 'add', 'args': {'x': 17.24, 'y': -918.1241}, 'id': 'toolu_01UTmMrGTmLpPrPCF1rShN46'}]),\n",
" ToolMessage(content='-900.8841', tool_call_id='toolu_01UTmMrGTmLpPrPCF1rShN46'),\n",
" AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01LgKnRuUcSyADCpxv9tPoYD', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 759, 'output_tokens': 24}}, id='run-1008254e-ccd1-497c-8312-9550dd77bd08-0')]}"
]
},
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand Down