Skip to content

Commit

Permalink
Total Variation Loss (#978)
Browse files Browse the repository at this point in the history
* Create total_variation.py

* Initialize the PR.

* Missed importing Metric.

* added functional and modular metric.

* Adding copyright notice to top of file.

* Apply first suggestions from code review

* docs

* changelog

* init file

* tv

* add tests

* init and requirements

* working implementation

* skip some testing

* typing

* vars

* fix type

* remove renaming

* fix tests

* add none option

* fix doctest

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ragavv <ragavv@nvidia.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
6 people committed Oct 5, 2022
1 parent 4aeb6cb commit eecfa7d
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `TotalVariation` to image package ([#978](https://github.com/PyTorchLightning/metrics/pull/978))


- Added a new NLP metric `InfoLM` ([#915](https://github.com/PyTorchLightning/metrics/pull/915))
- Added `Perplexity` metric ([#922](https://github.com/PyTorchLightning/metrics/pull/922))
- Added `ConcordanceCorrCoef` metric to regression package ([#1201](https://github.com/Lightning-AI/metrics/pull/1201))
Expand Down
22 changes: 22 additions & 0 deletions docs/source/image/total_variation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Total Variation (TV)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

####################
Total Variation (TV)
####################

Module Interface
________________

.. autoclass:: torchmetrics.TotalVariation
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.total_variation
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
.. _MER: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIL: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIP: https://infoscience.epfl.ch/record/82766
.. _TV: https://en.wikipedia.org/wiki/Total_variation_denoising
.. _InfoLM: https://arxiv.org/pdf/2112.01589.pdf
.. _alpha divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
.. _beta divergence: https://www.sciencedirect.com/science/article/pii/S0047259X08000456
Expand Down
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
scikit-image>0.17.1
kornia
pytorch-msssim==0.2.1
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
SpectralAngleMapper,
SpectralDistortionIndex,
StructuralSimilarityIndexMeasure,
TotalVariation,
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
Expand Down Expand Up @@ -191,6 +192,7 @@
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TotalVariation",
"TranslationEditRate",
"UniversalImageQualityIndex",
"WeightedMeanAbsolutePercentageError",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
Expand Down Expand Up @@ -166,6 +167,7 @@
"structural_similarity_index_measure",
"stat_scores",
"symmetric_mean_absolute_percentage_error",
"total_variation",
"translation_edit_rate",
"universal_image_quality_index",
"spectral_angle_mapper",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.tv import total_variation # noqa: F401
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
78 changes: 78 additions & 0 deletions src/torchmetrics/functional/image/tv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

from torch import Tensor
from typing_extensions import Literal


def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]:
"""Computes total variation statistics on current batch."""
if img.ndim != 4:
raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}")
diff1 = img[..., 1:, :] - img[..., :-1, :]
diff2 = img[..., :, 1:] - img[..., :, :-1]

res1 = diff1.abs().sum([1, 2, 3])
res2 = diff2.abs().sum([1, 2, 3])
score = res1 + res2
return score, img.shape[0]


def _total_variation_compute(
score: Tensor, num_elements: int, reduction: Literal["mean", "sum", "none", None]
) -> Tensor:
"""Compute final total variation score."""
if reduction == "mean":
return score.sum() / num_elements
elif reduction == "sum":
return score.sum()
elif reduction is None or reduction == "none":
return score
else:
raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None")


def total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor:
"""Computes total variation loss.
Args:
img: A `Tensor` of shape `(N, C, H, W)` consisting of images
reduction: a method to reduce metric score over samples.
- ``'mean'``: takes the mean over samples
- ``'sum'``: takes the sum over samples
- ``None`` or ``'none'``: return the score per sample
Returns:
A loss scalar value containing the total variation
Raises:
ValueError:
If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None``
RuntimeError:
If ``img`` is not 4D tensor
Example:
>>> import torch
>>> from torchmetrics.functional import total_variation
>>> _ = torch.manual_seed(42)
>>> img = torch.rand(5, 3, 28, 28)
>>> total_variation(img)
tensor(7546.8018)
"""
# code adapted from:
# from kornia.losses import total_variation as kornia_total_variation
score, num_elements = _total_variation_update(img)
return _total_variation_compute(score, num_elements, reduction)
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@

if _LPIPS_AVAILABLE:
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401

from torchmetrics.image.tv import TotalVariation # noqa: F401
86 changes: 86 additions & 0 deletions src/torchmetrics/image/tv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat


class TotalVariation(Metric):
"""Computes Total Variation loss (`TV`_).
Args:
reduction: a method to reduce metric score over samples
- ``'mean'``: takes the mean over samples
- ``'sum'``: takes the sum over samples
- ``None`` or ``'none'``: return the score per sample
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None``
Example:
>>> import torch
>>> from torchmetrics import TotalVariation
>>> _ = torch.manual_seed(42)
>>> tv = TotalVariation()
>>> img = torch.rand(5, 3, 28, 28)
>>> tv(img)
tensor(7546.8018)
"""

full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = False

def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None:
super().__init__(**kwargs)
if reduction is not None and reduction not in ("sum", "mean", "none"):
raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None")
self.reduction = reduction

if self.reduction is None or self.reduction == "none":
self.add_state("score", default=[], dist_reduce_fx="cat")
else:
self.add_state("score", default=tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("num_elements", default=tensor(0, dtype=torch.int), dist_reduce_fx="sum")

def update(self, img: Tensor) -> None: # type: ignore
"""Update current score with batch of input images.
Args:
img: A `Tensor` of shape `(N, C, H, W)` consisting of images
"""
score, num_elements = _total_variation_update(img)
if self.reduction is None or self.reduction == "none":
self.score.append(score)
else:
self.score += score.sum()
self.num_elements += num_elements

def compute(self) -> Tensor:
"""Compute final total variation."""
if self.reduction is None or self.reduction == "none":
score = dim_zero_cat(self.score)
else:
score = self.score
return _total_variation_compute(score, self.num_elements, self.reduction)
121 changes: 121 additions & 0 deletions tests/unittests/image/test_tv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from kornia.losses import total_variation as kornia_total_variation

from torchmetrics.functional.image.tv import total_variation
from torchmetrics.image.tv import TotalVariation
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8
from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester

seed_all(42)


# add extra argument to make the metric and reference fit into the testing framework
class TotalVariationTester(TotalVariation):
def update(self, img, *args):
super().update(img=img)


def total_variaion_tester(preds, target, reduction="mean"):
return total_variation(preds, reduction)


def total_variation_kornia_tester(preds, target, reduction):
score = kornia_total_variation(preds).sum(-1)
if reduction == "sum":
return score.sum()
elif reduction == "mean":
return score.mean()
return score


# define inputs
Input = namedtuple("Input", ["preds", "target"])

_inputs = []
for size, channel, dtype in [
(12, 3, torch.float),
(13, 3, torch.float32),
(14, 3, torch.double),
(15, 3, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(Input(preds=preds, target=target))


@pytest.mark.parametrize(
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
@pytest.mark.parametrize("reduction", ["sum", "mean", None])
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="Kornia used as reference requires min PT version")
class TestTotalVariation(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_total_variation(self, preds, target, reduction, ddp, dist_sync_on_step):
"""Test modular implementation."""
if reduction is None and ddp:
pytest.skip("reduction=None and ddp=True runs out of memory on CI hardware, but it does work")
self.run_class_metric_test(
ddp,
preds,
target,
TotalVariationTester,
partial(total_variation_kornia_tester, reduction=reduction),
dist_sync_on_step,
metric_args={"reduction": reduction},
)

def test_total_variation_functional(self, preds, target, reduction):
"""Test for functional implementation."""
self.run_functional_metric_test(
preds,
target,
total_variaion_tester,
partial(total_variation_kornia_tester, reduction=reduction),
metric_args={"reduction": reduction},
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6"
)
def test_sam_half_cpu(self, preds, target, reduction):
"""Test for half precision on CPU."""
self.run_precision_test_cpu(
preds,
target,
TotalVariationTester,
total_variaion_tester,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_sam_half_gpu(self, preds, target, reduction):
"""Test for half precision on GPU."""
self.run_precision_test_gpu(preds, target, TotalVariationTester, total_variaion_tester)


def test_correct_args():
"""that that arguments have the right type and sizes."""
with pytest.raises(ValueError, match="Expected argument `reduction`.*"):
_ = TotalVariation(reduction="diff")

with pytest.raises(RuntimeError, match="Expected input `img` to.*"):
_ = total_variation(torch.randn(1, 2, 3))

0 comments on commit eecfa7d

Please sign in to comment.