Skip to content

Commit

Permalink
fix: Add retriever self_query and score_threshold in DingoDB
Browse files Browse the repository at this point in the history
  • Loading branch information
HeChangHaoGary committed Feb 26, 2024
1 parent a2d5fa7 commit 2019c47
Show file tree
Hide file tree
Showing 5 changed files with 581 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/api_reference/guide_imports.json

Large diffs are not rendered by default.

496 changes: 496 additions & 0 deletions docs/docs/integrations/retrievers/self_query/dingo.ipynb

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions libs/community/langchain_community/vectorstores/dingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def similarity_search(
List of Documents most similar to the query and score for each
"""
docs_and_scores = self.similarity_search_with_score(
query, k=k, search_params=search_params
query, k=k, search_params=search_params, **kwargs
)
return [doc for doc, _ in docs_and_scores]

Expand Down Expand Up @@ -177,9 +177,15 @@ def similarity_search_with_score(
return []

for res in results[0]["vectorWithDistances"]:
score = res["distance"]
if (
"score_threshold" in kwargs
and kwargs.get("score_threshold") is not None
):
if score > kwargs.get("score_threshold"):
continue
metadatas = res["scalarData"]
id = res["id"]
score = res["distance"]
text = metadatas[self._text_key]["fields"][0]["data"]
metadata = {"id": id, "text": text, "score": score}
for meta_key in metadatas.keys():
Expand Down
52 changes: 27 additions & 25 deletions libs/langchain/langchain/retrievers/self_query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,37 @@
import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union

from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.base import load_query_constructor_runnable
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.astradb import AstraDBTranslator
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.dingo import DingoDBTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
from langchain.retrievers.self_query.milvus import MilvusTranslator
from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.retrievers.self_query.opensearch import OpenSearchTranslator
from langchain.retrievers.self_query.pgvector import PGVectorTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.redis import RedisTranslator
from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator
from langchain.retrievers.self_query.timescalevector import TimescaleVectorTranslator
from langchain.retrievers.self_query.vectara import VectaraTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain_community.vectorstores import (
AstraDB,
Chroma,
DashVector,
DeepLake,
Dingo,
ElasticsearchStore,
Milvus,
MongoDBAtlasVectorSearch,
Expand All @@ -28,31 +54,6 @@
from langchain_core.runnables import Runnable
from langchain_core.vectorstores import VectorStore

from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.base import load_query_constructor_runnable
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.astradb import AstraDBTranslator
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
from langchain.retrievers.self_query.milvus import MilvusTranslator
from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.retrievers.self_query.opensearch import OpenSearchTranslator
from langchain.retrievers.self_query.pgvector import PGVectorTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.redis import RedisTranslator
from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator
from langchain.retrievers.self_query.timescalevector import TimescaleVectorTranslator
from langchain.retrievers.self_query.vectara import VectaraTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator

logger = logging.getLogger(__name__)

Check failure on line 57 in libs/langchain/langchain/retrievers/self_query/base.py

View workflow job for this annotation

GitHub Actions / cd libs/langchain / make lint #3.8

Ruff (I001)

langchain/retrievers/self_query/base.py:2:1: I001 Import block is un-sorted or un-formatted


Expand All @@ -64,6 +65,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
DashVector: DashvectorTranslator,
Dingo: DingoDBTranslator,
Weaviate: WeaviateTranslator,
Vectara: VectaraTranslator,
Qdrant: QdrantTranslator,
Expand Down
49 changes: 49 additions & 0 deletions libs/langchain/langchain/retrievers/self_query/dingo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict, Tuple, Union

from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)


class DingoDBTranslator(Visitor):
"""Translate `DingoDB` internal query language elements to valid filters."""

allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""

def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"

def visit_operation(self, operation: Operation) -> Dict:
return operation

def visit_comparison(self, comparison: Comparison) -> Dict:
return comparison

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {
"search_params": {
"langchain_expr": structured_query.filter.accept(self)
}
}
return structured_query.query, kwargs

0 comments on commit 2019c47

Please sign in to comment.