Skip to content

Commit

Permalink
Add AstraDBChatMessageHistory to langchain-astradb package
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Feb 20, 2024
1 parent 1d2aa19 commit 1b46e4c
Show file tree
Hide file tree
Showing 8 changed files with 1,225 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
if TYPE_CHECKING:
from astrapy.db import AstraDB

from langchain_core._api.deprecation import deprecated
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
Expand All @@ -23,6 +24,11 @@
DEFAULT_COLLECTION_NAME = "langchain_message_store"


@deprecated(
since="0.1.24",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBChatMessageHistory",
)
class AstraDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Astra DB.
Expand Down
2 changes: 2 additions & 0 deletions libs/partners/astradb/langchain_astradb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
from langchain_astradb.vectorstores import AstraDBVectorStore

__all__ = [
"AstraDBChatMessageHistory",
"AstraDBVectorStore",
]
149 changes: 149 additions & 0 deletions libs/partners/astradb/langchain_astradb/chat_message_histories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Astra DB - based chat message history, based on astrapy."""
from __future__ import annotations

import json
import time
from typing import List, Optional, Sequence

from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)

from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)

DEFAULT_COLLECTION_NAME = "langchain_message_store"


class AstraDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Astra DB.
Args:
session_id: arbitrary key that is used to store the messages
of a single chat session.
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""

def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)

self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection

self.session_id = session_id
self.collection_name = collection_name

@property
def messages(self) -> List[BaseMessage]:
"""Retrieve all session messages from DB"""
self.astra_env.ensure_db_setup()
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages

@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
raise NotImplementedError("Use add_messages instead")

async def aget_messages(self) -> List[BaseMessage]:
await self.astra_env.aensure_db_setup()
docs = self.async_collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
)
sorted_docs = sorted(
[doc async for doc in docs],
key=lambda _doc: _doc["timestamp"],
)
message_blobs = [doc["body_blob"] for doc in sorted_docs]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
self.astra_env.ensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
self.collection.chunked_insert_many(docs)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
await self.astra_env.aensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
await self.async_collection.chunked_insert_many(docs)

def clear(self) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"session_id": self.session_id})

async def aclear(self) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"session_id": self.session_id})
142 changes: 142 additions & 0 deletions libs/partners/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import asyncio
import inspect
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import Awaitable, Optional, Union

from astrapy.db import AstraDB, AsyncAstraDB


class SetupMode(Enum):
SYNC = 1
ASYNC = 2
OFF = 3


class _AstraDBEnvironment:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
) -> None:
self.token = token
self.api_endpoint = api_endpoint
astra_db = astra_db_client
async_astra_db = async_astra_db_client
self.namespace = namespace

# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
)

if token and api_endpoint:
astra_db = AstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
async_astra_db = AsyncAstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)

if astra_db:
self.astra_db = astra_db
if async_astra_db:
self.async_astra_db = async_astra_db
else:
self.async_astra_db = self.astra_db.to_async()
elif async_astra_db:
self.async_astra_db = async_astra_db
self.astra_db = self.async_astra_db.to_sync()
else:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
"'token' and 'api_endpoint'"
)


class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: Union[int, Awaitable[int], None] = None,
metric: Optional[str] = None,
) -> None:
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection

super().__init__(
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
)
self.collection_name = collection_name
self.collection = AstraDBCollection(
collection_name=collection_name,
astra_db=self.astra_db,
)

self.async_collection = AsyncAstraDBCollection(
collection_name=collection_name,
astra_db=self.async_astra_db,
)

self.async_setup_db_task: Optional[Task] = None
if setup_mode == SetupMode.ASYNC:
async_astra_db = self.async_astra_db

async def _setup_db() -> None:
if pre_delete_collection:
await async_astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
dimension = await embedding_dimension
else:
dimension = embedding_dimension
await async_astra_db.create_collection(
collection_name, dimension=dimension, metric=metric
)

self.async_setup_db_task = asyncio.create_task(_setup_db())
elif setup_mode == SetupMode.SYNC:
if pre_delete_collection:
self.astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
raise ValueError(
"Cannot use an awaitable embedding_dimension with async_setup "
"set to False"
)
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
metric=metric,
)

def ensure_db_setup(self) -> None:
if self.async_setup_db_task:
try:
self.async_setup_db_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)

async def aensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task

0 comments on commit 1b46e4c

Please sign in to comment.