Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.5.7
Choose a base ref
...
head repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.5.8
Choose a head ref
  • 7 commits
  • 3 files changed
  • 3 contributors

Commits on Oct 8, 2024

  1. Enable using kwargs for selecting pad-to-max-length strategy for toke…

    …nizer in embeddings
    
    Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
    kcirred committed Oct 8, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    phutchins Philip Hutchins
    Copy the full SHA
    4f8a821 View commit details

Commits on Oct 15, 2024

  1. Added mocking to test tokenizer changes, directly pass padding strategy

    Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
    kcirred committed Oct 15, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    phutchins Philip Hutchins
    Copy the full SHA
    f79d65b View commit details

Commits on Oct 16, 2024

  1. Merge pull request #393 from kcirred/main

    Enable using kwargs for selecting pad-to-max-length strategy for tokenizer in embeddings
    gkumbhat authored Oct 16, 2024
    Copy the full SHA
    0219d50 View commit details

Commits on Oct 17, 2024

  1. Update torch requirement from <2.5.0,>=2.3.1 to >=2.3.1,<2.6.0

    Updates the requirements on [torch](https://github.com/pytorch/pytorch) to permit the latest version.
    - [Release notes](https://github.com/pytorch/pytorch/releases)
    - [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md)
    - [Commits](pytorch/pytorch@v2.3.1...v2.5.0)
    
    ---
    updated-dependencies:
    - dependency-name: torch
      dependency-type: direct:production
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored Oct 17, 2024
    Copy the full SHA
    e471a2a View commit details

Commits on Oct 28, 2024

  1. [embeddings] extend kwargs to high level functions

    Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
    kcirred committed Oct 28, 2024
    Copy the full SHA
    874982b View commit details

Commits on Oct 30, 2024

  1. Merge pull request #398 from caikit/dependabot/pip/torch-gte-2.3.1-an…

    …d-lt-2.6.0
    
    Update torch requirement from <2.5.0,>=2.3.1 to >=2.3.1,<2.6.0
    gkumbhat authored Oct 30, 2024
    Copy the full SHA
    bfd3d4d View commit details

Commits on Nov 7, 2024

  1. Merge pull request #400 from kcirred/main

    [embeddings] extend kwargs to functions that call _encode_with_retry
    gkumbhat authored Nov 7, 2024
    Copy the full SHA
    56b7e18 View commit details
Showing with 90 additions and 8 deletions.
  1. +19 −7 caikit_nlp/modules/text_embedding/embedding.py
  2. +1 −1 pyproject.toml
  3. +70 −0 tests/modules/text_embedding/test_embedding.py
26 changes: 19 additions & 7 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -412,9 +412,7 @@ def run_embedding(

@EmbeddingTasks.taskmethod()
def run_embeddings(
self,
texts: List[str],
truncate_input_tokens: Optional[int] = 0,
self, texts: List[str], truncate_input_tokens: Optional[int] = 0, **kwargs
) -> EmbeddingResults:
"""Get embedding vectors for texts.
Args:
@@ -440,6 +438,7 @@ def run_embeddings(
texts,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
**kwargs,
)
vectors = [Vector1D.from_vector(e) for e in embeddings]

@@ -455,6 +454,7 @@ def run_sentence_similarity(
source_sentence: str,
sentences: List[str],
truncate_input_tokens: Optional[int] = 0,
**kwargs,
) -> SentenceSimilarityResult:
"""Get similarity scores for each of sentences compared to the source_sentence.
Args:
@@ -476,11 +476,13 @@ def run_sentence_similarity(
source_sentence,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
**kwargs,
)
embeddings, sentences_token_count = self._encode_with_retry(
sentences,
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
**kwargs,
)

input_token_count = source_token_count + sentences_token_count
@@ -547,6 +549,7 @@ def run_rerank_query(
return_documents: bool = True,
return_query: bool = True,
return_text: bool = True,
**kwargs,
) -> RerankResult:
"""Rerank the documents returning the most relevant top_n in order for this query.
Args:
@@ -598,6 +601,7 @@ def run_rerank_query(
return_documents=return_documents,
return_queries=return_query,
return_text=return_text,
**kwargs,
)

if results.results:
@@ -626,6 +630,7 @@ def run_rerank_queries(
return_documents: bool = True,
return_queries: bool = True,
return_text: bool = True,
**kwargs,
) -> RerankResults:
"""Rerank the documents returning the most relevant top_n in order for each of the queries.
Args:
@@ -690,6 +695,7 @@ def get_text(doc):
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
convert_to_tensor=True,
**kwargs,
)
doc_embeddings = normalize(doc_embeddings.to(self.model.device))

@@ -698,6 +704,7 @@ def get_text(doc):
truncate_input_tokens=truncate_input_tokens,
return_token_count=True,
convert_to_tensor=True,
**kwargs,
)
query_embeddings = normalize(query_embeddings.to(self.model.device))

@@ -976,6 +983,7 @@ def _tokenize_plus(
truncate_input_tokens: int,
texts: List[str],
implicit_truncation_errors: bool = True,
**kwargs,
) -> TruncatedTokensTuple:
"""Tokenize with support for truncation handling and returning the token count
Args:
@@ -1015,21 +1023,21 @@ def _tokenize_plus(
texts = [str(s).strip() for s in texts]

# Call tokenizer with the same truncation parameters every time
tokenized = self._get_tokenized(texts)
tokenized = self._get_tokenized(texts, **kwargs)

# Custom truncation and/or error raise if needed
truncation_needed = self._truncation_needed(tokenized, max_length, texts)
if truncation_needed and okay_to_truncate:
# Truncate texts in place
_truncate_texts(texts, tokenized, max_length, truncation_needed)
# Re-tokenize the truncated texts
tokenized = self._get_tokenized(texts)
tokenized = self._get_tokenized(texts, **kwargs)
truncation_needed = [] # truncation accomplished

input_token_count = sum_token_count(tokenized)
return TruncatedTokensTuple(tokenized, input_token_count, truncation_needed)

def _get_tokenized(self, texts):
def _get_tokenized(self, texts, **kwargs):
"""Intentionally always call tokenizer the same way to avoid thread issues.
Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID).
@@ -1039,6 +1047,8 @@ def _get_tokenized(self, texts):
the fast tokenizer with different truncation settings.
"""

padding_strategy = kwargs.pop("padding_strategy", True)

# Keep copies of tokenizer per thread (in each wrapped model instance)
thread_id = threading.get_ident()
tokenizer = (
@@ -1056,7 +1066,7 @@ def _get_tokenized(self, texts):
return_length=False,
return_tensors="pt",
truncation=True, # DO NOT CHANGE else "Already borrowed" errors
padding=True, # DO NOT CHANGE else "Already borrowed" errors
padding=padding_strategy, # DO NOT CHANGE else "Already borrowed" errors
max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors
)

@@ -1077,6 +1087,7 @@ def encode(
return_token_count: bool = False,
implicit_truncation_errors: bool = True,
autocast: bool = False,
**kwargs,
) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
@@ -1161,6 +1172,7 @@ def encode(
truncate_input_tokens,
sentences_batch,
implicit_truncation_errors=implicit_truncation_errors,
**kwargs,
)

if truncation_needed: # truncation was needed and was not done/not allowed
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ dependencies = [
"scipy>=1.8.1",
"sentence-transformers>=3.0.0,<3.1.0",
"tokenizers>=0.13.3",
"torch>=2.3.1,<2.5.0",
"torch>=2.3.1,<2.6.0",
"tqdm>=4.65.0",
"transformers>=4.38.0,<4.44.0",
"peft==0.6.0",
70 changes: 70 additions & 0 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,14 @@

# Standard
from typing import List, Tuple
from unittest.mock import patch
import os
import tempfile

# Third Party
from pytest import approx
from torch.backends import mps
from transformers import BatchEncoding
import numpy as np
import pytest
import torch
@@ -1143,3 +1145,71 @@ def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens):
assert not np.allclose(
separate_vectors[1], separate_vectors[2], rtol=1e-05, atol=1e-08
)


def custom_sum_token_count(
tokenized: BatchEncoding,
) -> int:
"""Returns total number of tokens regardless of attention_mask value"""

token_count = 0
for encoding in tokenized.encodings:
token_count += len(encoding.attention_mask)

return token_count


@pytest.mark.parametrize("padding_strategy", [True, "max_length"])
def test_pad_to_max_length(padding_strategy, loaded_model):
"""Tests for tokenization kwargs max_length will modify tokenizer"""
model_max = loaded_model.model.max_seq_length

tokenizer_kwargs = {"padding_strategy": padding_strategy}
max_seq = "x " * (model_max - 2) # Subtract 2 for begin/end tokens
max_seq_minus_one = "x " * (
model_max - 3
) # 1 token length shorter than max_seq_length
single = "x "

if padding_strategy is True:
normal_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq_minus_one],
return_token_count=True,
**tokenizer_kwargs,
)
assert np.all(normal_result.embedding == padded_result.embedding)
elif padding_strategy == "max_length":
with patch(
"caikit_nlp.modules.text_embedding.embedding.sum_token_count"
) as mock_sum_token_count:
mock_sum_token_count.side_effect = custom_sum_token_count
normal_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq_minus_one],
return_token_count=True,
**tokenizer_kwargs,
)
assert normal_result.input_token_count != padded_result.input_token_count
assert padded_result.input_token_count == model_max
assert not np.all(normal_result.embedding == padded_result.embedding)
normal_result = loaded_model._encode_with_retry(
[max_seq], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq], return_token_count=True, **tokenizer_kwargs
)
assert normal_result.input_token_count == padded_result.input_token_count
assert np.all(normal_result.embedding == padded_result.embedding)
normal_result = loaded_model._encode_with_retry(
[single], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[single], return_token_count=True, **tokenizer_kwargs
)
assert normal_result.input_token_count != padded_result.input_token_count
assert not np.all(normal_result.embedding == padded_result.embedding)