Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix interaction between top_k and ignore_index #1589

Merged
merged 7 commits into from Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -103,6 +103,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576))


- Fixed bug related to `top_k>1` and `ignore_index!=None` in `StatScores` based metrics ([#1589](https://github.com/Lightning-AI/metrics/pull/1589))


- Fixed corner case for `PearsonCorrCoef` when running in ddp mode but only on single device ([#1587](https://github.com/Lightning-AI/metrics/pull/1587))


Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/classification/stat_scores.py
Expand Up @@ -358,8 +358,9 @@ def _multiclass_stat_scores_update(
preds = preds.clone()
target = target.clone()
idx = target == ignore_index
preds[idx] = num_classes
target[idx] = num_classes
idx = idx.unsqueeze(1).repeat(1, num_classes, 1) if preds.ndim > target.ndim else idx
preds[idx] = num_classes

if top_k > 1:
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
Expand All @@ -374,7 +375,7 @@ def _multiclass_stat_scores_update(
if 0 <= ignore_index <= num_classes - 1:
target_oh[target == ignore_index, :] = -1
else:
preds_oh = preds_oh[..., :-1]
preds_oh = preds_oh[..., :-1] if top_k == 1 else preds_oh
target_oh = target_oh[..., :-1]
target_oh[target == num_classes, :] = -1
sum_dim = [0, 1] if multidim_average == "global" else [1]
Expand Down
15 changes: 15 additions & 0 deletions tests/unittests/classification/test_stat_scores.py
Expand Up @@ -343,6 +343,21 @@ def test_top_k_multiclass(k, preds, target, average, expected):
)


def test_top_k_ignore_index_multiclass():
"""Test that top_k argument works together with ignore_index."""
preds_without = torch.randn(10, 3).softmax(dim=-1)
target_without = torch.randint(3, (10,))
preds_with = torch.cat([preds_without, torch.randn(10, 3).softmax(dim=-1)], 0)
target_with = torch.cat([target_without, -100 * torch.ones(10)], 0).long()

res_without = multiclass_stat_scores(preds_without, target_without, num_classes=3, average="micro", top_k=2)
res_with = multiclass_stat_scores(
preds_with, target_with, num_classes=3, average="micro", top_k=2, ignore_index=-100
)

assert torch.allclose(res_without, res_with)


def test_multiclass_overflow():
"""Test that multiclass computations does not overflow even on byte input."""
preds = torch.randint(20, (100,)).byte()
Expand Down