Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cohere: Add multihop tool agent #19919

Merged
merged 55 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
7914525
implement multihop agent
BeatrixCohere Mar 26, 2024
c047868
remove
BeatrixCohere Mar 26, 2024
8b0c86a
Update prompt
BeatrixCohere Mar 27, 2024
157fcd0
Remove new line
BeatrixCohere Mar 27, 2024
db7be4d
Fix notebok
BeatrixCohere Mar 27, 2024
6853755
Lint
BeatrixCohere Mar 27, 2024
bb4762d
Spelling
BeatrixCohere Mar 27, 2024
768dd37
Fix test
BeatrixCohere Mar 27, 2024
eeb46f1
Merge conflicts
BeatrixCohere Mar 28, 2024
a549a11
Add chat history
BeatrixCohere Mar 28, 2024
068549b
Fix prompt
BeatrixCohere Mar 28, 2024
4bde450
Update libs/partners/cohere/langchain_cohere/multi_hop/agent.py
efriis Mar 28, 2024
fc98316
Remove new lines
BeatrixCohere Mar 28, 2024
e705454
Merge branch 'beatrix/MultiHopAgent' of github.com:BeatrixCohere/lang…
BeatrixCohere Mar 28, 2024
74b55e3
Update the naming and notebook
BeatrixCohere Mar 28, 2024
fed0bad
Test
BeatrixCohere Mar 29, 2024
e00a3f1
Add premable override
BeatrixCohere Mar 29, 2024
a9e6a4a
Fix text repsonse
BeatrixCohere Mar 29, 2024
e3cf615
Format
BeatrixCohere Mar 29, 2024
a24ffc3
Delete
BeatrixCohere Mar 29, 2024
59c8c39
Fix formatting
BeatrixCohere Mar 29, 2024
e46b382
Increase default timeout
harry-cohere Mar 29, 2024
c31c157
Fix stop sequences
harry-cohere Mar 29, 2024
b3b9d3f
convert parameter types
harry-cohere Mar 29, 2024
ec58586
add resilience to action parsing
harry-cohere Mar 29, 2024
403f78a
prompt rendering changes
harry-cohere Mar 29, 2024
8f8d8f0
prompt rendering changes
harry-cohere Mar 29, 2024
a6f445d
fix types
harry-cohere Mar 29, 2024
610d9c4
don't nest prompt function
harry-cohere Mar 29, 2024
10a3212
WIP - improve prompt rendering
harry-cohere Mar 30, 2024
06c3a7c
work in progress - improve prompt rendering
harry-cohere Mar 31, 2024
d3afde2
work in progress - increase test coverage
harry-cohere Mar 31, 2024
ad2b779
Tweak document rendering, remove uneeded test
harry-cohere Mar 31, 2024
a0d988d
fix for py10
harry-cohere Mar 31, 2024
63be314
Accept strings in observation parsing
harry-cohere Mar 31, 2024
b48cd17
Accept Iterable of strings in observation parsing
harry-cohere Mar 31, 2024
97c576a
Better observation parsing
harry-cohere Mar 31, 2024
e9dc623
add tests and fix linting
harry-cohere Mar 31, 2024
62b6cf8
directly_answer is no longer added by users
harry-cohere Mar 31, 2024
5e748e3
available_tools -> tools
harry-cohere Mar 31, 2024
57e1358
Add space to prompt
harry-cohere Mar 31, 2024
615808b
refactor (move) unit tests
harry-cohere Mar 31, 2024
d344ea3
refactor (extract) and add comments
harry-cohere Mar 31, 2024
0453b1b
work in progress - more tests
harry-cohere Apr 2, 2024
f4b0246
Merge branch 'master' into harry/patch-multihop-agent
harry-cohere Apr 2, 2024
943438f
lint
harry-cohere Apr 2, 2024
b64fd85
fix tests
harry-cohere Apr 2, 2024
61d8bda
add docblocks
harry-cohere Apr 2, 2024
f47d563
fix types and imports
harry-cohere Apr 2, 2024
a4781f2
Remove extra Args: section from description
harry-cohere Apr 2, 2024
3ac8390
fix types, imports, add more unit tests to parsing
harry-cohere Apr 2, 2024
a7aab19
Add integration test
harry-cohere Apr 2, 2024
7208f43
fix lint
harry-cohere Apr 2, 2024
e6a1448
improve integration tests
harry-cohere Apr 2, 2024
d1dfdcb
Add unit test for chat history
harry-cohere Apr 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
318 changes: 318 additions & 0 deletions libs/partners/cohere/docs/multi_hop_agent.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion libs/partners/cohere/langchain_cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.rag_retrievers import CohereRagRetriever
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent
from langchain_cohere.rerank import CohereRerank

