Skip to content

Commit

Permalink
langchain[patch]: Use async memory in Chain when needed (langchain-ai…
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored and Dave Bechberger committed Mar 29, 2024
1 parent 6872023 commit e7898c5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
26 changes: 25 additions & 1 deletion libs/langchain/langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async def ainvoke(
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)

inputs = self.prep_inputs(input)
inputs = await self.aprep_inputs(input)
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
Expand Down Expand Up @@ -482,6 +482,30 @@ def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
inputs = dict(inputs, **external_context)
return inputs

async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
inputs: Dictionary of raw inputs, or single input if chain expects
only one param. Should contain all inputs specified in
`Chain.input_keys` except for inputs that will be set by the chain's
memory.
Returns:
A dictionary of all inputs, including those added by the chain's memory.
"""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = await self.memory.aload_memory_variables(inputs)
inputs = dict(inputs, **external_context)
return inputs

@property
def _run_output_key(self) -> str:
if len(self.output_keys) != 1:
Expand Down
36 changes: 33 additions & 3 deletions libs/langchain/tests/unit_tests/chains/test_conversation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Test conversation chain and memory."""
from typing import Any, List, Optional

import pytest
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.memory import BaseMemory
from langchain_core.prompts.prompt import PromptTemplate

Expand All @@ -10,6 +14,27 @@
from tests.unit_tests.llms.fake_llm import FakeLLM


class DummyLLM(LLM):
last_prompt = ""

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

@property
def _llm_type(self) -> str:
return "dummy"

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
self.last_prompt = prompt
return "dummy"


def test_memory_ai_prefix() -> None:
"""Test that ai_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
Expand All @@ -32,13 +57,18 @@ async def test_memory_async() -> None:
}


def test_conversation_chain_works() -> None:
async def test_conversation_chain_works() -> None:
"""Test that conversation chain works in basic setting."""
llm = FakeLLM()
llm = DummyLLM()
prompt = PromptTemplate(input_variables=["foo", "bar"], template="{foo} {bar}")
memory = ConversationBufferMemory(memory_key="foo")
chain = ConversationChain(llm=llm, prompt=prompt, memory=memory, input_key="bar")
chain.run("foo")
chain.run("aaa")
assert llm.last_prompt == " aaa"
chain.run("bbb")
assert llm.last_prompt == "Human: aaa\nAI: dummy bbb"
await chain.arun("ccc")
assert llm.last_prompt == "Human: aaa\nAI: dummy\nHuman: bbb\nAI: dummy ccc"


def test_conversation_chain_errors_bad_prompt() -> None:
Expand Down

0 comments on commit e7898c5

Please sign in to comment.