Skip to content

Commit

Permalink
Fix broken reset method in MultioutputWrapper (#1460)
Browse files Browse the repository at this point in the history
* code fix
* tests
* changelog
  • Loading branch information
SkafteNicki committed Jan 25, 2023
1 parent ef13ca1 commit f3c82e2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -13,8 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328))


- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))


### Changed

- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))
Expand All @@ -38,8 +40,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed type checking on the `maximize` parameter at the initialization of `MetricTracker` ([#1428](https://github.com/Lightning-AI/metrics/issues/1428))


- Fixed wrongly reset method in `MultioutputWrapper` ([#1460](https://github.com/Lightning-AI/metrics/issues/1460))


- Fix dtype checking in `PrecisionRecallCurve` for `target` tensor ([#1457](https://github.com/Lightning-AI/metrics/pull/1457))


## [0.11.0] - 2022-11-30

### Added
Expand Down
11 changes: 10 additions & 1 deletion src/torchmetrics/wrappers/multioutput.py
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, List, Tuple
from typing import Any, Callable, List, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -132,3 +132,12 @@ def reset(self) -> None:
"""Reset all underlying metrics."""
for metric in self.metrics:
metric.reset()
super().reset()

def _wrap_update(self, update: Callable) -> Callable:
"""Overwrite to do nothing."""
return update

def _wrap_compute(self, compute: Callable) -> Callable:
"""Overwrite to do nothing."""
return compute
18 changes: 16 additions & 2 deletions tests/unittests/wrappers/test_multioutput.py
Expand Up @@ -5,10 +5,10 @@
import torch
from sklearn.metrics import accuracy_score
from sklearn.metrics import r2_score as sk_r2score
from torch import Tensor
from torch import Tensor, tensor

from torchmetrics import Metric
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.classification import ConfusionMatrix, MulticlassAccuracy
from torchmetrics.regression import R2Score
from torchmetrics.wrappers.multioutput import MultioutputWrapper
from unittests.helpers import seed_all
Expand Down Expand Up @@ -120,3 +120,17 @@ def test_multioutput_wrapper(
dist_sync_on_step,
metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class),
)


def test_reset_called_correctly():
"""Check that underlying metric is being correctly reset when calling forward."""
base_metric = ConfusionMatrix(task="multiclass", num_classes=2)
cf = MultioutputWrapper(base_metric, num_outputs=2)

res = cf(tensor([[0, 0]]), tensor([[0, 0]]))
assert torch.allclose(res[0], tensor([[1, 0], [0, 0]]))
assert torch.allclose(res[1], tensor([[1, 0], [0, 0]]))
cf.reset()
res = cf(tensor([[1, 1]]), tensor([[0, 0]]))
assert torch.allclose(res[0], tensor([[0, 1], [0, 0]]))
assert torch.allclose(res[1], tensor([[0, 1], [0, 0]]))

0 comments on commit f3c82e2

Please sign in to comment.