__all__ = [
"ChatCohere",
"CohereVectorStore",
"CohereEmbeddings",
"CohereRagRetriever",
"CohereRerank",
"create_cohere_tools_agent",
"create_cohere_react_agent",
]
18 changes: 14 additions & 4 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_cohere_chat_request(
*,
documents: Optional[List[Document]] = None,
connectors: Optional[List[Dict[str, str]]] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.
Expand Down Expand Up @@ -130,6 +131,7 @@ def get_cohere_chat_request(
"documents": formatted_docs,
"connectors": connectors,
"prompt_truncation": prompt_truncation,
"stop_sequences": stop_sequences,
**kwargs,
}

Expand Down Expand Up @@ -226,7 +228,9 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)

if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
Expand Down Expand Up @@ -256,7 +260,9 @@ async def _astream(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)

if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
Expand Down Expand Up @@ -312,7 +318,9 @@ def _generate(
)
return generate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)
response = self.client.chat(**request)

generation_info = self._get_generation_info(response)
Expand All @@ -336,7 +344,9 @@ async def _agenerate(
)
return await agenerate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
request = get_cohere_chat_request(
messages, stop_sequences=stop, **self._default_params, **kwargs
)
response = self.client.chat(**request)

generation_info = self._get_generation_info(response)
Expand Down
29 changes: 23 additions & 6 deletions libs/partners/cohere/langchain_cohere/cohere_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.function_calling import (
convert_to_openai_function,
)

from langchain_cohere.utils import (
JSON_TO_PYTHON_TYPES,
_remove_signature_from_tool_description,
)


