Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

partners/astradb: Add AstraDBChatMessageHistory to langchain-astradb package #17732

Merged
merged 5 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs/integrations/providers/astradb.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Learn more in the [example notebook](/docs/integrations/vectorstores/astradb).
## Chat message history

```python
from langchain_community.chat_message_histories import AstraDBChatMessageHistory
from langchain_astradb import AstraDBChatMessageHistory
message_history = AstraDBChatMessageHistory(
session_id="test-session",
api_endpoint="...",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB

from langchain_core._api.deprecation import deprecated
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
Expand All @@ -23,6 +24,11 @@
DEFAULT_COLLECTION_NAME = "langchain_message_store"


@deprecated(
since="0.0.25",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBChatMessageHistory",
)
class AstraDBChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
Expand Down
13 changes: 12 additions & 1 deletion libs/partners/astradb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install langchain-astradb
### Vector Store

```python
from langchain_astradb.vectorstores import AstraDBVectorStore
from langchain_astradb import AstraDBVectorStore

my_store = AstraDBVectorStore(
embedding=my_embeddings,
Expand All @@ -30,6 +30,17 @@ my_store = AstraDBVectorStore(
)
```

### Chat message history

```python
from langchain_astradb import AstraDBChatMessageHistory
message_history = AstraDBChatMessageHistory(
session_id="test-session",
api_endpoint="...",
token="...",
)
```

### Store

```python
Expand Down
2 changes: 2 additions & 0 deletions libs/partners/astradb/langchain_astradb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.vectorstores import AstraDBVectorStore

__all__ = [
"AstraDBByteStore",
"AstraDBStore",
"AstraDBChatMessageHistory",
"AstraDBVectorStore",
]
148 changes: 148 additions & 0 deletions libs/partners/astradb/langchain_astradb/chat_message_histories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Astra DB - based chat message history, based on astrapy."""
from __future__ import annotations

import json
import time
from typing import List, Optional, Sequence

from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)

from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)

DEFAULT_COLLECTION_NAME = "langchain_message_store"


class AstraDBChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""Chat message history that stores history in Astra DB.

Args:
session_id: arbitrary key that is used to store the messages
of a single chat session.
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)

self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection

self.session_id = session_id
self.collection_name = collection_name

@property
def messages(self) -> List[BaseMessage]:
"""Retrieve all session messages from DB"""
self.astra_env.ensure_db_setup()
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages

@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
raise NotImplementedError("Use add_messages instead")

async def aget_messages(self) -> List[BaseMessage]:
await self.astra_env.aensure_db_setup()
docs = self.async_collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
)
sorted_docs = sorted(
[doc async for doc in docs],
key=lambda _doc: _doc["timestamp"],
)
message_blobs = [doc["body_blob"] for doc in sorted_docs]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
self.astra_env.ensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
self.collection.chunked_insert_many(docs)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
await self.astra_env.aensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
await self.async_collection.chunked_insert_many(docs)

def clear(self) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"session_id": self.session_id})

async def aclear(self) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"session_id": self.session_id})