From 9858b0ffc137ee5db1c6562844125ede253afcaa Mon Sep 17 00:00:00 2001 From: Ragav Venkatesan Date: Thu, 21 Apr 2022 20:12:49 -0700 Subject: [PATCH 01/30] Create total_variation.py Initialize the PR. --- torchmetrics/image/total_variation.py | 65 +++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 torchmetrics/image/total_variation.py diff --git a/torchmetrics/image/total_variation.py b/torchmetrics/image/total_variation.py new file mode 100644 index 00000000000..1228e203311 --- /dev/null +++ b/torchmetrics/image/total_variation.py @@ -0,0 +1,65 @@ +# Reference code: https://github.com/jxgu1016/Total_Variation_Loss.pytorch +import torch + + +class TotalVariation(Metric): + """ + A method to calculate total variation loss. + + .. note:: + + Because this loss uses sums, the value will be large. Use a weighting + of order e-5 to control it. Ensure to train with half-precision at least. + + :param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + :type dist_sync_on_step: bool + :param compute_on_step: Forward only calls ``update()`` and returns None if this is set to + False. + :type compute_on_step: bool + """ + + is_differentiable = True + higher_is_better = False + current: torch.Tensor + total: torch.Tensor + + def __init__(self, dist_sync_on_step: bool = False, compute_on_step: bool = True): + super().__init__( + dist_sync_on_step=dist_sync_on_step, compute_on_step=compute_on_step + ) + self.add_state( + "current", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum" + ) + self.add_state( + "total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum" + ) + + def update(self, sample: torch.Tensor) -> None: + """ + Update method for TV Loss. + :param sample: A NCHW image batch. + :type sample: torch.Tensor + + :returns: A loss scalar. + :rtype: torch.Tensor + """ + _height = sample.size()[2] + _width = sample.size()[3] + _count_height = self.tensor_size(sample[:, :, 1:, :]) + _count_width = self.tensor_size(sample[:, :, :, 1:]) + _height_tv = torch.pow( + (sample[:, :, 1:, :] - sample[:, :, : _height - 1, :]), 2 + ).sum() + _width_tv = torch.pow( + (sample[:, :, :, 1:] - sample[:, :, :, : _width - 1]), 2 + ).sum() + self.current += 2 * (_height_tv / _count_height + _width_tv / _count_width) + self.total += sample.numel() + + def compute(self): + return self.current.float() / self.total + + @staticmethod + def tensor_size(t): + return t.size()[1] * t.size()[2] * t.size()[3] From 2aed1d1bde5e1d7b376924f3b096c258fe39750a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Apr 2022 03:14:44 +0000 Subject: [PATCH 02/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/image/total_variation.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/torchmetrics/image/total_variation.py b/torchmetrics/image/total_variation.py index 1228e203311..9a20352954f 100644 --- a/torchmetrics/image/total_variation.py +++ b/torchmetrics/image/total_variation.py @@ -3,8 +3,7 @@ class TotalVariation(Metric): - """ - A method to calculate total variation loss. + """A method to calculate total variation loss. .. note:: @@ -25,19 +24,13 @@ class TotalVariation(Metric): total: torch.Tensor def __init__(self, dist_sync_on_step: bool = False, compute_on_step: bool = True): - super().__init__( - dist_sync_on_step=dist_sync_on_step, compute_on_step=compute_on_step - ) - self.add_state( - "current", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum" - ) - self.add_state( - "total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum" - ) + super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=compute_on_step) + self.add_state("current", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") def update(self, sample: torch.Tensor) -> None: - """ - Update method for TV Loss. + """Update method for TV Loss. + :param sample: A NCHW image batch. :type sample: torch.Tensor @@ -48,12 +41,8 @@ def update(self, sample: torch.Tensor) -> None: _width = sample.size()[3] _count_height = self.tensor_size(sample[:, :, 1:, :]) _count_width = self.tensor_size(sample[:, :, :, 1:]) - _height_tv = torch.pow( - (sample[:, :, 1:, :] - sample[:, :, : _height - 1, :]), 2 - ).sum() - _width_tv = torch.pow( - (sample[:, :, :, 1:] - sample[:, :, :, : _width - 1]), 2 - ).sum() + _height_tv = torch.pow((sample[:, :, 1:, :] - sample[:, :, : _height - 1, :]), 2).sum() + _width_tv = torch.pow((sample[:, :, :, 1:] - sample[:, :, :, : _width - 1]), 2).sum() self.current += 2 * (_height_tv / _count_height + _width_tv / _count_width) self.total += sample.numel() From 0c4dd174e4713b3cedcbaa1d4c1b74992eebfd2b Mon Sep 17 00:00:00 2001 From: ragavv Date: Thu, 21 Apr 2022 20:18:30 -0700 Subject: [PATCH 03/30] Missed importing Metric. --- torchmetrics/image/total_variation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/image/total_variation.py b/torchmetrics/image/total_variation.py index 9a20352954f..d4d598ea0d3 100644 --- a/torchmetrics/image/total_variation.py +++ b/torchmetrics/image/total_variation.py @@ -1,5 +1,6 @@ # Reference code: https://github.com/jxgu1016/Total_Variation_Loss.pytorch import torch +from torchmetrics.metric import Metric class TotalVariation(Metric): From 5b7b4dba8e24c3c359463ab99269532c4c04a13c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Apr 2022 03:22:14 +0000 Subject: [PATCH 04/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/image/total_variation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/image/total_variation.py b/torchmetrics/image/total_variation.py index d4d598ea0d3..ffe7e8e79ac 100644 --- a/torchmetrics/image/total_variation.py +++ b/torchmetrics/image/total_variation.py @@ -1,5 +1,6 @@ # Reference code: https://github.com/jxgu1016/Total_Variation_Loss.pytorch import torch + from torchmetrics.metric import Metric From 41de2093dd20888543a9e600d7c33c0d9fc23490 Mon Sep 17 00:00:00 2001 From: ragavv Date: Tue, 10 May 2022 01:43:41 -0700 Subject: [PATCH 05/30] added functional and modular metric. --- torchmetrics/functional/image/__init__.py | 2 + .../functional/image/total_variation.py | 25 ++++++++++ torchmetrics/image/__init__.py | 2 + torchmetrics/image/total_variation.py | 46 ++++++++----------- 4 files changed, 49 insertions(+), 26 deletions(-) create mode 100644 torchmetrics/functional/image/total_variation.py diff --git a/torchmetrics/functional/image/__init__.py b/torchmetrics/functional/image/__init__.py index a438cbe940b..cb1d4285b63 100644 --- a/torchmetrics/functional/image/__init__.py +++ b/torchmetrics/functional/image/__init__.py @@ -21,3 +21,5 @@ structural_similarity_index_measure, ) from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 +from torchmetrics.functional.image.total_variation import total_variation # noqa: F401 + diff --git a/torchmetrics/functional/image/total_variation.py b/torchmetrics/functional/image/total_variation.py new file mode 100644 index 00000000000..daecfe977dc --- /dev/null +++ b/torchmetrics/functional/image/total_variation.py @@ -0,0 +1,25 @@ +import torch + + +def total_variation(img: torch.Tensor) -> torch.Tensor: + """ + Computes total variation loss. + + Adapted from https://github.com/jxgu1016/Total_Variation_Loss.pytorch + Args: + img (torch.Tensor): A NCHW image batch. + + Returns: + A loss scalar value. + """ + + def tensor_size(t): + return t.size()[1] * t.size()[2] * t.size()[3] + + _height = img.size()[2] + _width = img.size()[3] + _count_height = tensor_size(img[:, :, 1:, :]) + _count_width = tensor_size(img[:, :, :, 1:]) + _height_tv = torch.pow((img[:, :, 1:, :] - img[:, :, : _height - 1, :]), 2).sum() + _width_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, : _width - 1]), 2).sum() + return (2 * (_height_tv / _count_height + _width_tv / _count_width)) / img.size()[0] diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index 05fda7d29ac..5438a3e4c77 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -29,3 +29,5 @@ if _LPIPS_AVAILABLE: from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401 + +from torchmetrics.image.total_variation import TotalVariation # noqa: F401 diff --git a/torchmetrics/image/total_variation.py b/torchmetrics/image/total_variation.py index ffe7e8e79ac..7284008be0e 100644 --- a/torchmetrics/image/total_variation.py +++ b/torchmetrics/image/total_variation.py @@ -1,23 +1,17 @@ -# Reference code: https://github.com/jxgu1016/Total_Variation_Loss.pytorch import torch from torchmetrics.metric import Metric class TotalVariation(Metric): - """A method to calculate total variation loss. - - .. note:: - - Because this loss uses sums, the value will be large. Use a weighting - of order e-5 to control it. Ensure to train with half-precision at least. - - :param dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - :type dist_sync_on_step: bool - :param compute_on_step: Forward only calls ``update()`` and returns None if this is set to - False. - :type compute_on_step: bool + """Computes Total Variation loss. + + Adapted from: https://github.com/jxgu1016/Total_Variation_Loss.pytorch + Args: + dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + compute_on_step: Forward only calls ``update()`` and returns None if this is set to + False. """ is_differentiable = True @@ -30,23 +24,23 @@ def __init__(self, dist_sync_on_step: bool = False, compute_on_step: bool = True self.add_state("current", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") - def update(self, sample: torch.Tensor) -> None: + def update(self, img: torch.Tensor) -> None: """Update method for TV Loss. - :param sample: A NCHW image batch. - :type sample: torch.Tensor + Args: + img (torch.Tensor): A NCHW image batch. - :returns: A loss scalar. - :rtype: torch.Tensor + Returns: + A loss scalar value. """ - _height = sample.size()[2] - _width = sample.size()[3] - _count_height = self.tensor_size(sample[:, :, 1:, :]) - _count_width = self.tensor_size(sample[:, :, :, 1:]) - _height_tv = torch.pow((sample[:, :, 1:, :] - sample[:, :, : _height - 1, :]), 2).sum() - _width_tv = torch.pow((sample[:, :, :, 1:] - sample[:, :, :, : _width - 1]), 2).sum() + _height = img.size()[2] + _width = img.size()[3] + _count_height = self.tensor_size(img[:, :, 1:, :]) + _count_width = self.tensor_size(img[:, :, :, 1:]) + _height_tv = torch.pow((img[:, :, 1:, :] - img[:, :, : _height - 1, :]), 2).sum() + _width_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, : _width - 1]), 2).sum() self.current += 2 * (_height_tv / _count_height + _width_tv / _count_width) - self.total += sample.numel() + self.total += img.numel() def compute(self): return self.current.float() / self.total From 545d17605cc4315f822b8f3d96376d89e18d6ad9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 May 2022 08:45:25 +0000 Subject: [PATCH 06/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/__init__.py | 3 +-- torchmetrics/functional/image/total_variation.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/image/__init__.py b/torchmetrics/functional/image/__init__.py index cb1d4285b63..da2f9c48647 100644 --- a/torchmetrics/functional/image/__init__.py +++ b/torchmetrics/functional/image/__init__.py @@ -20,6 +20,5 @@ multiscale_structural_similarity_index_measure, structural_similarity_index_measure, ) -from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 from torchmetrics.functional.image.total_variation import total_variation # noqa: F401 - +from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 diff --git a/torchmetrics/functional/image/total_variation.py b/torchmetrics/functional/image/total_variation.py index daecfe977dc..b4a62e19b38 100644 --- a/torchmetrics/functional/image/total_variation.py +++ b/torchmetrics/functional/image/total_variation.py @@ -2,8 +2,7 @@ def total_variation(img: torch.Tensor) -> torch.Tensor: - """ - Computes total variation loss. + """Computes total variation loss. Adapted from https://github.com/jxgu1016/Total_Variation_Loss.pytorch Args: From a5e64b055cb6c898fb8cdbfd0e7817b8f6709aa3 Mon Sep 17 00:00:00 2001 From: ragavv Date: Tue, 10 May 2022 10:14:51 -0700 Subject: [PATCH 07/30] Adding copyright notice to top of file. --- torchmetrics/functional/image/total_variation.py | 14 ++++++++++++++ torchmetrics/image/total_variation.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/torchmetrics/functional/image/total_variation.py b/torchmetrics/functional/image/total_variation.py index daecfe977dc..04a044f9c9c 100644 --- a/torchmetrics/functional/image/total_variation.py +++ b/torchmetrics/functional/image/total_variation.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch diff --git a/torchmetrics/image/total_variation.py b/torchmetrics/image/total_variation.py index 7284008be0e..a7ed845eb29 100644 --- a/torchmetrics/image/total_variation.py +++ b/torchmetrics/image/total_variation.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from torchmetrics.metric import Metric From 971be9cff92c97b0ceb9155858e78ccc7015c4f9 Mon Sep 17 00:00:00 2001 From: Ragav Venkatesan Date: Tue, 10 May 2022 10:22:34 -0700 Subject: [PATCH 08/30] Apply first suggestions from code review Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/functional/image/total_variation.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchmetrics/functional/image/total_variation.py b/torchmetrics/functional/image/total_variation.py index 8fadcf622cd..f20d8d39481 100644 --- a/torchmetrics/functional/image/total_variation.py +++ b/torchmetrics/functional/image/total_variation.py @@ -26,13 +26,10 @@ def total_variation(img: torch.Tensor) -> torch.Tensor: A loss scalar value. """ - def tensor_size(t): - return t.size()[1] * t.size()[2] * t.size()[3] - _height = img.size()[2] - _width = img.size()[3] - _count_height = tensor_size(img[:, :, 1:, :]) - _count_width = tensor_size(img[:, :, :, 1:]) + _batchsize, _channels, _height, _width = img.shape[1:] + _count_height = _channels * (_height - 1) * _width + _count_width = _channels * _height * (_width - 1) _height_tv = torch.pow((img[:, :, 1:, :] - img[:, :, : _height - 1, :]), 2).sum() _width_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, : _width - 1]), 2).sum() - return (2 * (_height_tv / _count_height + _width_tv / _count_width)) / img.size()[0] + return (2 * (_height_tv / _count_height + _width_tv / _count_width)) / _batchsize From 4d50c910ee953d4a082dc60da9e12935a7011737 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 May 2022 17:23:06 +0000 Subject: [PATCH 09/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/image/total_variation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmetrics/functional/image/total_variation.py b/torchmetrics/functional/image/total_variation.py index f20d8d39481..703b109fb13 100644 --- a/torchmetrics/functional/image/total_variation.py +++ b/torchmetrics/functional/image/total_variation.py @@ -26,7 +26,6 @@ def total_variation(img: torch.Tensor) -> torch.Tensor: A loss scalar value. """ - _batchsize, _channels, _height, _width = img.shape[1:] _count_height = _channels * (_height - 1) * _width _count_width = _channels * _height * (_width - 1) From 995c72a6ee3090e285eb2d41c7882ecb9e83cbfc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 16 May 2022 13:59:33 +0200 Subject: [PATCH 10/30] docs --- docs/source/image/total_variation.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 docs/source/image/total_variation.rst diff --git a/docs/source/image/total_variation.rst b/docs/source/image/total_variation.rst new file mode 100644 index 00000000000..0f0e7398d9b --- /dev/null +++ b/docs/source/image/total_variation.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Total Variation (TV) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +.. include:: ../links.rst + +#################### +Total Variation (TV) +#################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.TotalVariation + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.total_variation + :noindex: From 082d312fe2812373d857e4d8c7b7690b6be79d1c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 16 May 2022 14:01:21 +0200 Subject: [PATCH 11/30] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45b335f9481..f746c5c3e82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Dice` to classification package ([#1021](https://github.com/PyTorchLightning/metrics/pull/1021)) +- Added `TotalVariation` to image package ([#978](https://github.com/PyTorchLightning/metrics/pull/978)) + + ### Changed - From 2b0b4550f2ac06300d3b88eef7d8e108298bbfac Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 16 May 2022 14:03:05 +0200 Subject: [PATCH 12/30] init file --- torchmetrics/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 549367ce4da..63b36764d6d 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -57,6 +57,7 @@ SpectralAngleMapper, SpectralDistortionIndex, StructuralSimilarityIndexMeasure, + TotalVariation, UniversalImageQualityIndex, ) from torchmetrics.metric import Metric # noqa: E402 @@ -187,6 +188,7 @@ "StatScores", "SumMetric", "SymmetricMeanAbsolutePercentageError", + "TotalVariation", "TranslationEditRate", "UniversalImageQualityIndex", "WeightedMeanAbsolutePercentageError", From cdfdc691b72a0917b1e1c76a8e4153abdbc52f22 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 28 Jun 2022 00:41:35 +0200 Subject: [PATCH 13/30] tv --- src/torchmetrics/functional/image/__init__.py | 2 +- src/torchmetrics/image/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index da2f9c48647..c3c53501018 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -20,5 +20,5 @@ multiscale_structural_similarity_index_measure, structural_similarity_index_measure, ) -from torchmetrics.functional.image.total_variation import total_variation # noqa: F401 +from torchmetrics.functional.image.tv import total_variation # noqa: F401 from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 5438a3e4c77..781c28c8ab4 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -30,4 +30,4 @@ if _LPIPS_AVAILABLE: from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401 -from torchmetrics.image.total_variation import TotalVariation # noqa: F401 +from torchmetrics.image.tv import TotalVariation # noqa: F401 From 7cdfb094c71aefea21d21978b17cce3cfb674041 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 12 Jul 2022 15:10:15 +0200 Subject: [PATCH 14/30] add tests --- tests/unittests/image/test_tv.py | 110 +++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 tests/unittests/image/test_tv.py diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py new file mode 100644 index 00000000000..19e13b44107 --- /dev/null +++ b/tests/unittests/image/test_tv.py @@ -0,0 +1,110 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from kornia.losses import total_variation as kornia_total_variation + +from torchmetrics.functional.image.tv import total_variation +from torchmetrics.image.tv import TotalVariation +from unittests.helpers import seed_all +from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester + +seed_all(42) + + +# add extra argument to make the metric and reference fit into the testing framework +class TotalVariationTester(TotalVariation): + def update(self, img, *args): + super().update(img=img) + + +def total_variaion_tester(preds, target, reduction="mean"): + return total_variation(preds, reduction) + + +def total_variation_kornia_tester(preds, target, reduction): + score = kornia_total_variation(preds) + return score.sum() if reduction == "sum" else score.mean() + + +# define inputs +Input = namedtuple("Input", ["preds", "target"]) + +_inputs = [] +for size, channel, dtype in [ + (12, 3, torch.float), + (13, 3, torch.float32), + (14, 3, torch.double), + (15, 3, torch.float64), +]: + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + _inputs.append(Input(preds=preds, target=target)) + + +@pytest.mark.parametrize( + "preds, target", + [(i.preds, i.target) for i in _inputs], +) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +class TestTotalVariation(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_total_variation(self, preds, target, reduction, ddp, dist_sync_on_step): + """Test modular implementation.""" + self.run_class_metric_test( + ddp, + preds, + target, + TotalVariationTester, + partial(total_variation_kornia_tester, reduction=reduction), + dist_sync_on_step, + metric_args={"reduction": reduction}, + ) + + def test_total_variation_functional(self, preds, target, reduction): + """Test for functional implementation.""" + self.run_functional_metric_test( + preds, + target, + total_variaion_tester, + partial(total_variation_kornia_tester, reduction=reduction), + metric_args={"reduction": reduction}, + ) + + def test_sam_half_cpu(self, preds, target, reduction): + """Test for half precision on CPU.""" + self.run_precision_test_cpu( + preds, + target, + TotalVariationTester, + total_variaion_tester, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_sam_half_gpu(self, preds, target, reduction): + """Test for half precision on GPU.""" + self.run_precision_test_gpu(preds, target, TotalVariationTester, total_variaion_tester) + + +def test_correct_args(): + """that that arguments have the right type and sizes.""" + with pytest.raises(ValueError, match="Expected argument `reduction`.*"): + _ = TotalVariation(reduction="diff") + + with pytest.raises(RuntimeError, match="Expected input `img` to.*"): + _ = total_variation(torch.randn(1, 2, 3)) From 017330366f31fe70bf70d82e3a89680c5af35032 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 12 Jul 2022 15:10:37 +0200 Subject: [PATCH 15/30] init and requirements --- requirements/image_test.txt | 1 + src/torchmetrics/functional/__init__.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/requirements/image_test.txt b/requirements/image_test.txt index 5ef3abb3bde..642ab04ea01 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -1,2 +1,3 @@ scikit-image>0.17.1 pytorch_msssim +kornia diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 3782433bb2e..4e45c1cb357 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -47,6 +47,7 @@ multiscale_structural_similarity_index_measure, structural_similarity_index_measure, ) +from torchmetrics.functional.image.tv import total_variation from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance @@ -163,6 +164,7 @@ "structural_similarity_index_measure", "stat_scores", "symmetric_mean_absolute_percentage_error", + "total_variation", "translation_edit_rate", "universal_image_quality_index", "spectral_angle_mapper", From cde3a0c2bb1e45f19c453c987fab4c08c4b66621 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 12 Jul 2022 15:11:32 +0200 Subject: [PATCH 16/30] working implementation --- docs/source/links.rst | 1 + src/torchmetrics/functional/image/tv.py | 58 ++++++++++++++++----- src/torchmetrics/image/tv.py | 69 ++++++++++++++----------- 3 files changed, 86 insertions(+), 42 deletions(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index 9c34dadd216..453a5e841d2 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -86,3 +86,4 @@ .. _MER: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf .. _WIL: https://www.isca-speech.org/archive_v0/archive_papers/interspeech_2004/i04_2765.pdf .. _WIP: https://infoscience.epfl.ch/record/82766 +.. _TV: https://en.wikipedia.org/wiki/Total_variation_denoising diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index 703b109fb13..00b16b7ef1e 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -11,24 +11,58 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple -import torch +from torch import Tensor -def total_variation(img: torch.Tensor) -> torch.Tensor: +def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: + """Computes total variation statistics on current batch.""" + if img.ndim != 4: + raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}") + diff1 = img[..., 1:, :] - img[..., :-1, :] + diff2 = img[..., :, 1:] - img[..., :, :-1] + + res1 = diff1.abs().sum([1, 2, 3]) + res2 = diff2.abs().sum([1, 2, 3]) + score = (res1 + res2).sum() + return score, img.shape[0] + + +def _total_variation_compute(score: Tensor, num_elements: int, reduction: str) -> Tensor: + """Compute final total variation score.""" + return score if reduction == "sum" else score / num_elements + + +def total_variation(img: Tensor, reduction: str = "sum") -> Tensor: """Computes total variation loss. - Adapted from https://github.com/jxgu1016/Total_Variation_Loss.pytorch + Adapted from: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/total_variation.html + Args: - img (torch.Tensor): A NCHW image batch. + img: A `torch.Tensor` of shape `(N, C, H, W)` consisting of images + reduction: a method to reduce metric score over samples. + - ``'mean'``: takes the mean (default) + - ``'sum'``: takes the sum Returns: - A loss scalar value. - """ + A loss scalar value containing the total variation - _batchsize, _channels, _height, _width = img.shape[1:] - _count_height = _channels * (_height - 1) * _width - _count_width = _channels * _height * (_width - 1) - _height_tv = torch.pow((img[:, :, 1:, :] - img[:, :, : _height - 1, :]), 2).sum() - _width_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, : _width - 1]), 2).sum() - return (2 * (_height_tv / _count_height + _width_tv / _count_width)) / _batchsize + Raises: + ValueError: + If ``reduction`` is not one of ``'sum'`` or ``'mean'`` + RuntimeError: + If ``img`` is not 4D tensor + + Example: + >>> import torch + >>> from torchmetrics.functional import total_variation + >>> _ = torch.manual_seed(42) + >>> img = torch.rand(5, 3, 28, 28) + >>> total_variation(img) + tensor(7546.8018) + """ + if reduction not in ("sum", "mean"): + raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'") + score, num_elements = _total_variation_update(img) + return _total_variation_compute(score, num_elements, reduction) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index a7ed845eb29..ddf694e692f 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -14,51 +14,60 @@ import torch +from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update from torchmetrics.metric import Metric class TotalVariation(Metric): - """Computes Total Variation loss. + """Computes Total Variation loss (`TV`_). + + Adapted from: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/total_variation.html - Adapted from: https://github.com/jxgu1016/Total_Variation_Loss.pytorch Args: - dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - compute_on_step: Forward only calls ``update()`` and returns None if this is set to - False. + reduction: a method to reduce metric score over samples. + - ``'mean'``: takes the mean (default) + - ``'sum'``: takes the sum + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If ``reduction`` is not one of ``'sum'`` or ``'mean'`` + + Example: + >>> import torch + >>> from torchmetrics import TotalVariation + >>> _ = torch.manual_seed(42) + >>> tv = TotalVariation() + >>> img = torch.rand(5, 3, 28, 28) + >>> tv(img) + tensor(7546.8018) """ - is_differentiable = True - higher_is_better = False + full_state_update: bool = False + is_differentiable: bool = True + higher_is_better: bool = False current: torch.Tensor total: torch.Tensor - def __init__(self, dist_sync_on_step: bool = False, compute_on_step: bool = True): - super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=compute_on_step) - self.add_state("current", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") + def __init__(self, reduction: str = "sum", **kwargs): + super().__init__(**kwargs) + if reduction not in ("sum", "mean"): + raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'") + self.reduction = reduction + + self.add_state("score", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("num_elements", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") def update(self, img: torch.Tensor) -> None: - """Update method for TV Loss. + """Update current score with batch of input images. Args: - img (torch.Tensor): A NCHW image batch. - - Returns: - A loss scalar value. + img: A `torch.Tensor` of shape `(N, C, H, W)` consisting of images """ - _height = img.size()[2] - _width = img.size()[3] - _count_height = self.tensor_size(img[:, :, 1:, :]) - _count_width = self.tensor_size(img[:, :, :, 1:]) - _height_tv = torch.pow((img[:, :, 1:, :] - img[:, :, : _height - 1, :]), 2).sum() - _width_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, : _width - 1]), 2).sum() - self.current += 2 * (_height_tv / _count_height + _width_tv / _count_width) - self.total += img.numel() + score, num_elements = _total_variation_update(img) + self.score += score + self.num_elements += num_elements def compute(self): - return self.current.float() / self.total - - @staticmethod - def tensor_size(t): - return t.size()[1] * t.size()[2] * t.size()[3] + """Compute final total variation.""" + return _total_variation_compute(self.score, self.num_elements, self.reduction) From 852ea9a054172b34f053a00f873e074652175d60 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 12 Jul 2022 16:00:59 +0200 Subject: [PATCH 17/30] skip some testing --- tests/unittests/image/test_tv.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 19e13b44107..0b01f7664db 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -20,6 +20,7 @@ from torchmetrics.functional.image.tv import total_variation from torchmetrics.image.tv import TotalVariation +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester @@ -86,6 +87,9 @@ def test_total_variation_functional(self, preds, target, reduction): metric_args={"reduction": reduction}, ) + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6" + ) def test_sam_half_cpu(self, preds, target, reduction): """Test for half precision on CPU.""" self.run_precision_test_cpu( From 28636a8864b2cf623fa6bdb7e9219c04eabd4e1d Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 12 Jul 2022 19:47:09 +0200 Subject: [PATCH 18/30] typing --- src/torchmetrics/functional/image/tv.py | 2 +- src/torchmetrics/image/tv.py | 18 ++++++++++-------- ...est_hamming_distance.py => test_hamming.py} | 0 tests/unittests/image/test_tv.py | 3 ++- 4 files changed, 13 insertions(+), 10 deletions(-) rename tests/unittests/classification/{test_hamming_distance.py => test_hamming.py} (100%) diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index 00b16b7ef1e..e421a38b528 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -40,7 +40,7 @@ def total_variation(img: Tensor, reduction: str = "sum") -> Tensor: Adapted from: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/total_variation.html Args: - img: A `torch.Tensor` of shape `(N, C, H, W)` consisting of images + img: A `Tensor` of shape `(N, C, H, W)` consisting of images reduction: a method to reduce metric score over samples. - ``'mean'``: takes the mean (default) - ``'sum'``: takes the sum diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index ddf694e692f..167712edf98 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any import torch +from torch import Tensor from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update from torchmetrics.metric import Metric @@ -46,28 +48,28 @@ class TotalVariation(Metric): full_state_update: bool = False is_differentiable: bool = True higher_is_better: bool = False - current: torch.Tensor - total: torch.Tensor + current: Tensor + total: Tensor - def __init__(self, reduction: str = "sum", **kwargs): + def __init__(self, reduction: str = "sum", **kwargs: Any) -> None: super().__init__(**kwargs) if reduction not in ("sum", "mean"): raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'") self.reduction = reduction - self.add_state("score", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("num_elements", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") + self.add_state("score", default=Tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("num_elements", default=Tensor(0, dtype=torch.int), dist_reduce_fx="sum") - def update(self, img: torch.Tensor) -> None: + def update(self, img: Tensor) -> None: # type: ignore """Update current score with batch of input images. Args: - img: A `torch.Tensor` of shape `(N, C, H, W)` consisting of images + img: A `Tensor` of shape `(N, C, H, W)` consisting of images """ score, num_elements = _total_variation_update(img) self.score += score self.num_elements += num_elements - def compute(self): + def compute(self) -> Tensor: """Compute final total variation.""" return _total_variation_compute(self.score, self.num_elements, self.reduction) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming.py similarity index 100% rename from tests/unittests/classification/test_hamming_distance.py rename to tests/unittests/classification/test_hamming.py diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 0b01f7664db..93ca8249e2d 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -20,7 +20,7 @@ from torchmetrics.functional.image.tv import total_variation from torchmetrics.image.tv import TotalVariation -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8 from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester @@ -62,6 +62,7 @@ def total_variation_kornia_tester(preds, target, reduction): [(i.preds, i.target) for i in _inputs], ) @pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="Kornia used as reference requires min PT version") class TestTotalVariation(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) From 1fa0902ac171ac93fb0984a1677f0467e15150b2 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 20 Jul 2022 11:53:22 +0200 Subject: [PATCH 19/30] vars --- src/torchmetrics/image/tv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 167712edf98..04bee742abe 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -48,8 +48,8 @@ class TotalVariation(Metric): full_state_update: bool = False is_differentiable: bool = True higher_is_better: bool = False - current: Tensor - total: Tensor + score: Tensor + num_elements: Tensor def __init__(self, reduction: str = "sum", **kwargs: Any) -> None: super().__init__(**kwargs) From aea9a8c2cf4329c7bd1e6a7bad65c9f08adb4962 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 20 Jul 2022 12:01:16 +0200 Subject: [PATCH 20/30] fix type --- src/torchmetrics/image/tv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 04bee742abe..a7673803559 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -14,7 +14,7 @@ from typing import Any import torch -from torch import Tensor +from torch import Tensor, tensor from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update from torchmetrics.metric import Metric @@ -57,8 +57,8 @@ def __init__(self, reduction: str = "sum", **kwargs: Any) -> None: raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'") self.reduction = reduction - self.add_state("score", default=Tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("num_elements", default=Tensor(0, dtype=torch.int), dist_reduce_fx="sum") + self.add_state("score", default=tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("num_elements", default=tensor(0, dtype=torch.int), dist_reduce_fx="sum") def update(self, img: Tensor) -> None: # type: ignore """Update current score with batch of input images. From 9c190041e2bad3469165b4d0989a2a8561908690 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 29 Sep 2022 16:44:23 +0200 Subject: [PATCH 21/30] remove renaming --- .../classification/{test_hamming.py => test_hamming_distance.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unittests/classification/{test_hamming.py => test_hamming_distance.py} (100%) diff --git a/tests/unittests/classification/test_hamming.py b/tests/unittests/classification/test_hamming_distance.py similarity index 100% rename from tests/unittests/classification/test_hamming.py rename to tests/unittests/classification/test_hamming_distance.py From d0ae3878bb6bb3d6af8b5d4f164349e214964698 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 29 Sep 2022 16:50:24 +0200 Subject: [PATCH 22/30] fix tests --- tests/unittests/image/test_tv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 93ca8249e2d..39740fb0f96 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -38,7 +38,7 @@ def total_variaion_tester(preds, target, reduction="mean"): def total_variation_kornia_tester(preds, target, reduction): - score = kornia_total_variation(preds) + score = kornia_total_variation(preds).sum(-1) return score.sum() if reduction == "sum" else score.mean() From 1ce004e4892658f84726d6cd9075571ec0532005 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 29 Sep 2022 17:08:51 +0200 Subject: [PATCH 23/30] add none option --- src/torchmetrics/functional/image/tv.py | 33 +++++++++++++++-------- src/torchmetrics/image/tv.py | 35 ++++++++++++++++--------- tests/unittests/image/test_tv.py | 8 ++++-- 3 files changed, 50 insertions(+), 26 deletions(-) diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index e421a38b528..d24f4dac338 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -14,6 +14,7 @@ from typing import Tuple from torch import Tensor +from typing_extensions import Literal def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: @@ -25,32 +26,40 @@ def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: res1 = diff1.abs().sum([1, 2, 3]) res2 = diff2.abs().sum([1, 2, 3]) - score = (res1 + res2).sum() + score = res1 + res2 return score, img.shape[0] -def _total_variation_compute(score: Tensor, num_elements: int, reduction: str) -> Tensor: +def _total_variation_compute( + score: Tensor, num_elements: int, reduction: Literal["mean", "sum", "none", None] +) -> Tensor: """Compute final total variation score.""" - return score if reduction == "sum" else score / num_elements + if reduction == "mean": + return score.sum() / num_elements + elif reduction == "sum": + return score.sum() + elif reduction is None or reduction == "none": + return score + else: + raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None") -def total_variation(img: Tensor, reduction: str = "sum") -> Tensor: +def total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor: """Computes total variation loss. - Adapted from: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/total_variation.html - Args: img: A `Tensor` of shape `(N, C, H, W)` consisting of images reduction: a method to reduce metric score over samples. - - ``'mean'``: takes the mean (default) - - ``'sum'``: takes the sum + - ``'mean'``: takes the mean over samples + - ``'sum'``: takes the sum over samples + - ``None`` or ``'none': return the score per sample Returns: A loss scalar value containing the total variation Raises: ValueError: - If ``reduction`` is not one of ``'sum'`` or ``'mean'`` + If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` RuntimeError: If ``img`` is not 4D tensor @@ -61,8 +70,10 @@ def total_variation(img: Tensor, reduction: str = "sum") -> Tensor: >>> img = torch.rand(5, 3, 28, 28) >>> total_variation(img) tensor(7546.8018) + >>> total_variation(img, reduction=None) + tensor([7546.8018, 7546.8018, 7546.8018, 7546.8018, 7546.8018]) """ - if reduction not in ("sum", "mean"): - raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'") + # code adapted from: + # from kornia.losses import total_variation as kornia_total_variation score, num_elements = _total_variation_update(img) return _total_variation_compute(score, num_elements, reduction) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index a7673803559..9cb26e76969 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -15,25 +15,26 @@ import torch from torch import Tensor, tensor +from typing_extensions import Literal from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat class TotalVariation(Metric): """Computes Total Variation loss (`TV`_). - Adapted from: https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/total_variation.html - Args: reduction: a method to reduce metric score over samples. - - ``'mean'``: takes the mean (default) - - ``'sum'``: takes the sum + - ``'mean'``: takes the mean over samples + - ``'sum'``: takes the sum over samples + - ``None`` or ``'none': return the score per sample kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: - If ``reduction`` is not one of ``'sum'`` or ``'mean'`` + If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` Example: >>> import torch @@ -48,16 +49,17 @@ class TotalVariation(Metric): full_state_update: bool = False is_differentiable: bool = True higher_is_better: bool = False - score: Tensor - num_elements: Tensor - def __init__(self, reduction: str = "sum", **kwargs: Any) -> None: + def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None: super().__init__(**kwargs) - if reduction not in ("sum", "mean"): - raise ValueError("Expected argument `reduction` to either be 'sum' or 'mean'") + if reduction is not None and reduction not in ("sum", "mean", "none"): + raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None") self.reduction = reduction - self.add_state("score", default=tensor(0, dtype=torch.float), dist_reduce_fx="sum") + if self.reduction is None or self.reduction == "none": + self.add_state("score", default=[], dist_reduce_fx="cat") + else: + self.add_state("score", default=tensor(0, dtype=torch.float), dist_reduce_fx="sum") self.add_state("num_elements", default=tensor(0, dtype=torch.int), dist_reduce_fx="sum") def update(self, img: Tensor) -> None: # type: ignore @@ -67,9 +69,16 @@ def update(self, img: Tensor) -> None: # type: ignore img: A `Tensor` of shape `(N, C, H, W)` consisting of images """ score, num_elements = _total_variation_update(img) - self.score += score + if self.reduction is None or self.reduction == "none": + self.score.append(score) + else: + self.score += score.sum() self.num_elements += num_elements def compute(self) -> Tensor: """Compute final total variation.""" - return _total_variation_compute(self.score, self.num_elements, self.reduction) + if self.reduction is None or self.reduction == "none": + score = dim_zero_cat(self.score) + else: + score = self.score + return _total_variation_compute(score, self.num_elements, self.reduction) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 39740fb0f96..6bbd7388d0a 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -39,7 +39,11 @@ def total_variaion_tester(preds, target, reduction="mean"): def total_variation_kornia_tester(preds, target, reduction): score = kornia_total_variation(preds).sum(-1) - return score.sum() if reduction == "sum" else score.mean() + if reduction == "sum": + return score.sum() + elif reduction == "mean": + return score.mean() + return score # define inputs @@ -61,7 +65,7 @@ def total_variation_kornia_tester(preds, target, reduction): "preds, target", [(i.preds, i.target) for i in _inputs], ) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("reduction", ["sum", "mean", None]) @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="Kornia used as reference requires min PT version") class TestTotalVariation(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) From e5d64e01596a66b9fd6d2c5170921dfafbbe540c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 29 Sep 2022 17:13:24 +0200 Subject: [PATCH 24/30] fix doctest --- src/torchmetrics/functional/image/tv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index d24f4dac338..330a0cf7aab 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -71,7 +71,7 @@ def total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] >>> total_variation(img) tensor(7546.8018) >>> total_variation(img, reduction=None) - tensor([7546.8018, 7546.8018, 7546.8018, 7546.8018, 7546.8018]) + tensor([1475.1860, 1530.8628, 1538.2046, 1494.1165, 1508.4319]) """ # code adapted from: # from kornia.losses import total_variation as kornia_total_variation From d1a1e35724dfeddfc230caa2a41070508d366ca4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 29 Sep 2022 20:21:43 +0200 Subject: [PATCH 25/30] fix doctests --- src/torchmetrics/functional/image/tv.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index 330a0cf7aab..e3b0e9815b4 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -70,8 +70,6 @@ def total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] >>> img = torch.rand(5, 3, 28, 28) >>> total_variation(img) tensor(7546.8018) - >>> total_variation(img, reduction=None) - tensor([1475.1860, 1530.8628, 1538.2046, 1494.1165, 1508.4319]) """ # code adapted from: # from kornia.losses import total_variation as kornia_total_variation From d73f8ac97a9a1fb937c9020af6b8004f6a6e0866 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 30 Sep 2022 08:31:20 +0200 Subject: [PATCH 26/30] try again --- src/torchmetrics/image/tv.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 9cb26e76969..ec65b5ac787 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -26,10 +26,12 @@ class TotalVariation(Metric): """Computes Total Variation loss (`TV`_). Args: - reduction: a method to reduce metric score over samples. + reduction: a method to reduce metric score over samples + - ``'mean'``: takes the mean over samples - ``'sum'``: takes the sum over samples - ``None`` or ``'none': return the score per sample + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: From 113cb102db06e206b444f0bedcfe153f318d0aba Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 30 Sep 2022 08:44:30 +0200 Subject: [PATCH 27/30] another fix --- src/torchmetrics/image/tv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index ec65b5ac787..aef5e8b97df 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -30,7 +30,7 @@ class TotalVariation(Metric): - ``'mean'``: takes the mean over samples - ``'sum'``: takes the sum over samples - - ``None`` or ``'none': return the score per sample + - ``None`` or ``'none'``: return the score per sample kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. From 096c65240400a81fbd4d2fa7c4a10b226223a0b4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 30 Sep 2022 10:10:50 +0200 Subject: [PATCH 28/30] fix docbuild --- src/torchmetrics/functional/image/tv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/tv.py b/src/torchmetrics/functional/image/tv.py index e3b0e9815b4..d3432ebdabb 100644 --- a/src/torchmetrics/functional/image/tv.py +++ b/src/torchmetrics/functional/image/tv.py @@ -50,9 +50,10 @@ def total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] Args: img: A `Tensor` of shape `(N, C, H, W)` consisting of images reduction: a method to reduce metric score over samples. + - ``'mean'``: takes the mean over samples - ``'sum'``: takes the sum over samples - - ``None`` or ``'none': return the score per sample + - ``None`` or ``'none'``: return the score per sample Returns: A loss scalar value containing the total variation From 6b1e01ea77f64de6fe226c38329cdd1492e43439 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 5 Oct 2022 19:26:19 +0200 Subject: [PATCH 29/30] try fixing tests --- tests/unittests/image/test_tv.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 6bbd7388d0a..03bb736a449 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -72,6 +72,8 @@ class TestTotalVariation(MetricTester): @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_total_variation(self, preds, target, reduction, ddp, dist_sync_on_step): """Test modular implementation.""" + if reduction == None and ddp: + pytest.skip("reduction=None and ddp=True runs out of memory on CI hardware, but it does work") self.run_class_metric_test( ddp, preds, From 49c32b3d5ab78a137b154d759d0cae66a32f3aa9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 5 Oct 2022 19:50:15 +0200 Subject: [PATCH 30/30] fix flake --- tests/unittests/image/test_tv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 03bb736a449..c97f5ae2e6f 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -72,7 +72,7 @@ class TestTotalVariation(MetricTester): @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_total_variation(self, preds, target, reduction, ddp, dist_sync_on_step): """Test modular implementation.""" - if reduction == None and ddp: + if reduction is None and ddp: pytest.skip("reduction=None and ddp=True runs out of memory on CI hardware, but it does work") self.run_class_metric_test( ddp,