Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Total Variation Loss #978

Merged
merged 55 commits into from Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
9858b0f
Create total_variation.py
ragavvenkatesan Apr 22, 2022
2aed1d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2022
0c4dd17
Missed importing Metric.
Apr 22, 2022
5b7b4db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2022
e7df5d5
Merge branch 'master' into patch-2
ragavvenkatesan May 5, 2022
7c8bbca
Merge branch 'master' into patch-2
SkafteNicki May 5, 2022
172230d
Merge branch 'master' into patch-2
SkafteNicki May 7, 2022
e7537ea
Merge branch 'master' into patch-2
ragavvenkatesan May 9, 2022
41de209
added functional and modular metric.
May 10, 2022
545d176
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
94cc783
Merge branch 'master' into patch-2
SkafteNicki May 10, 2022
a5e64b0
Adding copyright notice to top of file.
May 10, 2022
7fc1a26
Merge branch 'patch-2' of https://github.com/ragavvenkatesan/metrics …
May 10, 2022
971be9c
Apply first suggestions from code review
ragavvenkatesan May 10, 2022
4d50c91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
386efa2
Merge branch 'master' into patch-2
SkafteNicki May 16, 2022
995c72a
docs
SkafteNicki May 16, 2022
082d312
changelog
SkafteNicki May 16, 2022
2b0b455
init file
SkafteNicki May 16, 2022
7222a06
Merge branch 'master' into patch-2
SkafteNicki May 17, 2022
b73c4bc
Merge branch 'master' into patch-2
Borda Jun 27, 2022
cdfdc69
tv
Borda Jun 27, 2022
5a1c71d
Merge branch 'master' into patch-2
SkafteNicki Jun 28, 2022
83147c2
Merge branch 'master' into patch-2
Borda Jul 11, 2022
79a9751
Merge branch 'master' into patch-2
SkafteNicki Jul 12, 2022
7cdfb09
add tests
SkafteNicki Jul 12, 2022
0173303
init and requirements
SkafteNicki Jul 12, 2022
cde3a0c
working implementation
SkafteNicki Jul 12, 2022
852ea9a
skip some testing
SkafteNicki Jul 12, 2022
ed3e57d
Merge branch 'master' into patch-2
Borda Jul 12, 2022
28636a8
typing
Borda Jul 12, 2022
a9dd4e7
Merge branch 'master' into patch-2
SkafteNicki Jul 13, 2022
2e15d26
Merge branch 'master' into patch-2
Borda Jul 20, 2022
1fa0902
vars
Borda Jul 20, 2022
aea9a8c
fix type
SkafteNicki Jul 20, 2022
e2ab475
Merge branch 'master' into patch-2
Borda Jul 26, 2022
2701f51
Merge branch 'master' into patch-2
SkafteNicki Aug 30, 2022
7b68233
Merge branch 'master' into patch-2
SkafteNicki Sep 29, 2022
771c9d5
Merge branch 'master' into patch-2
SkafteNicki Sep 29, 2022
9c19004
remove renaming
SkafteNicki Sep 29, 2022
d0ae387
fix tests
SkafteNicki Sep 29, 2022
1ce004e
add none option
SkafteNicki Sep 29, 2022
e5d64e0
fix doctest
SkafteNicki Sep 29, 2022
d1a1e35
fix doctests
SkafteNicki Sep 29, 2022
d73f8ac
try again
SkafteNicki Sep 30, 2022
0486a00
Merge branch 'patch-2' of https://github.com/ragavvenkatesan/metrics …
SkafteNicki Sep 30, 2022
113cb10
another fix
SkafteNicki Sep 30, 2022
096c652
fix docbuild
SkafteNicki Sep 30, 2022
e1a6141
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
89a82fe
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
8270a39
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
5b65ce2
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
993ec33
Merge branch 'master' into patch-2
mergify[bot] Oct 4, 2022
6b1e01e
try fixing tests
SkafteNicki Oct 5, 2022
49c32b3
fix flake
SkafteNicki Oct 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions torchmetrics/functional/image/__init__.py
Expand Up @@ -20,4 +20,5 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.total_variation import total_variation # noqa: F401
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
24 changes: 24 additions & 0 deletions torchmetrics/functional/image/total_variation.py
@@ -0,0 +1,24 @@
import torch
ragavvenkatesan marked this conversation as resolved.
Show resolved Hide resolved


def total_variation(img: torch.Tensor) -> torch.Tensor:
"""Computes total variation loss.

Adapted from https://github.com/jxgu1016/Total_Variation_Loss.pytorch
Args:
img (torch.Tensor): A NCHW image batch.

Returns:
A loss scalar value.
"""

def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]
ragavvenkatesan marked this conversation as resolved.
Show resolved Hide resolved

_height = img.size()[2]
_width = img.size()[3]
_count_height = tensor_size(img[:, :, 1:, :])
_count_width = tensor_size(img[:, :, :, 1:])
ragavvenkatesan marked this conversation as resolved.
Show resolved Hide resolved
_height_tv = torch.pow((img[:, :, 1:, :] - img[:, :, : _height - 1, :]), 2).sum()
_width_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, : _width - 1]), 2).sum()
return (2 * (_height_tv / _count_height + _width_tv / _count_width)) / img.size()[0]
ragavvenkatesan marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions torchmetrics/image/__init__.py
Expand Up @@ -29,3 +29,5 @@

if _LPIPS_AVAILABLE:
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401

from torchmetrics.image.total_variation import TotalVariation # noqa: F401
50 changes: 50 additions & 0 deletions torchmetrics/image/total_variation.py
@@ -0,0 +1,50 @@
import torch
ragavvenkatesan marked this conversation as resolved.
Show resolved Hide resolved

from torchmetrics.metric import Metric


class TotalVariation(Metric):
"""Computes Total Variation loss.
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

Adapted from: https://github.com/jxgu1016/Total_Variation_Loss.pytorch
Args:
dist_sync_on_step: Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
compute_on_step: Forward only calls ``update()`` and returns None if this is set to
False.
"""

is_differentiable = True
higher_is_better = False
current: torch.Tensor
total: torch.Tensor

def __init__(self, dist_sync_on_step: bool = False, compute_on_step: bool = True):
super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=compute_on_step)
self.add_state("current", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum")

def update(self, img: torch.Tensor) -> None:
"""Update method for TV Loss.

Args:
img (torch.Tensor): A NCHW image batch.

Returns:
A loss scalar value.
"""
_height = img.size()[2]
_width = img.size()[3]
_count_height = self.tensor_size(img[:, :, 1:, :])
_count_width = self.tensor_size(img[:, :, :, 1:])
_height_tv = torch.pow((img[:, :, 1:, :] - img[:, :, : _height - 1, :]), 2).sum()
_width_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, : _width - 1]), 2).sum()
self.current += 2 * (_height_tv / _count_height + _width_tv / _count_width)
self.total += img.numel()

def compute(self):
return self.current.float() / self.total

@staticmethod
def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]