Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: add possibility to search by vector in OpenSearchVectorSearch #17878

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@ def similarity_search(
docs_with_scores = self.similarity_search_with_score(query, k, **kwargs)
return [doc[0] for doc in docs_with_scores]

def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to the embedding vector."""
docs_with_scores = self.similarity_search_with_score_by_vector(
embedding, k, **kwargs
)
return [doc[0] for doc in docs_with_scores]

def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
Expand All @@ -534,46 +543,69 @@ def similarity_search_with_score(
Optional Args:
same as `similarity_search`
"""
embedding = self.embedding_function.embed_query(query)
return self.similarity_search_with_score_by_vector(embedding, k, **kwargs)

def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs and it's scores most similar to the embedding vector.

By default, supports Approximate Search.
Also supports Script Scoring and Painless Scripting.

Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.

Returns:
List of Documents along with its scores most similar to the query.

Optional Args:
same as `similarity_search`
"""
text_field = kwargs.get("text_field", "text")
metadata_field = kwargs.get("metadata_field", "metadata")

hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs)
hits = self._raw_similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)

documents_with_scores = [
(
Document(
page_content=hit["_source"][text_field],
metadata=hit["_source"]
if metadata_field == "*" or metadata_field not in hit["_source"]
else hit["_source"][metadata_field],
metadata=(
hit["_source"]
if metadata_field == "*" or metadata_field not in hit["_source"]
else hit["_source"][metadata_field]
),
),
hit["_score"],
)
for hit in hits
]
return documents_with_scores

def _raw_similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
def _raw_similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[dict]:
"""Return raw opensearch documents (dict) including vectors,
scores most similar to query.
scores most similar to the embedding vector.

By default, supports Approximate Search.
Also supports Script Scoring and Painless Scripting.

Args:
query: Text to look up documents similar to.
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.

Returns:
List of dict with its scores most similar to the query.
List of dict with its scores most similar to the embedding.

Optional Args:
same as `similarity_search`
"""
embedding = self.embedding_function.embed_query(query)
search_type = kwargs.get("search_type", "approximate_search")
vector_field = kwargs.get("vector_field", "vector_field")
index_name = kwargs.get("index_name", self.index_name)
Expand Down Expand Up @@ -702,7 +734,9 @@ def max_marginal_relevance_search(
embedding = self.embedding_function.embed_query(query)

# Do ANN/KNN search to get top fetch_k results where fetch_k >= k
results = self._raw_similarity_search_with_score(query, fetch_k, **kwargs)
results = self._raw_similarity_search_with_score_by_vector(
embedding, fetch_k, **kwargs
)

embeddings = [result["_source"][vector_field] for result in results]

Expand Down