Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLIP score #1314

Merged
merged 44 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f9c3fbe
first steps
SkafteNicki Nov 5, 2022
aaa3265
further updates
SkafteNicki Nov 5, 2022
7295fef
add some testing
SkafteNicki Nov 5, 2022
f199d7e
changelog
SkafteNicki Nov 5, 2022
7814d2b
docstring
SkafteNicki Nov 5, 2022
7ca7854
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2022
7df3cbc
add to index
SkafteNicki Nov 8, 2022
b9db500
add docstrings
SkafteNicki Nov 8, 2022
ff3e62a
update
SkafteNicki Nov 9, 2022
7e1c8d1
fix tests
SkafteNicki Nov 10, 2022
2839bae
Merge branch 'metric/clip' of https://github.com/PyTorchLightning/met…
SkafteNicki Nov 10, 2022
4a7dc1d
Merge branch 'master' into metric/clip
SkafteNicki Nov 10, 2022
b1c8b27
add requirement
SkafteNicki Nov 11, 2022
c354fe0
try fixing mypy and docs
SkafteNicki Nov 11, 2022
711e343
fix
SkafteNicki Nov 11, 2022
95fbff7
skip on no transformer
SkafteNicki Nov 11, 2022
2111d09
fix typing
SkafteNicki Nov 11, 2022
9a3f256
Merge branch 'master' into metric/clip
Borda Nov 11, 2022
3702610
Apply suggestions from code review
SkafteNicki Nov 13, 2022
cd1f50e
Merge branch 'master' into metric/clip
SkafteNicki Nov 14, 2022
1df18be
add functional and refactor
SkafteNicki Nov 14, 2022
3205b91
change variable name
SkafteNicki Nov 14, 2022
6490a03
Merge branch 'master' into metric/clip
SkafteNicki Nov 14, 2022
61cedeb
fix testing
SkafteNicki Nov 14, 2022
c0cec12
try fixing typing
SkafteNicki Nov 14, 2022
61cb3cc
Merge branch 'master' into metric/clip
SkafteNicki Nov 16, 2022
0cbc15d
8g
Borda Nov 16, 2022
0a734db
fix requirement + testing
SkafteNicki Nov 16, 2022
02c5234
Merge branch 'metric/clip' of https://github.com/PyTorchLightning/met…
SkafteNicki Nov 16, 2022
8a876c6
more requirements
SkafteNicki Nov 16, 2022
ae79845
fix
SkafteNicki Nov 16, 2022
89813ce
Merge branch 'master' into metric/clip
SkafteNicki Nov 16, 2022
f091e62
fix doctests
SkafteNicki Nov 16, 2022
fcf268d
Merge branch 'metric/clip' of https://github.com/PyTorchLightning/met…
SkafteNicki Nov 16, 2022
53ec80d
fix
SkafteNicki Nov 16, 2022
e3a9117
remove back
SkafteNicki Nov 17, 2022
2851923
Merge branch 'master' into metric/clip
Borda Nov 17, 2022
ea4b11a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
f1595a2
move section in index
SkafteNicki Nov 17, 2022
ae925b9
set min version of transformers
SkafteNicki Nov 17, 2022
8feb2cb
fix flake
SkafteNicki Nov 17, 2022
95bd30b
simple
Borda Nov 17, 2022
0debc25
Apply suggestions from code review
Borda Nov 17, 2022
56dc6f7
avail
Borda Nov 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `LogCoshError` to regression package ([#1316](https://github.com/Lightning-AI/metrics/pull/1316))


- Added `CLIPScore` to new multimodal package ([#1314](https://github.com/Lightning-AI/metrics/pull/1314))

### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ Or directly from conda

image/*

.. toctree::
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
:maxdepth: 2
:name: multimodal
:caption: Multimodal
:glob:

multimodal/*

.. toctree::
:maxdepth: 2
:name: detection
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,5 @@
.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
.. _CLIP score: https://arxiv.org/pdf/2104.08718.pdf
.. _Huggingface OpenAI: https://huggingface.co/openai
14 changes: 14 additions & 0 deletions docs/source/multimodal/clip_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. customcarditem::
:header: CLIP Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

############################################
CLIP Score
############################################

Module Interface
________________

.. autoclass:: torchmetrics.multimodal.clip_score.CLIPScore
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
:noindex:
1 change: 1 addition & 0 deletions requirements/devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
-r text.txt
# -r detection.txt # version collision with min versio of PyTorch
-r audio.txt
-r multimodal.txt

# add extra testing
-r image_test.txt
Expand Down
1 change: 1 addition & 0 deletions requirements/multimodal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers>=4.4.0
17 changes: 17 additions & 0 deletions src/torchmetrics/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from torchmetrics.multimodal.clip_score import CLIPScore # noqa: F401
140 changes: 140 additions & 0 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor
else:
__doctest_skip__ = ["CLIPScore"]

from torchmetrics import Metric


class CLIPScore(Metric):
"""`CLIP Score`_ is a reference free metric that can be used to evaluate the correlation between an generated
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
caption for an image and the actual content of the image. It has been found to be highly correlated with human
judgement. The metric is defined as:

.. math::
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)

which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer
to 100 the better.

.. note:: Metric is not scriptable

Args:
version: string indicating the version of the CLIP model to use. Available models are
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
and `"openai/clip-vit-large-patch14"`,

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ModuleNotFoundError:
If transformers package is not installed

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.multimodal import CLIPScore
>>> metric = CLIPScore()
>>> metric(torch.randint(255, (3, 224, 224)), "a photo of a cat")
tensor(19.4135, grad_fn=<SqueezeBackward0>)
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
score: Tensor
n_samples: Tensor

def __init__(
self,
version: Literal[
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
] = "openai/clip-vit-large-patch14",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if _TRANSFORMERS_AVAILABLE:
self.model = _CLIPModel.from_pretrained(version)
self.processor = _CLIPProcessor.from_pretrained(version)
else:
raise ModuleNotFoundError(
"`CLIPScore` metric requires `transformers` package be installed."
" Either install with `pip install transformers>=4.0` or `pip install torchmetrics[multimodal]`."
)
self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")

def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str]]) -> None:
"""Updates CLIP score on a batch of images and text.

Args:
images: Either a single [N, C, H, W] tensor or an list of [C, H, W] tensors
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
text: Either a single caption or a list of captions

Raises:
ValueError:
If not all images have format [C, H, W]
ValueError:
If the number of images and captions do not match
"""
if not isinstance(images, List):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
if images.ndim == 3:
images = [images]
else: # unwrap into list
images = [i for i in images]

if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")

if not isinstance(text, List):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
text = [text]

if len(text) != len(images):
raise ValueError(
f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}"
)

processed_input = self.processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True)

img_features = self.model.get_image_features(processed_input["pixel_values"].to(self.device))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)

txt_features = self.model.get_text_features(
processed_input["input_ids"].to(self.device), processed_input["attention_mask"].to(self.device)
)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)

# cosine similarity between feature vectors
score = (img_features * txt_features).sum(axis=-1)
self.score += 100 * score.sum(0)
self.n_samples += img_features.shape[0]

def compute(self) -> Tensor:
"""Computes accumulated clip score."""
return torch.max(self.score / self.n_samples, torch.zeros_like(self.score))
4 changes: 3 additions & 1 deletion tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ def _class_test(

if isinstance(preds, Tensor):
total_preds = torch.cat([preds[i] for i in range(num_batches)]).cpu()
total_target = torch.cat([target[i] for i in range(num_batches)]).cpu()
else:
total_preds = [item for sublist in preds for item in sublist]
if isinstance(target, Tensor):
total_target = torch.cat([target[i] for i in range(num_batches)]).cpu()
else:
total_target = [item for sublist in target for item in sublist]

total_kwargs_update = {
Expand Down
Empty file.
88 changes: 88 additions & 0 deletions tests/unittests/multimodal/test_clip_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

from torchmetrics.multimodal.clip_score import CLIPScore
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)


Input = namedtuple("Input", ["images", "captions"])


captions = [
"28-year-old chef found dead in San Francisco mall",
"A 28-year-old chef who recently moved to San Francisco was "
"found dead in the staircase of a local shopping center.",
"The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at "
'him."',
"A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto .",
]

_random_input = Input(images=torch.randint(255, (2, 2, 3, 224, 224)), captions=[captions[0:2], captions[2:]])


def _compare_fn(preds, target, version):
processor = _CLIPProcessor.from_pretrained(version)
model = _CLIPModel.from_pretrained(version)
inputs = processor(text=target, images=[p.cpu() for p in preds], return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
return logits_per_image.diag().mean().detach()


@pytest.mark.parametrize("version", ["openai/clip-vit-base-patch32"])
@pytest.mark.parametrize(
"input",
[
_random_input,
],
)
class TestCLIPScore(MetricTester):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
atol = 1e-5

@pytest.mark.parametrize("ddp", [True, False])
def test_clip_score(self, input, version, ddp):
# images are preds and targets are captions
preds, target = input
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=CLIPScore,
sk_metric=partial(_compare_fn, version=version),
metric_args={"version": version},
check_scriptable=False,
check_state_dict=False,
)

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
def test_error_on_not_same_amount_of_input(self, input, version):
"""Test that an error is raised if the number of images and text examples does not match."""
metric = CLIPScore(version=version)
with pytest.raises(ValueError):
metric(torch.randint(255, (2, 3, 224, 224)), "28-year-old chef found dead in San Francisco mall")

def test_error_on_wrong_image_format(self, input, version):
"""Test that an error is raised if not all images are [c, h, w] format."""
metric = CLIPScore(version=version)
with pytest.raises(ValueError):
metric(torch.randint(255, (224, 224)), "28-year-old chef found dead in San Francisco mall")