Skip to content

Commit

Permalink
Fix broken reset method in MultioutputWrapper (#1460)
Browse files Browse the repository at this point in the history
* code fix
* tests
* changelog

(cherry picked from commit f3c82e2)
  • Loading branch information
SkafteNicki authored and Borda committed Jan 30, 2023
1 parent b45213e commit 09131f2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-



### Changed

-
Expand All @@ -34,8 +35,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed type checking on the `maximize` parameter at the initialization of `MetricTracker` ([#1428](https://github.com/Lightning-AI/metrics/issues/1428))


- Fixed wrongly reset method in `MultioutputWrapper` ([#1460](https://github.com/Lightning-AI/metrics/issues/1460))


- Fix dtype checking in `PrecisionRecallCurve` for `target` tensor ([#1457](https://github.com/Lightning-AI/metrics/pull/1457))


## [0.11.0] - 2022-11-30

### Added
Expand Down
11 changes: 10 additions & 1 deletion src/torchmetrics/wrappers/multioutput.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, List, Tuple
from typing import Any, Callable, List, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -132,3 +132,12 @@ def reset(self) -> None:
"""Reset all underlying metrics."""
for metric in self.metrics:
metric.reset()
super().reset()

def _wrap_update(self, update: Callable) -> Callable:
"""Overwrite to do nothing."""
return update

def _wrap_compute(self, compute: Callable) -> Callable:
"""Overwrite to do nothing."""
return compute
18 changes: 16 additions & 2 deletions tests/unittests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch
from sklearn.metrics import accuracy_score
from sklearn.metrics import r2_score as sk_r2score
from torch import Tensor
from torch import Tensor, tensor

from torchmetrics import Metric
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.classification import ConfusionMatrix, MulticlassAccuracy
from torchmetrics.regression import R2Score
from torchmetrics.wrappers.multioutput import MultioutputWrapper
from unittests.helpers import seed_all
Expand Down Expand Up @@ -122,3 +122,17 @@ def test_multioutput_wrapper(
dist_sync_on_step,
metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class, **metric_kwargs),
)


def test_reset_called_correctly():
"""Check that underlying metric is being correctly reset when calling forward."""
base_metric = ConfusionMatrix(task="multiclass", num_classes=2)
cf = MultioutputWrapper(base_metric, num_outputs=2)

res = cf(tensor([[0, 0]]), tensor([[0, 0]]))
assert torch.allclose(res[0], tensor([[1, 0], [0, 0]]))
assert torch.allclose(res[1], tensor([[1, 0], [0, 0]]))
cf.reset()
res = cf(tensor([[1, 1]]), tensor([[0, 0]]))
assert torch.allclose(res[0], tensor([[0, 1], [0, 0]]))
assert torch.allclose(res[1], tensor([[0, 1], [0, 0]]))

0 comments on commit 09131f2

Please sign in to comment.