Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.5.6
Choose a base ref
...
head repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.5.7
Choose a head ref
  • 8 commits
  • 3 files changed
  • 2 contributors

Commits on Sep 11, 2024

  1. CrossEncoderModule with rerank API

    This module is closely related to EmbeddingModule.
    
    Cross-encoder models use Q and A pairs and are trained return a relevance score for rank().
    The existing rerank APIs in EmbeddingModule had to encode Q and A
    separately and use cosine similarity as a score. So the API is the same, but the results
    are supposed to be better (and slower).
    
    Cross-encoder models do not support returning embedding vectors or sentence-similarity.
    
    Support for the existing tokenization and model_info endpoints was also added.
    
    Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
    markstur committed Sep 11, 2024
    Copy the full SHA
    5b0989f View commit details

Commits on Sep 12, 2024

  1. Cross-encoder improvements from code review

    * mostly removing unnecessary code
    * some better clarity
    
    Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
    markstur committed Sep 12, 2024
    Copy the full SHA
    7146ffe View commit details
  2. Cross-encoder docstring fix

    * The already borrowed errors are fixed with tokenizers per thread,
      so there were some misleading comments about not changing params
      for truncation (which we do for cross-encoder truncation).
    
    Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
    markstur committed Sep 12, 2024
    Copy the full SHA
    ac46993 View commit details
  3. Cross-Encoder use configurable batch size.

    Default is 32.
    Can override with embedding batch_size in config or EMBEDDING_BATCH_SIZE env var.
    
    Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
    markstur committed Sep 12, 2024
    Copy the full SHA
    4e9c5aa View commit details
  4. Cross-encoder: Move truncation check and add tests

    * Moved the truncation check to a place that can determine
      the proper index for the error message (with batching).
    
    * Added test to validate some results after truncation.
      This is with a tiny model, but works for sanity.
    
    Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
    markstur committed Sep 12, 2024
    Copy the full SHA
    211668a View commit details
  5. Cross-encoder: fix truncation test

    The part that really tests that a token is truncated was wrong.
    
    * It was backwards and passing because the scores are sorted by rank
    * Using the index to get scores in the order of the inputs
    * Now correctly xx != xy but xy == xyz (truncated z)
    
    Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
    markstur committed Sep 12, 2024
    Copy the full SHA
    2cb6183 View commit details
  6. Cross-encoder: remove some unused and tidy up some comments

    Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
    markstur committed Sep 12, 2024
    Copy the full SHA
    8fa67cc View commit details
  7. Merge pull request #389 from markstur/crossencoder

    CrossEncoderModule with rerank API
    evaline-ju authored Sep 12, 2024
    Copy the full SHA
    1695c3b View commit details
Showing with 1,225 additions and 0 deletions.
  1. +1 −0 caikit_nlp/modules/text_embedding/__init__.py
  2. +706 −0 caikit_nlp/modules/text_embedding/crossencoder.py
  3. +518 −0 tests/modules/text_embedding/test_crossencoder.py
