Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain-mongodb: Add MongoDB LLM Cache #17470

Merged
merged 22 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 58 additions & 0 deletions docs/docs/integrations/providers/mongodb_atlas.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,61 @@ See a [usage example](/docs/integrations/vectorstores/mongodb_atlas).
from langchain_mongodb import MongoDBAtlasVectorSearch
```


## LLM Caches

### MongoDBCache
An abstraction to store a simple cache in MongoDB. This does not use Semantic Caching, nor does it require an index to be made on the collection before generation.

To import this cache:
```python
from langchain_mongodb.cache import MongoDBCache
```

To use this cache with your LLMs:
```python
from langchain_core.globals import set_llm_cache

# use any embedding provider...
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings

mongodb_atlas_uri = "<YOUR_CONNECTION_STRING>"
COLLECTION_NAME="<YOUR_CACHE_COLLECTION_NAME>"
DATABASE_NAME="<YOUR_DATABASE_NAME>"

set_llm_cache(MongoDBCache(
connection_string=mongodb_atlas_uri,
collection_name=COLLECTION_NAME,
database_name=DATABASE_NAME,
))
```


### MongoDBAtlasSemanticCache
Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached results. Under the hood it blends MongoDBAtlas as both a cache and a vectorstore.
The MongoDBAtlasSemanticCache inherits from `MongoDBAtlasVectorSearch` and needs an Atlas Vector Search Index defined to work. Please look at the [usage example](/docs/integrations/vectorstores/mongodb_atlas) on how to set up the index.

To import this cache:
```python
from langchain_mongodb.cache import MongoDBAtlasSemanticCache
```

To use this cache with your LLMs:
```python
from langchain_core.globals import set_llm_cache

# use any embedding provider...
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings

mongodb_atlas_uri = "<YOUR_CONNECTION_STRING>"
COLLECTION_NAME="<YOUR_CACHE_COLLECTION_NAME>"
DATABASE_NAME="<YOUR_DATABASE_NAME>"

set_llm_cache(MongoDBAtlasSemanticCache(
embedding=FakeEmbeddings(),
connection_string=mongodb_atlas_uri,
collection_name=COLLECTION_NAME,
database_name=DATABASE_NAME,
))
```
``
312 changes: 312 additions & 0 deletions libs/partners/mongodb/langchain_mongodb/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
"""
LangChain MongoDB Caches

Functions "_loads_generations" and "_dumps_generations"
are duplicated in this utility from modules:
- "libs/community/langchain_community/cache.py"
"""

import json
import logging
import time
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Union

from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.outputs import Generation
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.driver_info import DriverInfo

from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch

logger = logging.getLogger(__file__)


def _generate_mongo_client(connection_string: str) -> MongoClient:
return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)


def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`

Args:
generations (RETURN_VAL_TYPE): A list of language model generations.

Returns:
str: a single string representing a list of generations.

This function (+ its counterpart `_loads_generations`) rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.

Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])


def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).

See `_dumps_generations`, the inverse of this function.

Args:
generations_str (str): A string representing a list of generations.

Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.

Returns:
RETURN_VAL_TYPE: A list of generations.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass

try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None


