Skip to content

Commit

Permalink
fix interaction between top_k and ignore_index (#1589)
Browse files Browse the repository at this point in the history
* fix code
* changelog
* fix tests

(cherry picked from commit 0d28f26)
  • Loading branch information
SkafteNicki authored and Borda committed Mar 10, 2023
1 parent 7de11a7 commit ecf1cd2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576))


- Fixed bug related to `top_k>1` and `ignore_index!=None` in `StatScores` based metrics ([#1589](https://github.com/Lightning-AI/metrics/pull/1589))


- Fixed corner case for `PearsonCorrCoef` when running in ddp mode but only on single device ([#1587](https://github.com/Lightning-AI/metrics/pull/1587))


Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,9 @@ def _multiclass_stat_scores_update(
preds = preds.clone()
target = target.clone()
idx = target == ignore_index
preds[idx] = num_classes
target[idx] = num_classes
idx = idx.unsqueeze(1).repeat(1, num_classes, 1) if preds.ndim > target.ndim else idx
preds[idx] = num_classes

if top_k > 1:
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
Expand All @@ -382,7 +383,7 @@ def _multiclass_stat_scores_update(
if 0 <= ignore_index <= num_classes - 1:
target_oh[target == ignore_index, :] = -1
else:
preds_oh = preds_oh[..., :-1]
preds_oh = preds_oh[..., :-1] if top_k == 1 else preds_oh
target_oh = target_oh[..., :-1]
target_oh[target == num_classes, :] = -1
sum_dim = [0, 1] if multidim_average == "global" else [1]
Expand Down
15 changes: 15 additions & 0 deletions tests/unittests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,21 @@ def test_top_k_multiclass(k, preds, target, average, expected):
)


def test_top_k_ignore_index_multiclass():
"""Test that top_k argument works together with ignore_index."""
preds_without = torch.randn(10, 3).softmax(dim=-1)
target_without = torch.randint(3, (10,))
preds_with = torch.cat([preds_without, torch.randn(10, 3).softmax(dim=-1)], 0)
target_with = torch.cat([target_without, -100 * torch.ones(10)], 0).long()

res_without = multiclass_stat_scores(preds_without, target_without, num_classes=3, average="micro", top_k=2)
res_with = multiclass_stat_scores(
preds_with, target_with, num_classes=3, average="micro", top_k=2, ignore_index=-100
)

assert torch.allclose(res_without, res_with)


def test_multiclass_overflow():
"""Test that multiclass computations does not overflow even on byte input."""
preds = torch.randint(20, (100,)).byte()
Expand Down

0 comments on commit ecf1cd2

Please sign in to comment.