forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: infinity embedding local option (langchain-ai#17671)
**drop-in-replacement for sentence-transformers inference.** langchain-ai#17670 tldr from the discussion above -> around a 4x-22x speedup over using SentenceTransformers / huggingface embeddings. For more info: https://github.com/michaelfeil/infinity (pure-python dependency) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
- Loading branch information
Showing
5 changed files
with
316 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
156 changes: 156 additions & 0 deletions
156
libs/community/langchain_community/embeddings/infinity_local.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
"""written under MIT Licence, Michael Feil 2023.""" | ||
|
||
import asyncio | ||
from logging import getLogger | ||
from typing import Any, Dict, List, Optional | ||
|
||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator | ||
|
||
__all__ = ["InfinityEmbeddingsLocal"] | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class InfinityEmbeddingsLocal(BaseModel, Embeddings): | ||
"""Optimized Embedding models https://github.com/michaelfeil/infinity | ||
This class deploys a local Infinity instance to embed text. | ||
The class requires async usage. | ||
Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.embeddings import InfinityEmbeddingsLocal | ||
async with InfinityEmbeddingsLocal( | ||
model="BAAI/bge-small-en-v1.5", | ||
revision=None, | ||
device="cpu", | ||
) as embedder: | ||
embeddings = await engine.aembed_documents(["text1", "text2"]) | ||
""" | ||
|
||
model: str | ||
"Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5" | ||
|
||
revision: Optional[str] = None | ||
"Model version, the commit hash from huggingface" | ||
|
||
batch_size: int = 32 | ||
"Internal batch size for inference, e.g. 32" | ||
|
||
device: str = "auto" | ||
"Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'" | ||
|
||
backend: str = "torch" | ||
"Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)" | ||
" or 'optimum' for onnx/tensorrt" | ||
|
||
model_warmup: bool = True | ||
"Warmup the model with the max batch size." | ||
|
||
engine: Any = None #: :meta private: | ||
"""Infinity's AsyncEmbeddingEngine.""" | ||
|
||
# LLM call kwargs | ||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
@root_validator(allow_reuse=True) | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
|
||
try: | ||
from infinity_emb import AsyncEmbeddingEngine # type: ignore | ||
except ImportError: | ||
raise ImportError( | ||
"Please install the " | ||
"`pip install 'infinity_emb[optimum,torch]>=0.0.24'` " | ||
"package to use the InfinityEmbeddingsLocal." | ||
) | ||
logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}") | ||
|
||
values["engine"] = AsyncEmbeddingEngine( | ||
model_name_or_path=values["model"], | ||
device=values["device"], | ||
revision=values["revision"], | ||
model_warmup=values["model_warmup"], | ||
batch_size=values["batch_size"], | ||
engine=values["backend"], | ||
) | ||
return values | ||
|
||
async def __aenter__(self) -> None: | ||
"""start the background worker. | ||
recommended usage is with the async with statement. | ||
async with InfinityEmbeddingsLocal( | ||
model="BAAI/bge-small-en-v1.5", | ||
revision=None, | ||
device="cpu", | ||
) as embedder: | ||
embeddings = await engine.aembed_documents(["text1", "text2"]) | ||
""" | ||
await self.engine.__aenter__() | ||
|
||
async def __aexit__(self, *args: Any) -> None: | ||
"""stop the background worker, | ||
required to free references to the pytorch model.""" | ||
await self.engine.__aexit__(*args) | ||
|
||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | ||
"""Async call out to Infinity's embedding endpoint. | ||
Args: | ||
texts: The list of texts to embed. | ||
Returns: | ||
List of embeddings, one for each text. | ||
""" | ||
if not self.engine.running: | ||
logger.warning( | ||
"Starting Infinity engine on the fly. This is not recommended." | ||
"Please start the engine before using it." | ||
) | ||
async with self: | ||
# spawning threadpool for multithreaded encode, tokenization | ||
embeddings, _ = await self.engine.embed(texts) | ||
# stopping threadpool on exit | ||
logger.warning("Stopped infinity engine after usage.") | ||
else: | ||
embeddings, _ = await self.engine.embed(texts) | ||
return embeddings | ||
|
||
async def aembed_query(self, text: str) -> List[float]: | ||
"""Async call out to Infinity's embedding endpoint. | ||
Args: | ||
text: The text to embed. | ||
Returns: | ||
Embeddings for the text. | ||
""" | ||
embeddings = await self.aembed_documents([text]) | ||
return embeddings[0] | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
""" | ||
This method is async only. | ||
""" | ||
logger.warning( | ||
"This method is async only. " | ||
"Please use the async version `await aembed_documents`." | ||
) | ||
return asyncio.run(self.aembed_documents(texts)) | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
""" """ | ||
logger.warning( | ||
"This method is async only." | ||
" Please use the async version `await aembed_query`." | ||
) | ||
return asyncio.run(self.aembed_query(text)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.