Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partners: AI21 Labs Batch Support in Embeddings #18633

Merged
merged 35 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9976ca2
feat: Contextual answers in langchain
asafgardin Feb 28, 2024
ae411b6
docs: Updated readme
asafgardin Feb 28, 2024
f0e6aa0
Merge branch 'langchain-ai:master' into master
Josephasafg Feb 28, 2024
3eebce9
Merge branch 'master' into Josephasafg/master
efriis Feb 28, 2024
ccb7561
deps
efriis Feb 28, 2024
c054713
docs: Added ca examples to llm docs
asafgardin Feb 29, 2024
931ff88
fix: readme
asafgardin Feb 29, 2024
9b1799d
feat: Added chunking to embedding
asafgardin Feb 29, 2024
0b0ede6
fix: Used default chunk size
asafgardin Feb 29, 2024
57eb9dd
fix: default chunk
asafgardin Feb 29, 2024
4669ef9
refactor: use itertools
asafgardin Mar 3, 2024
a94f423
fix: Moved function outside of class and added a test
asafgardin Mar 3, 2024
7a03e63
fix: test name
asafgardin Mar 3, 2024
a6f1d91
fix: lint
asafgardin Mar 3, 2024
383f522
fix: rename
asafgardin Mar 3, 2024
c5d1d43
Merge pull request #1 from Josephasafg/add_chunking_in_embeddings
Josephasafg Mar 3, 2024
5f4d507
revert: embeddings
asafgardin Mar 4, 2024
16db69a
Merge pull request #2 from Josephasafg/revert_embeddings
Josephasafg Mar 4, 2024
570eea0
Merge branch 'langchain-ai:master' into master
Josephasafg Mar 4, 2024
bf29ec4
feat: Added batching to embeddings
asafgardin Mar 4, 2024
7546b9b
Merge branch 'master' into master
Josephasafg Mar 5, 2024
67eee9b
fix: notebook
asafgardin Mar 5, 2024
85dc73b
Merge pull request #4 from Josephasafg/notebook_format
Josephasafg Mar 5, 2024
2a439b7
Merge branch 'langchain-ai:master' into master
Josephasafg Mar 5, 2024
dcc74d1
Merge branch 'master' into master
efriis Mar 5, 2024
65ff240
Merge branch 'master' into Josephasafg/master
efriis Mar 5, 2024
6f788a3
cr
efriis Mar 5, 2024
721eb15
Merge pull request #3 from Josephasafg/embeddings_batch
Josephasafg Mar 6, 2024
84bd20e
Merge branch 'langchain-ai:master' into master
Josephasafg Mar 6, 2024
5a90e5e
fix: Changed model name in llm integration tests
asafgardin Mar 10, 2024
e645ce9
Merge branch 'master' into master
Josephasafg Mar 10, 2024
9daf6dc
Merge branch 'master' into master
efriis Mar 14, 2024
ad0bd9b
cr
efriis Mar 14, 2024
e77bda2
cr
efriis Mar 14, 2024
4fc25ae
release 0.1.1
efriis Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 47 additions & 11 deletions libs/partners/ai21/langchain_ai21/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Any, List
from itertools import islice
from typing import Any, Iterator, List, Optional

from ai21.models import EmbedType
from langchain_core.embeddings import Embeddings

from langchain_ai21.ai21_base import AI21Base

_DEFAULT_CHUNK_SIZE = 128


def chunked_text_generator(texts: List[str], chunk_size: int) -> Iterator[List[str]]:
texts_itr = iter(texts)
return iter(lambda: list(islice(texts_itr, chunk_size)), [])


class AI21Embeddings(Embeddings, AI21Base):
"""AI21 Embeddings embedding model.
Expand All @@ -20,22 +28,50 @@ class AI21Embeddings(Embeddings, AI21Base):
query_result = embeddings.embed_query("Hello embeddings world!")
"""

def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
chunk_size: int = _DEFAULT_CHUNK_SIZE
"""Maximum number of texts to embed in each batch"""

def embed_documents(
self,
texts: List[str],
chunk_size: Optional[int] = None,
**kwargs: Any,
) -> List[List[float]]:
"""Embed search docs."""
response = self.client.embed.create(
return self._send_embeddings(
texts=texts,
type=EmbedType.SEGMENT,
chunk_size=chunk_size or self.chunk_size,
embed_type=EmbedType.SEGMENT,
**kwargs,
)

return [result.embedding for result in response.results]

def embed_query(self, text: str, **kwargs: Any) -> List[float]:
def embed_query(
self,
text: str,
chunk_size: Optional[int] = None,
**kwargs: Any,
) -> List[float]:
"""Embed query text."""
response = self.client.embed.create(
return self._send_embeddings(
texts=[text],
type=EmbedType.QUERY,
chunk_size=chunk_size or self.chunk_size,
embed_type=EmbedType.QUERY,
**kwargs,
)
)[0]

def _send_embeddings(
self, texts: List[str], chunk_size: int, embed_type: EmbedType, **kwargs: Any
) -> List[List[float]]:
chunks = chunked_text_generator(texts, chunk_size)
responses = [
self.client.embed.create(
texts=chunk,
type=embed_type,
**kwargs,
)
for chunk in chunks
]

return [result.embedding for result in response.results][0]
return [
result.embedding for response in responses for result in response.results
]
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
"""Test ChatAI21 chat model."""

from langchain_core.messages import HumanMessage
from langchain_core.outputs import ChatGeneration

from langchain_ai21.chat_models import ChatAI21

_MODEL_NAME = "j2-ultra"


def test_invoke() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)


