Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first steps * further updates * add some testing * changelog * add docstrings * Apply suggestions from code review * add functional and refactor * fix requirement + testing * more requirements Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
- Loading branch information
1 parent
e4014f4
commit 0f6f861
Showing
16 changed files
with
456 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
.. customcarditem:: | ||
:header: CLIP Score | ||
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg | ||
:tags: Image | ||
|
||
############################################ | ||
CLIP Score | ||
############################################ | ||
|
||
Module Interface | ||
________________ | ||
|
||
.. autoclass:: torchmetrics.multimodal.clip_score.CLIPScore | ||
:noindex: | ||
|
||
Functional Interface | ||
____________________ | ||
|
||
.. autofunction:: torchmetrics.functional.multimodal.clip_score.clip_score | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
transformers>=4.10.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
jiwer>=2.3.0 | ||
rouge-score>=0.0.4 | ||
bert_score==0.3.10 | ||
transformers>=4.4.0 | ||
transformers>4.4.0 | ||
huggingface-hub<0.7 # hotfix, failing SDR for latest PT 1.11 | ||
sacrebleu>=2.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE | ||
|
||
if _TRANSFORMERS_AVAILABLE: | ||
from torchmetrics.functional.multimodal.clip_score import clip_score # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List, Tuple, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
from typing_extensions import Literal | ||
|
||
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE | ||
|
||
if _TRANSFORMERS_AVAILABLE: | ||
from transformers import CLIPModel as _CLIPModel | ||
from transformers import CLIPProcessor as _CLIPProcessor | ||
else: | ||
__doctest_skip__ = ["clip_score"] | ||
_CLIPModel = None # type:ignore | ||
_CLIPProcessor = None # type:ignore | ||
|
||
|
||
def _clip_score_update( | ||
images: Union[Tensor, List[Tensor]], | ||
text: Union[str, List[str]], | ||
model: _CLIPModel, | ||
processor: _CLIPProcessor, | ||
) -> Tuple[Tensor, int]: | ||
if not isinstance(images, list): | ||
if images.ndim == 3: | ||
images = [images] | ||
else: # unwrap into list | ||
images = [i for i in images] | ||
|
||
if not all(i.ndim == 3 for i in images): | ||
raise ValueError("Expected all images to be 3d but found image that has either more or less") | ||
|
||
if not isinstance(text, list): | ||
text = [text] | ||
|
||
if len(text) != len(images): | ||
raise ValueError( | ||
f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}" | ||
) | ||
device = images[0].device | ||
processed_input = processor( | ||
text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True | ||
) # type:ignore | ||
|
||
img_features = model.get_image_features(processed_input["pixel_values"].to(device)) | ||
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) | ||
|
||
txt_features = model.get_text_features( | ||
processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device) | ||
) | ||
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) | ||
|
||
# cosine similarity between feature vectors | ||
score = 100 * (img_features * txt_features).sum(axis=-1) | ||
return score, len(text) | ||
|
||
|
||
def _get_model_and_processor( | ||
model_name_or_path: Literal[ | ||
"openai/clip-vit-base-patch16", | ||
"openai/clip-vit-base-patch32", | ||
"openai/clip-vit-large-patch14-336", | ||
"openai/clip-vit-large-patch14", | ||
] = "openai/clip-vit-large-patch14", | ||
) -> Tuple[_CLIPModel, _CLIPProcessor]: | ||
if _TRANSFORMERS_AVAILABLE: | ||
model = _CLIPModel.from_pretrained(model_name_or_path) | ||
processor = _CLIPProcessor.from_pretrained(model_name_or_path) | ||
return model, processor | ||
else: | ||
raise ModuleNotFoundError( | ||
"`clip_score` metric requires `transformers` package be installed." | ||
" Either install with `pip install transformers>=4.0` or `pip install torchmetrics[multimodal]`." | ||
) | ||
|
||
|
||
def clip_score( | ||
images: Union[Tensor, List[Tensor]], | ||
text: Union[str, List[str]], | ||
model_name_or_path: Literal[ | ||
"openai/clip-vit-base-patch16", | ||
"openai/clip-vit-base-patch32", | ||
"openai/clip-vit-large-patch14-336", | ||
"openai/clip-vit-large-patch14", | ||
] = "openai/clip-vit-large-patch14", | ||
) -> Tensor: | ||
"""`CLIP Score`_ is a reference free metric that can be used to evaluate the correlation between a generated | ||
caption for an image and the actual content of the image. It has been found to be highly correlated with human | ||
judgement. The metric is defined as: | ||
.. math:: | ||
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) | ||
which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and | ||
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer | ||
to 100 the better. | ||
.. note:: Metric is not scriptable | ||
Args: | ||
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors | ||
text: Either a single caption or a list of captions | ||
model_name_or_path: string indicating the version of the CLIP model to use. Available models are | ||
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"` | ||
and `"openai/clip-vit-large-patch14"`, | ||
Raises: | ||
ModuleNotFoundError: | ||
If transformers package is not installed or version is lower than 4.10.0 | ||
ValueError: | ||
If not all images have format [C, H, W] | ||
ValueError: | ||
If the number of images and captions do not match | ||
Example: | ||
>>> import torch | ||
>>> _ = torch.manual_seed(42) | ||
>>> from torchmetrics.functional.multimodal import clip_score | ||
>>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16") | ||
>>> print(score.detach()) | ||
tensor(24.4255) | ||
""" | ||
model, processor = _get_model_and_processor(model_name_or_path) | ||
device = images.device if isinstance(images, Tensor) else images[0].device | ||
score, _ = _clip_score_update(images, text, model.to(device), processor) | ||
score = score.mean(0) | ||
return torch.max(score, torch.zeros_like(score)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE | ||
|
||
if _TRANSFORMERS_AVAILABLE: | ||
from torchmetrics.multimodal.clip_score import CLIPScore # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, List, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
from typing_extensions import Literal | ||
|
||
from torchmetrics.functional.multimodal.clip_score import _clip_score_update, _get_model_and_processor | ||
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE | ||
|
||
if not _TRANSFORMERS_AVAILABLE: | ||
__doctest_skip__ = ["CLIPScore"] | ||
|
||
from torchmetrics import Metric | ||
|
||
|
||
class CLIPScore(Metric): | ||
"""`CLIP Score`_ is a reference free metric that can be used to evaluate the correlation between a generated | ||
caption for an image and the actual content of the image. It has been found to be highly correlated with human | ||
judgement. The metric is defined as: | ||
.. math:: | ||
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) | ||
which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and | ||
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer | ||
to 100 the better. | ||
.. note:: Metric is not scriptable | ||
Args: | ||
model_name_or_path: string indicating the version of the CLIP model to use. Available models are | ||
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"` | ||
and `"openai/clip-vit-large-patch14"`, | ||
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. | ||
Raises: | ||
ModuleNotFoundError: | ||
If transformers package is not installed or version is lower than 4.10.0 | ||
Example: | ||
>>> import torch | ||
>>> _ = torch.manual_seed(42) | ||
>>> from torchmetrics.multimodal import CLIPScore | ||
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") | ||
>>> score = metric(torch.randint(255, (3, 224, 224)), "a photo of a cat") | ||
>>> print(score.detach()) | ||
tensor(25.0936) | ||
""" | ||
|
||
is_differentiable: bool = False | ||
higher_is_better: bool = True | ||
full_state_update: bool = True | ||
score: Tensor | ||
n_samples: Tensor | ||
|
||
def __init__( | ||
self, | ||
model_name_or_path: Literal[ | ||
"openai/clip-vit-base-patch16", | ||
"openai/clip-vit-base-patch32", | ||
"openai/clip-vit-large-patch14-336", | ||
"openai/clip-vit-large-patch14", | ||
] = "openai/clip-vit-large-patch14", | ||
**kwargs: Any, | ||
) -> None: | ||
|
||
super().__init__(**kwargs) | ||
self.model, self.processor = _get_model_and_processor(model_name_or_path) | ||
self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") | ||
self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") | ||
|
||
def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str]]) -> None: | ||
"""Updates CLIP score on a batch of images and text. | ||
Args: | ||
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors | ||
text: Either a single caption or a list of captions | ||
Raises: | ||
ValueError: | ||
If not all images have format [C, H, W] | ||
ValueError: | ||
If the number of images and captions do not match | ||
""" | ||
score, n_samples = _clip_score_update(images, text, self.model, self.processor) | ||
self.score += score.sum(0) | ||
self.n_samples += n_samples | ||
|
||
def compute(self) -> Tensor: | ||
"""Computes accumulated clip score.""" | ||
return torch.max(self.score / self.n_samples, torch.zeros_like(self.score)) |
Oops, something went wrong.