1 change: 1 addition & 0 deletions caikit_nlp/modules/text_embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -29,4 +29,5 @@
"""

# Local
from .crossencoder import CrossEncoderModule
from .embedding import EmbeddingModule
706 changes: 706 additions & 0 deletions caikit_nlp/modules/text_embedding/crossencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,706 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from copy import deepcopy
from functools import partial
from typing import Any, Dict, List, NamedTuple, Optional, Union
import os
import threading

# Third Party
from sentence_transformers import CrossEncoder
from torch.utils.data import DataLoader
import numpy as np
import torch

# First Party
from caikit import get_config
from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.core.data_model.json_dict import JsonDict
from caikit.core.exceptions import error_handler
from caikit.interfaces.nlp.data_model import (
RerankResult,
RerankResults,
RerankScore,
RerankScores,
Token,
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import RerankTask, RerankTasks, TokenizationTask
import alog

# Local
from caikit_nlp.modules.text_embedding.utils import env_val_to_bool

logger = alog.use_channel("CROSS_ENCODER")
error = error_handler.get(logger)


class RerankResultTuple(NamedTuple):
"""Output of modified rank()"""

scores: list
input_token_count: int


class PredictResultTuple(NamedTuple):
"""Output of modified predict()"""

scores: np.ndarray
input_token_count: int


# pylint: disable=too-many-lines disable=duplicate-code
@module(
"1673f8f2-726f-48cb-93a1-540c81f0f3c9",
"CrossEncoderModule",
"0.0.1",
tasks=[
RerankTask,
RerankTasks,
TokenizationTask,
],
)
class CrossEncoderModule(ModuleBase):

_ARTIFACTS_PATH_KEY = "artifacts_path"
_ARTIFACTS_PATH_DEFAULT = "artifacts"

def __init__(
self,
model: "CrossEncoderWithTruncate",
):
super().__init__()
self.model = model

# model_max_length attribute availability might(?) vary by model/tokenizer
self.model_max_length = getattr(model.tokenizer, "model_max_length", None)

# Read config/env settings that are needed at run_* time.
embedding_cfg = get_config().get("embedding", {})

self.batch_size = embedding_cfg.get("batch_size", 32)
error.type_check("<NLP83501588E>", int, EMBEDDING_BATCH_SIZE=self.batch_size)
if self.batch_size <= 0:
self.batch_size = 32 # 0 or negative, use the default.

@classmethod
def load(
cls, model_path: Union[str, ModuleConfig], *args, **kwargs
) -> "CrossEncoderModule":
"""Load model
Args:
model_path (Union[str, ModuleConfig]): Path to saved model or
in-memory ModuleConfig
Returns:
CrossEncoderModule
Instance of this class built from the model.
"""

config = ModuleConfig.load(model_path)
error.dir_check("<NLP13823362E>", config.model_path)

artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY)
error.value_check(
"<NLP20896115E>",
artifacts_path,
f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'",
)

artifacts_path = os.path.abspath(
os.path.join(config.model_path, artifacts_path)
)
error.dir_check("<NLP33193321E>", artifacts_path)

# Read config/env settings that are needed at load time.
embedding_cfg = get_config().get("embedding", {})

trust_remote_code = env_val_to_bool(embedding_cfg.get("trust_remote_code"))

model = CrossEncoderWithTruncate(
model_name=artifacts_path,
trust_remote_code=trust_remote_code,
)
model.model.eval()
model.model.to(model._target_device)

return cls(model)

@property
def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument
"""Helper property to return public metadata about a specific Model. This
function is separate from `metadata` as that contains the entire ModelConfig
which might not want to be shared/exposed.
Returns:
Dict[str, str]: A dictionary of this model's public metadata
"""

return (
{"max_seq_length": cls.model_max_length}
if cls.model_max_length is not None
else {}
)

@TokenizationTask.taskmethod()
def run_tokenizer(
self,
text: str,
) -> TokenizationResults:
"""Run tokenization task against the model
Args:
text: str
Text to tokenize
Returns:
TokenizationResults
The token count
"""
result = self.model.get_tokenized([text], return_offsets_mapping=True)

mapping = [
interv for interv in result.offset_mapping[0] if (interv[1] - interv[0]) > 0
]
tokens = [Token(start=i[0], end=i[1], text=text[i[0] : i[1]]) for i in mapping]

return TokenizationResults(token_count=len(result.input_ids[0]), results=tokens)

@RerankTask.taskmethod()
def run_rerank_query(
self,
query: str,
documents: List[JsonDict],
top_n: Optional[int] = None,
truncate_input_tokens: Optional[int] = 0,
return_documents: bool = True,
return_query: bool = True,
return_text: bool = True,
) -> RerankResult:
"""Rerank the documents returning the most relevant top_n in order for this query.
Args:
query: str
Query is the source string to be compared to the text of the documents.
documents: List[JsonDict]
Each document is a dict. The text value is used for comparison to the query.
If there is no text key, then _text is used and finally default is "".
top_n: Optional[int]
Results for the top n most relevant documents will be returned.
If top_n is not provided or (not > 0), then all are returned.
truncate_input_tokens: int
Truncation length for input tokens.
If less than zero, this is disabled (returns texts without processing).
If zero or greater than the model's maximum, then this is a test
to see if truncation is needed. If needed, an exception is thrown.
Otherwise, we take this usable truncation limit to truncate the tokens and then
decode them to return truncated strings that can be used with this model.
return_documents: bool
Default True
Setting to False will disable returning of the input document (index is returned).
return_query: bool
Default True
Setting to False will disable returning of the query (results are in query order)
return_text: bool
Default True
Setting to False will disable returning of document text string that was used.
Returns:
RerankResult
Returns the (top_n) scores in relevance order (most relevant first).
The results always include a score and index which may be used to find the document
in the original documents list. Optionally, the results also contain the entire
document with its score (for use in chaining) and for convenience the query and
text used for comparison may be returned.
"""

error.type_check(
"<NLP61983803E>",
int,
allow_none=True,
top_n=top_n,
)

error.type_check(
"<NLP05323654E>",
str,
query=query,
)

results = self.run_rerank_queries(
queries=[query],
documents=documents,
top_n=top_n,
truncate_input_tokens=truncate_input_tokens,
return_documents=return_documents,
return_queries=return_query,
return_text=return_text,
)

return RerankResult(
result=results.results[0],
producer_id=self.PRODUCER_ID,
input_token_count=results.input_token_count,
)

@RerankTasks.taskmethod()
def run_rerank_queries(
self,
queries: List[str],
documents: List[JsonDict],
top_n: Optional[int] = None,
truncate_input_tokens: Optional[int] = 0,
return_documents: bool = True,
return_queries: bool = True,
return_text: bool = True,
) -> RerankResults:
"""Rerank the documents returning the most relevant top_n in order for each of the queries.
Args:
queries: List[str]
Each of the queries will be compared to the text of each of the documents.
documents: List[JsonDict]
Each document is a dict. The text value is used for comparison to the query.
If there is no text key, then _text is used and finally default is "".
top_n: Optional[int]
Results for the top n most relevant documents will be returned.
If top_n is not provided or (not > 0), then all are returned.
truncate_input_tokens: int
Truncation length for input tokens.
If less than zero, this is disabled (returns texts without processing).
If zero or greater than the model's maximum, then this is a test
to see if truncation is needed. If needed, an exception is thrown.
Otherwise, we take this usable truncation limit to truncate the tokens and then
decode them to return truncated strings that can be used with this model.
return_documents: bool
Default True
Setting to False will disable returning of the input document (index is returned).
return_queries: bool
Default True
Setting to False will disable returning of the query (results are in query order)
return_text: bool
Default True
Setting to False will disable returning of document text string that was used.
Returns:
RerankResults
For each query in queries (in the original order)...
Returns the (top_n) scores in relevance order (most relevant first).
The results always include a score and index which may be used to find the document
in the original documents list. Optionally, the results also contain the entire
document with its score (for use in chaining) and for convenience the query and
text used for comparison may be returned.
"""

error.type_check(
"<NLP09038249E>",
list,
queries=queries,
documents=documents,
)

error.value_check(
"<NLP24788937E>",
queries and documents,
"Cannot rerank without a query and at least one document",
)

if top_n is None or top_n < 1:
top_n = len(documents)

# Using input document dicts so get "text" else "_text" else default to ""
def get_text(doc):
return doc.get("text") or doc.get("_text", "")

doc_texts = [get_text(doc) for doc in documents]

input_token_count = 0
results = []
for query in queries:
scores, token_count = self.model.rank(
query=query,
documents=doc_texts,
top_k=top_n,
return_documents=False,
batch_size=self.batch_size,
convert_to_numpy=True,
truncate_input_tokens=truncate_input_tokens,
)
results.append(scores)
input_token_count += token_count

# Fixup result dicts
for r in results:
for x in r:
x["score"] = float(x["score"].item())
# Renaming corpus_id to index
corpus_id = x.pop("corpus_id")
x["index"] = corpus_id
# Optionally adding the original document and/or just the text that was used
if return_documents:
x["document"] = documents[corpus_id]
if return_text:
x["text"] = get_text(documents[corpus_id])

def add_query(q):
return queries[q] if return_queries else None

results = [
RerankScores(
query=add_query(q),
scores=[RerankScore(**x) for x in r],
)
for q, r in enumerate(results)
]

return RerankResults(
results=results,
producer_id=self.PRODUCER_ID,
input_token_count=input_token_count,
)

@classmethod
def bootstrap(cls, *args, **kwargs) -> "CrossEncoderModule":
"""Bootstrap a cross-encoder model
Args:
args/kwargs are passed to CrossEncoder
"""

# Add ability to bootstrap with trust_remote_code using env var.
if "trust_remote_code" not in kwargs:
# Read config/env settings that are needed at bootstrap time.
embedding_cfg = get_config().get("embedding", {})
kwargs["trust_remote_code"] = env_val_to_bool(
embedding_cfg.get("trust_remote_code")
)

return cls(model=CrossEncoder(*args, **kwargs))

def save(self, model_path: str, *args, **kwargs):
"""Save model using config in model_path
Args:
model_path: str
Path to model config
"""

error.type_check("<NLP82314992E>", str, model_path=model_path)
model_config_path = model_path.strip()
error.value_check(
"<NLP40145207E>",
model_config_path,
f"model_path '{model_config_path}' is invalid",
)

model_config_path = os.path.abspath(
model_config_path.strip()
) # No leading/trailing spaces sneaky weirdness

# Only allow new dirs because there are not enough controls to safely update in-place
os.makedirs(model_config_path, exist_ok=False)

saver = ModuleSaver(
module=self,
model_path=model_config_path,
)
artifacts_path = self._ARTIFACTS_PATH_DEFAULT
saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path})

# Save the model
self.model.save(os.path.join(model_config_path, artifacts_path))

# Save the config
ModuleConfig(saver.config).save(model_config_path)


class CrossEncoderWithTruncate(CrossEncoder):
def __init__(
self,
model_name: str,
num_labels: int = None,
max_length: int = None,
device: str = None,
tokenizer_args: Dict = None,
automodel_args: Dict = None,
trust_remote_code: bool = False,
revision: Optional[str] = None,
local_files_only: bool = False,
default_activation_function=None,
classifier_dropout: float = None,
):
super().__init__(
model_name,
num_labels,
max_length,
device,
tokenizer_args,
automodel_args,
trust_remote_code,
revision,
local_files_only,
default_activation_function,
classifier_dropout,
)
self.tokenizers = {}

def _get_tokenizer_per_thread(self):
"""Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID)."""

# Keep copies of tokenizer per thread (in each wrapped model instance)
thread_id = threading.get_ident()
tokenizer = (
self.tokenizers[thread_id]
if thread_id in self.tokenizers
else self.tokenizers.setdefault(thread_id, deepcopy(self.tokenizer))
)

return tokenizer

def get_tokenized(self, texts, **kwargs):
"""Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID)"""

max_len = kwargs.get("truncate_input_tokens", self.tokenizer.model_max_length)
max_len = min(max_len, self.tokenizer.model_max_length)
if max_len <= 0:
max_len = None # Use the default
elif max_len < 5:
# 1, 2, 3 don't really work (4 might but...)
# Bare minimum is [CLS] token [SEP] token [SEP]
max_len = 5

tokenizer = self._get_tokenizer_per_thread()
tokenized = tokenizer(
*texts,
return_attention_mask=True, # Used for determining token count
return_token_type_ids=False, # Needed for cross-encoders
return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches
return_offsets_mapping=True, # Used for truncation needed error
return_length=False,
return_tensors="pt",
truncation=True,
padding=True,
max_length=max_len,
)
return tokenized

def _truncation_needed(self, encoding, texts):
"""Check for truncation needed to meet max_length token limit
Returns:
True if was truncated, False otherwise
"""

input_tokens = sum(encoding.attention_mask)
if input_tokens < self.tokenizer.model_max_length:
return False

# At model limit, including start/end...
# This may or may not have already been truncated at the model limit.
# Check the strlen and last offset.
# We need to know this, for default implementation of throwing error.
offsets = encoding.offsets
type_ids = encoding.type_ids
attn_mask = encoding.attention_mask

# Find the last offset by counting attn masks
# and keeping the last non-zero offset end.
index = 0 # index of longest
type_id = 0 # track type_id of longest

for n, attn in enumerate(attn_mask):
if attn == 1:
end = offsets[n][1] # Index to end character from offset
if end > index: # Grab last non-zero end index (ensures increasing too)
type_id = type_ids[n]
index = end
end_index = index # longest last char index
end_typeid = type_id # longest type (query or text)

# If last token offset is before the last char, then it was truncated
return end_index < len(texts[end_typeid].strip())

def smart_batching_collate_text_only(
self, batch, truncate_input_tokens: Optional[int] = 0
):
texts = [[] for _ in range(len(batch[0]))]

for example in batch:
for idx, text in enumerate(example):
texts[idx].append(text.strip())

tokenized = self.get_tokenized(
texts, truncate_input_tokens=truncate_input_tokens
)

return tokenized

@staticmethod
def raise_truncation_error(max_len, truncation_needed_indexes):

indexes = f"{', '.join(str(i) for i in truncation_needed_indexes)}."
index_hint = (
" for text at "
f"{'index' if len(truncation_needed_indexes) == 1 else 'indexes'}: {indexes}"
)
error.log_raise(
"<NLP08391926E>",
ValueError(
f"Token sequence length (+3 for separators) exceeds the "
f"maximum sequence length for this model ({max_len})"
f"{index_hint}"
),
)

def predict(
self,
sentences: List[List[str]],
batch_size: int = 32,
show_progress_bar: bool = None,
num_workers: int = 0,
activation_fct=None,
apply_softmax=False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
truncate_input_tokens: Optional[int] = 0,
) -> PredictResultTuple:
"""
Performs predictions with the CrossEncoder on the given sentence pairs.
Args:
See overriden method for details.
truncate_input_tokens: Optional[int] = 0 added for truncation
Returns:
Uses PredictResultTuple to add input_token_count
"""
input_was_string = False
if isinstance(
sentences[0], str
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True

collate_fn = partial(
self.smart_batching_collate_text_only,
truncate_input_tokens=truncate_input_tokens,
)
iterator = DataLoader(
sentences,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=num_workers,
shuffle=False,
)

if activation_fct is None:
activation_fct = self.default_activation_function

max_len = self.tokenizer.model_max_length
pred_scores = []
input_token_count = 0
row = -1
truncation_needed_indexes = []
with torch.no_grad():
for features in iterator:
# Sum the length of all encodings for all samples
for encoding in features.encodings:
row += 1

# for mask in encoding.attention_mask:
input_token_count += sum(encoding.attention_mask)

if truncate_input_tokens == 0 or truncate_input_tokens > max_len:
# default (for zero or over max) is to error on truncation
if self._truncation_needed(encoding, sentences[row]):
truncation_needed_indexes.append(row)

if truncation_needed_indexes:
self.raise_truncation_error(max_len, truncation_needed_indexes)

# We cannot send offset_mapping to the model with features,
# but we needed offset_mapping for other uses.
if "offset_mapping" in features:
del features["offset_mapping"]

for name in features:
features[name] = features[name].to(self._target_device)

model_predictions = self.model(**features, return_dict=True)
logits = activation_fct(model_predictions.logits)

if apply_softmax and len(logits[0]) > 1:
logits = torch.nn.functional.softmax(logits, dim=1)
pred_scores.extend(logits)

if self.config.num_labels == 1:
pred_scores = [score[0] for score in pred_scores]

if convert_to_tensor:
pred_scores = torch.stack(pred_scores)
elif convert_to_numpy:
pred_scores = np.asarray(
[score.cpu().detach().float().item() for score in pred_scores]
)

if input_was_string:
pred_scores = pred_scores[0]

return PredictResultTuple(pred_scores, input_token_count)

def rank(
self,
query: str,
documents: List[str],
top_k: Optional[int] = None,
return_documents: bool = False,
batch_size: int = 32,
show_progress_bar: bool = None,
num_workers: int = 0,
activation_fct=None,
apply_softmax=False,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
truncate_input_tokens: Optional[int] = 0,
) -> RerankResultTuple:
"""
Performs ranking with the CrossEncoder on the given query and documents.
Returns a sorted list with the document indices and scores.
Args:
See overridden method for argument description.
truncate_input_tokens (int, optional): Added to support truncation.
Returns:
RerankResultTuple: Adds input_token_count to result
"""
query_doc_pairs = [[query, doc] for doc in documents]
scores, input_token_count = self.predict(
query_doc_pairs,
batch_size=batch_size,
show_progress_bar=show_progress_bar,
num_workers=num_workers,
activation_fct=activation_fct,
apply_softmax=apply_softmax,
convert_to_numpy=convert_to_numpy,
convert_to_tensor=convert_to_tensor,
truncate_input_tokens=truncate_input_tokens,
)
results = []
for i, score in enumerate(scores):
if return_documents:
results.append({"corpus_id": i, "score": score, "text": documents[i]})
else:
results.append({"corpus_id": i, "score": score})

results = sorted(results, key=lambda x: x["score"], reverse=True)
return RerankResultTuple(results[:top_k], input_token_count)
518 changes: 518 additions & 0 deletions tests/modules/text_embedding/test_crossencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,518 @@
"""Tests for CrossEncoderModule"""

# Standard
from typing import List
import os
import tempfile

# Third Party
from pytest import approx
import numpy as np
import pytest

# First Party
from caikit.interfaces.nlp.data_model import (
RerankResult,
RerankResults,
RerankScore,
RerankScores,
Token,
TokenizationResults,
)

# Local
from caikit_nlp.modules.text_embedding import CrossEncoderModule
from tests.fixtures import SEQ_CLASS_MODEL

## Setup ########################################################################

# Bootstrapped sequence classification model for reuse across tests
# .bootstrap is tested separately in the first test
# This model needs a tweak (num_labels = 1) to behave like a cross-encoder.
BOOTSTRAPPED_MODEL = CrossEncoderModule.bootstrap(SEQ_CLASS_MODEL)

# Token counts:
# All expected token counts were calculated with reference to the
# `BertForSequenceClassification` model. Each model's tokenizer behaves differently
# which can lead to the expected token counts being invalid.

INPUT = "The quick brown fox jumps over the lazy dog."
INPUT_TOKEN_COUNT = 36 + 2 # [CLS] Thequickbrownfoxjumpsoverthelazydog. [SEP]

QUERY = "What is foo bar?"
QUERY_TOKEN_COUNT = 13 + 2 # [CLS] Whatisfoobar? [SEP]

QUERIES: List[str] = [
"Who is foo?",
"Where is the bar?",
]
QUERIES_TOKEN_COUNT = (9 + 2) + (
14 + 2
) # [CLS] Whoisfoo? [SEP], [CLS] Whereisthebar? [SEP]

DOCS = [
{
"text": "foo",
"title": "title or whatever",
"str_test": "test string",
"int_test": 1,
"float_test": 1.234,
"score": 99999,
"nested_dict_test": {"deep1": 1, "deep string": "just testing"},
},
{
"_text": "bar",
"title": "title 2",
},
{
"text": "foo and bar",
},
{
"_text": "Where is the bar",
"another": "something else",
},
]

# The `text` and `_text` keys are extracted from DOCS as input to the tokenizer
# [CLS] foo [SEP], [CLS] bar [SEP], [CLS] fooandbar [SEP], [CLS] Whereisthebar [SEP]
DOCS_TOKEN_COUNT = (3 + 2) + (3 + 2) + (9 + 2) + (13 + 2)

# [CLS] query [SEP] text [SEP] for each text in DOCS.
# Subtract one from QUERY_TOKEN_COUNT to avoid counting
# an extra [SEP].
QUERY_DOCS_TOKENS = (QUERY_TOKEN_COUNT - 1) * len(DOCS) + DOCS_TOKEN_COUNT

# [CLS] query [SEP] text [SEP] for each QUERY for each text in DOCS.
# Subtract len(QUERIES) from QUERY_TOKEN_COUNT to avoid counting
# an extra [SEP].
QUERIES_DOCS_TOKENS = (QUERIES_TOKEN_COUNT - len(QUERIES)) * len(DOCS) + (
DOCS_TOKEN_COUNT * len(QUERIES)
)


## Tests ########################################################################


@pytest.fixture(scope="module", name="loaded_model")
def fixture_loaded_model(tmp_path_factory):
models_dir = tmp_path_factory.mktemp("models")
model_path = str(models_dir / "model_id")
BOOTSTRAPPED_MODEL.save(model_path)
model = CrossEncoderModule.load(model_path)
# Make our tiny test model act more like a cross-encoder model with 1 label
model.model.config.num_labels = 1
return model


def _assert_is_expected_scores(rerank_scores):
# Just testing a few values for readability
assert isinstance(rerank_scores, RerankScores)
scores = rerank_scores.scores
assert approx(scores[0].score) == -0.015608355402946472
assert approx(scores[1].score) == -0.015612606890499592
assert approx(scores[2].score) == -0.015648163855075836


def _assert_is_expected_rerank_result(actual):
assert isinstance(actual, RerankResult)
scores = actual.result
_assert_is_expected_scores(scores)


def _assert_is_expected_rerank_results(actual):
assert isinstance(actual, RerankResults)


def test_bootstrap():
assert isinstance(
CrossEncoderModule.bootstrap(SEQ_CLASS_MODEL), CrossEncoderModule
), "bootstrap error"


def _assert_valid_scores(scores):
for score in scores:
assert isinstance(score, RerankScore)
assert isinstance(score.score, float)
assert isinstance(score.index, int)
assert isinstance(score.text, str)

document = score.document
assert isinstance(document, dict)
assert document == DOCS[score.index]

# Test document key named score (None or 9999) is independent of the result score
assert score.score != document.get(
"score"
), "unexpected passthru score same as result score"


def test_bootstrap_model(loaded_model):
assert isinstance(BOOTSTRAPPED_MODEL, CrossEncoderModule), "bootstrap model type"
assert (
BOOTSTRAPPED_MODEL.model.__class__.__name__ == "CrossEncoder"
), "bootstrap model class name"
# worth noting that bootstrap does not wrap, but load does
assert (
loaded_model.model.__class__.__name__ == "CrossEncoderWithTruncate"
), "loaded model class name"


def test_save_load_and_run():
"""Check if we can load and run a saved model successfully"""
model_id = "model_id"
with tempfile.TemporaryDirectory(suffix="-xe-1st") as model_dir:
model_path = os.path.join(model_dir, model_id)
BOOTSTRAPPED_MODEL.save(model_path)
new_model = CrossEncoderModule.load(model_path)

assert isinstance(new_model, CrossEncoderModule), "save and load error"
assert new_model != BOOTSTRAPPED_MODEL, "did not load a new model"

# Make our tiny test model act more like a cross-encoder model
new_model.model.config.num_labels = 1

# Use run_rerank_query just to make sure this new model is usable
top_n = 3
rerank_result = new_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n)

assert isinstance(rerank_result, RerankResult)

result = rerank_result.result
assert isinstance(result, RerankScores)
scores = result.scores
assert isinstance(scores, list)
assert len(scores) == top_n

_assert_valid_scores(scores)

assert rerank_result.input_token_count == QUERY_DOCS_TOKENS
_assert_is_expected_rerank_result(rerank_result)
rerank_results = new_model.run_rerank_queries(
queries=QUERIES, documents=DOCS, top_n=1
)
_assert_is_expected_rerank_results(rerank_results)


def test_public_model_info():
"""Check if we can get model info successfully"""
model_id = "model_id"
with tempfile.TemporaryDirectory(suffix="-xe-mi") as model_dir:
model_path = os.path.join(model_dir, model_id)
BOOTSTRAPPED_MODEL.save(model_path)
new_model = CrossEncoderModule.load(model_path)

result = new_model.public_model_info
assert "max_seq_length" in result
assert type(result["max_seq_length"]) is int
assert new_model.model.tokenizer.model_max_length == 512
assert result["max_seq_length"] == new_model.model.tokenizer.model_max_length

# We only have the following key(s) in model_info right now for cross-encoders...
assert list(result.keys()) == ["max_seq_length"]


def test_run_tokenization(loaded_model):
res = loaded_model.run_tokenizer(text=INPUT)
assert isinstance(res, TokenizationResults)
assert isinstance(res.results, list)
assert isinstance(res.results[0], Token)
assert res.token_count == INPUT_TOKEN_COUNT


@pytest.mark.parametrize(
"query,docs,top_n",
[
(["test list"], DOCS, None),
(None, DOCS, 1234),
(False, DOCS, 1234),
(QUERY, {"testdict": "not list"}, 1234),
(QUERY, DOCS, "topN string is not an integer or None"),
],
)
def test_run_rerank_query_type_error(query, docs, top_n, loaded_model):
"""test for type checks matching task/run signature"""
match = r"type check failed"
with pytest.raises(TypeError, match=match):
loaded_model.run_rerank_query(query=query, documents=docs, top_n=top_n)
pytest.fail("Should not reach here.")


@pytest.mark.parametrize("top_n", [1, 99, None])
def test_run_rerank_query_no_type_error(loaded_model, top_n):
"""no type error with list of string queries and list of dict documents"""
res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n)

# [CLS] query [SEP] text [SEP] for each text in DOCS.
# Subtract one from QUERY_TOKEN_COUNT to avoid counting
# an extra [SEP].
q_tokens = (QUERY_TOKEN_COUNT - 1) * len(DOCS)
expected = q_tokens + DOCS_TOKEN_COUNT
assert res.input_token_count == expected


@pytest.mark.parametrize(
"top_n, expected",
[
(1, 1),
(2, 2),
(None, len(DOCS)),
(-1, len(DOCS)),
(0, len(DOCS)),
(9999, len(DOCS)),
],
)
def test_run_rerank_query_top_n(top_n, expected, loaded_model):
res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n)
assert isinstance(res, RerankResult)
assert len(res.result.scores) == expected
assert res.input_token_count == QUERY_DOCS_TOKENS


def test_run_rerank_query_no_query(loaded_model):
with pytest.raises(TypeError):
loaded_model.run_rerank_query(query=None, documents=DOCS, top_n=99)


def test_run_rerank_query_zero_docs(loaded_model):
"""No empty doc list therefore result is zero result scores"""
with pytest.raises(ValueError):
loaded_model.run_rerank_query(query=QUERY, documents=[], top_n=99)


def test_run_rerank_query(loaded_model):
res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS)
assert isinstance(res, RerankResult)

scores = res.result.scores
assert isinstance(scores, list)
assert len(scores) == len(DOCS)

_assert_valid_scores(scores)
assert res.input_token_count == QUERY_DOCS_TOKENS


@pytest.mark.parametrize(
"queries,docs", [("test string", DOCS), (QUERIES, {"testdict": "not list"})]
)
def test_run_rerank_queries_type_error(queries, docs, loaded_model):
"""type error check ensures params are lists and not just 1 string or just one doc (for example)"""
with pytest.raises(TypeError):
loaded_model.run_rerank_queries(queries=queries, documents=docs)
pytest.fail("Should not reach here.")


def test_run_rerank_queries_no_type_error(loaded_model):
"""no type error with list of string queries and list of dict documents"""
res = loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99)

assert res.input_token_count == QUERIES_DOCS_TOKENS


@pytest.mark.parametrize(
"top_n, expected",
[
(1, 1),
(2, 2),
(None, len(DOCS)),
(-1, len(DOCS)),
(0, len(DOCS)),
(9999, len(DOCS)),
],
)
def test_run_rerank_queries_top_n(top_n, expected, loaded_model):
"""no type error with list of string queries and list of dict documents"""
res = loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=top_n)
assert isinstance(res, RerankResults)
assert len(res.results) == len(QUERIES)
for result in res.results:
assert len(result.scores) == expected
assert res.input_token_count == QUERIES_DOCS_TOKENS


@pytest.mark.parametrize(
"queries, docs",
[
([], DOCS),
(QUERIES, []),
([], []),
],
ids=["no queries", "no docs", "no queries and no docs"],
)
def test_run_rerank_queries_no_queries_or_no_docs(queries, docs, loaded_model):
"""No queries and/or no docs therefore result is zero results"""

with pytest.raises(ValueError):
loaded_model.run_rerank_queries(queries=queries, documents=docs, top_n=9)


def test_run_rerank_queries(loaded_model):
top_n = 2
rerank_result = loaded_model.run_rerank_queries(
queries=QUERIES, documents=DOCS, top_n=top_n
)
assert isinstance(rerank_result, RerankResults)

results = rerank_result.results
assert isinstance(results, list)
assert len(results) == 2 == len(QUERIES) # 2 queries yields 2 result(s)

for result in results:
assert isinstance(result, RerankScores)
scores = result.scores
assert isinstance(scores, list)
assert len(scores) == top_n
_assert_valid_scores(scores)

assert rerank_result.input_token_count == QUERIES_DOCS_TOKENS


@pytest.mark.parametrize("truncate_input_tokens", [-1, 512])
def test_truncate_input_tokens_default(truncate_input_tokens, loaded_model):
"""Test truncation using model max.
-1 means let the model truncate at its model max
512 is more explicitly the same thing (this model's max)
"""
model_max = loaded_model.model.tokenizer.model_max_length

too_long = "x " * (model_max - 3) # 3 for tokens (no room for a query token)
just_barely = "x " * (model_max - 4) # 3 for tokens plus room for a query token
queries = ["x"]
docs = [{"text": t} for t in ["x", too_long, just_barely, too_long, just_barely]]

# Just testing for no errors raised for now
_res = loaded_model.run_rerank_queries(
queries=queries, documents=docs, truncate_input_tokens=truncate_input_tokens
)


@pytest.mark.parametrize("truncate_input_tokens", [0, 513])
def test_truncate_input_tokens_errors(truncate_input_tokens, loaded_model):
"""Test that we get truncation errors.
0 (the default) means we return errors when truncation would happen.
513+ (any number above the max) is treated the same way.
"""
model_max = loaded_model.model.tokenizer.model_max_length

too_long = "a " * (model_max - 3) # 3 for tokens (no room for a query token)
just_barely = "a " * (model_max - 4) # 3 for tokens plus room for a query token
queries = ["q"]

# Add 50 of these little ones to get past the first batch(es)
# to verify that this error message index is for the input
# position and not just an index into some internal batch.
docs = [{"text": "a"}] * 50
docs.extend([{"text": t} for t in [too_long, just_barely, too_long, just_barely]])

match1 = rf"exceeds the maximum sequence length for this model \({model_max}\) for text at indexes: 50, 52."
with pytest.raises(ValueError, match=match1):
loaded_model.run_rerank_queries(
queries=queries, documents=docs, truncate_input_tokens=truncate_input_tokens
)


@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 510, 511, 512])
def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model):
"""truncate_input_tokens prevents these endpoints from raising an error when too many tokens.
Test with -1 which lets the model do truncation instead of raising an error.
Test with 99 (< 512 -2) which causes our code to do the truncation instead of raising an error.
Test with 510 (512 -2) which causes our code to do the truncation instead of raising an error.
511 and 512 also behave like 510. The value is allowed, but begin/end tokens will take space.
"""

model_max = loaded_model.model.tokenizer.model_max_length

ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens
too_long = "x " * (model_max - 1) # This will go over

# reranker test both query and document text
loaded_model.run_rerank_query(
query=too_long,
documents=[{"text": ok}],
truncate_input_tokens=truncate_input_tokens,
)
loaded_model.run_rerank_query(
query=ok,
documents=[{"text": too_long}],
truncate_input_tokens=truncate_input_tokens,
)

loaded_model.run_rerank_queries(
queries=[too_long],
documents=[{"text": ok}],
truncate_input_tokens=truncate_input_tokens,
)
loaded_model.run_rerank_queries(
queries=[ok],
documents=[{"text": too_long}],
truncate_input_tokens=truncate_input_tokens,
)


@pytest.mark.parametrize(
"truncate_input_tokens", [1, 2, 3, 4, 5, 6, 99, 100, 101, 510, 511, 512, -1]
)
def test_truncation(truncate_input_tokens, loaded_model):
"""verify that results are as expected with truncation"""

max_len = loaded_model.model.tokenizer.model_max_length

if truncate_input_tokens is None or truncate_input_tokens < 0:
# For -1 we don't truncate, but model will
repeat = max_len
else:
repeat = min(
truncate_input_tokens, max_len
) # max_len is used when we need -4 for begin/"q"/sep/end

# Build a text like "x x x.. x " with room for one more token
repeat = repeat - 4 # room for separators and a single-token query
repeat = repeat - 1 # space for the final x or y token to show difference

base = ""
if repeat > 0:
base = "x " * repeat # A bunch of "x" tokens
x = base + "x" # One last "x" that will not get truncated
y = base + "y" # A different last character "y" not truncated
z = y + " z" # Add token "z" after "y". This should get truncated.

# Multiple queries to test query-loop vs queries
# Query for the significant added chars to affect score.
queries = ["y", "z"]
docs = [{"text": t} for t in [x, y, z]]
res = loaded_model.run_rerank_queries(
queries=queries,
documents=docs,
truncate_input_tokens=truncate_input_tokens,
)
queries_results = res.results

# Compare with results from individual embedding calls in a loop
query_results = []
for query in queries:
r = loaded_model.run_rerank_query(
query=query,
documents=docs,
truncate_input_tokens=truncate_input_tokens,
)
query_results.append(r.result)

assert len(queries_results) == len(
query_results
), "expected the same length results"

# compare the scores (queries call vs query call in a loop)
# order is the same
for i, r in enumerate(queries_results):
queries_scores = [x.score for x in r.scores]
query_scores = [x.score for x in query_results[i].scores]
assert np.array_equal(queries_scores, query_scores)

# To compare scores based on the inputs, we need to use the index too
indexed_query_scores = {s.index: s.score for s in query_results[i].scores}

# Make sure the x...xx, x...xy are not a match (we kept the significant last token)
assert indexed_query_scores[0] != indexed_query_scores[1]

# x...xy is the same as x...xyz because we truncated the z token -- it worked!
assert indexed_query_scores[1] == indexed_query_scores[2]