def _wait_until(
predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
"""Wait up to 10 seconds (by default) for predicate to be true.

E.g.:

wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')

If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").

Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval

if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)

time.sleep(interval)


class MongoDBCache(BaseCache):
"""MongoDB Atlas cache

A cache that uses MongoDB Atlas as a backend
"""

PROMPT = "prompt"
LLM = "llm"
RETURN_VAL = "return_val"
_local_cache: Dict[str, Any]

def __init__(
self,
connection_string: str,
collection_name: str = "default",
database_name: str = "default",
**kwargs: Dict[str, Any],
) -> None:
"""
Initialize Atlas Cache. Creates collection on instantiation

Args:
collection_name (str): Name of collection for cache to live.
Defaults to "default".
connection_string (str): Connection URI to MongoDB Atlas.
Defaults to "default".
database_name (str): Name of database for cache to live.
Defaults to "default".
"""
self.client = _generate_mongo_client(connection_string)
self.__database_name = database_name
self.__collection_name = collection_name
self._local_cache = {}

if self.__collection_name not in self.database.list_collection_names():
self.database.create_collection(self.__collection_name)
# Create an index on key and llm_string
self.collection.create_index([self.PROMPT, self.LLM])

@property
def database(self) -> Database:
"""Returns the database used to store cache values."""
return self.client[self.__database_name]

@property
def collection(self) -> Collection:
"""Returns the collection used to store cache values."""
return self.database[self.__collection_name]

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
if cache_key in self._local_cache:
return self._local_cache[cache_key]

return_doc = (
self.collection.find_one(self._generate_keys(prompt, llm_string)) or {}
)
return_val = return_doc.get(self.RETURN_VAL)
return _loads_generations(return_val) if return_val else None # type: ignore

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
self._local_cache[cache_key] = return_val

self.collection.update_one(
{**self._generate_keys(prompt, llm_string)},
{"$set": {self.RETURN_VAL: _dumps_generations(return_val)}},
upsert=True,
)

def _generate_keys(self, prompt: str, llm_string: str) -> Dict[str, str]:
"""Create keyed fields for caching layer"""
return {self.PROMPT: prompt, self.LLM: llm_string}

def _generate_local_key(self, prompt: str, llm_string: str) -> str:
"""Create keyed fields for local caching layer"""
return f"{prompt}#{llm_string}"

def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted.

E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})


class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
"""MongoDB Atlas Semantic cache.

A Cache backed by a MongoDB Atlas server with vector-store support
"""

LLM = "llm_string"
RETURN_VAL = "return_val"
_local_cache: Dict[str, Any]

def __init__(
self,
connection_string: str,
embedding: Embeddings,
collection_name: str = "default",
database_name: str = "default",
wait_until_ready: bool = False,
**kwargs: Dict[str, Any],
):
"""
Initialize Atlas VectorSearch Cache.
Assumes collection exists before instantiation

Args:
connection_string (str): MongoDB URI to connect to MongoDB Atlas cluster.
embedding (Embeddings): Text embedding model to use.
collection_name (str): MongoDB Collection to add the texts to.
Defaults to "default".
database_name (str): MongoDB Database where to store texts.
Defaults to "default".
wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
the stored text. Hard timeout of 10 seconds. Defaults to False.
"""
client = _generate_mongo_client(connection_string)
self.collection = client[database_name][collection_name]
self._wait_until_ready = wait_until_ready
super().__init__(self.collection, embedding, **kwargs) # type: ignore
self._local_cache = dict()

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
if cache_key in self._local_cache:
return self._local_cache[cache_key]
search_response = self.similarity_search_with_score(
prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}}
)
if search_response:
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
response = _loads_generations(return_val) or return_val # type: ignore
self._local_cache[cache_key] = response
return response
return None

def update(
self,
prompt: str,
llm_string: str,
return_val: RETURN_VAL_TYPE,
wait_until_ready: Optional[bool] = None,
) -> None:
"""Update cache based on prompt and llm_string."""
cache_key = self._generate_local_key(prompt, llm_string)
self._local_cache[cache_key] = return_val

self.add_texts(
[prompt],
[
{
self.LLM: llm_string,
self.RETURN_VAL: _dumps_generations(return_val),
}
],
)
wait = self._wait_until_ready if wait_until_ready is None else wait_until_ready

def is_indexed() -> bool:
return self.lookup(prompt, llm_string) == return_val

if wait:
_wait_until(is_indexed, return_val)

def _generate_local_key(self, prompt: str, llm_string: str) -> str:
"""Create keyed fields for local caching layer"""
return f"{prompt}#{llm_string}"

def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted. It will delete any locally cached content regardless

E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
self._local_cache.clear()