Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: llamafile embeddings support #17976

Merged
merged 17 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
157 changes: 157 additions & 0 deletions docs/docs/integrations/text_embedding/llamafile.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "278b6c63",
"metadata": {},
"source": [
"# llamafile\n",
"\n",
"Let's load the [llamafile](https://github.com/Mozilla-Ocho/llamafile) Embeddings class.\n",
"\n",
"## Setup\n",
"\n",
"First, the are 3 setup steps:\n",
"\n",
"1. Download a llamafile. In this notebook, we use `TinyLlama-1.1B-Chat-v1.0.Q5_K_M` but there are many others available on [HuggingFace](https://huggingface.co/models?other=llamafile).\n",
"2. Make the llamafile executable.\n",
"3. Start the llamafile in server mode.\n",
"\n",
"You can run the following bash script to do all this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43ef6dfa-9cc4-4552-8a53-5df523afae7c",
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"# llamafile setup\n",
"\n",
"# Step 1: Download a llamafile. The download may take several minutes.\n",
"wget -nv -nc https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n",
"\n",
"# Step 2: Make the llamafile executable. Note: if you're on Windows, just append '.exe' to the filename.\n",
"chmod +x TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n",
"\n",
"# Step 3: Start llamafile server in background. All the server logs will be written to 'tinyllama.log'.\n",
"# Alternatively, you can just open a separate terminal outside this notebook and run: \n",
"# ./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding\n",
"./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding > tinyllama.log 2>&1 &\n",
"pid=$!\n",
"echo \"${pid}\" > .llamafile_pid # write the process pid to a file so we can terminate the server later"
]
},
{
"cell_type": "markdown",
"id": "3188b22f-879f-47b3-9a27-24412f6fad5f",
"metadata": {},
"source": [
"## Embedding texts using LlamafileEmbeddings\n",
"\n",
"Now, we can use the `LlamafileEmbeddings` class to interact with the llamafile server that's currently serving our TinyLlama model at http://localhost:8080."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0be1af71",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import LlamafileEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c66e5da",
"metadata": {},
"outputs": [],
"source": [
"embedder = LlamafileEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01370375",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "markdown",
"id": "a42e4035",
"metadata": {},
"source": [
"To generate embeddings, you can either query an invidivual text, or you can query a list of texts."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91bc875d-829b-4c3d-8e6f-fc2dda30a3bd",
"metadata": {},
"outputs": [],
"source": [
"query_result = embedder.embed_query(text)\n",
"query_result[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4b0d49e-0c73-44b6-aed5-5b426564e085",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embedder.embed_documents([text])\n",
"doc_result[0][:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ccc78fc-03ae-411d-ae73-74a4ee91c725",
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"# cleanup: kill the llamafile server process\n",
"kill $(cat .llamafile_pid)\n",
"rm .llamafile_pid"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
},
"vscode": {
"interpreter": {
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from langchain_community.embeddings.jina import JinaEmbeddings
from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
from langchain_community.embeddings.llamafile import LlamafileEmbeddings
from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings
from langchain_community.embeddings.localai import LocalAIEmbeddings
from langchain_community.embeddings.minimax import MiniMaxEmbeddings
Expand Down Expand Up @@ -110,6 +111,7 @@
"GradientEmbeddings",
"JinaEmbeddings",
"LlamaCppEmbeddings",
"LlamafileEmbeddings",
"LLMRailsEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowEmbeddings",
Expand Down
119 changes: 119 additions & 0 deletions libs/community/langchain_community/embeddings/llamafile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging
from typing import List, Optional

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel

logger = logging.getLogger(__name__)


class LlamafileEmbeddings(BaseModel, Embeddings):
"""Llamafile lets you distribute and run large language models with a
single file.

To get started, see: https://github.com/Mozilla-Ocho/llamafile

To use this class, you will need to first:

1. Download a llamafile.
2. Make the downloaded file executable: `chmod +x path/to/model.llamafile`
3. Start the llamafile in server mode with embeddings enabled:

`./path/to/model.llamafile --server --nobrowser --embedding`

Example:
.. code-block:: python

from langchain_community.embeddings import LlamafileEmbeddings
embedder = LlamafileEmbeddings()
doc_embeddings = embedder.embed_documents(
[
"Alpha is the first letter of the Greek alphabet",
"Beta is the second letter of the Greek alphabet",
]
)
query_embedding = embedder.embed_query(
"What is the second letter of the Greek alphabet"
)

"""

base_url: str = "http://localhost:8080"
"""Base url where the llamafile server is listening."""

request_timeout: Optional[int] = None
"""Timeout for server requests"""

def _embed(self, text: str) -> List[float]:
try:
response = requests.post(
url=f"{self.base_url}/embedding",
headers={
"Content-Type": "application/json",
},
json={
"content": text,
},
timeout=self.request_timeout,
)
except requests.exceptions.ConnectionError:
raise requests.exceptions.ConnectionError(
f"Could not connect to Llamafile server. Please make sure "
f"that a server is running at {self.base_url}."
)

# Raise exception if we got a bad (non-200) response status code
response.raise_for_status()

contents = response.json()
if "embedding" not in contents:
raise KeyError(
"Unexpected output from /embedding endpoint, output dict "
"missing 'embedding' key."
)

embedding = contents["embedding"]

# Sanity check the embedding vector:
# Prior to llamafile v0.6.2, if the server was not started with the
# `--embedding` option, the embedding endpoint would always return a
# 0-vector. See issue:
# https://github.com/Mozilla-Ocho/llamafile/issues/243
# So here we raise an exception if the vector sums to exactly 0.
if sum(embedding) == 0.0:
raise ValueError(
"Embedding sums to 0, did you start the llamafile server with "
"the `--embedding` option enabled?"
)

return embedding

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using a llamafile server running at `self.base_url`.
llamafile server should be started in a separate process before invoking
this method.

Args:
texts: The list of texts to embed.

Returns:
List of embeddings, one for each text.
"""
doc_embeddings = []
for text in texts:
doc_embeddings.append(self._embed(text))
return doc_embeddings

def embed_query(self, text: str) -> List[float]:
"""Embed a query using a llamafile server running at `self.base_url`.
llamafile server should be started in a separate process before invoking
this method.

Args:
text: The text to embed.

Returns:
Embeddings for the text.
"""
return self._embed(text)
6 changes: 6 additions & 0 deletions libs/community/langchain_community/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ def _import_llamacpp() -> Type[BaseLLM]:
return LlamaCpp


def _import_llamafile() -> Type[BaseLLM]:
from langchain_community.llms.llamafile import Llamafile

return Llamafile


def _import_manifest() -> Type[BaseLLM]:
from langchain_community.llms.manifest import ManifestWrapper

Expand Down
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"GradientEmbeddings",
"JinaEmbeddings",
"LlamaCppEmbeddings",
"LlamafileEmbeddings",
"LLMRailsEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowAIGatewayEmbeddings",
Expand Down
67 changes: 67 additions & 0 deletions libs/community/tests/unit_tests/embeddings/test_llamafile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json

import numpy as np
import requests
from pytest import MonkeyPatch

from langchain_community.embeddings import LlamafileEmbeddings


def mock_response() -> requests.Response:
contents = json.dumps({"embedding": np.random.randn(512).tolist()})
response = requests.Response()
response.status_code = 200
response._content = str.encode(contents)
return response


def test_embed_documents(monkeypatch: MonkeyPatch) -> None:
"""
Test basic functionality of the `embed_documents` method
"""
embedder = LlamafileEmbeddings(
base_url="http://llamafile-host:8080",
)

def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
assert url == "http://llamafile-host:8080/embedding"
assert headers == {
"Content-Type": "application/json",
}
# 'unknown' kwarg should be ignored
assert json == {"content": "Test text"}
# assert stream is False
assert timeout is None
return mock_response()

monkeypatch.setattr(requests, "post", mock_post)
out = embedder.embed_documents(["Test text", "Test text"])
assert isinstance(out, list)
assert len(out) == 2
for vec in out:
assert len(vec) == 512


def test_embed_query(monkeypatch: MonkeyPatch) -> None:
"""
Test basic functionality of the `embed_query` method
"""
embedder = LlamafileEmbeddings(
base_url="http://llamafile-host:8080",
)

def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
assert url == "http://llamafile-host:8080/embedding"
assert headers == {
"Content-Type": "application/json",
}
# 'unknown' kwarg should be ignored
assert json == {"content": "Test text"}
# assert stream is False
assert timeout is None
return mock_response()

monkeypatch.setattr(requests, "post", mock_post)
out = embedder.embed_query("Test text")
assert isinstance(out, list)
assert len(out) == 512