Skip to content

Commit

Permalink
Fix ignore_index in jaccard index (#1386)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Feb 27, 2023
1 parent 2850524 commit 18aa8da
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -95,6 +95,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed classification metrics for `byte` input ([#1521](https://github.com/Lightning-AI/metrics/pull/1474))


- Fixed the use of `ignore_index` in `MulticlassJaccardIndex` ([#1386](https://github.com/Lightning-AI/metrics/pull/1386))


## [0.11.1] - 2023-01-30

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/jaccard.py
Expand Up @@ -176,7 +176,7 @@ def __init__(

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average)
return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index)


class MultilabelJaccardIndex(MultilabelConfusionMatrix):
Expand Down
15 changes: 11 additions & 4 deletions src/torchmetrics/functional/classification/jaccard.py
Expand Up @@ -38,6 +38,7 @@
def _jaccard_index_reduce(
confmat: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]],
ignore_index: Optional[int] = None,
) -> Tensor:
"""Perform reduction of an un-normalized confusion matrix into jaccard score.
Expand All @@ -53,6 +54,9 @@ def _jaccard_index_reduce(
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
"""
allowed_average = ["binary", "micro", "macro", "weighted", "none", None]
if average not in allowed_average:
Expand All @@ -61,6 +65,7 @@ def _jaccard_index_reduce(
if average == "binary":
return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1])

ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0]
if confmat.ndim == 3: # multilabel
num = confmat[:, 1, 1]
denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0]
Expand All @@ -70,16 +75,18 @@ def _jaccard_index_reduce(

if average == "micro":
num = num.sum()
denom = denom.sum()
denom = denom.sum() - (denom[ignore_index] if ignore_index_cond else 0.0)

jaccard = _safe_divide(num, denom)

if average is None or average == "none":
if average is None or average == "none" or average == "micro":
return jaccard
if average == "weighted":
weights = confmat[:, 1, 1] + confmat[:, 1, 0] if confmat.ndim == 3 else confmat.sum(1)
else:
weights = torch.ones_like(jaccard)
if ignore_index_cond:
weights[ignore_index] = 0.0
return ((weights * jaccard) / weights.sum()).sum()


Expand Down Expand Up @@ -217,7 +224,7 @@ def multiclass_jaccard_index(
_multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index)
preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index)
confmat = _multiclass_confusion_matrix_update(preds, target, num_classes)
return _jaccard_index_reduce(confmat, average=average)
return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index)


def _multilabel_jaccard_index_arg_validation(
Expand Down Expand Up @@ -297,7 +304,7 @@ def multilabel_jaccard_index(
_multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index)
preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index)
confmat = _multilabel_confusion_matrix_update(preds, target, num_labels)
return _jaccard_index_reduce(confmat, average=average)
return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index)


def jaccard_index(
Expand Down
8 changes: 6 additions & 2 deletions tests/unittests/classification/test_jaccard.py
Expand Up @@ -126,6 +126,10 @@ def _sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average=
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
if ignore_index is not None and 0 <= ignore_index <= NUM_CLASSES:
labels = [i for i in range(NUM_CLASSES) if i != ignore_index]
res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels)
return np.insert(res, ignore_index, 0.0) if average is None else res
return sk_jaccard_index(y_true=target, y_pred=preds, average=average)


Expand Down Expand Up @@ -242,7 +246,7 @@ class TestMultilabelJaccardIndex(MetricTester):
"""Test class for `MultilabelJaccardIndex` metric."""

@pytest.mark.parametrize("average", ["macro", "micro", "weighted", None])
@pytest.mark.parametrize("ignore_index", [None]) # , -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
@pytest.mark.parametrize("ddp", [True, False])
def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average):
preds, target = input
Expand All @@ -262,7 +266,7 @@ def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average):
)

@pytest.mark.parametrize("average", ["macro", "micro", "weighted", None])
@pytest.mark.parametrize("ignore_index", [None, -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
def test_multilabel_jaccard_index_functional(self, input, ignore_index, average):
preds, target = input
if ignore_index is not None:
Expand Down

0 comments on commit 18aa8da

Please sign in to comment.