Skip to content

Commit

Permalink
elasticsearch: add ElasticsearchRetriever (langchain-ai#18587)
Browse files Browse the repository at this point in the history
Implement
[Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/)
interface for Elasticsearch.

I opted to only expose the `body`, which gives you full flexibility, and
none the other 68 arguments of the [search
method](https://elasticsearch-py.readthedocs.io/en/v8.12.1/api/elasticsearch.html#elasticsearch.Elasticsearch.search).

Added a user agent header for usage tracking in Elastic Cloud.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
2 people authored and thebhulawat committed Mar 6, 2024
1 parent c7518c1 commit 37bd801
Show file tree
Hide file tree
Showing 11 changed files with 493 additions and 187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import List, Union

import numpy as np
from elasticsearch import Elasticsearch
from langchain_core import __version__ as langchain_version

Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]

Expand All @@ -17,6 +19,12 @@ class DistanceStrategy(str, Enum):
COSINE = "COSINE"


def with_user_agent_header(client: Elasticsearch, header_prefix: str) -> Elasticsearch:
headers = dict(client._headers)
headers.update({"user-agent": f"{header_prefix}/{langchain_version}"})
return client.options(headers=headers)


def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,
Expand Down
83 changes: 18 additions & 65 deletions libs/partners/elasticsearch/langchain_elasticsearch/chat_history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, List, Optional

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
Expand All @@ -10,6 +10,9 @@
messages_from_dict,
)

from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client

if TYPE_CHECKING:
from elasticsearch import Elasticsearch

Expand Down Expand Up @@ -51,23 +54,27 @@ def __init__(

# Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None:
self.client = es_connection.options(
headers={"user-agent": self.get_user_agent()}
)
self.client = es_connection
elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
es_url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
try:
self.client = create_elasticsearch_client(
url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
else:
raise ValueError(
"""Either provide a pre-existing Elasticsearch connection, \
or valid credentials for creating a new connection."""
)

self.client = with_user_agent_header(self.client, "langchain-py-ms")

if self.client.indices.exists(index=index):
logger.debug(
f"Chat history index {index} already exists, skipping creation."
Expand All @@ -86,60 +93,6 @@ def __init__(
},
)

@staticmethod
def get_user_agent() -> str:
from langchain_core import __version__

return f"langchain-py-ms/{__version__}"

@staticmethod
def connect_to_elasticsearch(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> "Elasticsearch":
try:
import elasticsearch
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)

if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)

connection_params: Dict[str, Any] = {}

if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")

if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)

es_client = elasticsearch.Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
)
try:
es_client.info()
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err

return es_client

@property
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""Retrieve the messages from Elasticsearch"""
Expand Down
40 changes: 40 additions & 0 deletions libs/partners/elasticsearch/langchain_elasticsearch/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Dict, Optional

from elasticsearch import Elasticsearch


def create_elasticsearch_client(
url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Elasticsearch:
if url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)

connection_params: Dict[str, Any] = {}

if url:
connection_params["hosts"] = [url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")

if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)

if params is not None:
connection_params.update(params)

es_client = Elasticsearch(**connection_params)

es_client.info() # test connection

return es_client
97 changes: 97 additions & 0 deletions libs/partners/elasticsearch/langchain_elasticsearch/retrievers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
from typing import Any, Callable, Dict, List, Optional

from elasticsearch import Elasticsearch
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain_elasticsearch._utilities import with_user_agent_header
from langchain_elasticsearch.client import create_elasticsearch_client

logger = logging.getLogger(__name__)


class ElasticsearchRetriever(BaseRetriever):
"""
Elasticsearch retriever
Args:
es_client: Elasticsearch client connection. Alternatively you can use the
`from_es_params` method with parameters to initialize the client.
index_name: The name of the index to query.
body_func: Function to create an Elasticsearch DSL query body from a search
string. All parameters (including for example the `size` parameter to limit
the number of results) must also be set in the body.
content_field: The document field name that contains the page content.
document_mapper: Function to map Elasticsearch hits to LangChain Documents.
"""

es_client: Elasticsearch
index_name: str
body_func: Callable[[str], Dict]
content_field: Optional[str] = None
document_mapper: Optional[Callable[[Dict], Document]] = None

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

if self.content_field is None and self.document_mapper is None:
raise ValueError("One of content_field or document_mapper must be defined.")
if self.content_field is not None and self.document_mapper is not None:
raise ValueError(
"Both content_field and document_mapper are defined. "
"Please provide only one."
)

self.document_mapper = self.document_mapper or self._field_mapper
self.es_client = with_user_agent_header(self.es_client, "langchain-py-r")

@staticmethod
def from_es_params(
index_name: str,
body_func: Callable[[str], Dict],
content_field: Optional[str] = None,
document_mapper: Optional[Callable[[Dict], Document]] = None,
url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> "ElasticsearchRetriever":
client = None
try:
client = create_elasticsearch_client(
url=url,
cloud_id=cloud_id,
api_key=api_key,
username=username,
password=password,
params=params,
)
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err

return ElasticsearchRetriever(
es_client=client,
index_name=index_name,
body_func=body_func,
content_field=content_field,
document_mapper=document_mapper,
)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if not self.es_client or not self.document_mapper:
raise ValueError("faulty configuration") # should not happen

body = self.body_func(query)
results = self.es_client.search(index=self.index_name, body=body)
return [self.document_mapper(hit) for hit in results["hits"]["hits"]]

def _field_mapper(self, hit: Dict[str, Any]) -> Document:
content = hit["_source"].pop(self.content_field)
return Document(page_content=content, metadata=hit)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from langchain_elasticsearch._utilities import (
DistanceStrategy,
maximal_marginal_relevance,
with_user_agent_header,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -526,9 +527,7 @@ def __init__(
self.strategy = strategy

if es_connection is not None:
headers = dict(es_connection._headers)
headers.update({"user-agent": self.get_user_agent()})
self.client = es_connection.options(headers=headers)
self.client = es_connection
elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchStore.connect_to_elasticsearch(
es_url=es_url,
Expand All @@ -544,11 +543,7 @@ def __init__(
or valid credentials for creating a new connection."""
)

@staticmethod
def get_user_agent() -> str:
from langchain_core import __version__

return f"langchain-py-vs/{__version__}"
self.client = with_user_agent_header(self.client, "langchain-py-vs")

@staticmethod
def connect_to_elasticsearch(
Expand Down Expand Up @@ -582,10 +577,7 @@ def connect_to_elasticsearch(
if es_params is not None:
connection_params.update(es_params)

es_client = Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchStore.get_user_agent()},
)
es_client = Elasticsearch(**connection_params)
try:
es_client.info()
except Exception as e:
Expand Down

0 comments on commit 37bd801

Please sign in to comment.