def test_generation() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)
message = HumanMessage(content="Hello")

result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
Expand All @@ -30,7 +33,7 @@ def test_generation() -> None:

async def test_ageneration() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)
message = HumanMessage(content="Hello")

result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
Expand Down
18 changes: 18 additions & 0 deletions libs/partners/ai21/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test AI21 embeddings."""

from langchain_ai21.embeddings import AI21Embeddings


Expand All @@ -17,3 +18,20 @@ def test_langchain_ai21_embedding_query() -> None:
embedding = AI21Embeddings()
output = embedding.embed_query(document)
assert len(output) > 0


def test_langchain_ai21_embedding_documents__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = ["foo", "bar"]
embedding = AI21Embeddings()
output = embedding.embed_documents(documents, chunk_size=1)
assert len(output) == 2
assert len(output[0]) > 0


def test_langchain_ai21_embedding_query__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = "foo bar"
embedding = AI21Embeddings()
output = embedding.embed_query(documents, chunk_size=1)
assert len(output) > 0
17 changes: 9 additions & 8 deletions libs/partners/ai21/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Test AI21LLM llm."""


from langchain_ai21.llms import AI21LLM

_MODEL_NAME = "j2-mid"


def _generate_llm() -> AI21LLM:
"""
Testing AI21LLm using non default parameters with the following parameters
"""
return AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
max_tokens=2, # Use less tokens for a faster response
temperature=0, # for a consistent response
epoch=1,
Expand All @@ -19,7 +20,7 @@ def _generate_llm() -> AI21LLM:
def test_stream() -> None:
"""Test streaming tokens from AI21."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

for token in llm.stream("I'm Pickle Rick"):
Expand All @@ -29,7 +30,7 @@ def test_stream() -> None:
async def test_abatch() -> None:
"""Test streaming tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
Expand All @@ -40,7 +41,7 @@ async def test_abatch() -> None:
async def test_abatch_tags() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = await llm.abatch(
Expand All @@ -53,7 +54,7 @@ async def test_abatch_tags() -> None:
def test_batch() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
Expand All @@ -64,7 +65,7 @@ def test_batch() -> None:
async def test_ainvoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
Expand All @@ -74,7 +75,7 @@ async def test_ainvoke() -> None:
def test_invoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
Expand Down
35 changes: 35 additions & 0 deletions libs/partners/ai21/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Test embedding model integration."""

from typing import List
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -65,3 +67,36 @@ def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
texts=texts,
type=EmbedType.SEGMENT,
)


@pytest.mark.parametrize(
ids=[
"empty_texts",
"chunk_size_greater_than_texts_length",
"chunk_size_equal_to_texts_length",
"chunk_size_less_than_texts_length",
"chunk_size_one_with_multiple_texts",
"chunk_size_greater_than_texts_length",
],
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
argvalues=[
([], 3, 0),
(["text1", "text2", "text3"], 5, 1),
(["text1", "text2", "text3"], 3, 1),
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
(["text1", "text2", "text3"], 1, 3),
(["text1", "text2", "text3"], 10, 1),
],
)
def test_get_len_safe_embeddings(
mock_client_with_embeddings: Mock,
texts: List[str],
chunk_size: int,
expected_internal_embeddings_calls: int,
) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
llm.embed_documents(texts=texts, chunk_size=chunk_size)
assert (
mock_client_with_embeddings.embed.create.call_count
== expected_internal_embeddings_calls
)
29 changes: 29 additions & 0 deletions libs/partners/ai21/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List

import pytest

from langchain_ai21.embeddings import chunked_text_generator


@pytest.mark.parametrize(
ids=[
"when_chunk_size_is_2__should_return_3_chunks",
"when_texts_is_empty__should_return_empty_list",
"when_chunk_size_is_1__should_return_10_chunks",
],
argnames=["input_texts", "chunk_size", "expected_output"],
argvalues=[
(["a", "b", "c", "d", "e"], 2, [["a", "b"], ["c", "d"], ["e"]]),
([], 3, []),
(
["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
1,
[["1"], ["2"], ["3"], ["4"], ["5"], ["6"], ["7"], ["8"], ["9"], ["10"]],
),
],
)
def test_chunked_text_generator(
input_texts: List[str], chunk_size: int, expected_output: List[List[str]]
) -> None:
result = list(chunked_text_generator(input_texts, chunk_size))
assert result == expected_output