Skip to content

Commit

Permalink
Add docstring to AstraDBStore
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Feb 20, 2024
1 parent 865cabf commit 950a415
Showing 1 changed file with 108 additions and 49 deletions.
157 changes: 108 additions & 49 deletions libs/community/langchain_community/storage/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,8 @@
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""Base class for the DataStax AstraDB data store."""

def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
*,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection

Expand All @@ -65,15 +45,13 @@ def encode_value(self, value: Optional[V]) -> Any:
"""Encodes value for Astra DB"""

def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
"""Get the values associated with the given keys."""
self.astra_env.ensure_db_setup()
docs_dict = {}
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]

async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
"""Get the values associated with the given keys."""
await self.astra_env.aensure_db_setup()
docs_dict = {}
async for doc in self.async_collection.paginated_find(
Expand All @@ -83,31 +61,26 @@ async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
return [self.decode_value(docs_dict.get(key)) for key in keys]

def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the given key-value pairs."""
self.astra_env.ensure_db_setup()
for k, v in key_value_pairs:
self.collection.upsert({"_id": k, "value": self.encode_value(v)})

async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
"""Set the given key-value pairs."""
await self.astra_env.aensure_db_setup()
for k, v in key_value_pairs:
await self.async_collection.upsert(
{"_id": k, "value": self.encode_value(v)}
)

def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys."""
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})

async def amdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys."""
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})

def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
"""Yield keys in the store."""
self.astra_env.ensure_db_setup()
docs = self.collection.paginated_find()
for doc in docs:
Expand All @@ -116,7 +89,6 @@ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
yield key

async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
"""Yield keys in the store."""
await self.astra_env.aensure_db_setup()
async for doc in self.async_collection.paginated_find():
key = doc["_id"]
Expand All @@ -125,16 +97,60 @@ async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[st


class AstraDBStore(AstraDBBaseStore[Any]):
"""BaseStore implementation using DataStax AstraDB as the underlying store.
The value type can be any type serializable by json.dumps.
Can be used to store embeddings with the CacheBackedEmbeddings.
Documents in the AstraDB collection will have the format
{
"_id": "<key>",
"value": <value>
}
"""
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
*,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""BaseStore implementation using DataStax AstraDB as the underlying store.
The value type can be any type serializable by json.dumps.
Can be used to store embeddings with the CacheBackedEmbeddings.
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": <value>
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
# Constructor doc is not inherited so we have to override it.
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)

def decode_value(self, value: Any) -> Any:
return value
Expand All @@ -144,15 +160,58 @@ def encode_value(self, value: Any) -> Any:


class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
"""ByteStore implementation using DataStax AstraDB as the underlying store.
The bytes values are converted to base64 encoded strings
Documents in the AstraDB collection will have the format
{
"_id": "<key>",
"value": "<byte64 string value>"
}
"""
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
*,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""ByteStore implementation using DataStax AstraDB as the underlying store.
The bytes values are converted to base64 encoded strings
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": "<byte64 string value>"
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
# Constructor doc is not inherited so we have to override it.
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)

def decode_value(self, value: Any) -> Optional[bytes]:
if value is None:
Expand Down

0 comments on commit 950a415

Please sign in to comment.