Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain-mongodb: Add MongoDB LLM Cache #17470

Merged
merged 22 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
207 changes: 207 additions & 0 deletions libs/community/langchain_community/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import inspect
import json
import logging
import time
import uuid
import warnings
from abc import ABC
Expand Down Expand Up @@ -68,6 +69,7 @@
SetupMode,
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
from langchain_community.vectorstores.redis import Redis as RedisVectorstore

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -1836,10 +1838,215 @@ async def adelete_by_document_id(self, document_id: str) -> None:

def clear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
self.astra_db.truncate_collection(self.collection_name)
self.astra_env.ensure_db_setup()
self.collection.clear()

async def aclear(self, **kwargs: Any) -> None:
"""Clear the *whole* semantic cache."""
await self.astra_env.aensure_db_setup()
await self.async_collection.clear()

def _generate_mongo_client(connection_string: str):
try:
from importlib.metadata import version

from pymongo import MongoClient
from pymongo.driver_info import DriverInfo
except ImportError:
raise ImportError(
"Could not import pymongo, please install it with " "`pip install pymongo`."
)

return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain")),
)


def _wait_until(predicate, success_description, timeout=10):
"""Wait up to 10 seconds (by default) for predicate to be true.

E.g.:

wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')

If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").

Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval

if time.time() - start > timeout:
raise AssertionError("Didn't ever %s" % success_description)

time.sleep(interval)


class MongoDBAtlasCache(BaseCache):
"""MongoDB Atlas cache

A cache that uses MongoDB Atlas as a backend
"""

PROMPT = "prompt"
LLM = "llm"

def __init__(
self,
collection_name: str = "default",
connection_string: str = "default",
database_name: str = "default",
**kwargs,
):
"""
Initialize Atlas Cache. Creates collection on instantiation

Args:
collection_name (str): Name of collection for cache to live.
Defaults to "default".
connection_string (str): Connection URI to MongoDB Atlas.
Defaults to "default".
database_name (str): Name of database for cache to live.
Defaults to "default".
"""
self.client = _generate_mongo_client(connection_string)

self.__database_name = database_name
self.__collection_name = collection_name

if self.__collection_name not in self.database.list_collection_names():
self.database.create_collection(self.__collection_name)
# Create an index on key and llm_string
self.collection.create_index([self.PROMPT, self.LLM])

@property
def database(self):
"""Returns the database used to store cache values."""
return self.client[self.__database_name]

@property
def collection(self):
"""Returns the collection used to store cache values."""
return self.database[self.__collection_name]

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
return_doc = (
self.collection.find_one(self._generate_keys(prompt, llm_string)) or {}
)
if return_doc.get("return_val"):
return _loads_generations(return_doc.get("return_val"))

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.collection.update_one(
{**self._generate_keys(prompt, llm_string)},
{"$set": {"return_val": _dumps_generations(return_val)}},
upsert=True,
)

def _generate_keys(self, prompt: str, llm_string: str) -> dict[str, str]:
"""Create keyed fields for caching layer"""
return {self.PROMPT: prompt, self.LLM: llm_string}

def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted.

E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})


class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
"""MongoDB Atlas Semantic cache.

A Cache backed by a MongoDB Atlas server with vector-store support
"""

LLM = "llm_string"
RETURN_VAL = "return_val"

def __init__(
self,
connection_string: str,
embedding: Embeddings,
collection_name: str = "default",
database_name: str = "default",
wait_until_ready: bool = False,
**kwargs,
):
"""
Initialize Atlas VectorSearch Cache.
Assumes collection exists before instantiation

Args:
connection_string (str): MongoDB URI to connect to MongoDB Atlas cluster.
embedding (Embeddings): Text embedding model to use.
collection_name (str): MongoDB Collection to add the texts to.
Defaults to "default".
database_name (str): MongoDB Database where to store texts.
Defaults to "default".
wait_until_ready (bool): Block until MongoDB Atlas finishes indexing
the stored text. Hard timeout of 10 seconds. Defaults to False.
"""
client = _generate_mongo_client(connection_string)
self.collection = client[database_name][collection_name]
self._wait_until_ready = wait_until_ready
super().__init__(self.collection, embedding, **kwargs)

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
search_response = self.similarity_search_with_score(
prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}}
)
if search_response:
return_val = search_response[0][0].metadata.get("return_val")
return _loads_generations(return_val) or return_val

def update(
self,
prompt: str,
llm_string: str,
return_val: RETURN_VAL_TYPE,
wait_until_ready: Optional[bool] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be bool = False

Can we think of any other names besides wait_until_ready?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about ensure_indexed?
or wait_until_indexed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like wait_until_indexed

) -> None:
"""Update cache based on prompt and llm_string."""
self.add_texts(
[prompt],
[
{
self.LLM: llm_string,
self.RETURN_VAL: _dumps_generations(return_val),
}
],
)
wait = self._wait_until_ready if wait_until_ready is None else wait_until_ready

def is_value_stored():
return self.lookup(prompt, llm_string) == return_val

if wait:
_wait_until(is_value_stored, return_val)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't raise AssertionError on a timeout. Some kind of timeout error would be more appropriate.


def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted.

E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
4 changes: 4 additions & 0 deletions libs/langchain/langchain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
GPTCache,
InMemoryCache,
MomentoCache,
MongoDBAtlasCache,
MongoDBAtlasSemanticCache,
RedisCache,
RedisSemanticCache,
SQLAlchemyCache,
Expand All @@ -26,6 +28,8 @@
"RedisSemanticCache",
"GPTCache",
"MomentoCache",
"MongoDBAtlasCache",
"MongoDBAtlasSemanticCache",
"CassandraCache",
"CassandraSemanticCache",
"FullMd5LLMCache",
Expand Down