Skip to content

Commit

Permalink
Fix autocast in SSIM metric (#1454)
Browse files Browse the repository at this point in the history
(cherry picked from commit 3d0c392)
  • Loading branch information
SkafteNicki authored and Borda committed Jan 30, 2023
1 parent fca8d74 commit 1797bf4
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed type checking on the `maximize` parameter at the initialization of `MetricTracker` ([#1428](https://github.com/Lightning-AI/metrics/issues/1428))

- Fixed mixed precision autocast for `SSIM` metric ([#1454](https://github.com/Lightning-AI/metrics/pull/1454))

- Fixed wrongly reset method in `MultioutputWrapper` ([#1460](https://github.com/Lightning-AI/metrics/issues/1460))

Expand Down
5 changes: 1 addition & 4 deletions src/torchmetrics/functional/image/ssim.py
Expand Up @@ -33,10 +33,7 @@ def _ssim_check_inputs(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""

if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got preds: {preds.dtype} and target: {target.dtype}."
)
target = target.to(preds.dtype)
_check_same_shape(preds, target)
if len(preds.shape) not in (4, 5):
raise ValueError(
Expand Down
7 changes: 2 additions & 5 deletions tests/unittests/image/test_ssim.py
Expand Up @@ -231,11 +231,8 @@ def test_ssim_half_gpu(self, preds, target, sigma):
],
)
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
pred_t = torch.rand(pred, dtype=torch.float32)
target_t = torch.rand(target, dtype=torch.float64)
with pytest.raises(TypeError):
structural_similarity_index_measure(pred_t, target_t)

"""Test that an value errors are raised if input sizes are different, kernel length and sigma does not match
size or invalid values are provided."""
pred = torch.rand(pred)
target = torch.rand(target)
with pytest.raises(ValueError):
Expand Down

0 comments on commit 1797bf4

Please sign in to comment.