diff --git a/CHANGELOG.md b/CHANGELOG.md index e0e4b14be10..80a96065eae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed type checking on the `maximize` parameter at the initialization of `MetricTracker` ([#1428](https://github.com/Lightning-AI/metrics/issues/1428)) +- Fixed mixed precision autocast for `SSIM` metric ([#1454](https://github.com/Lightning-AI/metrics/pull/1454)) - Fixed wrongly reset method in `MultioutputWrapper` ([#1460](https://github.com/Lightning-AI/metrics/issues/1460)) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 77c23a9872c..e257212ba7e 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -33,10 +33,7 @@ def _ssim_check_inputs(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """ if preds.dtype != target.dtype: - raise TypeError( - "Expected `preds` and `target` to have the same data type." - f" Got preds: {preds.dtype} and target: {target.dtype}." - ) + target = target.to(preds.dtype) _check_same_shape(preds, target) if len(preds.shape) not in (4, 5): raise ValueError( diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index c4dc62d267b..1c061ab27d2 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -231,11 +231,8 @@ def test_ssim_half_gpu(self, preds, target, sigma): ], ) def test_ssim_invalid_inputs(pred, target, kernel, sigma): - pred_t = torch.rand(pred, dtype=torch.float32) - target_t = torch.rand(target, dtype=torch.float64) - with pytest.raises(TypeError): - structural_similarity_index_measure(pred_t, target_t) - + """Test that an value errors are raised if input sizes are different, kernel length and sigma does not match + size or invalid values are provided.""" pred = torch.rand(pred) target = torch.rand(target) with pytest.raises(ValueError):