Skip to content

Commit

Permalink
Add AstraDBStore to langchain-astradb package
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Feb 20, 2024
1 parent 865cabf commit 54df796
Show file tree
Hide file tree
Showing 6 changed files with 488 additions and 0 deletions.
11 changes: 11 additions & 0 deletions libs/community/langchain_community/storage/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TypeVar,
)

from langchain_core._api.deprecation import deprecated
from langchain_core.stores import BaseStore, ByteStore

from langchain_community.utilities.astradb import (
Expand Down Expand Up @@ -124,6 +125,11 @@ async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[st
yield key


@deprecated(
since="0.1.24",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBStore",
)
class AstraDBStore(AstraDBBaseStore[Any]):
"""BaseStore implementation using DataStax AstraDB as the underlying store.
Expand All @@ -143,6 +149,11 @@ def encode_value(self, value: Any) -> Any:
return value


@deprecated(
since="0.1.24",
removal="0.2.0",
alternative_import="langchain_astradb.AstraDBByteStore",
)
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
"""ByteStore implementation using DataStax AstraDB as the underlying store.
Expand Down
3 changes: 3 additions & 0 deletions libs/partners/astradb/langchain_astradb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.vectorstores import AstraDBVectorStore

__all__ = [
"AstraDBByteStore",
"AstraDBStore",
"AstraDBVectorStore",
]
154 changes: 154 additions & 0 deletions libs/partners/astradb/langchain_astradb/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

import base64
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)

from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.stores import BaseStore, ByteStore

from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)

V = TypeVar("V")


class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""Base class for the DataStax AstraDB data store."""

def __init__(
self,
collection_name: str,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection

@abstractmethod
def decode_value(self, value: Any) -> Optional[V]:
"""Decodes value from Astra DB"""

@abstractmethod
def encode_value(self, value: Optional[V]) -> Any:
"""Encodes value for Astra DB"""

def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
self.astra_env.ensure_db_setup()
docs_dict = {}
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]

async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
await self.astra_env.aensure_db_setup()
docs_dict = {}
async for doc in self.async_collection.paginated_find(
filter={"_id": {"$in": list(keys)}}
):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]

def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
self.astra_env.ensure_db_setup()
for k, v in key_value_pairs:
self.collection.upsert({"_id": k, "value": self.encode_value(v)})

async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
await self.astra_env.aensure_db_setup()
for k, v in key_value_pairs:
await self.async_collection.upsert(
{"_id": k, "value": self.encode_value(v)}
)

def mdelete(self, keys: Sequence[str]) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})

async def amdelete(self, keys: Sequence[str]) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})

def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.astra_env.ensure_db_setup()
docs = self.collection.paginated_find()
for doc in docs:
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key

async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.astra_env.aensure_db_setup()
async for doc in self.async_collection.paginated_find():
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key


class AstraDBStore(AstraDBBaseStore[Any]):
"""BaseStore implementation using DataStax AstraDB as the underlying store.
The value type can be any type serializable by json.dumps.
Can be used to store embeddings with the CacheBackedEmbeddings.
Documents in the AstraDB collection will have the format
{
"_id": "<key>",
"value": <value>
}
"""

def decode_value(self, value: Any) -> Any:
return value

def encode_value(self, value: Any) -> Any:
return value


class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
"""ByteStore implementation using DataStax AstraDB as the underlying store.
The bytes values are converted to base64 encoded strings
Documents in the AstraDB collection will have the format
{
"_id": "<key>",
"value": "<byte64 string value>"
}
"""

def decode_value(self, value: Any) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value)

def encode_value(self, value: Optional[bytes]) -> Any:
if value is None:
return None
return base64.b64encode(value).decode("ascii")
142 changes: 142 additions & 0 deletions libs/partners/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import asyncio
import inspect
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import Awaitable, Optional, Union

from astrapy.db import AstraDB, AsyncAstraDB


class SetupMode(Enum):
SYNC = 1
ASYNC = 2
OFF = 3


class _AstraDBEnvironment:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
) -> None:
self.token = token
self.api_endpoint = api_endpoint
astra_db = astra_db_client
async_astra_db = async_astra_db_client
self.namespace = namespace

# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
)

if token and api_endpoint:
astra_db = AstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
async_astra_db = AsyncAstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)

if astra_db:
self.astra_db = astra_db
if async_astra_db:
self.async_astra_db = async_astra_db
else:
self.async_astra_db = self.astra_db.to_async()
elif async_astra_db:
self.async_astra_db = async_astra_db
self.astra_db = self.async_astra_db.to_sync()
else:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
"'token' and 'api_endpoint'"
)


class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: Union[int, Awaitable[int], None] = None,
metric: Optional[str] = None,
) -> None:
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection

super().__init__(
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
)
self.collection_name = collection_name
self.collection = AstraDBCollection(
collection_name=collection_name,
astra_db=self.astra_db,
)

self.async_collection = AsyncAstraDBCollection(
collection_name=collection_name,
astra_db=self.async_astra_db,
)

self.async_setup_db_task: Optional[Task] = None
if setup_mode == SetupMode.ASYNC:
async_astra_db = self.async_astra_db

async def _setup_db() -> None:
if pre_delete_collection:
await async_astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
dimension = await embedding_dimension
else:
dimension = embedding_dimension
await async_astra_db.create_collection(
collection_name, dimension=dimension, metric=metric
)

self.async_setup_db_task = asyncio.create_task(_setup_db())
elif setup_mode == SetupMode.SYNC:
if pre_delete_collection:
self.astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
raise ValueError(
"Cannot use an awaitable embedding_dimension with async_setup "
"set to False"
)
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
metric=metric,
)

def ensure_db_setup(self) -> None:
if self.async_setup_db_task:
try:
self.async_setup_db_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)

async def aensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task

0 comments on commit 54df796

Please sign in to comment.