Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

comunity: Implement delete method and all async methods in opensearch_vector_search #17321

Merged
merged 11 commits into from
Apr 3, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
IMPORT_OPENSEARCH_PY_ERROR = (
"Could not import OpenSearch. Please install it with `pip install opensearch-py`."
)
IMPORT_ASYNC_OPENSEARCH_PY_ERROR = """
Could not import AsyncOpenSearch.
Please install it with `pip install opensearch-py`."""

SCRIPT_SCORING_SEARCH = "script_scoring"
PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict
Expand All @@ -29,6 +33,15 @@ def _import_opensearch() -> Any:
return OpenSearch


def _import_async_opensearch() -> Any:
"""Import AsyncOpenSearch if available, otherwise raise error."""
try:
from opensearchpy import AsyncOpenSearch
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return AsyncOpenSearch


def _import_bulk() -> Any:
"""Import bulk if available, otherwise raise error."""
try:
Expand All @@ -38,6 +51,15 @@ def _import_bulk() -> Any:
return bulk


def _import_async_bulk() -> Any:
"""Import async_bulk if available, otherwise raise error."""
try:
from opensearchpy.helpers import async_bulk
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return async_bulk


def _import_not_found_error() -> Any:
"""Import not found error if available, otherwise raise error."""
try:
Expand All @@ -60,6 +82,19 @@ def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
return client


def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
try:
async_opensearch = _import_async_opensearch()
client = async_opensearch(opensearch_url, **kwargs)
except ValueError as e:
raise ImportError(
f"AsyncOpenSearch client string provided is not in proper format. "
f"Got error: {e} "
)
return client


def _validate_embeddings_and_bulk_size(embeddings_length: int, bulk_size: int) -> None:
"""Validate Embeddings Length and Bulk Size."""
if embeddings_length == 0:
Expand Down Expand Up @@ -141,6 +176,57 @@ def _bulk_ingest_embeddings(
return return_ids


async def _abulk_ingest_embeddings(
client: Any,
index_name: str,
embeddings: List[List[float]],
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
vector_field: str = "vector_field",
text_field: str = "text",
mapping: Optional[Dict] = None,
max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
is_aoss: bool = False,
) -> List[str]:
"""Bulk Ingest Embeddings into given index asynchronously using AsyncOpenSearch."""
if not mapping:
mapping = dict()

async_bulk = _import_async_bulk()
not_found_error = _import_not_found_error()
requests = []
return_ids = []

try:
await client.indices.get(index=index_name)
except not_found_error:
await client.indices.create(index=index_name, body=mapping)

for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
_id = ids[i] if ids else str(uuid.uuid4())
request = {
"_op_type": "index",
"_index": index_name,
vector_field: embeddings[i],
text_field: text,
"metadata": metadata,
}
if is_aoss:
request["id"] = _id
else:
request["_id"] = _id
requests.append(request)
return_ids.append(_id)

await async_bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
if not is_aoss:
await client.indices.refresh(index=index_name)

return return_ids


def _default_scripting_text_mapping(
dim: int,
vector_field: str = "vector_field",
Expand Down Expand Up @@ -334,6 +420,7 @@ def __init__(
http_auth = kwargs.get("http_auth")
self.is_aoss = _is_aoss_enabled(http_auth=http_auth)
self.client = _get_opensearch_client(opensearch_url, **kwargs)
self.async_client = _get_async_opensearch_client(opensearch_url, **kwargs)
self.engine = kwargs.get("engine")

@property
Expand Down Expand Up @@ -381,6 +468,47 @@ def __add(
is_aoss=self.is_aoss,
)

async def __aadd(
self,
texts: Iterable[str],
embeddings: List[List[float]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
bulk_size: int = 500,
**kwargs: Any,
) -> List[str]:
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
index_name = kwargs.get("index_name", self.index_name)
text_field = kwargs.get("text_field", "text")
dim = len(embeddings[0])
engine = kwargs.get("engine", "nmslib")
space_type = kwargs.get("space_type", "l2")
ef_search = kwargs.get("ef_search", 512)
ef_construction = kwargs.get("ef_construction", 512)
m = kwargs.get("m", 16)
vector_field = kwargs.get("vector_field", "vector_field")
max_chunk_bytes = kwargs.get("max_chunk_bytes", 1 * 1024 * 1024)

_validate_aoss_with_engines(self.is_aoss, engine)

mapping = _default_text_mapping(
dim, engine, space_type, ef_search, ef_construction, m, vector_field
)

return await _abulk_ingest_embeddings(
self.async_client,
index_name,
embeddings,
texts,
metadatas=metadatas,
ids=ids,
vector_field=vector_field,
text_field=text_field,
mapping=mapping,
max_chunk_bytes=max_chunk_bytes,
is_aoss=self.is_aoss,
)

def add_texts(
self,
texts: Iterable[str],
Expand Down Expand Up @@ -417,6 +545,28 @@ def add_texts(
**kwargs,
)

async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
bulk_size: int = 500,
**kwargs: Any,
) -> List[str]:
"""
Asynchronously run more texts through the embeddings
and add to the vectorstore.
"""
embeddings = await self.embedding_function.aembed_documents(list(texts))
return await self.__aadd(
texts,
embeddings,
metadatas=metadatas,
ids=ids,
bulk_size=bulk_size,
**kwargs,
)

def add_embeddings(
self,
text_embeddings: Iterable[Tuple[str, List[float]]],
Expand Down Expand Up @@ -454,6 +604,49 @@ def add_embeddings(
**kwargs,
)

def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria.

Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.

Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
if ids is None:
raise ValueError("No ids provided to delete.")

actions = [{"delete": {"_index": self.index_name, "_id": id_}} for id_ in ids]
response = self.client.bulk(actions, **kwargs)

return not any(
item.get("delete", {}).get("error") for item in response["items"]
)

async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
"""Asynchronously delete by vector ID or other criteria.

Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.

Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
if ids is None:
raise ValueError("No ids provided to delete.")

actions = [{"delete": {"_index": self.index_name, "_id": id_}} for id_ in ids]
response = await self.async_client.bulk(body=actions, **kwargs)
return not any(
item.get("delete", {}).get("error") for item in response["items"]
)

def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
Expand Down Expand Up @@ -544,9 +737,11 @@ def similarity_search_with_score(
(
Document(
page_content=hit["_source"][text_field],
metadata=hit["_source"]
if metadata_field == "*" or metadata_field not in hit["_source"]
else hit["_source"][metadata_field],
metadata=(
hit["_source"]
if metadata_field == "*" or metadata_field not in hit["_source"]
else hit["_source"][metadata_field]
),
),
hit["_score"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,6 @@ def check_compatibility(vector_store: VectorStore) -> bool:
"ZepVectorStore",
"Zilliz",
"Lantern",
"OpenSearchVectorSearch",
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
}
assert compatible == documented