Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exact match multiclass #1343

Merged
merged 16 commits into from Nov 21, 2022
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -34,6 +34,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `CLIPScore` to new multimodal package ([#1314](https://github.com/Lightning-AI/metrics/pull/1314))


- Added `MulticlassExactMatch` to classification metrics ([#1343](https://github.com/Lightning-AI/metrics/pull/1343))



### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
25 changes: 25 additions & 0 deletions docs/source/classification/exact_match.rst
Expand Up @@ -10,15 +10,40 @@ Exact Match
Module Interface
________________

ExactMatch
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ExactMatch
:noindex:

MulticlassExactMatch
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MulticlassExactMatch
:noindex:

MultilabelExactMatch
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MultilabelExactMatch
:noindex:


Functional Interface
____________________

exact_match
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multilabel_exact_match
:noindex:

multiclass_exact_match
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multiclass_exact_match
:noindex:

multilabel_exact_match
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Expand Up @@ -29,6 +29,7 @@
CohenKappa,
ConfusionMatrix,
Dice,
ExactMatch,
F1Score,
FBetaScore,
HammingDistance,
Expand Down Expand Up @@ -126,6 +127,7 @@
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
"ExactMatch",
"ExplainedVariance",
"ExtendedEditDistance",
"F1Score",
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/__init__.py
Expand Up @@ -44,7 +44,7 @@
)
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa
from torchmetrics.classification.dice import Dice
from torchmetrics.classification.exact_match import MultilabelExactMatch
from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch
from torchmetrics.classification.f_beta import (
BinaryF1Score,
BinaryFBetaScore,
Expand Down Expand Up @@ -138,6 +138,8 @@
"CohenKappa",
"MulticlassCohenKappa",
"Dice",
"ExactMatch",
"MulticlassExactMatch",
"MultilabelExactMatch",
"BinaryF1Score",
"BinaryFBetaScore",
Expand Down
172 changes: 157 additions & 15 deletions src/torchmetrics/classification/exact_match.py
Expand Up @@ -18,10 +18,14 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.exact_match import (
_multilabel_exact_scores_compute,
_multilabel_exact_scores_update,
_exact_match_reduce,
_multiclass_exact_match_update,
_multilabel_exact_match_update,
)
from torchmetrics.functional.classification.stat_scores import (
_multiclass_stat_scores_arg_validation,
_multiclass_stat_scores_format,
_multiclass_stat_scores_tensor_validation,
_multilabel_stat_scores_arg_validation,
_multilabel_stat_scores_format,
_multilabel_stat_scores_tensor_validation,
Expand All @@ -30,6 +34,107 @@
from torchmetrics.utilities.data import dim_zero_cat


class MulticlassExactMatch(Metric):
r"""Computes Exact match (also known as subset accuracy) for multiclass tasks. Exact Match is a stricter version
of accuracy where all labels have to match exactly for the sample to be correctly classified.

Accepts the following input tensors:

- ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
an int tensor.
- ``target`` (int tensor): ``(N, ...)``

The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average`
argument.

Args:
num_classes: Integer specifing the number of labels
multidim_average:
Defines how additionally dimensions ``...`` should be handled. Should be one of the following:

- ``global``: Additional dimensions are flatted along the batch dimension
- ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
The statistics in this case are calculated over the additional dimensions.

ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.

Returns:
The returned shape depends on the ``multidim_average`` argument:

- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``

Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassExactMatch
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassExactMatch(num_classes=3, multidim_average='global')
>>> metric(preds, target)
tensor(0.5000)

Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassExactMatch
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassExactMatch(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([1., 0.])
"""
is_differentiable = False
higher_is_better = True
full_state_update: bool = False

def __init__(
self,
num_classes: int,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
top_k, average = 1, None
if validate_args:
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
self.num_classes = num_classes
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args

self.add_state(
"correct",
torch.zeros(1, dtype=torch.long) if self.multidim_average == "global" else [],
dist_reduce_fx="sum" if self.multidim_average == "global" else "cat",
)
self.add_state(
"total",
torch.zeros(1, dtype=torch.long),
dist_reduce_fx="sum" if self.multidim_average == "global" else "mean",
)

def update(self, preds, target) -> None:
if self.validate_args:
_multiclass_stat_scores_tensor_validation(
preds, target, self.num_classes, self.multidim_average, self.ignore_index
)
preds, target = _multiclass_stat_scores_format(preds, target, 1)
correct, total = _multiclass_exact_match_update(preds, target, self.multidim_average)
if self.multidim_average == "samplewise":
self.correct.append(correct)
self.total = total
else:
self.correct += correct
self.total += total

def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)


class MultilabelExactMatch(Metric):
r"""Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter version
of accuracy where all labels have to match exactly for the sample to be correctly classified.
Expand Down Expand Up @@ -60,17 +165,10 @@ class MultilabelExactMatch(Metric):
Set to ``False`` for faster computations.

Returns:
The returned shape depends on the ``average`` and ``multidim_average`` arguments:

- If ``multidim_average`` is set to ``global``:
The returned shape depends on the ``multidim_average`` argument:

- If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
- If ``average=None/'none'``, the shape will be ``(C,)``

- If ``multidim_average`` is set to ``samplewise``:

- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``

Example (preds is int tensor):
>>> from torchmetrics.classification import MultilabelExactMatch
Expand Down Expand Up @@ -151,7 +249,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds, target = _multilabel_stat_scores_format(
preds, target, self.num_labels, self.threshold, self.ignore_index
)
correct, total = _multilabel_exact_scores_update(preds, target, self.num_labels, self.multidim_average)
correct, total = _multilabel_exact_match_update(preds, target, self.num_labels, self.multidim_average)
if self.multidim_average == "samplewise":
self.correct.append(correct)
self.total = total
Expand All @@ -160,5 +258,49 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.total += total

def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct)
return _multilabel_exact_scores_compute(correct, self.total)
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)


class ExactMatch:
r"""Computes Exact match (also known as subset accuracy). Exact Match is a stricter version of accuracy where
all labels have to match exactly for the sample to be correctly classified.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'multiclass'`` or ``multilabel``. See the documentation of
:mod:`MulticlassExactMatch` and :mod:`MultilabelExactMatch` for the specific details of
each argument influence and examples.
Borda marked this conversation as resolved.
Show resolved Hide resolved

Legacy Example:
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='global')
>>> metric(preds, target)
tensor(0.5000)

>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([1., 0.])
"""

def __new__(
cls,
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args))
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassExactMatch(num_classes, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelExactMatch(num_labels, threshold, **kwargs)
raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}")
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.cohen_kappa import cohen_kappa
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix
from torchmetrics.functional.classification.dice import dice, dice_score
from torchmetrics.functional.classification.exact_match import exact_match
from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score
from torchmetrics.functional.classification.hamming import hamming_distance
from torchmetrics.functional.classification.hinge import hinge_loss
Expand Down Expand Up @@ -114,6 +115,7 @@
"dice_score",
"dice",
"error_relative_global_dimensionless_synthesis",
"exact_match",
"explained_variance",
"extended_edit_distance",
"f1_score",
Expand Down
6 changes: 5 additions & 1 deletion src/torchmetrics/functional/classification/__init__.py
Expand Up @@ -46,7 +46,11 @@
multilabel_confusion_matrix,
)
from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401
from torchmetrics.functional.classification.exact_match import multilabel_exact_match # noqa: F401
from torchmetrics.functional.classification.exact_match import ( # noqa: F401
exact_match,
multiclass_exact_match,
multilabel_exact_match,
)
from torchmetrics.functional.classification.f_beta import ( # noqa: F401
binary_f1_score,
binary_fbeta_score,
Expand Down