Skip to content

Commit

Permalink
Exact match multiclass (#1343)
Browse files Browse the repository at this point in the history
* initial work
* more code changes
* changelog
* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 21, 2022
1 parent 96862e0 commit 68a6990
Show file tree
Hide file tree
Showing 9 changed files with 447 additions and 41 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `CLIPScore` to new multimodal package ([#1314](https://github.com/Lightning-AI/metrics/pull/1314))


- Added `MulticlassExactMatch` to classification metrics ([#1343](https://github.com/Lightning-AI/metrics/pull/1343))



### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
25 changes: 25 additions & 0 deletions docs/source/classification/exact_match.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,40 @@ Exact Match
Module Interface
________________

ExactMatch
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ExactMatch
:noindex:

MulticlassExactMatch
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MulticlassExactMatch
:noindex:

MultilabelExactMatch
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MultilabelExactMatch
:noindex:


Functional Interface
____________________

exact_match
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multilabel_exact_match
:noindex:

multiclass_exact_match
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multiclass_exact_match
:noindex:

multilabel_exact_match
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CohenKappa,
ConfusionMatrix,
Dice,
ExactMatch,
F1Score,
FBetaScore,
HammingDistance,
Expand Down Expand Up @@ -126,6 +127,7 @@
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
"ExactMatch",
"ExplainedVariance",
"ExtendedEditDistance",
"F1Score",
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa
from torchmetrics.classification.dice import Dice
from torchmetrics.classification.exact_match import MultilabelExactMatch
from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch
from torchmetrics.classification.f_beta import (
BinaryF1Score,
BinaryFBetaScore,
Expand Down Expand Up @@ -138,6 +138,8 @@
"CohenKappa",
"MulticlassCohenKappa",
"Dice",
"ExactMatch",
"MulticlassExactMatch",
"MultilabelExactMatch",
"BinaryF1Score",
"BinaryFBetaScore",
Expand Down
172 changes: 157 additions & 15 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.exact_match import (
_multilabel_exact_scores_compute,
_multilabel_exact_scores_update,
_exact_match_reduce,
_multiclass_exact_match_update,
_multilabel_exact_match_update,
)
from torchmetrics.functional.classification.stat_scores import (
_multiclass_stat_scores_arg_validation,
_multiclass_stat_scores_format,
_multiclass_stat_scores_tensor_validation,
_multilabel_stat_scores_arg_validation,
_multilabel_stat_scores_format,
_multilabel_stat_scores_tensor_validation,
Expand All @@ -30,6 +34,107 @@
from torchmetrics.utilities.data import dim_zero_cat


class MulticlassExactMatch(Metric):
r"""Computes Exact match (also known as subset accuracy) for multiclass tasks. Exact Match is a stricter version
of accuracy where all labels have to match exactly for the sample to be correctly classified.
Accepts the following input tensors:
- ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
an int tensor.
- ``target`` (int tensor): ``(N, ...)``
The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average`
argument.
Args:
num_classes: Integer specifing the number of labels
multidim_average:
Defines how additionally dimensions ``...`` should be handled. Should be one of the following:
- ``global``: Additional dimensions are flatted along the batch dimension
- ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis.
The statistics in this case are calculated over the additional dimensions.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Returns:
The returned shape depends on the ``multidim_average`` argument:
- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassExactMatch
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassExactMatch(num_classes=3, multidim_average='global')
>>> metric(preds, target)
tensor(0.5000)
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassExactMatch
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassExactMatch(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([1., 0.])
"""
is_differentiable = False
higher_is_better = True
full_state_update: bool = False

def __init__(
self,
num_classes: int,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
top_k, average = 1, None
if validate_args:
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
self.num_classes = num_classes
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args

self.add_state(
"correct",
torch.zeros(1, dtype=torch.long) if self.multidim_average == "global" else [],
dist_reduce_fx="sum" if self.multidim_average == "global" else "cat",
)
self.add_state(
"total",
torch.zeros(1, dtype=torch.long),
dist_reduce_fx="sum" if self.multidim_average == "global" else "mean",
)

def update(self, preds, target) -> None:
if self.validate_args:
_multiclass_stat_scores_tensor_validation(
preds, target, self.num_classes, self.multidim_average, self.ignore_index
)
preds, target = _multiclass_stat_scores_format(preds, target, 1)
correct, total = _multiclass_exact_match_update(preds, target, self.multidim_average)
if self.multidim_average == "samplewise":
self.correct.append(correct)
self.total = total
else:
self.correct += correct
self.total += total

def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)


class MultilabelExactMatch(Metric):
r"""Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter version
of accuracy where all labels have to match exactly for the sample to be correctly classified.
Expand Down Expand Up @@ -60,17 +165,10 @@ class MultilabelExactMatch(Metric):
Set to ``False`` for faster computations.
Returns:
The returned shape depends on the ``average`` and ``multidim_average`` arguments:
- If ``multidim_average`` is set to ``global``:
- If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
- If ``average=None/'none'``, the shape will be ``(C,)``
The returned shape depends on the ``multidim_average`` argument:
- If ``multidim_average`` is set to ``samplewise``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
Example (preds is int tensor):
>>> from torchmetrics.classification import MultilabelExactMatch
Expand Down Expand Up @@ -151,7 +249,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds, target = _multilabel_stat_scores_format(
preds, target, self.num_labels, self.threshold, self.ignore_index
)
correct, total = _multilabel_exact_scores_update(preds, target, self.num_labels, self.multidim_average)
correct, total = _multilabel_exact_match_update(preds, target, self.num_labels, self.multidim_average)
if self.multidim_average == "samplewise":
self.correct.append(correct)
self.total = total
Expand All @@ -160,5 +258,49 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.total += total

def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct)
return _multilabel_exact_scores_compute(correct, self.total)
correct = dim_zero_cat(self.correct) if isinstance(self.correct, list) else self.correct
return _exact_match_reduce(correct, self.total)


class ExactMatch:
r"""Computes Exact match (also known as subset accuracy). Exact Match is a stricter version of accuracy where
all labels have to match exactly for the sample to be correctly classified.
This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'multiclass'`` or ``multilabel``. See the documentation of
:mod:`MulticlassExactMatch` and :mod:`MultilabelExactMatch` for the specific details of
each argument influence and examples.
Legacy Example:
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='global')
>>> metric(preds, target)
tensor(0.5000)
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([1., 0.])
"""

def __new__(
cls,
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args))
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassExactMatch(num_classes, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelExactMatch(num_labels, threshold, **kwargs)
raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}")
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.cohen_kappa import cohen_kappa
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix
from torchmetrics.functional.classification.dice import dice, dice_score
from torchmetrics.functional.classification.exact_match import exact_match
from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score
from torchmetrics.functional.classification.hamming import hamming_distance
from torchmetrics.functional.classification.hinge import hinge_loss
Expand Down Expand Up @@ -115,6 +116,7 @@
"dice_score",
"dice",
"error_relative_global_dimensionless_synthesis",
"exact_match",
"explained_variance",
"extended_edit_distance",
"f1_score",
Expand Down
6 changes: 5 additions & 1 deletion src/torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@
multilabel_confusion_matrix,
)
from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401
from torchmetrics.functional.classification.exact_match import multilabel_exact_match # noqa: F401
from torchmetrics.functional.classification.exact_match import ( # noqa: F401
exact_match,
multiclass_exact_match,
multilabel_exact_match,
)
from torchmetrics.functional.classification.f_beta import ( # noqa: F401
binary_f1_score,
binary_fbeta_score,
Expand Down

0 comments on commit 68a6990

Please sign in to comment.