diff --git a/CHANGELOG.md b/CHANGELOG.md index a270f02c526..d1be285e7a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576)) +- Fixed corner case for `PearsonCorrCoef` when running in ddp mode but only on single device ([#1587](https://github.com/Lightning-AI/metrics/pull/1587)) + + ## [0.11.2] - 2023-02-21 ### Fixed diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 6a6554a2469..4d2afb09ed7 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -32,7 +32,8 @@ def _final_aggregation( Formula taken from here: `Aggregate the statistics from multiple devices`_ """ - # assert len(means_x) > 1 and len(means_y) > 1 and len(vars_x) > 1 and len(vars_y) > 1 and len(corrs_xy) > 1 + if len(means_x) == 1: + return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0] mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0] for i in range(1, len(means_x)): mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i] diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 7364f653740..7c3883b18b9 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -19,7 +19,7 @@ from scipy.stats import pearsonr from torchmetrics.functional.regression.pearson import pearson_corrcoef -from torchmetrics.regression.pearson import PearsonCorrCoef +from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester @@ -120,3 +120,12 @@ def test_error_on_different_shape(): metric = PearsonCorrCoef(num_outputs=2) with pytest.raises(ValueError, match="Expected argument `num_outputs` to match the second dimension of input.*"): metric(torch.randn(100, 5), torch.randn(100, 5)) + + +@pytest.mark.parametrize("shapes", [(5,), (1, 5), (2, 5)]) +def test_final_aggregation_function(shapes): + """Test that final aggregation function can take various shapes of input.""" + input_fn = lambda: torch.rand(shapes) + output = _final_aggregation(input_fn(), input_fn(), input_fn(), input_fn(), input_fn(), torch.randint(10, shapes)) + assert all(isinstance(out, torch.Tensor) for out in output) + assert all(out.ndim == input_fn().ndim - 1 for out in output)