Skip to content

Commit

Permalink
community[minor]: Add Dria retriever (#17098)
Browse files Browse the repository at this point in the history
[Dria](https://dria.co/) is a hub of public RAG models for developers to
both contribute and utilize a shared embedding lake. This PR adds a
retriever that can retrieve documents from Dria.
  • Loading branch information
anilaltuner committed Apr 1, 2024
1 parent 0b0a551 commit 4384fa8
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 0 deletions.
191 changes: 191 additions & 0 deletions docs/docs/integrations/retrievers/dria_index.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "UYyFIEKEkmHb"
},
"source": [
"# Dria\n",
"\n",
"Dria is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This notebook demonstrates how to use the Dria API for data retrieval tasks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VNTFUgK9kmHd"
},
"source": [
"# Installation\n",
"\n",
"Ensure you have the `dria` package installed. You can install it using pip:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X--1A8EEkmHd"
},
"outputs": [],
"source": [
"%pip install --upgrade --quiet dria"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xRbRL0SgkmHe"
},
"source": [
"# Configure API Key\n",
"\n",
"Set up your Dria API key for access."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "hGqOByNMkmHe"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"DRIA_API_KEY\"] = \"DRIA_API_KEY\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nDfAEqQtkmHe"
},
"source": [
"# Initialize Dria Retriever\n",
"\n",
"Create an instance of `DriaRetriever`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vlyorgCckmHe"
},
"outputs": [],
"source": [
"from langchain.retrievers import DriaRetriever\n",
"\n",
"api_key = os.getenv(\"DRIA_API_KEY\")\n",
"retriever = DriaRetriever(api_key=api_key)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j7WUY5jBOLQd"
},
"source": [
"# **Create Knowledge Base**\n",
"\n",
"Create a knowledge on [Dria's Knowledge Hub](https://dria.co/knowledge)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L5ER81eWOKnt"
},
"outputs": [],
"source": [
"contract_id = retriever.create_knowledge_base(\n",
" name=\"France's AI Development\",\n",
" embedding=DriaRetriever.models.jina_embeddings_v2_base_en.value,\n",
" category=\"Artificial Intelligence\",\n",
" description=\"Explore the growth and contributions of France in the field of Artificial Intelligence.\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9VCTzSFpkmHe"
},
"source": [
"# Add Data\n",
"\n",
"Load data into your Dria knowledge base."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xeTMafIekmHf"
},
"outputs": [],
"source": [
"texts = [\n",
" \"The first text to add to Dria.\",\n",
" \"Another piece of information to store.\",\n",
" \"More data to include in the Dria knowledge base.\",\n",
"]\n",
"\n",
"ids = retriever.add_texts(texts)\n",
"print(\"Data added with IDs:\", ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dy1UlvLCkmHf"
},
"source": [
"# Retrieve Data\n",
"\n",
"Use the retriever to find relevant documents given a query."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9y3msv9tkmHf"
},
"outputs": [],
"source": [
"query = \"Find information about Dria.\"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.x"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
1 change: 1 addition & 0 deletions libs/community/langchain_community/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"ChatGPTPluginRetriever": "langchain_community.retrievers.chatgpt_plugin_retriever",
"CohereRagRetriever": "langchain_community.retrievers.cohere_rag_retriever",
"DocArrayRetriever": "langchain_community.retrievers.docarray",
"DriaRetriever": "langchain_community.retrievers.dria_index",
"ElasticSearchBM25Retriever": "langchain_community.retrievers.elastic_search_bm25",
"EmbedchainRetriever": "langchain_community.retrievers.embedchain",
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
Expand Down
87 changes: 87 additions & 0 deletions libs/community/langchain_community/retrievers/dria_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Wrapper around Dria Retriever."""

from typing import Any, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain_community.utilities import DriaAPIWrapper


class DriaRetriever(BaseRetriever):
"""`Dria` retriever using the DriaAPIWrapper."""

api_wrapper: DriaAPIWrapper

def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
"""
Initialize the DriaRetriever with a DriaAPIWrapper instance.
Args:
api_key: The API key for Dria.
contract_id: The contract ID of the knowledge base to interact with.
"""
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
super().__init__(api_wrapper=api_wrapper, **kwargs)

def create_knowledge_base(
self,
name: str,
description: str,
category: str = "Unspecified",
embedding: str = "jina",
) -> str:
"""Create a new knowledge base in Dria.
Args:
name: The name of the knowledge base.
description: The description of the knowledge base.
category: The category of the knowledge base.
embedding: The embedding model to use for the knowledge base.
Returns:
The ID of the created knowledge base.
"""
response = self.api_wrapper.create_knowledge_base(
name, description, category, embedding
)
return response

def add_texts(
self,
texts: List,
) -> None:
"""Add texts to the Dria knowledge base.
Args:
texts: An iterable of texts and metadatas to add to the knowledge base.
Returns:
List of IDs representing the added texts.
"""
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
self.api_wrapper.insert_data(data)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve relevant documents from Dria based on a query.
Args:
query: The query string to search for in the knowledge base.
run_manager: Callback manager for the retriever run.
Returns:
A list of Documents containing the search results.
"""
results = self.api_wrapper.search(query)
docs = [
Document(
page_content=result["metadata"],
metadata={"id": result["id"], "score": result["score"]},
)
for result in results
]
return docs
1 change: 1 addition & 0 deletions libs/community/langchain_community/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"BibtexparserWrapper": "langchain_community.utilities.bibtex",
"BingSearchAPIWrapper": "langchain_community.utilities.bing_search",
"BraveSearchWrapper": "langchain_community.utilities.brave_search",
"DriaAPIWrapper": "langchain_community.utilities.dria_index",
"DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search",
"GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query",
"GoogleFinanceAPIWrapper": "langchain_community.utilities.google_finance",
Expand Down
95 changes: 95 additions & 0 deletions libs/community/langchain_community/utilities/dria_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging
from typing import Any, Dict, List, Optional, Union

logger = logging.getLogger(__name__)


class DriaAPIWrapper:
"""Wrapper around Dria API.
This wrapper facilitates interactions with Dria's vector search
and retrieval services, including creating knowledge bases, inserting data,
and fetching search results.
Attributes:
api_key: Your API key for accessing Dria.
contract_id: The contract ID of the knowledge base to interact with.
top_n: Number of top results to fetch for a search.
"""

def __init__(
self, api_key: str, contract_id: Optional[str] = None, top_n: int = 10
):
try:
from dria import Dria, Models
except ImportError:
logger.error(
"""Dria is not installed. Please install Dria to use this wrapper.
You can install Dria using the following command:
pip install dria
"""
)
return

self.api_key = api_key
self.models = Models
self.contract_id = contract_id
self.top_n = top_n
self.dria_client = Dria(api_key=self.api_key)
if self.contract_id:
self.dria_client.set_contract(self.contract_id)

def create_knowledge_base(
self,
name: str,
description: str,
category: str,
embedding: str,
) -> str:
"""Create a new knowledge base."""
contract_id = self.dria_client.create(
name=name, embedding=embedding, category=category, description=description
)
logger.info(f"Knowledge base created with ID: {contract_id}")
self.contract_id = contract_id
return contract_id

def insert_data(self, data: List[Dict[str, Any]]) -> str:
"""Insert data into the knowledge base."""
response = self.dria_client.insert_text(data)
logger.info(f"Data inserted: {response}")
return response

def search(self, query: str) -> List[Dict[str, Any]]:
"""Perform a text-based search."""
results = self.dria_client.search(query, top_n=self.top_n)
logger.info(f"Search results: {results}")
return results

def query_with_vector(self, vector: List[float]) -> List[Dict[str, Any]]:
"""Perform a vector-based query."""
vector_query_results = self.dria_client.query(vector, top_n=self.top_n)
logger.info(f"Vector query results: {vector_query_results}")
return vector_query_results

def run(self, query: Union[str, List[float]]) -> Optional[List[Dict[str, Any]]]:
"""Method to handle both text-based searches and vector-based queries.
Args:
query: A string for text-based search or a list of floats for
vector-based query.
Returns:
The search or query results from Dria.
"""
if isinstance(query, str):
return self.search(query)
elif isinstance(query, list) and all(isinstance(item, float) for item in query):
return self.query_with_vector(query)
else:
logger.error(
"""Invalid query type. Please provide a string for text search or a
list of floats for vector query."""
)
return None

0 comments on commit 4384fa8

Please sign in to comment.