-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Add Dria retriever (#17098)
[Dria](https://dria.co/) is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This PR adds a retriever that can retrieve documents from Dria.
- Loading branch information
1 parent
0b0a551
commit 4384fa8
Showing
8 changed files
with
418 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "UYyFIEKEkmHb" | ||
}, | ||
"source": [ | ||
"# Dria\n", | ||
"\n", | ||
"Dria is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This notebook demonstrates how to use the Dria API for data retrieval tasks." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "VNTFUgK9kmHd" | ||
}, | ||
"source": [ | ||
"# Installation\n", | ||
"\n", | ||
"Ensure you have the `dria` package installed. You can install it using pip:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "X--1A8EEkmHd" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install --upgrade --quiet dria" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "xRbRL0SgkmHe" | ||
}, | ||
"source": [ | ||
"# Configure API Key\n", | ||
"\n", | ||
"Set up your Dria API key for access." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"id": "hGqOByNMkmHe" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"DRIA_API_KEY\"] = \"DRIA_API_KEY\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "nDfAEqQtkmHe" | ||
}, | ||
"source": [ | ||
"# Initialize Dria Retriever\n", | ||
"\n", | ||
"Create an instance of `DriaRetriever`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "vlyorgCckmHe" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.retrievers import DriaRetriever\n", | ||
"\n", | ||
"api_key = os.getenv(\"DRIA_API_KEY\")\n", | ||
"retriever = DriaRetriever(api_key=api_key)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "j7WUY5jBOLQd" | ||
}, | ||
"source": [ | ||
"# **Create Knowledge Base**\n", | ||
"\n", | ||
"Create a knowledge on [Dria's Knowledge Hub](https://dria.co/knowledge)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "L5ER81eWOKnt" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"contract_id = retriever.create_knowledge_base(\n", | ||
" name=\"France's AI Development\",\n", | ||
" embedding=DriaRetriever.models.jina_embeddings_v2_base_en.value,\n", | ||
" category=\"Artificial Intelligence\",\n", | ||
" description=\"Explore the growth and contributions of France in the field of Artificial Intelligence.\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "9VCTzSFpkmHe" | ||
}, | ||
"source": [ | ||
"# Add Data\n", | ||
"\n", | ||
"Load data into your Dria knowledge base." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "xeTMafIekmHf" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"texts = [\n", | ||
" \"The first text to add to Dria.\",\n", | ||
" \"Another piece of information to store.\",\n", | ||
" \"More data to include in the Dria knowledge base.\",\n", | ||
"]\n", | ||
"\n", | ||
"ids = retriever.add_texts(texts)\n", | ||
"print(\"Data added with IDs:\", ids)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "dy1UlvLCkmHf" | ||
}, | ||
"source": [ | ||
"# Retrieve Data\n", | ||
"\n", | ||
"Use the retriever to find relevant documents given a query." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "9y3msv9tkmHf" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"query = \"Find information about Dria.\"\n", | ||
"result = retriever.get_relevant_documents(query)\n", | ||
"for doc in result:\n", | ||
" print(doc)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.x" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
libs/community/langchain_community/retrievers/dria_index.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Wrapper around Dria Retriever.""" | ||
|
||
from typing import Any, List, Optional | ||
|
||
from langchain_core.callbacks import CallbackManagerForRetrieverRun | ||
from langchain_core.documents import Document | ||
from langchain_core.retrievers import BaseRetriever | ||
|
||
from langchain_community.utilities import DriaAPIWrapper | ||
|
||
|
||
class DriaRetriever(BaseRetriever): | ||
"""`Dria` retriever using the DriaAPIWrapper.""" | ||
|
||
api_wrapper: DriaAPIWrapper | ||
|
||
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any): | ||
""" | ||
Initialize the DriaRetriever with a DriaAPIWrapper instance. | ||
Args: | ||
api_key: The API key for Dria. | ||
contract_id: The contract ID of the knowledge base to interact with. | ||
""" | ||
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id) | ||
super().__init__(api_wrapper=api_wrapper, **kwargs) | ||
|
||
def create_knowledge_base( | ||
self, | ||
name: str, | ||
description: str, | ||
category: str = "Unspecified", | ||
embedding: str = "jina", | ||
) -> str: | ||
"""Create a new knowledge base in Dria. | ||
Args: | ||
name: The name of the knowledge base. | ||
description: The description of the knowledge base. | ||
category: The category of the knowledge base. | ||
embedding: The embedding model to use for the knowledge base. | ||
Returns: | ||
The ID of the created knowledge base. | ||
""" | ||
response = self.api_wrapper.create_knowledge_base( | ||
name, description, category, embedding | ||
) | ||
return response | ||
|
||
def add_texts( | ||
self, | ||
texts: List, | ||
) -> None: | ||
"""Add texts to the Dria knowledge base. | ||
Args: | ||
texts: An iterable of texts and metadatas to add to the knowledge base. | ||
Returns: | ||
List of IDs representing the added texts. | ||
""" | ||
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts] | ||
self.api_wrapper.insert_data(data) | ||
|
||
def _get_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | ||
) -> List[Document]: | ||
"""Retrieve relevant documents from Dria based on a query. | ||
Args: | ||
query: The query string to search for in the knowledge base. | ||
run_manager: Callback manager for the retriever run. | ||
Returns: | ||
A list of Documents containing the search results. | ||
""" | ||
results = self.api_wrapper.search(query) | ||
docs = [ | ||
Document( | ||
page_content=result["metadata"], | ||
metadata={"id": result["id"], "score": result["score"]}, | ||
) | ||
for result in results | ||
] | ||
return docs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
libs/community/langchain_community/utilities/dria_index.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DriaAPIWrapper: | ||
"""Wrapper around Dria API. | ||
This wrapper facilitates interactions with Dria's vector search | ||
and retrieval services, including creating knowledge bases, inserting data, | ||
and fetching search results. | ||
Attributes: | ||
api_key: Your API key for accessing Dria. | ||
contract_id: The contract ID of the knowledge base to interact with. | ||
top_n: Number of top results to fetch for a search. | ||
""" | ||
|
||
def __init__( | ||
self, api_key: str, contract_id: Optional[str] = None, top_n: int = 10 | ||
): | ||
try: | ||
from dria import Dria, Models | ||
except ImportError: | ||
logger.error( | ||
"""Dria is not installed. Please install Dria to use this wrapper. | ||
You can install Dria using the following command: | ||
pip install dria | ||
""" | ||
) | ||
return | ||
|
||
self.api_key = api_key | ||
self.models = Models | ||
self.contract_id = contract_id | ||
self.top_n = top_n | ||
self.dria_client = Dria(api_key=self.api_key) | ||
if self.contract_id: | ||
self.dria_client.set_contract(self.contract_id) | ||
|
||
def create_knowledge_base( | ||
self, | ||
name: str, | ||
description: str, | ||
category: str, | ||
embedding: str, | ||
) -> str: | ||
"""Create a new knowledge base.""" | ||
contract_id = self.dria_client.create( | ||
name=name, embedding=embedding, category=category, description=description | ||
) | ||
logger.info(f"Knowledge base created with ID: {contract_id}") | ||
self.contract_id = contract_id | ||
return contract_id | ||
|
||
def insert_data(self, data: List[Dict[str, Any]]) -> str: | ||
"""Insert data into the knowledge base.""" | ||
response = self.dria_client.insert_text(data) | ||
logger.info(f"Data inserted: {response}") | ||
return response | ||
|
||
def search(self, query: str) -> List[Dict[str, Any]]: | ||
"""Perform a text-based search.""" | ||
results = self.dria_client.search(query, top_n=self.top_n) | ||
logger.info(f"Search results: {results}") | ||
return results | ||
|
||
def query_with_vector(self, vector: List[float]) -> List[Dict[str, Any]]: | ||
"""Perform a vector-based query.""" | ||
vector_query_results = self.dria_client.query(vector, top_n=self.top_n) | ||
logger.info(f"Vector query results: {vector_query_results}") | ||
return vector_query_results | ||
|
||
def run(self, query: Union[str, List[float]]) -> Optional[List[Dict[str, Any]]]: | ||
"""Method to handle both text-based searches and vector-based queries. | ||
Args: | ||
query: A string for text-based search or a list of floats for | ||
vector-based query. | ||
Returns: | ||
The search or query results from Dria. | ||
""" | ||
if isinstance(query, str): | ||
return self.search(query) | ||
elif isinstance(query, list) and all(isinstance(item, float) for item in query): | ||
return self.query_with_vector(query) | ||
else: | ||
logger.error( | ||
"""Invalid query type. Please provide a string for text search or a | ||
list of floats for vector query.""" | ||
) | ||
return None |
Oops, something went wrong.