Skip to content

Commit

Permalink
cohere[patch]: Fix retriever (langchain-ai#19771)
Browse files Browse the repository at this point in the history
* Replace `source_documents` with `documents`
* Pass `documents` as a named arg vs keyword
* Make `parsed_docs` more robust
* Fix edge case of doc page_content being `None`
  • Loading branch information
giannis2two authored and marlenezw committed Apr 2, 2024
1 parent 811f86b commit 65f7868
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
19 changes: 14 additions & 5 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
Expand Down Expand Up @@ -73,7 +74,7 @@ def get_role(message: BaseMessage) -> str:
def get_cohere_chat_request(
messages: List[BaseMessage],
*,
documents: Optional[List[Dict[str, str]]] = None,
documents: Optional[List[Document]] = None,
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
Expand All @@ -95,17 +96,25 @@ def get_cohere_chat_request(
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
)

parsed_docs: Optional[List[Document]] = None
if "documents" in additional_kwargs:
parsed_docs = (
additional_kwargs["documents"]
if len(additional_kwargs["documents"]) > 0
else None
)
elif documents is not None and len(documents) > 0:
parsed_docs = documents

formatted_docs: Optional[List[Dict[str, Any]]] = None
if additional_kwargs.get("documents"):
if parsed_docs is not None:
formatted_docs = [
{
"text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(additional_kwargs.get("documents", []))
for i, doc in enumerate(parsed_docs)
]
elif documents:
formatted_docs = documents

# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
Expand Down
35 changes: 22 additions & 13 deletions libs/partners/cohere/langchain_cohere/rag_retrievers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
Expand All @@ -17,15 +17,16 @@


def _get_docs(response: Any) -> List[Document]:
docs = (
[]
if "documents" not in response.generation_info
or len(response.generation_info["documents"]) == 0
else [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
)
docs = []
if (
"documents" in response.generation_info
and len(response.generation_info["documents"]) > 0
):
for doc in response.generation_info["documents"]:
content = doc.get("snippet", None) or doc.get("text", None)
if content is not None:
docs.append(Document(page_content=content, metadata=doc))

docs.append(
Document(
page_content=response.message.content,
Expand Down Expand Up @@ -63,12 +64,18 @@ class Config:
"""Allow arbitrary types."""

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
documents: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = self.llm.generate(
messages,
connectors=self.connectors,
connectors=self.connectors if documents is None else None,
documents=documents,
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
Expand All @@ -79,13 +86,15 @@ async def _aget_relevant_documents(
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
documents: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = (
await self.llm.agenerate(
messages,
connectors=self.connectors,
connectors=self.connectors if documents is None else None,
documents=documents,
callbacks=run_manager.get_child(),
**kwargs,
)
Expand Down

0 comments on commit 65f7868

Please sign in to comment.