-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AstraDBChatMessageHistory to langchain-astradb package
- Loading branch information
Showing
8 changed files
with
1,224 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory | ||
from langchain_astradb.vectorstores import AstraDBVectorStore | ||
|
||
__all__ = [ | ||
"AstraDBChatMessageHistory", | ||
"AstraDBVectorStore", | ||
] |
148 changes: 148 additions & 0 deletions
148
libs/partners/astradb/langchain_astradb/chat_message_histories.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
"""Astra DB - based chat message history, based on astrapy.""" | ||
from __future__ import annotations | ||
|
||
import json | ||
import time | ||
from typing import List, Optional, Sequence | ||
|
||
from astrapy.db import AstraDB, AsyncAstraDB | ||
from langchain_core.chat_history import BaseChatMessageHistory | ||
from langchain_core.messages import ( | ||
BaseMessage, | ||
message_to_dict, | ||
messages_from_dict, | ||
) | ||
|
||
from langchain_astradb.utils.astradb import ( | ||
SetupMode, | ||
_AstraDBCollectionEnvironment, | ||
) | ||
|
||
DEFAULT_COLLECTION_NAME = "langchain_message_store" | ||
|
||
|
||
class AstraDBChatMessageHistory(BaseChatMessageHistory): | ||
def __init__( | ||
self, | ||
*, | ||
session_id: str, | ||
collection_name: str = DEFAULT_COLLECTION_NAME, | ||
token: Optional[str] = None, | ||
api_endpoint: Optional[str] = None, | ||
astra_db_client: Optional[AstraDB] = None, | ||
async_astra_db_client: Optional[AsyncAstraDB] = None, | ||
namespace: Optional[str] = None, | ||
setup_mode: SetupMode = SetupMode.SYNC, | ||
pre_delete_collection: bool = False, | ||
) -> None: | ||
"""Chat message history that stores history in Astra DB. | ||
Args: | ||
session_id: arbitrary key that is used to store the messages | ||
of a single chat session. | ||
collection_name: name of the Astra DB collection to create/use. | ||
token: API token for Astra DB usage. | ||
api_endpoint: full URL to the API endpoint, | ||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". | ||
astra_db_client: *alternative to token+api_endpoint*, | ||
you can pass an already-created 'astrapy.db.AstraDB' instance. | ||
async_astra_db_client: *alternative to token+api_endpoint*, | ||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. | ||
namespace: namespace (aka keyspace) where the | ||
collection is created. Defaults to the database's "default namespace". | ||
""" | ||
self.astra_env = _AstraDBCollectionEnvironment( | ||
collection_name=collection_name, | ||
token=token, | ||
api_endpoint=api_endpoint, | ||
astra_db_client=astra_db_client, | ||
async_astra_db_client=async_astra_db_client, | ||
namespace=namespace, | ||
setup_mode=setup_mode, | ||
pre_delete_collection=pre_delete_collection, | ||
) | ||
|
||
self.collection = self.astra_env.collection | ||
self.async_collection = self.astra_env.async_collection | ||
|
||
self.session_id = session_id | ||
self.collection_name = collection_name | ||
|
||
@property | ||
def messages(self) -> List[BaseMessage]: | ||
"""Retrieve all session messages from DB""" | ||
self.astra_env.ensure_db_setup() | ||
message_blobs = [ | ||
doc["body_blob"] | ||
for doc in sorted( | ||
self.collection.paginated_find( | ||
filter={ | ||
"session_id": self.session_id, | ||
}, | ||
projection={ | ||
"timestamp": 1, | ||
"body_blob": 1, | ||
}, | ||
), | ||
key=lambda _doc: _doc["timestamp"], | ||
) | ||
] | ||
items = [json.loads(message_blob) for message_blob in message_blobs] | ||
messages = messages_from_dict(items) | ||
return messages | ||
|
||
@messages.setter | ||
def messages(self, messages: List[BaseMessage]) -> None: | ||
raise NotImplementedError("Use add_messages instead") | ||
|
||
async def aget_messages(self) -> List[BaseMessage]: | ||
await self.astra_env.aensure_db_setup() | ||
docs = self.async_collection.paginated_find( | ||
filter={ | ||
"session_id": self.session_id, | ||
}, | ||
projection={ | ||
"timestamp": 1, | ||
"body_blob": 1, | ||
}, | ||
) | ||
sorted_docs = sorted( | ||
[doc async for doc in docs], | ||
key=lambda _doc: _doc["timestamp"], | ||
) | ||
message_blobs = [doc["body_blob"] for doc in sorted_docs] | ||
items = [json.loads(message_blob) for message_blob in message_blobs] | ||
messages = messages_from_dict(items) | ||
return messages | ||
|
||
def add_messages(self, messages: Sequence[BaseMessage]) -> None: | ||
self.astra_env.ensure_db_setup() | ||
docs = [ | ||
{ | ||
"timestamp": time.time(), | ||
"session_id": self.session_id, | ||
"body_blob": json.dumps(message_to_dict(message)), | ||
} | ||
for message in messages | ||
] | ||
self.collection.chunked_insert_many(docs) | ||
|
||
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: | ||
await self.astra_env.aensure_db_setup() | ||
docs = [ | ||
{ | ||
"timestamp": time.time(), | ||
"session_id": self.session_id, | ||
"body_blob": json.dumps(message_to_dict(message)), | ||
} | ||
for message in messages | ||
] | ||
await self.async_collection.chunked_insert_many(docs) | ||
|
||
def clear(self) -> None: | ||
self.astra_env.ensure_db_setup() | ||
self.collection.delete_many(filter={"session_id": self.session_id}) | ||
|
||
async def aclear(self) -> None: | ||
await self.astra_env.aensure_db_setup() | ||
await self.async_collection.delete_many(filter={"session_id": self.session_id}) |
142 changes: 142 additions & 0 deletions
142
libs/partners/astradb/langchain_astradb/utils/astradb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import inspect | ||
from asyncio import InvalidStateError, Task | ||
from enum import Enum | ||
from typing import Awaitable, Optional, Union | ||
|
||
from astrapy.db import AstraDB, AsyncAstraDB | ||
|
||
|
||
class SetupMode(Enum): | ||
SYNC = 1 | ||
ASYNC = 2 | ||
OFF = 3 | ||
|
||
|
||
class _AstraDBEnvironment: | ||
def __init__( | ||
self, | ||
token: Optional[str] = None, | ||
api_endpoint: Optional[str] = None, | ||
astra_db_client: Optional[AstraDB] = None, | ||
async_astra_db_client: Optional[AsyncAstraDB] = None, | ||
namespace: Optional[str] = None, | ||
) -> None: | ||
self.token = token | ||
self.api_endpoint = api_endpoint | ||
astra_db = astra_db_client | ||
async_astra_db = async_astra_db_client | ||
self.namespace = namespace | ||
|
||
# Conflicting-arg checks: | ||
if astra_db_client is not None or async_astra_db_client is not None: | ||
if token is not None or api_endpoint is not None: | ||
raise ValueError( | ||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to " | ||
"AstraDBEnvironment if passing 'token' and 'api_endpoint'." | ||
) | ||
|
||
if token and api_endpoint: | ||
astra_db = AstraDB( | ||
token=token, | ||
api_endpoint=api_endpoint, | ||
namespace=self.namespace, | ||
) | ||
async_astra_db = AsyncAstraDB( | ||
token=token, | ||
api_endpoint=api_endpoint, | ||
namespace=self.namespace, | ||
) | ||
|
||
if astra_db: | ||
self.astra_db = astra_db | ||
if async_astra_db: | ||
self.async_astra_db = async_astra_db | ||
else: | ||
self.async_astra_db = self.astra_db.to_async() | ||
elif async_astra_db: | ||
self.async_astra_db = async_astra_db | ||
self.astra_db = self.async_astra_db.to_sync() | ||
else: | ||
raise ValueError( | ||
"Must provide 'astra_db_client' or 'async_astra_db_client' or " | ||
"'token' and 'api_endpoint'" | ||
) | ||
|
||
|
||
class _AstraDBCollectionEnvironment(_AstraDBEnvironment): | ||
def __init__( | ||
self, | ||
collection_name: str, | ||
token: Optional[str] = None, | ||
api_endpoint: Optional[str] = None, | ||
astra_db_client: Optional[AstraDB] = None, | ||
async_astra_db_client: Optional[AsyncAstraDB] = None, | ||
namespace: Optional[str] = None, | ||
setup_mode: SetupMode = SetupMode.SYNC, | ||
pre_delete_collection: bool = False, | ||
embedding_dimension: Union[int, Awaitable[int], None] = None, | ||
metric: Optional[str] = None, | ||
) -> None: | ||
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection | ||
|
||
super().__init__( | ||
token, api_endpoint, astra_db_client, async_astra_db_client, namespace | ||
) | ||
self.collection_name = collection_name | ||
self.collection = AstraDBCollection( | ||
collection_name=collection_name, | ||
astra_db=self.astra_db, | ||
) | ||
|
||
self.async_collection = AsyncAstraDBCollection( | ||
collection_name=collection_name, | ||
astra_db=self.async_astra_db, | ||
) | ||
|
||
self.async_setup_db_task: Optional[Task] = None | ||
if setup_mode == SetupMode.ASYNC: | ||
async_astra_db = self.async_astra_db | ||
|
||
async def _setup_db() -> None: | ||
if pre_delete_collection: | ||
await async_astra_db.delete_collection(collection_name) | ||
if inspect.isawaitable(embedding_dimension): | ||
dimension = await embedding_dimension | ||
else: | ||
dimension = embedding_dimension | ||
await async_astra_db.create_collection( | ||
collection_name, dimension=dimension, metric=metric | ||
) | ||
|
||
self.async_setup_db_task = asyncio.create_task(_setup_db()) | ||
elif setup_mode == SetupMode.SYNC: | ||
if pre_delete_collection: | ||
self.astra_db.delete_collection(collection_name) | ||
if inspect.isawaitable(embedding_dimension): | ||
raise ValueError( | ||
"Cannot use an awaitable embedding_dimension with async_setup " | ||
"set to False" | ||
) | ||
self.astra_db.create_collection( | ||
collection_name, | ||
dimension=embedding_dimension, # type: ignore[arg-type] | ||
metric=metric, | ||
) | ||
|
||
def ensure_db_setup(self) -> None: | ||
if self.async_setup_db_task: | ||
try: | ||
self.async_setup_db_task.result() | ||
except InvalidStateError: | ||
raise ValueError( | ||
"Asynchronous setup of the DB not finished. " | ||
"NB: AstraDB components sync methods shouldn't be called from the " | ||
"event loop. Consider using their async equivalents." | ||
) | ||
|
||
async def aensure_db_setup(self) -> None: | ||
if self.async_setup_db_task: | ||
await self.async_setup_db_task |
Oops, something went wrong.