diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d25019e173..8fd7ca7494e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `CLIPScore` to new multimodal package ([#1314](https://github.com/Lightning-AI/metrics/pull/1314)) +- Added `MulticlassExactMatch` to classification metrics ([#1343](https://github.com/Lightning-AI/metrics/pull/1343)) + + + ### Changed - Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259)) diff --git a/docs/source/classification/exact_match.rst b/docs/source/classification/exact_match.rst index c3a9000d4c5..c2cf947fb5e 100644 --- a/docs/source/classification/exact_match.rst +++ b/docs/source/classification/exact_match.rst @@ -10,15 +10,40 @@ Exact Match Module Interface ________________ +ExactMatch +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.ExactMatch + :noindex: + +MulticlassExactMatch +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassExactMatch + :noindex: + MultilabelExactMatch ^^^^^^^^^^^^^^^^^^^^ .. autoclass:: torchmetrics.classification.MultilabelExactMatch :noindex: + Functional Interface ____________________ +exact_match +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_exact_match + :noindex: + +multiclass_exact_match +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_exact_match + :noindex: + multilabel_exact_match ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 75dd96d4a66..235a7078f39 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -29,6 +29,7 @@ CohenKappa, ConfusionMatrix, Dice, + ExactMatch, F1Score, FBetaScore, HammingDistance, @@ -126,6 +127,7 @@ "Dice", "TweedieDevianceScore", "ErrorRelativeGlobalDimensionlessSynthesis", + "ExactMatch", "ExplainedVariance", "ExtendedEditDistance", "F1Score", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 29185666b5c..ad9c1e39e36 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -44,7 +44,7 @@ ) from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa from torchmetrics.classification.dice import Dice -from torchmetrics.classification.exact_match import MultilabelExactMatch +from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch from torchmetrics.classification.f_beta import ( BinaryF1Score, BinaryFBetaScore, @@ -138,6 +138,8 @@ "CohenKappa", "MulticlassCohenKappa", "Dice", + "ExactMatch", + "MulticlassExactMatch", "MultilabelExactMatch", "BinaryF1Score", "BinaryFBetaScore", diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index b6ba08bf950..40e7f8060ab 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -18,10 +18,14 @@ from typing_extensions import Literal from torchmetrics.functional.classification.exact_match import ( - _multilabel_exact_scores_compute, - _multilabel_exact_scores_update, + _exact_match_reduce, + _multiclass_exact_match_update, + _multilabel_exact_match_update, ) from torchmetrics.functional.classification.stat_scores import ( + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, _multilabel_stat_scores_arg_validation, _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, @@ -30,6 +34,107 @@ from torchmetrics.utilities.data import dim_zero_cat +class MulticlassExactMatch(Metric): + r"""Computes Exact match (also known as subset accuracy) for multiclass tasks. Exact Match is a stricter version + of accuracy where all labels have to match exactly for the sample to be correctly classified. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of labels + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``multidim_average`` argument: + + - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor + - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassExactMatch + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassExactMatch(num_classes=3, multidim_average='global') + >>> metric(preds, target) + tensor(0.5000) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassExactMatch + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassExactMatch(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([1., 0.]) + """ + is_differentiable = False + higher_is_better = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + top_k, average = 1, None + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + self.num_classes = num_classes + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self.add_state( + "correct", + torch.zeros(1, dtype=torch.long) if self.multidim_average == "global" else [], + dist_reduce_fx="sum" if self.multidim_average == "global" else "cat", + ) + self.add_state( + "total", + torch.zeros(1, dtype=torch.long), + dist_reduce_fx="sum" if self.multidim_average == "global" else "mean", + ) + + def update(self, preds, target) -> None: + if self.validate_args: + _multiclass_stat_scores_tensor_validation( + preds, target, self.num_classes, self.multidim_average, self.ignore_index + ) + preds, target = _multiclass_stat_scores_format(preds, target, 1) + correct, total = _multiclass_exact_match_update(preds, target, self.multidim_average) + if self.multidim_average == "samplewise": + self.correct.append(correct) + self.total = total + else: + self.correct += correct + self.total += total + + def compute(self) -> Tensor: + correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct + return _exact_match_reduce(correct, self.total) + + class MultilabelExactMatch(Metric): r"""Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified. @@ -60,17 +165,10 @@ class MultilabelExactMatch(Metric): Set to ``False`` for faster computations. Returns: - The returned shape depends on the ``average`` and ``multidim_average`` arguments: - - - If ``multidim_average`` is set to ``global``: - - - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor - - If ``average=None/'none'``, the shape will be ``(C,)`` + The returned shape depends on the ``multidim_average`` argument: - - If ``multidim_average`` is set to ``samplewise``: - - - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - - If ``average=None/'none'``, the shape will be ``(N, C)`` + - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor + - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` Example (preds is int tensor): >>> from torchmetrics.classification import MultilabelExactMatch @@ -151,7 +249,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds, target = _multilabel_stat_scores_format( preds, target, self.num_labels, self.threshold, self.ignore_index ) - correct, total = _multilabel_exact_scores_update(preds, target, self.num_labels, self.multidim_average) + correct, total = _multilabel_exact_match_update(preds, target, self.num_labels, self.multidim_average) if self.multidim_average == "samplewise": self.correct.append(correct) self.total = total @@ -160,5 +258,49 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += total def compute(self) -> Tensor: - correct = dim_zero_cat(self.correct) - return _multilabel_exact_scores_compute(correct, self.total) + correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct + return _exact_match_reduce(correct, self.total) + + +class ExactMatch: + r"""Computes Exact match (also known as subset accuracy). Exact Match is a stricter version of accuracy where + all labels have to match exactly for the sample to be correctly classified. + + This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`MulticlassExactMatch` and :mod:`MultilabelExactMatch` for the specific details of + each argument influence and examples. + + Legacy Example: + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='global') + >>> metric(preds, target) + tensor(0.5000) + + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([1., 0.]) + """ + + def __new__( + cls, + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassExactMatch(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelExactMatch(num_labels, threshold, **kwargs) + raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 4ac31cd0d20..228b76a4162 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -21,6 +21,7 @@ from torchmetrics.functional.classification.cohen_kappa import cohen_kappa from torchmetrics.functional.classification.confusion_matrix import confusion_matrix from torchmetrics.functional.classification.dice import dice, dice_score +from torchmetrics.functional.classification.exact_match import exact_match from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score from torchmetrics.functional.classification.hamming import hamming_distance from torchmetrics.functional.classification.hinge import hinge_loss @@ -115,6 +116,7 @@ "dice_score", "dice", "error_relative_global_dimensionless_synthesis", + "exact_match", "explained_variance", "extended_edit_distance", "f1_score", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index e772f93aa80..c5625f063b5 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -46,7 +46,11 @@ multilabel_confusion_matrix, ) from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401 -from torchmetrics.functional.classification.exact_match import multilabel_exact_match # noqa: F401 +from torchmetrics.functional.classification.exact_match import ( # noqa: F401 + exact_match, + multiclass_exact_match, + multilabel_exact_match, +) from torchmetrics.functional.classification.f_beta import ( # noqa: F401 binary_f1_score, binary_fbeta_score, diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index 82de77cacd1..6993c3b5ed7 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -18,6 +18,9 @@ from typing_extensions import Literal from torchmetrics.functional.classification.stat_scores import ( + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, _multilabel_stat_scores_arg_validation, _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, @@ -25,7 +28,93 @@ from torchmetrics.utilities.compute import _safe_divide -def _multilabel_exact_scores_update( +def _exact_match_reduce( + correct: Tensor, + total: Tensor, +) -> Tensor: + """Final reduction for exact match.""" + return _safe_divide(correct, total) + + +def _multiclass_exact_match_update( + preds: Tensor, + target: Tensor, + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tuple[Tensor, Tensor]: + """Computes the statistics.""" + correct = (preds == target).sum(1) == preds.shape[1] + correct = correct if multidim_average == "samplewise" else correct.sum() + total = torch.tensor(preds.shape[0] if multidim_average == "global" else 1, device=correct.device) + return correct, total + + +def multiclass_exact_match( + preds: Tensor, + target: Tensor, + num_classes: int, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes Exact match (also known as subset accuracy) for multiclass tasks. Exact Match is a stricter version + of accuracy where all labels have to match exactly for the sample to be correctly classified. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of labels + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``multidim_average`` argument: + + - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor + - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_exact_match + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='global') + tensor(0.5000) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_exact_match + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='samplewise') + tensor([1., 0.]) + """ + top_k, average = 1, None + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + correct, total = _multiclass_exact_match_update(preds, target, multidim_average) + return _exact_match_reduce(correct, total) + + +def _multilabel_exact_match_update( preds: Tensor, target: Tensor, num_labels: int, multidim_average: Literal["global", "samplewise"] = "global" ) -> Tuple[Tensor, Tensor]: """Computes the statistics.""" @@ -38,14 +127,6 @@ def _multilabel_exact_scores_update( return correct, total -def _multilabel_exact_scores_compute( - correct: Tensor, - total: Tensor, -) -> Tensor: - """Final reduction for exact match.""" - return _safe_divide(correct, total) - - def multilabel_exact_match( preds: Tensor, target: Tensor, @@ -86,17 +167,10 @@ def multilabel_exact_match( Set to ``False`` for faster computations. Returns: - The returned shape depends on the ``average`` and ``multidim_average`` arguments: - - - If ``multidim_average`` is set to ``global``: + The returned shape depends on the ``multidim_average`` argument: - - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor - - If ``average=None/'none'``, the shape will be ``(C,)`` - - - If ``multidim_average`` is set to ``samplewise``: - - - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - - If ``average=None/'none'``, the shape will be ``(N, C)`` + - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor + - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` Example (preds is int tensor): >>> from torchmetrics.functional.classification import multilabel_exact_match @@ -129,5 +203,45 @@ def multilabel_exact_match( _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) - correct, total = _multilabel_exact_scores_update(preds, target, num_labels, multidim_average) - return _multilabel_exact_scores_compute(correct, total) + correct, total = _multilabel_exact_match_update(preds, target, num_labels, multidim_average) + return _exact_match_reduce(correct, total) + + +def exact_match( + preds: Tensor, + target: Tensor, + task: Literal["multiclass", "multilabel"], + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes Exact match (also known as subset accuracy). Exact Match is a stricter version of accuracy where + all classes/labels have to match exactly for the sample to be correctly classified. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'multiclass'`` or ``'multilabel'``. See the documentation of + :func:`multiclass_exact_match` and :func:`multilabel_exact_match` for the specific details of + each argument influence and examples. + Legacy Example: + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='global') + tensor(0.5000) + + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) + >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise') + tensor([1., 0.]) + """ + if task == "multiclass": + assert num_classes is not None + return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args) + if task == "multilalbe": + assert num_labels is not None + return multilabel_exact_match( + preds, target, num_labels, threshold, multidim_average, ignore_index, validate_args + ) + raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 66e39f797b9..d70f6c88079 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -18,15 +18,126 @@ import torch from scipy.special import expit as sigmoid -from torchmetrics.classification.exact_match import MultilabelExactMatch -from torchmetrics.functional.classification.exact_match import multilabel_exact_match -from unittests.classification.inputs import _multilabel_cases +from torchmetrics.classification.exact_match import MulticlassExactMatch, MultilabelExactMatch +from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match +from unittests.classification.inputs import _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) +def _sk_exact_match_multiclass(preds, target, ignore_index, multidim_average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + preds = preds.numpy() + target = target.numpy() + + if ignore_index is not None: + target = np.copy(target) + target[target == ignore_index] = -1 + + correct = (preds == target).sum(-1) == preds.shape[1] + correct = correct.sum() if multidim_average == "global" else correct + total = len(preds) if multidim_average == "global" else 1 + return correct / total + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassExactMatch(MetricTester): + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_exact_match(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if target.ndim < 3: + pytest.skip("non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassExactMatch, + sk_metric=partial( + _sk_exact_match_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={ + "ignore_index": ignore_index, + "num_classes": NUM_CLASSES, + "multidim_average": multidim_average, + }, + ) + + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_exact_match_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if target.ndim < 3: + pytest.skip("non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_exact_match, + sk_metric=partial( + _sk_exact_match_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={ + "ignore_index": ignore_index, + "num_classes": NUM_CLASSES, + "multidim_average": multidim_average, + }, + ) + + def test_multiclass_exact_match_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassExactMatch, + metric_functional=multiclass_exact_match, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_exact_match_half_cpu(self, input, dtype): + preds, target = input + + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassExactMatch, + metric_functional=multiclass_exact_match, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_exact_match_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassExactMatch, + metric_functional=multiclass_exact_match, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + def _sk_exact_match_multilabel(preds, target, ignore_index, multidim_average): preds = preds.numpy() target = target.numpy()