diff --git a/CHANGELOG.md b/CHANGELOG.md index 7315c4baf7f..5a5186f5b2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `TotalVariation` to image package ([#978](https://github.com/PyTorchLightning/metrics/pull/978)) + + - Added a new NLP metric `InfoLM` ([#915](https://github.com/PyTorchLightning/metrics/pull/915)) - Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922)) - Added `ConcordanceCorrCoef` metric to regression package ([#1201](https://github.com/Lightning-AI/metrics/pull/1201)) diff --git a/docs/source/image/total_variation.rst b/docs/source/image/total_variation.rst new file mode 100644 index 00000000000..0f0e7398d9b --- /dev/null +++ b/docs/source/image/total_variation.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Total Variation (TV) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +.. include:: ../links.rst + +#################### +Total Variation (TV) +#################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.TotalVariation + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.total_variation + :noindex: diff --git a/docs/source/links.rst b/docs/source/links.rst index fd61d1e6863..2cb9be7918f 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -86,6 +86,7 @@ .. _MER: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf .. _WIL: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf .. _WIP: https://infoscience.epfl.ch/record/82766 +.. _TV: https://en.wikipedia.org/wiki/Total_variation_denoising .. _InfoLM: https://arxiv.org/pdf/2112.01589.pdf .. _alpha divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf .. _beta divergence: https://www.sciencedirect.com/science/article/pii/S0047259X08000456 diff --git a/requirements/image_test.txt b/requirements/image_test.txt index b27f18a46c5..c0525cdf814 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -1,2 +1,3 @@ scikit-image>0.17.1 +kornia pytorch-msssim==0.2.1 diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 9143f4b4486..8aa046a9938 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -56,6 +56,7 @@ SpectralAngleMapper, SpectralDistortionIndex, StructuralSimilarityIndexMeasure, + TotalVariation, UniversalImageQualityIndex, ) from torchmetrics.metric import Metric # noqa: E402 @@ -191,6 +192,7 @@ "StatScores", "SumMetric", "SymmetricMeanAbsolutePercentageError", + "TotalVariation", "TranslationEditRate", "UniversalImageQualityIndex", "WeightedMeanAbsolutePercentageError", diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 164e20fd88e..44245a2d6cd 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -46,6 +46,7 @@ multiscale_structural_similarity_index_measure, structural_similarity_index_measure, ) +from torchmetrics.functional.image.tv import total_variation from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance @@ -166,6 +167,7 @@ "structural_similarity_index_measure", "stat_scores", "symmetric_mean_absolute_percentage_error", + "total_variation", "translation_edit_rate", "universal_image_quality_index", "spectral_angle_mapper", diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index a438cbe940b..c3c53501018 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -20,4 +20,5 @@ multiscale_structural_similarity_index_measure, structural_similarity_index_measure, ) +from torchmetrics.functional.image.tv import total_variation # noqa: F401 from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py new file mode 100644 index 00000000000..d3432ebdabb --- /dev/null +++ b/src/torchmetrics/functional/image/tv.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +from torch import Tensor +from typing_extensions import Literal + + +def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: + """Computes total variation statistics on current batch.""" + if img.ndim != 4: + raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}") + diff1 = img[..., 1:, :] - img[..., :-1, :] + diff2 = img[..., :, 1:] - img[..., :, :-1] + + res1 = diff1.abs().sum([1, 2, 3]) + res2 = diff2.abs().sum([1, 2, 3]) + score = res1 + res2 + return score, img.shape[0] + + +def _total_variation_compute( + score: Tensor, num_elements: int, reduction: Literal["mean", "sum", "none", None] +) -> Tensor: + """Compute final total variation score.""" + if reduction == "mean": + return score.sum() / num_elements + elif reduction == "sum": + return score.sum() + elif reduction is None or reduction == "none": + return score + else: + raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None") + + +def total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor: + """Computes total variation loss. + + Args: + img: A `Tensor` of shape `(N, C, H, W)` consisting of images + reduction: a method to reduce metric score over samples. + + - ``'mean'``: takes the mean over samples + - ``'sum'``: takes the sum over samples + - ``None`` or ``'none'``: return the score per sample + + Returns: + A loss scalar value containing the total variation + + Raises: + ValueError: + If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` + RuntimeError: + If ``img`` is not 4D tensor + + Example: + >>> import torch + >>> from torchmetrics.functional import total_variation + >>> _ = torch.manual_seed(42) + >>> img = torch.rand(5, 3, 28, 28) + >>> total_variation(img) + tensor(7546.8018) + """ + # code adapted from: + # from kornia.losses import total_variation as kornia_total_variation + score, num_elements = _total_variation_update(img) + return _total_variation_compute(score, num_elements, reduction) diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 05fda7d29ac..781c28c8ab4 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -29,3 +29,5 @@ if _LPIPS_AVAILABLE: from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401 + +from torchmetrics.image.tv import TotalVariation # noqa: F401 diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py new file mode 100644 index 00000000000..aef5e8b97df --- /dev/null +++ b/src/torchmetrics/image/tv.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch +from torch import Tensor, tensor +from typing_extensions import Literal + +from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class TotalVariation(Metric): + """Computes Total Variation loss (`TV`_). + + Args: + reduction: a method to reduce metric score over samples + + - ``'mean'``: takes the mean over samples + - ``'sum'``: takes the sum over samples + - ``None`` or ``'none'``: return the score per sample + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` + + Example: + >>> import torch + >>> from torchmetrics import TotalVariation + >>> _ = torch.manual_seed(42) + >>> tv = TotalVariation() + >>> img = torch.rand(5, 3, 28, 28) + >>> tv(img) + tensor(7546.8018) + """ + + full_state_update: bool = False + is_differentiable: bool = True + higher_is_better: bool = False + + def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None: + super().__init__(**kwargs) + if reduction is not None and reduction not in ("sum", "mean", "none"): + raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None") + self.reduction = reduction + + if self.reduction is None or self.reduction == "none": + self.add_state("score", default=[], dist_reduce_fx="cat") + else: + self.add_state("score", default=tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("num_elements", default=tensor(0, dtype=torch.int), dist_reduce_fx="sum") + + def update(self, img: Tensor) -> None: # type: ignore + """Update current score with batch of input images. + + Args: + img: A `Tensor` of shape `(N, C, H, W)` consisting of images + """ + score, num_elements = _total_variation_update(img) + if self.reduction is None or self.reduction == "none": + self.score.append(score) + else: + self.score += score.sum() + self.num_elements += num_elements + + def compute(self) -> Tensor: + """Compute final total variation.""" + if self.reduction is None or self.reduction == "none": + score = dim_zero_cat(self.score) + else: + score = self.score + return _total_variation_compute(score, self.num_elements, self.reduction) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py new file mode 100644 index 00000000000..c97f5ae2e6f --- /dev/null +++ b/tests/unittests/image/test_tv.py @@ -0,0 +1,121 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from kornia.losses import total_variation as kornia_total_variation + +from torchmetrics.functional.image.tv import total_variation +from torchmetrics.image.tv import TotalVariation +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8 +from unittests.helpers import seed_all +from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester + +seed_all(42) + + +# add extra argument to make the metric and reference fit into the testing framework +class TotalVariationTester(TotalVariation): + def update(self, img, *args): + super().update(img=img) + + +def total_variaion_tester(preds, target, reduction="mean"): + return total_variation(preds, reduction) + + +def total_variation_kornia_tester(preds, target, reduction): + score = kornia_total_variation(preds).sum(-1) + if reduction == "sum": + return score.sum() + elif reduction == "mean": + return score.mean() + return score + + +# define inputs +Input = namedtuple("Input", ["preds", "target"]) + +_inputs = [] +for size, channel, dtype in [ + (12, 3, torch.float), + (13, 3, torch.float32), + (14, 3, torch.double), + (15, 3, torch.float64), +]: + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + _inputs.append(Input(preds=preds, target=target)) + + +@pytest.mark.parametrize( + "preds, target", + [(i.preds, i.target) for i in _inputs], +) +@pytest.mark.parametrize("reduction", ["sum", "mean", None]) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="Kornia used as reference requires min PT version") +class TestTotalVariation(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_total_variation(self, preds, target, reduction, ddp, dist_sync_on_step): + """Test modular implementation.""" + if reduction is None and ddp: + pytest.skip("reduction=None and ddp=True runs out of memory on CI hardware, but it does work") + self.run_class_metric_test( + ddp, + preds, + target, + TotalVariationTester, + partial(total_variation_kornia_tester, reduction=reduction), + dist_sync_on_step, + metric_args={"reduction": reduction}, + ) + + def test_total_variation_functional(self, preds, target, reduction): + """Test for functional implementation.""" + self.run_functional_metric_test( + preds, + target, + total_variaion_tester, + partial(total_variation_kornia_tester, reduction=reduction), + metric_args={"reduction": reduction}, + ) + + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6" + ) + def test_sam_half_cpu(self, preds, target, reduction): + """Test for half precision on CPU.""" + self.run_precision_test_cpu( + preds, + target, + TotalVariationTester, + total_variaion_tester, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_sam_half_gpu(self, preds, target, reduction): + """Test for half precision on GPU.""" + self.run_precision_test_gpu(preds, target, TotalVariationTester, total_variaion_tester) + + +def test_correct_args(): + """that that arguments have the right type and sizes.""" + with pytest.raises(ValueError, match="Expected argument `reduction`.*"): + _ = TotalVariation(reduction="diff") + + with pytest.raises(RuntimeError, match="Expected input `img` to.*"): + _ = total_variation(torch.randn(1, 2, 3))