Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Total Variation Loss #978

Merged
merged 55 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
9858b0f
Create total_variation.py
ragavvenkatesan Apr 22, 2022
2aed1d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2022
0c4dd17
Missed importing Metric.
Apr 22, 2022
5b7b4db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2022
e7df5d5
Merge branch 'master' into patch-2
ragavvenkatesan May 5, 2022
7c8bbca
Merge branch 'master' into patch-2
SkafteNicki May 5, 2022
172230d
Merge branch 'master' into patch-2
SkafteNicki May 7, 2022
e7537ea
Merge branch 'master' into patch-2
ragavvenkatesan May 9, 2022
41de209
added functional and modular metric.
May 10, 2022
545d176
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
94cc783
Merge branch 'master' into patch-2
SkafteNicki May 10, 2022
a5e64b0
Adding copyright notice to top of file.
May 10, 2022
7fc1a26
Merge branch 'patch-2' of https://github.com/ragavvenkatesan/metrics …
May 10, 2022
971be9c
Apply first suggestions from code review
ragavvenkatesan May 10, 2022
4d50c91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
386efa2
Merge branch 'master' into patch-2
SkafteNicki May 16, 2022
995c72a
docs
SkafteNicki May 16, 2022
082d312
changelog
SkafteNicki May 16, 2022
2b0b455
init file
SkafteNicki May 16, 2022
7222a06
Merge branch 'master' into patch-2
SkafteNicki May 17, 2022
b73c4bc
Merge branch 'master' into patch-2
Borda Jun 27, 2022
cdfdc69
tv
Borda Jun 27, 2022
5a1c71d
Merge branch 'master' into patch-2
SkafteNicki Jun 28, 2022
83147c2
Merge branch 'master' into patch-2
Borda Jul 11, 2022
79a9751
Merge branch 'master' into patch-2
SkafteNicki Jul 12, 2022
7cdfb09
add tests
SkafteNicki Jul 12, 2022
0173303
init and requirements
SkafteNicki Jul 12, 2022
cde3a0c
working implementation
SkafteNicki Jul 12, 2022
852ea9a
skip some testing
SkafteNicki Jul 12, 2022
ed3e57d
Merge branch 'master' into patch-2
Borda Jul 12, 2022
28636a8
typing
Borda Jul 12, 2022
a9dd4e7
Merge branch 'master' into patch-2
SkafteNicki Jul 13, 2022
2e15d26
Merge branch 'master' into patch-2
Borda Jul 20, 2022
1fa0902
vars
Borda Jul 20, 2022
aea9a8c
fix type
SkafteNicki Jul 20, 2022
e2ab475
Merge branch 'master' into patch-2
Borda Jul 26, 2022
2701f51
Merge branch 'master' into patch-2
SkafteNicki Aug 30, 2022
7b68233
Merge branch 'master' into patch-2
SkafteNicki Sep 29, 2022
771c9d5
Merge branch 'master' into patch-2
SkafteNicki Sep 29, 2022
9c19004
remove renaming
SkafteNicki Sep 29, 2022
d0ae387
fix tests
SkafteNicki Sep 29, 2022
1ce004e
add none option
SkafteNicki Sep 29, 2022
e5d64e0
fix doctest
SkafteNicki Sep 29, 2022
d1a1e35
fix doctests
SkafteNicki Sep 29, 2022
d73f8ac
try again
SkafteNicki Sep 30, 2022
0486a00
Merge branch 'patch-2' of https://github.com/ragavvenkatesan/metrics …
SkafteNicki Sep 30, 2022
113cb10
another fix
SkafteNicki Sep 30, 2022
096c652
fix docbuild
SkafteNicki Sep 30, 2022
e1a6141
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
89a82fe
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
8270a39
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
5b65ce2
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
993ec33
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
6b1e01e
try fixing tests
SkafteNicki Oct 5, 2022
49c32b3
fix flake
SkafteNicki Oct 5, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `TotalVariation` to image package ([#978](https://github.com/PyTorchLightning/metrics/pull/978))


- Added a new NLP metric `InfoLM` ([#915](https://github.com/PyTorchLightning/metrics/pull/915))

Expand Down
22 changes: 22 additions & 0 deletions docs/source/image/total_variation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Total Variation (TV)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

####################
Total Variation (TV)
####################

Module Interface
________________

.. autoclass:: torchmetrics.TotalVariation
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.total_variation
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
.. _MER: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIL: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf
.. _WIP: https://infoscience.epfl.ch/record/82766
.. _TV: https://en.wikipedia.org/wiki/Total_variation_denoising
.. _InfoLM: https://arxiv.org/pdf/2112.01589.pdf
.. _alpha divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf
.. _beta divergence: https://www.sciencedirect.com/science/article/pii/S0047259X08000456
Expand Down
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
scikit-image>0.17.1
pytorch_msssim
kornia
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
SpectralAngleMapper,
SpectralDistortionIndex,
StructuralSimilarityIndexMeasure,
TotalVariation,
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
Expand Down Expand Up @@ -189,6 +190,7 @@
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TotalVariation",
"TranslationEditRate",
"UniversalImageQualityIndex",
"WeightedMeanAbsolutePercentageError",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
Expand Down Expand Up @@ -164,6 +165,7 @@
"structural_similarity_index_measure",
"stat_scores",
"symmetric_mean_absolute_percentage_error",
"total_variation",
"translation_edit_rate",
"universal_image_quality_index",
"spectral_angle_mapper",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.tv import total_variation # noqa: F401
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
68 changes: 68 additions & 0 deletions src/torchmetrics/functional/image/tv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

from torch import Tensor


def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]:
"""Computes total variation statistics on current batch."""
if img.ndim != 4:
raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}")
diff1 = img[..., 1:, :] - img[..., :-1, :]
diff2 = img[..., :, 1:] - img[..., :, :-1]

res1 = diff1.abs().sum([1, 2, 3])
res2 = diff2.abs().sum([1, 2, 3])
score = (res1 + res2).sum()
return score, img.shape[0]


def _total_variation_compute(score: Tensor, num_elements: int, reduction: str) -> Tensor:
"""Compute final total variation score."""
return score if reduction == "sum" else score / num_elements


def total_variation(img: Tensor, reduction: str = "sum") -> Tensor:
"""Computes total variation loss.

Adapted from: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/total_variation.html

Args:
img: A `Tensor` of shape `(N, C, H, W)` consisting of images
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
reduction: a method to reduce metric score over samples.
- ``'mean'``: takes the mean (default)
- ``'sum'``: takes the sum

Returns:
A loss scalar value containing the total variation

Raises:
ValueError:
If ``reduction`` is not one of ``'sum'`` or ``'mean'``
RuntimeError:
If ``img`` is not 4D tensor

Example:
>>> import torch
>>> from torchmetrics.functional import total_variation
>>> _ = torch.manual_seed(42)
>>> img = torch.rand(5, 3, 28, 28)
>>> total_variation(img)
tensor(7546.8018)
"""
if reduction not in ("sum", "mean"):
raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'")
score, num_elements = _total_variation_update(img)
return _total_variation_compute(score, num_elements, reduction)
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@

if _LPIPS_AVAILABLE:
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401

from torchmetrics.image.tv import TotalVariation # noqa: F401
75 changes: 75 additions & 0 deletions src/torchmetrics/image/tv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
from torch import Tensor

from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update
from torchmetrics.metric import Metric


class TotalVariation(Metric):
"""Computes Total Variation loss (`TV`_).

Adapted from: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/total_variation.html

Args:
reduction: a method to reduce metric score over samples.
- ``'mean'``: takes the mean (default)
- ``'sum'``: takes the sum
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
If ``reduction`` is not one of ``'sum'`` or ``'mean'``

Example:
>>> import torch
>>> from torchmetrics import TotalVariation
>>> _ = torch.manual_seed(42)
>>> tv = TotalVariation()
>>> img = torch.rand(5, 3, 28, 28)
>>> tv(img)
tensor(7546.8018)
"""

full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = False
current: Tensor
total: Tensor

def __init__(self, reduction: str = "sum", **kwargs: Any) -> None:
super().__init__(**kwargs)
if reduction not in ("sum", "mean"):
raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'")
self.reduction = reduction

self.add_state("score", default=Tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("num_elements", default=Tensor(0, dtype=torch.int), dist_reduce_fx="sum")

def update(self, img: Tensor) -> None: # type: ignore
"""Update current score with batch of input images.

Args:
img: A `Tensor` of shape `(N, C, H, W)` consisting of images
"""
score, num_elements = _total_variation_update(img)
self.score += score
self.num_elements += num_elements

def compute(self) -> Tensor:
"""Compute final total variation."""
return _total_variation_compute(self.score, self.num_elements, self.reduction)
115 changes: 115 additions & 0 deletions tests/unittests/image/test_tv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from kornia.losses import total_variation as kornia_total_variation

from torchmetrics.functional.image.tv import total_variation
from torchmetrics.image.tv import TotalVariation
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8
from unittests.helpers import seed_all
from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester

seed_all(42)


# add extra argument to make the metric and reference fit into the testing framework
class TotalVariationTester(TotalVariation):
def update(self, img, *args):
super().update(img=img)


def total_variaion_tester(preds, target, reduction="mean"):
return total_variation(preds, reduction)


def total_variation_kornia_tester(preds, target, reduction):
score = kornia_total_variation(preds)
return score.sum() if reduction == "sum" else score.mean()


# define inputs
Input = namedtuple("Input", ["preds", "target"])

_inputs = []
for size, channel, dtype in [
(12, 3, torch.float),
(13, 3, torch.float32),
(14, 3, torch.double),
(15, 3, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(Input(preds=preds, target=target))


@pytest.mark.parametrize(
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="Kornia used as reference requires min PT version")
class TestTotalVariation(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_total_variation(self, preds, target, reduction, ddp, dist_sync_on_step):
"""Test modular implementation."""
self.run_class_metric_test(
ddp,
preds,
target,
TotalVariationTester,
partial(total_variation_kornia_tester, reduction=reduction),
dist_sync_on_step,
metric_args={"reduction": reduction},
)

def test_total_variation_functional(self, preds, target, reduction):
"""Test for functional implementation."""
self.run_functional_metric_test(
preds,
target,
total_variaion_tester,
partial(total_variation_kornia_tester, reduction=reduction),
metric_args={"reduction": reduction},
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6"
)
def test_sam_half_cpu(self, preds, target, reduction):
"""Test for half precision on CPU."""
self.run_precision_test_cpu(
preds,
target,
TotalVariationTester,
total_variaion_tester,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_sam_half_gpu(self, preds, target, reduction):
"""Test for half precision on GPU."""
self.run_precision_test_gpu(preds, target, TotalVariationTester, total_variaion_tester)


def test_correct_args():
"""that that arguments have the right type and sizes."""
with pytest.raises(ValueError, match="Expected argument `reduction`.*"):
_ = TotalVariation(reduction="diff")

with pytest.raises(RuntimeError, match="Expected input `img` to.*"):
_ = total_variation(torch.randn(1, 2, 3))