From d01f1fd2cfc63a023229ac9f792e130af7254591 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sun, 5 Mar 2023 21:40:54 +0100 Subject: [PATCH] Fix cornercase for `PearsonCorrCoef` (#1587) * fix * changelog --- CHANGELOG.md | 3 +++ src/torchmetrics/regression/pearson.py | 3 ++- tests/unittests/regression/test_pearson.py | 11 ++++++++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1a55c77ce7..9f6dc80500e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576)) +- Fixed corner case for `PearsonCorrCoef` when running in ddp mode but only on single device ([#1587](https://github.com/Lightning-AI/metrics/pull/1587)) + + ## [0.11.2] - 2023-02-21 ### Fixed diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index c77efb69673..8004c4f3b54 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -32,7 +32,8 @@ def _final_aggregation( Formula taken from here: `Aggregate the statistics from multiple devices`_ """ - # assert len(means_x) > 1 and len(means_y) > 1 and len(vars_x) > 1 and len(vars_y) > 1 and len(corrs_xy) > 1 + if len(means_x) == 1: + return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0] mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0] for i in range(1, len(means_x)): mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i] diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 79e420fce7b..dc695782f53 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -19,7 +19,7 @@ from scipy.stats import pearsonr from torchmetrics.functional.regression.pearson import pearson_corrcoef -from torchmetrics.regression.pearson import PearsonCorrCoef +from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -121,3 +121,12 @@ def test_error_on_different_shape(): metric = PearsonCorrCoef(num_outputs=2) with pytest.raises(ValueError, match="Expected argument `num_outputs` to match the second dimension of input.*"): metric(torch.randn(100, 5), torch.randn(100, 5)) + + +@pytest.mark.parametrize("shapes", [(5,), (1, 5), (2, 5)]) +def test_final_aggregation_function(shapes): + """Test that final aggregation function can take various shapes of input.""" + input_fn = lambda: torch.rand(shapes) + output = _final_aggregation(input_fn(), input_fn(), input_fn(), input_fn(), input_fn(), torch.randint(10, shapes)) + assert all(isinstance(out, torch.Tensor) for out in output) + assert all(out.ndim == input_fn().ndim - 1 for out in output)