Skip to content

Commit

Permalink
langchain[patch]: Add async methods to VectorStoreRetrieverMemory (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Mar 22, 2024
1 parent ef6d3d6 commit 1b813fe
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions libs/langchain/langchain/memory/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,34 @@ def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key

def load_memory_variables(
self, inputs: Dict[str, Any]
def _documents_to_memory_variables(
self, docs: List[Document]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)
result: Union[List[Document], str]
if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs])
else:
result = docs
return {self.memory_key: result}

def load_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)
return self._documents_to_memory_variables(docs)

async def aload_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = await self.retriever.aget_relevant_documents(query)
return self._documents_to_memory_variables(docs)

def _form_documents(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> List[Document]:
Expand All @@ -73,5 +87,15 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
documents = self._form_documents(inputs, outputs)
self.retriever.add_documents(documents)

async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
documents = self._form_documents(inputs, outputs)
await self.retriever.aadd_documents(documents)

def clear(self) -> None:
"""Nothing to clear."""

async def aclear(self) -> None:
"""Nothing to clear."""

0 comments on commit 1b813fe

Please sign in to comment.