Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ignore_index in jaccard index #1386

Merged
merged 11 commits into from Feb 27, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -90,6 +90,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed classification metrics for `byte` input ([#1521](https://github.com/Lightning-AI/metrics/pull/1474))


- Fixed the use of `ignore_index` in `MulticlassJaccardIndex` ([#1386](https://github.com/Lightning-AI/metrics/pull/1386))


## [0.11.1] - 2023-01-30

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/jaccard.py
Expand Up @@ -176,7 +176,7 @@ def __init__(

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average)
return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index)


class MultilabelJaccardIndex(MultilabelConfusionMatrix):
Expand Down
15 changes: 11 additions & 4 deletions src/torchmetrics/functional/classification/jaccard.py
Expand Up @@ -38,6 +38,7 @@
def _jaccard_index_reduce(
confmat: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]],
ignore_index: Optional[int] = None,
) -> Tensor:
"""Perform reduction of an un-normalized confusion matrix into jaccard score.

Expand All @@ -53,6 +54,9 @@ def _jaccard_index_reduce(
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.

ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
"""
allowed_average = ["binary", "micro", "macro", "weighted", "none", None]
if average not in allowed_average:
Expand All @@ -61,6 +65,7 @@ def _jaccard_index_reduce(
if average == "binary":
return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1])
else:
ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0]
if confmat.ndim == 3: # multilabel
num = confmat[:, 1, 1]
denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0]
Expand All @@ -70,16 +75,18 @@ def _jaccard_index_reduce(

if average == "micro":
num = num.sum()
denom = denom.sum()
denom = denom.sum() - (denom[ignore_index] if ignore_index_cond else 0.0)

jaccard = _safe_divide(num, denom)

if average is None or average == "none":
if average is None or average == "none" or average == "micro":
return jaccard
if average == "weighted":
weights = confmat[:, 1, 1] + confmat[:, 1, 0] if confmat.ndim == 3 else confmat.sum(1)
else:
weights = torch.ones_like(jaccard)
if ignore_index_cond:
weights[ignore_index] = 0.0
return ((weights * jaccard) / weights.sum()).sum()


Expand Down Expand Up @@ -217,7 +224,7 @@ def multiclass_jaccard_index(
_multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index)
preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index)
confmat = _multiclass_confusion_matrix_update(preds, target, num_classes)
return _jaccard_index_reduce(confmat, average=average)
return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index)


def _multilabel_jaccard_index_arg_validation(
Expand Down Expand Up @@ -297,7 +304,7 @@ def multilabel_jaccard_index(
_multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index)
preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index)
confmat = _multilabel_confusion_matrix_update(preds, target, num_labels)
return _jaccard_index_reduce(confmat, average=average)
return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index)


def jaccard_index(
Expand Down
8 changes: 6 additions & 2 deletions tests/unittests/classification/test_jaccard.py
Expand Up @@ -124,6 +124,10 @@ def _sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average=
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
if ignore_index is not None and 0 <= ignore_index <= NUM_CLASSES:
labels = [i for i in range(NUM_CLASSES) if i != ignore_index]
res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels)
return np.insert(res, ignore_index, 0.0) if average is None else res
return sk_jaccard_index(y_true=target, y_pred=preds, average=average)


Expand Down Expand Up @@ -236,7 +240,7 @@ def _sklearn_jaccard_index_multilabel(preds, target, ignore_index=None, average=
@pytest.mark.parametrize("input", _multilabel_cases)
class TestMultilabelJaccardIndex(MetricTester):
@pytest.mark.parametrize("average", ["macro", "micro", "weighted", None])
@pytest.mark.parametrize("ignore_index", [None]) # , -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
@pytest.mark.parametrize("ddp", [True, False])
def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average):
preds, target = input
Expand All @@ -256,7 +260,7 @@ def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average):
)

@pytest.mark.parametrize("average", ["macro", "micro", "weighted", None])
@pytest.mark.parametrize("ignore_index", [None, -1, 0])
@pytest.mark.parametrize("ignore_index", [None, -1])
def test_multilabel_jaccard_index_functional(self, input, ignore_index, average):
preds, target = input
if ignore_index is not None:
Expand Down