From f3c82e2cda9448948b0a691f54a1e1a306686090 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 25 Jan 2023 07:31:35 +0100 Subject: [PATCH] Fix broken `reset` method in `MultioutputWrapper` (#1460) * code fix * tests * changelog --- CHANGELOG.md | 6 ++++++ src/torchmetrics/wrappers/multioutput.py | 11 ++++++++++- tests/unittests/wrappers/test_multioutput.py | 18 ++++++++++++++++-- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b0b00933dd..1edf4b27f75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328)) + - Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419)) + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) @@ -38,8 +40,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed type checking on the `maximize` parameter at the initialization of `MetricTracker` ([#1428](https://github.com/Lightning-AI/metrics/issues/1428)) +- Fixed wrongly reset method in `MultioutputWrapper` ([#1460](https://github.com/Lightning-AI/metrics/issues/1460)) + + - Fix dtype checking in `PrecisionRecallCurve` for `target` tensor ([#1457](https://github.com/Lightning-AI/metrics/pull/1457)) + ## [0.11.0] - 2022-11-30 ### Added diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index ffc3d9ead3c..a8c292c24e7 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, List, Tuple +from typing import Any, Callable, List, Tuple import torch from torch import Tensor @@ -132,3 +132,12 @@ def reset(self) -> None: """Reset all underlying metrics.""" for metric in self.metrics: metric.reset() + super().reset() + + def _wrap_update(self, update: Callable) -> Callable: + """Overwrite to do nothing.""" + return update + + def _wrap_compute(self, compute: Callable) -> Callable: + """Overwrite to do nothing.""" + return compute diff --git a/tests/unittests/wrappers/test_multioutput.py b/tests/unittests/wrappers/test_multioutput.py index 3435cb73069..19e1fda8793 100644 --- a/tests/unittests/wrappers/test_multioutput.py +++ b/tests/unittests/wrappers/test_multioutput.py @@ -5,10 +5,10 @@ import torch from sklearn.metrics import accuracy_score from sklearn.metrics import r2_score as sk_r2score -from torch import Tensor +from torch import Tensor, tensor from torchmetrics import Metric -from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.classification import ConfusionMatrix, MulticlassAccuracy from torchmetrics.regression import R2Score from torchmetrics.wrappers.multioutput import MultioutputWrapper from unittests.helpers import seed_all @@ -120,3 +120,17 @@ def test_multioutput_wrapper( dist_sync_on_step, metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class), ) + + +def test_reset_called_correctly(): + """Check that underlying metric is being correctly reset when calling forward.""" + base_metric = ConfusionMatrix(task="multiclass", num_classes=2) + cf = MultioutputWrapper(base_metric, num_outputs=2) + + res = cf(tensor([[0, 0]]), tensor([[0, 0]])) + assert torch.allclose(res[0], tensor([[1, 0], [0, 0]])) + assert torch.allclose(res[1], tensor([[1, 0], [0, 0]])) + cf.reset() + res = cf(tensor([[1, 1]]), tensor([[0, 0]])) + assert torch.allclose(res[0], tensor([[0, 1], [0, 0]])) + assert torch.allclose(res[1], tensor([[0, 1], [0, 0]]))