Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add Dria retriever #17098

Merged
merged 13 commits into from
Apr 1, 2024
191 changes: 191 additions & 0 deletions docs/docs/integrations/retrievers/dria_index.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "UYyFIEKEkmHb"
},
"source": [
"# Dria\n",
"\n",
"Dria is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This notebook demonstrates how to use the Dria API for data retrieval tasks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VNTFUgK9kmHd"
},
"source": [
"# Installation\n",
"\n",
"Ensure you have the `dria` package installed. You can install it using pip:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X--1A8EEkmHd"
},
"outputs": [],
"source": [
"%pip install --upgrade --quiet dria"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xRbRL0SgkmHe"
},
"source": [
"# Configure API Key\n",
"\n",
"Set up your Dria API key for access."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "hGqOByNMkmHe"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"DRIA_API_KEY\"] = \"DRIA_API_KEY\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nDfAEqQtkmHe"
},
"source": [
"# Initialize Dria Retriever\n",
"\n",
"Create an instance of `DriaRetriever`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vlyorgCckmHe"
},
"outputs": [],
"source": [
"from langchain.retrievers import DriaRetriever\n",
"\n",
"api_key = os.getenv(\"DRIA_API_KEY\")\n",
"retriever = DriaRetriever(api_key=api_key)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j7WUY5jBOLQd"
},
"source": [
"# **Create Knowledge Base**\n",
"\n",
"Create a knowledge on [Dria's Knowledge Hub](https://dria.co/knowledge)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L5ER81eWOKnt"
},
"outputs": [],
"source": [
"contract_id = retriever.create_knowledge_base(\n",
" name=\"France's AI Development\",\n",
" embedding=DriaRetriever.models.jina_embeddings_v2_base_en.value,\n",
" category=\"Artificial Intelligence\",\n",
" description=\"Explore the growth and contributions of France in the field of Artificial Intelligence.\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9VCTzSFpkmHe"
},
"source": [
"# Add Data\n",
"\n",
"Load data into your Dria knowledge base."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xeTMafIekmHf"
},
"outputs": [],
"source": [
"texts = [\n",
" \"The first text to add to Dria.\",\n",
" \"Another piece of information to store.\",\n",
" \"More data to include in the Dria knowledge base.\",\n",
"]\n",
"\n",
"ids = retriever.add_texts(texts)\n",
"print(\"Data added with IDs:\", ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dy1UlvLCkmHf"
},
"source": [
"# Retrieve Data\n",
"\n",
"Use the retriever to find relevant documents given a query."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9y3msv9tkmHf"
},
"outputs": [],
"source": [
"query = \"Find information about Dria.\"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.x"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
2 changes: 2 additions & 0 deletions libs/community/langchain_community/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from langchain_community.retrievers.cohere_rag_retriever import CohereRagRetriever
from langchain_community.retrievers.docarray import DocArrayRetriever
from langchain_community.retrievers.dria_index import DriaRetriever
from langchain_community.retrievers.elastic_search_bm25 import (
ElasticSearchBM25Retriever,
)
Expand Down Expand Up @@ -85,6 +86,7 @@
"ChatGPTPluginRetriever",
"ChaindeskRetriever",
"CohereRagRetriever",
"DriaRetriever",
"ElasticSearchBM25Retriever",
"EmbedchainRetriever",
"GoogleDocumentAIWarehouseRetriever",
Expand Down
87 changes: 87 additions & 0 deletions libs/community/langchain_community/retrievers/dria_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Wrapper around Dria Retriever."""

from typing import Any, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain_community.utilities import DriaAPIWrapper


class DriaRetriever(BaseRetriever):
"""`Dria` retriever using the DriaAPIWrapper."""

api_wrapper: DriaAPIWrapper

def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
"""
Initialize the DriaRetriever with a DriaAPIWrapper instance.

Args:
api_key: The API key for Dria.
contract_id: The contract ID of the knowledge base to interact with.
"""
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
super().__init__(api_wrapper=api_wrapper, **kwargs)

def create_knowledge_base(
self,
name: str,
description: str,
category: str = "Unspecified",
embedding: str = "jina",
) -> str:
"""Create a new knowledge base in Dria.

Args:
name: The name of the knowledge base.
description: The description of the knowledge base.
category: The category of the knowledge base.
embedding: The embedding model to use for the knowledge base.


Returns:
The ID of the created knowledge base.
"""
response = self.api_wrapper.create_knowledge_base(
name, description, category, embedding
)
return response

def add_texts(
self,
texts: List,
) -> None:
"""Add texts to the Dria knowledge base.

Args:
texts: An iterable of texts and metadatas to add to the knowledge base.

Returns:
List of IDs representing the added texts.
"""
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
self.api_wrapper.insert_data(data)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve relevant documents from Dria based on a query.

Args:
query: The query string to search for in the knowledge base.
run_manager: Callback manager for the retriever run.

Returns:
A list of Documents containing the search results.
"""
results = self.api_wrapper.search(query)
docs = [
Document(
page_content=result["metadata"],
metadata={"id": result["id"], "score": result["score"]},
)
for result in results
]
return docs
11 changes: 11 additions & 0 deletions libs/community/langchain_community/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def _import_duckduckgo_search() -> Any:
return DuckDuckGoSearchAPIWrapper


def _import_dria_index() -> Any:
from langchain_community.utilities.dria_index import (
DriaAPIWrapper,
)

return DriaAPIWrapper


def _import_golden_query() -> Any:
from langchain_community.utilities.golden_query import GoldenQueryAPIWrapper

Expand Down Expand Up @@ -391,6 +399,8 @@ def __getattr__(name: str) -> Any:
return _import_wolfram_alpha()
elif name == "ZapierNLAWrapper":
return _import_zapier()
elif name == "DriaAPIWrapper":
return _import_dria_index()
else:
raise AttributeError(f"Could not find: {name}")

Expand All @@ -404,6 +414,7 @@ def __getattr__(name: str) -> Any:
"BingSearchAPIWrapper",
"BraveSearchWrapper",
"DuckDuckGoSearchAPIWrapper",
"DriaAPIWrapper",
"GoldenQueryAPIWrapper",
"GoogleFinanceAPIWrapper",
"GoogleLensAPIWrapper",
Expand Down