diff --git a/.azure/gpu-pipeline.yml b/.azure/gpu-pipeline.yml index dfa1d53b916..320b3000fc4 100644 --- a/.azure/gpu-pipeline.yml +++ b/.azure/gpu-pipeline.yml @@ -37,7 +37,7 @@ jobs: container: image: "$(docker-image)" - options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --name ci-container -v /usr/bin/docker:/tmp/docker:ro" + options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --shm-size=8g --name ci-container -v /usr/bin/docker:/tmp/docker:ro" workspace: clean: all diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 7fd14a61cdb..fb2611973f0 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -8,7 +8,7 @@ jobs: test-docs: runs-on: ubuntu-20.04 - timeout-minutes: 15 + timeout-minutes: 20 steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 diff --git a/CHANGELOG.md b/CHANGELOG.md index 89f43ff42be..df9b3ede696 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `normalize` argument to `Inception`, `FID`, `KID` metrics ([#1246](https://github.com/Lightning-AI/metrics/pull/1246)) +- Added `CLIPScore` to new multimodal package ([#1314](https://github.com/Lightning-AI/metrics/pull/1314)) + + ### Changed - Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259)) @@ -52,7 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- ## [0.10.3] - 2022-11-16 diff --git a/docs/source/index.rst b/docs/source/index.rst index c54b21d3d39..896fbfb69bd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -174,6 +174,14 @@ Or directly from conda image/* +.. toctree:: + :maxdepth: 2 + :name: multimodal + :caption: Multimodal + :glob: + + multimodal/* + .. toctree:: :maxdepth: 2 :name: nominal diff --git a/docs/source/links.rst b/docs/source/links.rst index b6aa3e8cb61..69d05991183 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -99,3 +99,5 @@ .. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf .. _Tschuprow's T: https://en.wikipedia.org/wiki/Tschuprow%27s_T .. _Pearson's Contingency Coefficient: https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/pearcont.htm +.. _CLIP score: https://arxiv.org/pdf/2104.08718.pdf +.. _Huggingface OpenAI: https://huggingface.co/openai diff --git a/docs/source/multimodal/clip_score.rst b/docs/source/multimodal/clip_score.rst new file mode 100644 index 00000000000..3f45ed441b6 --- /dev/null +++ b/docs/source/multimodal/clip_score.rst @@ -0,0 +1,20 @@ +.. customcarditem:: + :header: CLIP Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +############################################ +CLIP Score +############################################ + +Module Interface +________________ + +.. autoclass:: torchmetrics.multimodal.clip_score.CLIPScore + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.multimodal.clip_score.clip_score + :noindex: diff --git a/requirements/devel.txt b/requirements/devel.txt index 0c256bc83c8..4b1b325dece 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -9,6 +9,7 @@ -r text.txt # -r detection.txt # version collision with min versio of PyTorch -r audio.txt +-r multimodal.txt # add extra testing -r image_test.txt diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt new file mode 100644 index 00000000000..89029eeeb33 --- /dev/null +++ b/requirements/multimodal.txt @@ -0,0 +1 @@ +transformers>=4.10.0 diff --git a/requirements/text_test.txt b/requirements/text_test.txt index a4fff8a013a..2046de7363c 100644 --- a/requirements/text_test.txt +++ b/requirements/text_test.txt @@ -1,6 +1,6 @@ jiwer>=2.3.0 rouge-score>=0.0.4 bert_score==0.3.10 -transformers>=4.4.0 +transformers>4.4.0 huggingface-hub<0.7 # hotfix, failing SDR for latest PT 1.11 sacrebleu>=2.0.0 diff --git a/src/torchmetrics/functional/multimodal/__init__.py b/src/torchmetrics/functional/multimodal/__init__.py new file mode 100644 index 00000000000..42b82326afe --- /dev/null +++ b/src/torchmetrics/functional/multimodal/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE + +if _TRANSFORMERS_AVAILABLE: + from torchmetrics.functional.multimodal.clip_score import clip_score # noqa: F401 diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py new file mode 100644 index 00000000000..b97d7c1ddf3 --- /dev/null +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -0,0 +1,140 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE + +if _TRANSFORMERS_AVAILABLE: + from transformers import CLIPModel as _CLIPModel + from transformers import CLIPProcessor as _CLIPProcessor +else: + __doctest_skip__ = ["clip_score"] + _CLIPModel = None # type:ignore + _CLIPProcessor = None # type:ignore + + +def _clip_score_update( + images: Union[Tensor, List[Tensor]], + text: Union[str, List[str]], + model: _CLIPModel, + processor: _CLIPProcessor, +) -> Tuple[Tensor, int]: + if not isinstance(images, list): + if images.ndim == 3: + images = [images] + else: # unwrap into list + images = [i for i in images] + + if not all(i.ndim == 3 for i in images): + raise ValueError("Expected all images to be 3d but found image that has either more or less") + + if not isinstance(text, list): + text = [text] + + if len(text) != len(images): + raise ValueError( + f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}" + ) + device = images[0].device + processed_input = processor( + text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True + ) # type:ignore + + img_features = model.get_image_features(processed_input["pixel_values"].to(device)) + img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) + + txt_features = model.get_text_features( + processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device) + ) + txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity between feature vectors + score = 100 * (img_features * txt_features).sum(axis=-1) + return score, len(text) + + +def _get_model_and_processor( + model_name_or_path: Literal[ + "openai/clip-vit-base-patch16", + "openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14-336", + "openai/clip-vit-large-patch14", + ] = "openai/clip-vit-large-patch14", +) -> Tuple[_CLIPModel, _CLIPProcessor]: + if _TRANSFORMERS_AVAILABLE: + model = _CLIPModel.from_pretrained(model_name_or_path) + processor = _CLIPProcessor.from_pretrained(model_name_or_path) + return model, processor + else: + raise ModuleNotFoundError( + "`clip_score` metric requires `transformers` package be installed." + " Either install with `pip install transformers>=4.0` or `pip install torchmetrics[multimodal]`." + ) + + +def clip_score( + images: Union[Tensor, List[Tensor]], + text: Union[str, List[str]], + model_name_or_path: Literal[ + "openai/clip-vit-base-patch16", + "openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14-336", + "openai/clip-vit-large-patch14", + ] = "openai/clip-vit-large-patch14", +) -> Tensor: + """`CLIP Score`_ is a reference free metric that can be used to evaluate the correlation between a generated + caption for an image and the actual content of the image. It has been found to be highly correlated with human + judgement. The metric is defined as: + + .. math:: + \text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) + + which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and + textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer + to 100 the better. + + .. note:: Metric is not scriptable + + Args: + images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors + text: Either a single caption or a list of captions + model_name_or_path: string indicating the version of the CLIP model to use. Available models are + `"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"` + and `"openai/clip-vit-large-patch14"`, + + Raises: + ModuleNotFoundError: + If transformers package is not installed or version is lower than 4.10.0 + ValueError: + If not all images have format [C, H, W] + ValueError: + If the number of images and captions do not match + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.functional.multimodal import clip_score + >>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16") + >>> print(score.detach()) + tensor(24.4255) + """ + model, processor = _get_model_and_processor(model_name_or_path) + device = images.device if isinstance(images, Tensor) else images[0].device + score, _ = _clip_score_update(images, text, model.to(device), processor) + score = score.mean(0) + return torch.max(score, torch.zeros_like(score)) diff --git a/src/torchmetrics/multimodal/__init__.py b/src/torchmetrics/multimodal/__init__.py new file mode 100644 index 00000000000..e5b4ad56ce2 --- /dev/null +++ b/src/torchmetrics/multimodal/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE + +if _TRANSFORMERS_AVAILABLE: + from torchmetrics.multimodal.clip_score import CLIPScore # noqa: F401 diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py new file mode 100644 index 00000000000..234874b931b --- /dev/null +++ b/src/torchmetrics/multimodal/clip_score.py @@ -0,0 +1,105 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.multimodal.clip_score import _clip_score_update, _get_model_and_processor +from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE + +if not _TRANSFORMERS_AVAILABLE: + __doctest_skip__ = ["CLIPScore"] + +from torchmetrics import Metric + + +class CLIPScore(Metric): + """`CLIP Score`_ is a reference free metric that can be used to evaluate the correlation between a generated + caption for an image and the actual content of the image. It has been found to be highly correlated with human + judgement. The metric is defined as: + + .. math:: + \text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) + + which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and + textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer + to 100 the better. + + .. note:: Metric is not scriptable + + Args: + model_name_or_path: string indicating the version of the CLIP model to use. Available models are + `"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"` + and `"openai/clip-vit-large-patch14"`, + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ModuleNotFoundError: + If transformers package is not installed or version is lower than 4.10.0 + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.multimodal import CLIPScore + >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") + >>> score = metric(torch.randint(255, (3, 224, 224)), "a photo of a cat") + >>> print(score.detach()) + tensor(25.0936) + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = True + score: Tensor + n_samples: Tensor + + def __init__( + self, + model_name_or_path: Literal[ + "openai/clip-vit-base-patch16", + "openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14-336", + "openai/clip-vit-large-patch14", + ] = "openai/clip-vit-large-patch14", + **kwargs: Any, + ) -> None: + + super().__init__(**kwargs) + self.model, self.processor = _get_model_and_processor(model_name_or_path) + self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str]]) -> None: + """Updates CLIP score on a batch of images and text. + + Args: + images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors + text: Either a single caption or a list of captions + + Raises: + ValueError: + If not all images have format [C, H, W] + ValueError: + If the number of images and captions do not match + """ + score, n_samples = _clip_score_update(images, text, self.model, self.processor) + self.score += score.sum(0) + self.n_samples += n_samples + + def compute(self) -> Tensor: + """Computes accumulated clip score.""" + return torch.max(self.score / self.n_samples, torch.zeros_like(self.score)) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 167a78f868a..ddf320c8260 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -26,7 +26,7 @@ from torchmetrics import Metric from torchmetrics.detection.mean_ap import MAPMetricResults -from torchmetrics.utilities.data import apply_to_collection +from torchmetrics.utilities.data import _flatten, apply_to_collection try: set_start_method("spawn") @@ -112,8 +112,8 @@ def _assert_requires_grad(metric: Metric, pl_result: Any, key: Optional[str] = N def _class_test( rank: int, worldsize: int, - preds: Union[Tensor, List[Dict[str, Tensor]]], - target: Union[Tensor, List[Dict[str, Tensor]]], + preds: Union[Tensor, list, List[Dict[str, Tensor]]], + target: Union[Tensor, list, List[Dict[str, Tensor]]], metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, @@ -189,15 +189,16 @@ def _class_test( if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: if isinstance(preds, Tensor): ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]).cpu() + else: + ddp_preds = _flatten([preds[i + r] for r in range(worldsize)]) + if isinstance(target, Tensor): ddp_target = torch.cat([target[i + r] for r in range(worldsize)]).cpu() else: - ddp_preds = [preds[i + r] for r in range(worldsize)] - ddp_target = [target[i + r] for r in range(worldsize)] + ddp_target = _flatten([target[i + r] for r in range(worldsize)]) ddp_kwargs_upd = { k: torch.cat([v[i + r] for r in range(worldsize)]).cpu() if isinstance(v, Tensor) else v for k, v in (kwargs_update if fragment_kwargs else batch_kwargs_update).items() } - sk_batch_result = sk_metric(ddp_preds, ddp_target, **ddp_kwargs_upd) if isinstance(batch_result, dict): for key in batch_result: @@ -237,9 +238,11 @@ def _class_test( if isinstance(preds, Tensor): total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu() - total_target = torch.cat([target[i] for i in range(num_batches)]).cpu() else: total_preds = [item for sublist in preds for item in sublist] + if isinstance(target, Tensor): + total_target = torch.cat([target[i] for i in range(num_batches)]).cpu() + else: total_target = [item for sublist in target for item in sublist] total_kwargs_update = { @@ -257,8 +260,8 @@ def _class_test( def _functional_test( - preds: Tensor, - target: Tensor, + preds: Union[Tensor, list], + target: Union[Tensor, list], metric_functional: Callable, sk_metric: Callable, metric_args: dict = None, @@ -280,8 +283,10 @@ def _functional_test( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ - assert preds.shape[0] == target.shape[0] - num_batches = preds.shape[0] + p_size = preds.shape[0] if isinstance(preds, Tensor) else len(preds) + t_size = target.shape[0] if isinstance(target, Tensor) else len(target) + assert p_size == t_size + num_batches = p_size if not metric_args: metric_args = {} @@ -289,8 +294,10 @@ def _functional_test( metric = partial(metric_functional, **metric_args) # move to device - preds = preds.to(device) - target = target.to(device) + if isinstance(preds, Tensor): + preds = preds.to(device) + if isinstance(target, Tensor): + target = target.to(device) kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} for i in range(num_batches): @@ -300,7 +307,11 @@ def _functional_test( k: v.cpu() if isinstance(v, Tensor) else v for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items() } - sk_result = sk_metric(preds[i].cpu(), target[i].cpu(), **extra_kwargs) + sk_result = sk_metric( + preds[i].cpu() if isinstance(preds, Tensor) else preds[i], + target[i].cpu() if isinstance(target, Tensor) else target[i], + **extra_kwargs, + ) # assert its the same _assert_allclose(tm_result, sk_result, atol=atol) diff --git a/tests/unittests/multimodal/__init__.py b/tests/unittests/multimodal/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py new file mode 100644 index 00000000000..5ee389ab0dc --- /dev/null +++ b/tests/unittests/multimodal/test_clip_score.py @@ -0,0 +1,113 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from transformers import CLIPModel as _CLIPModel +from transformers import CLIPProcessor as _CLIPProcessor + +from torchmetrics.functional.multimodal.clip_score import clip_score +from torchmetrics.multimodal.clip_score import CLIPScore +from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester +from unittests.text.helpers import skip_on_connection_issues + +seed_all(42) + + +Input = namedtuple("Input", ["images", "captions"]) + + +captions = [ + "28-year-old chef found dead in San Francisco mall", + "A 28-year-old chef who recently moved to San Francisco was " + "found dead in the staircase of a local shopping center.", + "The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at " + 'him."', + "A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto .", +] + +_random_input = Input(images=torch.randint(255, (2, 2, 3, 224, 224)), captions=[captions[0:2], captions[2:]]) + + +def _compare_fn(preds, target, model_name_or_path): + processor = _CLIPProcessor.from_pretrained(model_name_or_path) + model = _CLIPModel.from_pretrained(model_name_or_path) + inputs = processor(text=target, images=[p.cpu() for p in preds], return_tensors="pt", padding=True) + outputs = model(**inputs) + logits_per_image = outputs.logits_per_image + return logits_per_image.diag().mean().detach() + + +@pytest.mark.parametrize("model_name_or_path", ["openai/clip-vit-base-patch32"]) +@pytest.mark.parametrize("input", [_random_input]) +@pytest.mark.skipif(not _TRANSFORMERS_AVAILABLE, reason="test requires bert_score") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") +class TestCLIPScore(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @skip_on_connection_issues() + def test_clip_score(self, input, model_name_or_path, ddp, dist_sync_on_step): + # images are preds and targets are captions + preds, target = input + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=CLIPScore, + sk_metric=partial(_compare_fn, model_name_or_path=model_name_or_path), + dist_sync_on_step=dist_sync_on_step, + metric_args={"model_name_or_path": model_name_or_path}, + check_scriptable=False, + check_state_dict=False, + ) + + @skip_on_connection_issues() + def test_clip_score_functional(self, input, model_name_or_path): + preds, target = input + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=clip_score, + sk_metric=partial(_compare_fn, model_name_or_path=model_name_or_path), + metric_args={"model_name_or_path": model_name_or_path}, + ) + + @skip_on_connection_issues() + def test_clip_score_differentiability(self, input, model_name_or_path): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=CLIPScore, + metric_functional=clip_score, + metric_args={"model_name_or_path": model_name_or_path}, + ) + + @skip_on_connection_issues() + def test_error_on_not_same_amount_of_input(self, input, model_name_or_path): + """Test that an error is raised if the number of images and text examples does not match.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + with pytest.raises(ValueError): + metric(torch.randint(255, (2, 3, 224, 224)), "28-year-old chef found dead in San Francisco mall") + + @skip_on_connection_issues() + def test_error_on_wrong_image_format(self, input, model_name_or_path): + """Test that an error is raised if not all images are [c, h, w] format.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + with pytest.raises(ValueError): + metric(torch.randint(255, (224, 224)), "28-year-old chef found dead in San Francisco mall")