diff --git a/CHANGELOG.md b/CHANGELOG.md index f5a5ef0fa38..f76996612ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed classification metrics for `byte` input ([#1521](https://github.com/Lightning-AI/metrics/pull/1474)) +- Fixed the use of `ignore_index` in `MulticlassJaccardIndex` ([#1386](https://github.com/Lightning-AI/metrics/pull/1386)) + + ## [0.11.1] - 2023-01-30 ### Fixed diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index f4451492dc4..36cbaa6e4e8 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -176,7 +176,7 @@ def __init__( def compute(self) -> Tensor: """Compute metric.""" - return _jaccard_index_reduce(self.confmat, average=self.average) + return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index) class MultilabelJaccardIndex(MultilabelConfusionMatrix): diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 986292b61cf..0a83d8de464 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -38,6 +38,7 @@ def _jaccard_index_reduce( confmat: Tensor, average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]], + ignore_index: Optional[int] = None, ) -> Tensor: """Perform reduction of an un-normalized confusion matrix into jaccard score. @@ -53,6 +54,9 @@ def _jaccard_index_reduce( metrics across classes, weighting each class by its support (``tp + fn``). - ``'none'`` or ``None``: Calculate the metric for each class separately, and return the metric for every class. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation """ allowed_average = ["binary", "micro", "macro", "weighted", "none", None] if average not in allowed_average: @@ -61,6 +65,7 @@ def _jaccard_index_reduce( if average == "binary": return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]) + ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0] if confmat.ndim == 3: # multilabel num = confmat[:, 1, 1] denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0] @@ -70,16 +75,18 @@ def _jaccard_index_reduce( if average == "micro": num = num.sum() - denom = denom.sum() + denom = denom.sum() - (denom[ignore_index] if ignore_index_cond else 0.0) jaccard = _safe_divide(num, denom) - if average is None or average == "none": + if average is None or average == "none" or average == "micro": return jaccard if average == "weighted": weights = confmat[:, 1, 1] + confmat[:, 1, 0] if confmat.ndim == 3 else confmat.sum(1) else: weights = torch.ones_like(jaccard) + if ignore_index_cond: + weights[ignore_index] = 0.0 return ((weights * jaccard) / weights.sum()).sum() @@ -217,7 +224,7 @@ def multiclass_jaccard_index( _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) - return _jaccard_index_reduce(confmat, average=average) + return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index) def _multilabel_jaccard_index_arg_validation( @@ -297,7 +304,7 @@ def multilabel_jaccard_index( _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) - return _jaccard_index_reduce(confmat, average=average) + return _jaccard_index_reduce(confmat, average=average, ignore_index=ignore_index) def jaccard_index( diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 408ab6a31d8..ae77ad3f9ab 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -126,6 +126,10 @@ def _sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average= preds = preds.flatten() target = target.flatten() target, preds = remove_ignore_index(target, preds, ignore_index) + if ignore_index is not None and 0 <= ignore_index <= NUM_CLASSES: + labels = [i for i in range(NUM_CLASSES) if i != ignore_index] + res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels) + return np.insert(res, ignore_index, 0.0) if average is None else res return sk_jaccard_index(y_true=target, y_pred=preds, average=average) @@ -242,7 +246,7 @@ class TestMultilabelJaccardIndex(MetricTester): """Test class for `MultilabelJaccardIndex` metric.""" @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) - @pytest.mark.parametrize("ignore_index", [None]) # , -1, 0]) + @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [True, False]) def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average): preds, target = input @@ -262,7 +266,7 @@ def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average): ) @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) - @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ignore_index", [None, -1]) def test_multilabel_jaccard_index_functional(self, input, ignore_index, average): preds, target = input if ignore_index is not None: