Skip to content

Commit

Permalink
cohere[patch]: Add multihop tool agent (#19919)
Browse files Browse the repository at this point in the history
**Description**: Adds an agent that uses Cohere with multiple hops and
multiple tools.

This PR is a continuation of
#19650 - which was
previously approved. Conceptually nothing has changed, but this PR has
extra fixes, documentation and testing.

---------

Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com>
Co-authored-by: Erick Friis <erickfriis@gmail.com>
  • Loading branch information
3 people committed Apr 2, 2024
1 parent 22dbcc9 commit e2b83c8
Show file tree
Hide file tree
Showing 36 changed files with 2,494 additions and 33 deletions.
318 changes: 318 additions & 0 deletions libs/partners/cohere/docs/multi_hop_agent.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion libs/partners/cohere/langchain_cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.rag_retrievers import CohereRagRetriever
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent
from langchain_cohere.rerank import CohereRerank

__all__ = [
"ChatCohere",
"CohereVectorStore",
"CohereEmbeddings",
"CohereRagRetriever",
"CohereRerank",
"create_cohere_tools_agent",
"create_cohere_react_agent",
]
18 changes: 14 additions & 4 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_cohere_chat_request(
*,
documents: Optional[List[Document]] = None,
connectors: Optional[List[Dict[str, str]]] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.
Expand Down Expand Up @@ -130,6 +131,7 @@ def get_cohere_chat_request(
"documents": formatted_docs,
"connectors": connectors,
"prompt_truncation": prompt_truncation,
"stop_sequences": stop_sequences,
**kwargs,
}

Expand Down Expand Up @@ -226,7 +228,9 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)

if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
Expand Down Expand Up @@ -256,7 +260,9 @@ async def _astream(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)

if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
Expand Down Expand Up @@ -312,7 +318,9 @@ def _generate(
)
return generate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)
response = self.client.chat(**request)

generation_info = self._get_generation_info(response)
Expand All @@ -336,7 +344,9 @@ async def _agenerate(
)
return await agenerate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)
response = self.client.chat(**request)

generation_info = self._get_generation_info(response)
Expand Down
29 changes: 23 additions & 6 deletions libs/partners/cohere/langchain_cohere/cohere_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.function_calling import (
convert_to_openai_function,
)

from langchain_cohere.utils import (
JSON_TO_PYTHON_TYPES,
_remove_signature_from_tool_description,
)


