Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add retriever self_query and score_threshold in DingoDB #18106

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api_reference/guide_imports.json

Large diffs are not rendered by default.

496 changes: 496 additions & 0 deletions docs/docs/integrations/retrievers/self_query/dingo.ipynb

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions libs/community/langchain_community/vectorstores/dingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def similarity_search(
List of Documents most similar to the query and score for each
"""
docs_and_scores = self.similarity_search_with_score(
query, k=k, search_params=search_params
query, k=k, search_params=search_params, **kwargs
)
return [doc for doc, _ in docs_and_scores]

Expand Down Expand Up @@ -177,9 +177,15 @@ def similarity_search_with_score(
return []

for res in results[0]["vectorWithDistances"]:
score = res["distance"]
if (
"score_threshold" in kwargs
and kwargs.get("score_threshold") is not None
):
if score > kwargs.get("score_threshold"):
continue
metadatas = res["scalarData"]
id = res["id"]
score = res["distance"]
text = metadatas[self._text_key]["fields"][0]["data"]
metadata = {"id": id, "text": text, "score": score}
for meta_key in metadatas.keys():
Expand Down
3 changes: 3 additions & 0 deletions libs/langchain/langchain/retrievers/self_query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Chroma,
DashVector,
DeepLake,
Dingo,
ElasticsearchStore,
Milvus,
MongoDBAtlasVectorSearch,
Expand Down Expand Up @@ -39,6 +40,7 @@
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.dingo import DingoDBTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
from langchain.retrievers.self_query.milvus import MilvusTranslator
from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator
Expand All @@ -65,6 +67,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
DashVector: DashvectorTranslator,
Dingo: DingoDBTranslator,
Weaviate: WeaviateTranslator,
Vectara: VectaraTranslator,
Qdrant: QdrantTranslator,
Expand Down
49 changes: 49 additions & 0 deletions libs/langchain/langchain/retrievers/self_query/dingo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Tuple, Union

from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)


class DingoDBTranslator(Visitor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coudl we add some unit tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, I added some unit tests

"""Translate `DingoDB` internal query language elements to valid filters."""

allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""

def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"

def visit_operation(self, operation: Operation) -> Operation:
return operation

def visit_comparison(self, comparison: Comparison) -> Comparison:
return comparison

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {
"search_params": {
"langchain_expr": structured_query.filter.accept(self)
}
}
return structured_query.query, kwargs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Dict, Tuple

from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
)
from langchain.retrievers.self_query.dingo import DingoDBTranslator

DEFAULT_TRANSLATOR = DingoDBTranslator()


def test_visit_comparison() -> None:
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
expected = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual


def test_visit_operation() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
expected = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual


def test_visit_structured_query() -> None:
query = "What is the capital of France?"

structured_query = StructuredQuery(
query=query,
filter=None,
)
expected: Tuple[str, Dict] = (query, {})
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual

comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
structured_query = StructuredQuery(
query=query,
filter=comp,
)
expected = (
query,
{
"search_params": {
"langchain_expr": Comparison(
comparator=Comparator.LT, attribute="foo", value=["1", "2"]
)
}
},
)
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual

op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
structured_query = StructuredQuery(
query=query,
filter=op,
)
expected = (
query,
{
"search_params": {
"langchain_expr": Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(
comparator=Comparator.EQ, attribute="bar", value="baz"
),
],
)
}
},
)

actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual