Skip to content

Commit

Permalink
Fix cornercase for PearsonCorrCoef (#1587)
Browse files Browse the repository at this point in the history
* fix
* changelog
  • Loading branch information
SkafteNicki committed Mar 5, 2023
1 parent 7f01332 commit d01f1fd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -102,6 +102,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576))


- Fixed corner case for `PearsonCorrCoef` when running in ddp mode but only on single device ([#1587](https://github.com/Lightning-AI/metrics/pull/1587))


## [0.11.2] - 2023-02-21

### Fixed
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/regression/pearson.py
Expand Up @@ -32,7 +32,8 @@ def _final_aggregation(
Formula taken from here: `Aggregate the statistics from multiple devices`_
"""
# assert len(means_x) > 1 and len(means_y) > 1 and len(vars_x) > 1 and len(vars_y) > 1 and len(corrs_xy) > 1
if len(means_x) == 1:
return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
for i in range(1, len(means_x)):
mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
Expand Down
11 changes: 10 additions & 1 deletion tests/unittests/regression/test_pearson.py
Expand Up @@ -19,7 +19,7 @@
from scipy.stats import pearsonr

from torchmetrics.functional.regression.pearson import pearson_corrcoef
from torchmetrics.regression.pearson import PearsonCorrCoef
from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation
from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
Expand Down Expand Up @@ -121,3 +121,12 @@ def test_error_on_different_shape():
metric = PearsonCorrCoef(num_outputs=2)
with pytest.raises(ValueError, match="Expected argument `num_outputs` to match the second dimension of input.*"):
metric(torch.randn(100, 5), torch.randn(100, 5))


@pytest.mark.parametrize("shapes", [(5,), (1, 5), (2, 5)])
def test_final_aggregation_function(shapes):
"""Test that final aggregation function can take various shapes of input."""
input_fn = lambda: torch.rand(shapes)
output = _final_aggregation(input_fn(), input_fn(), input_fn(), input_fn(), input_fn(), torch.randint(10, shapes))
assert all(isinstance(out, torch.Tensor) for out in output)
assert all(out.ndim == input_fn().ndim - 1 for out in output)

0 comments on commit d01f1fd

Please sign in to comment.