def create_cohere_tools_agent(
Expand Down Expand Up @@ -99,11 +106,17 @@ def _convert_to_cohere_tool(
if isinstance(tool, BaseTool):
return Tool(
name=tool.name,
description=tool.description,
description=_remove_signature_from_tool_description(
tool.name, tool.description
),
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
description=param_definition.get("description")
if "description" in param_definition
else "",
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required="default" not in param_definition,
)
for param_name, param_definition in tool.args.items()
Expand All @@ -120,7 +133,9 @@ def _convert_to_cohere_tool(
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required="default" not in param_definition,
)
for param_name, param_definition in tool.get("properties", {}).items()
Expand All @@ -140,7 +155,9 @@ def _convert_to_cohere_tool(
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
type=JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
required=param_name in parameters.get("required", []),
)
for param_name, param_definition in properties.items()
Expand Down
11 changes: 6 additions & 5 deletions libs/partners/cohere/langchain_cohere/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ class BaseCohere(Serializable):
streaming: bool = Field(default=False)
"""Whether to stream the results."""

timeout: Optional[float] = 60
"""Timeout in seconds for the Cohere API request."""

user_agent: str = "langchain"
"""Identifier for the application making the request."""

timeout_seconds: Optional[float] = 300
"""Timeout in seconds for the Cohere API request."""

base_url: Optional[str] = None
"""Override the default Cohere API URL."""

Expand All @@ -82,16 +82,17 @@ def validate_environment(cls, values: Dict) -> Dict:
get_from_dict_or_env(values, "cohere_api_key", "COHERE_API_KEY")
)
client_name = values["user_agent"]
timeout_seconds = values.get("timeout_seconds")
values["client"] = cohere.Client(
api_key=values["cohere_api_key"].get_secret_value(),
timeout=timeout_seconds,
client_name=client_name,
timeout=values["timeout"],
base_url=values["base_url"],
)
values["async_client"] = cohere.AsyncClient(
api_key=values["cohere_api_key"].get_secret_value(),
client_name=client_name,
timeout=values["timeout"],
timeout=timeout_seconds,
base_url=values["base_url"],
)
return values
Expand Down
Empty file.
74 changes: 74 additions & 0 deletions libs/partners/cohere/langchain_cohere/react_multi_hop/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Cohere multi-hop agent enables multiple tools to be used in sequence to complete a
task.

This agent uses a multi hop prompt by Cohere, which is experimental and subject
to change. The latest prompt can be used by upgrading the langchain-cohere package.
"""
from typing import Sequence

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool

from langchain_cohere.react_multi_hop.parsing import (
CohereToolsReactAgentOutputParser,
)
from langchain_cohere.react_multi_hop.prompt import (
multi_hop_prompt,
)


def create_cohere_react_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
) -> Runnable:
"""
Create an agent that enables multiple tools to be used in sequence to complete
a task.

Args:
llm: The ChatCohere LLM instance to use.
tools: Tools this agent has access to.
prompt: The prompt to use.

Returns:
A Runnable sequence representing an agent. It takes as input all the same input
variables as the prompt passed in does and returns an AgentAction or
AgentFinish.

Example:
. code-block:: python
from langchain.agents import AgentExecutor
from langchain.prompts import ChatPromptTemplate
from langchain_cohere import ChatCohere, create_cohere_react_agent

prompt = ChatPromptTemplate.from_template("{input}")
tools = [] # Populate this with a list of tools you would like to use.

llm = ChatCohere()

agent = create_cohere_react_agent(
llm,
tools,
prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools)

agent_executor.invoke({
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
})
""" # noqa: E501
agent = (
RunnablePassthrough.assign(
# agent_scratchpad isn't used in this chain, but added here for
# interoperability with other chains that may require it.
agent_scratchpad=lambda _: [],
)
| multi_hop_prompt(tools=tools, prompt=prompt)
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
| CohereToolsReactAgentOutputParser()
)
return agent
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from enum import Enum


class _SpecialToken(str, Enum):
bos = "<BOS_TOKEN>"
start_turn = "<|START_OF_TURN_TOKEN|>"
end_turn = "<|END_OF_TURN_TOKEN|>"
role_system = "<|SYSTEM_TOKEN|>"
role_chatbot = "<|CHATBOT_TOKEN|>"
role_user = "<|USER_TOKEN|>"


default_basic_rules = "You are a powerful language agent trained by Cohere to help people. You are capable of complex reasoning and augmented with a number of tools. Your job is to plan and reason about how you will use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see an instruction informing you what kind of response to generate. You will construct a plan and then perform a number of reasoning and action steps to solve the problem. When you have determined the answer to the user's request, you will cite your sources in your answers, according the instructions" # noqa: E501

default_task_context = "You use your advanced complex reasoning capabilities to help people by answering their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You may need to use multiple tools in parallel or sequentially to complete your task. You should focus on serving the user's needs as best you can, which will be wide-ranging. The current date is {now}" # noqa: E501

default_style_guide = "Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling" # noqa: E501

default_safety_rules = "The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral" # noqa: E501

default_multi_hop_instruction = """Carefully perform the following instructions, in order, starting each with a new line.
Firstly, You may need to use complex and advanced reasoning to complete your task and answer the question. Think about how you can use the provided tools to answer the question and come up with a high level plan you will execute.
Write 'Plan:' followed by an initial high level plan of how you will solve the problem including the tools and steps required.
Secondly, Carry out your plan by repeatedly using actions, reasoning over the results, and re-evaluating your plan. Perform Action, Observation, Reflection steps with the following format. Write 'Action:' followed by a json formatted action containing the "tool_name" and "parameters"
Next you will analyze the 'Observation:', this is the result of the action.
After that you should always think about what to do next. Write 'Reflection:' followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next including if you know the answer to the question.
... (this Action/Observation/Reflection can repeat N times)
Thirdly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.
Fourthly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.
Fifthly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup.
Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 4>my fact</co: 4> for a fact from document 4.""" # noqa: E501