Skip to content

Commit

Permalink
cohere: Improve integration test stability, fix documents bug (#19929)
Browse files Browse the repository at this point in the history
**Description**: Improves the stability of all Cohere partner package
integration tests. Fixes a bug with document parsing (both dicts and
Documents are handled).
  • Loading branch information
harry-cohere committed Apr 2, 2024
1 parent 37fc1c5 commit beab9ad
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
22 changes: 13 additions & 9 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_cohere_chat_request(
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
)

parsed_docs: Optional[List[Document]] = None
parsed_docs: Optional[Union[List[Document], List[Dict]]] = None
if "documents" in additional_kwargs:
parsed_docs = (
additional_kwargs["documents"]
Expand All @@ -108,14 +108,18 @@ def get_cohere_chat_request(
parsed_docs = documents

formatted_docs: Optional[List[Dict[str, Any]]] = None
if parsed_docs is not None:
formatted_docs = [
{
"text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(parsed_docs)
]
if parsed_docs:
formatted_docs = []
for i, parsed_doc in enumerate(parsed_docs):
if isinstance(parsed_doc, Document):
formatted_docs.append(
{
"text": parsed_doc.page_content,
"id": parsed_doc.metadata.get("id") or f"doc-{str(i)}",
}
)
elif isinstance(parsed_doc, dict):
formatted_docs.append(parsed_doc)

# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_streaming_tool_call() -> None:
llm = ChatCohere(temperature=0)

class Person(BaseModel):
name: str
age: int
name: str = Field(type=str, description="The name of the person")
age: int = Field(type=int, description="The age of the person")

tool_llm = llm.bind_tools([Person])

Expand Down Expand Up @@ -129,8 +129,8 @@ def test_streaming_tool_call_no_tool_calls() -> None:
llm = ChatCohere(temperature=0)

class Person(BaseModel):
name: str
age: int
name: str = Field(type=str, description="The name of the person")
age: int = Field(type=int, description="The age of the person")

tool_llm = llm.bind_tools([Person])

Expand Down

0 comments on commit beab9ad

Please sign in to comment.