def create_cohere_tools_agent(
Expand Down Expand Up @@ -99,11 +106,17 @@ def _convert_to_cohere_tool(
if isinstance(tool, BaseTool):
return Tool(
name=tool.name,
description=tool.description,
description=_remove_signature_from_tool_description(
tool.name, tool.description
),
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
description=param_definition.get("description")
if "description" in param_definition
else "",
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required="default" not in param_definition,
)
for param_name, param_definition in tool.args.items()
Expand All @@ -120,7 +133,9 @@ def _convert_to_cohere_tool(
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required="default" not in param_definition,
)
for param_name, param_definition in tool.get("properties", {}).items()
Expand All @@ -140,7 +155,9 @@ def _convert_to_cohere_tool(
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required=param_name in parameters.get("required", []),
)
for param_name, param_definition in properties.items()
Expand Down
11 changes: 6 additions & 5 deletions libs/partners/cohere/langchain_cohere/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ class BaseCohere(Serializable):
streaming: bool = Field(default=False)
"""Whether to stream the results."""

timeout: Optional[float] = 60
"""Timeout in seconds for the Cohere API request."""

user_agent: str = "langchain"
"""Identifier for the application making the request."""

timeout_seconds: Optional[float] = 300
"""Timeout in seconds for the Cohere API request."""

base_url: Optional[str] = None
"""Override the default Cohere API URL."""

Expand All @@ -82,16 +82,17 @@ def validate_environment(cls, values: Dict) -> Dict:
get_from_dict_or_env(values, "cohere_api_key", "COHERE_API_KEY")
)
client_name = values["user_agent"]
timeout_seconds = values.get("timeout_seconds")
values["client"] = cohere.Client(
api_key=values["cohere_api_key"].get_secret_value(),
timeout=timeout_seconds,
client_name=client_name,
timeout=values["timeout"],
base_url=values["base_url"],
)
values["async_client"] = cohere.AsyncClient(
api_key=values["cohere_api_key"].get_secret_value(),
client_name=client_name,
timeout=values["timeout"],
timeout=timeout_seconds,
base_url=values["base_url"],
)
return values
Expand Down
Empty file.
74 changes: 74 additions & 0 deletions libs/partners/cohere/langchain_cohere/react_multi_hop/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Cohere multi-hop agent enables multiple tools to be used in sequence to complete a
task.
This agent uses a multi hop prompt by Cohere, which is experimental and subject
to change. The latest prompt can be used by upgrading the langchain-cohere package.
"""
from typing import Sequence

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool

from langchain_cohere.react_multi_hop.parsing import (
CohereToolsReactAgentOutputParser,
)
from langchain_cohere.react_multi_hop.prompt import (
multi_hop_prompt,
)


def create_cohere_react_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
) -> Runnable:
"""
Create an agent that enables multiple tools to be used in sequence to complete
a task.
Args:
llm: The ChatCohere LLM instance to use.
tools: Tools this agent has access to.
prompt: The prompt to use.
Returns:
A Runnable sequence representing an agent. It takes as input all the same input
variables as the prompt passed in does and returns an AgentAction or
AgentFinish.
Example:
. code-block:: python
from langchain.agents import AgentExecutor
from langchain.prompts import ChatPromptTemplate
from langchain_cohere import ChatCohere, create_cohere_react_agent
prompt = ChatPromptTemplate.from_template("{input}")
tools = [] # Populate this with a list of tools you would like to use.
llm = ChatCohere()
agent = create_cohere_react_agent(
llm,
tools,
prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools)
agent_executor.invoke({
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
})
""" # noqa: E501
agent = (
RunnablePassthrough.assign(
# agent_scratchpad isn't used in this chain, but added here for
# interoperability with other chains that may require it.
agent_scratchpad=lambda _: [],
)
| multi_hop_prompt(tools=tools, prompt=prompt)
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
| CohereToolsReactAgentOutputParser()
)
return agent
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from enum import Enum


class _SpecialToken(str, Enum):
bos = "<BOS_TOKEN>"
start_turn = "<|START_OF_TURN_TOKEN|>"
end_turn = "<|END_OF_TURN_TOKEN|>"
role_system = "<|SYSTEM_TOKEN|>"
role_chatbot = "<|CHATBOT_TOKEN|>"
role_user = "<|USER_TOKEN|>"


default_basic_rules = "You are a powerful language agent trained by Cohere to help people. You are capable of complex reasoning and augmented with a number of tools. Your job is to plan and reason about how you will use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see an instruction informing you what kind of response to generate. You will construct a plan and then perform a number of reasoning and action steps to solve the problem. When you have determined the answer to the user's request, you will cite your sources in your answers, according the instructions" # noqa: E501

default_task_context = "You use your advanced complex reasoning capabilities to help people by answering their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You may need to use multiple tools in parallel or sequentially to complete your task. You should focus on serving the user's needs as best you can, which will be wide-ranging. The current date is {now}" # noqa: E501

default_style_guide = "Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling" # noqa: E501

default_safety_rules = "The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral" # noqa: E501

default_multi_hop_instruction = """Carefully perform the following instructions, in order, starting each with a new line.
Firstly, You may need to use complex and advanced reasoning to complete your task and answer the question. Think about how you can use the provided tools to answer the question and come up with a high level plan you will execute.
Write 'Plan:' followed by an initial high level plan of how you will solve the problem including the tools and steps required.
Secondly, Carry out your plan by repeatedly using actions, reasoning over the results, and re-evaluating your plan. Perform Action, Observation, Reflection steps with the following format. Write 'Action:' followed by a json formatted action containing the "tool_name" and "parameters"
Next you will analyze the 'Observation:', this is the result of the action.
After that you should always think about what to do next. Write 'Reflection:' followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next including if you know the answer to the question.
... (this Action/Observation/Reflection can repeat N times)
Thirdly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.
Fourthly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.
Fifthly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup.
Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 4>my fact</co: 4> for a fact from document 4.""" # noqa: E501

0 comments on commit e2b83c8

Please sign in to comment.