Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exact match multiclass #1343

Merged
merged 16 commits into from Nov 21, 2022
Merged
4 changes: 3 additions & 1 deletion CHANGELOG.md
Expand Up @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `normalize` argument to `Inception`, `FID`, `KID` metrics ([#1246](https://github.com/Lightning-AI/metrics/pull/1246))


- Added `MulticlassExactMatch` to classification metrics ([#1343](https://github.com/Lightning-AI/metrics/pull/1343))

### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand All @@ -52,7 +54,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
-


## [0.10.3] - 2022-11-16
Expand Down
25 changes: 25 additions & 0 deletions docs/source/classification/exact_match.rst
Expand Up @@ -10,15 +10,40 @@ Exact Match
Module Interface
________________

ExactMatch
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ExactMatch
:noindex:

MulticlassExactMatch
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MulticlassExactMatch
:noindex:

MultilabelExactMatch
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MultilabelExactMatch
:noindex:


Functional Interface
____________________

exact_match
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multilabel_exact_match
:noindex:

multiclass_exact_match
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multiclass_exact_match
:noindex:

multilabel_exact_match
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
110 changes: 105 additions & 5 deletions src/torchmetrics/classification/exact_match.py
Expand Up @@ -18,10 +18,14 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.exact_match import (
_multilabel_exact_scores_compute,
_multilabel_exact_scores_update,
_exact_match_reduce,
_multiclass_exact_match_update,
_multilabel_exact_match_update,
)
from torchmetrics.functional.classification.stat_scores import (
_multiclass_stat_scores_arg_validation,
_multiclass_stat_scores_format,
_multiclass_stat_scores_tensor_validation,
_multilabel_stat_scores_arg_validation,
_multilabel_stat_scores_format,
_multilabel_stat_scores_tensor_validation,
Expand All @@ -30,6 +34,71 @@
from torchmetrics.utilities.data import dim_zero_cat


class MulticlassExactMatch(Metric):
r"""Computes Exact match (also known as subset accuracy) for multiclass tasks. Exact Match is a stricter version
of accuracy where all labels have to match exactly for the sample to be correctly classified.

Accepts the following input tensors:

- ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
an int tensor.
- ``target`` (int tensor): ``(N, ...)``

The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average`
argument.
"""
is_differentiable = False
higher_is_better = True
full_state_update: bool = False

def __init__(
self,
num_classes: int,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
top_k, average = 1, None
if validate_args:
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
self.num_classes = num_classes
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args

self.add_state(
"correct",
torch.zeros(1, dtype=torch.long) if self.multidim_average == "global" else [],
dist_reduce_fx="sum" if self.multidim_average == "global" else "cat",
)
self.add_state(
"total",
torch.zeros(1, dtype=torch.long),
dist_reduce_fx="sum" if self.multidim_average == "global" else "mean",
)

def update(self, preds, target) -> None:
if self.validate_args:
_multiclass_stat_scores_tensor_validation(
preds, target, self.num_classes, self.multidim_average, self.ignore_index
)
preds, target = _multiclass_stat_scores_format(preds, target, 1)
correct, total = _multiclass_exact_match_update(preds, target, self.multidim_average)
if self.multidim_average == "samplewise":
self.correct.append(correct)
self.total = total
else:
self.correct += correct
self.total += total

def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct) if isinstance(self.correct, Tensor) else self.correct
return _exact_match_reduce(correct, self.total)


class MultilabelExactMatch(Metric):
r"""Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter version
of accuracy where all labels have to match exactly for the sample to be correctly classified.
Expand Down Expand Up @@ -151,7 +220,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds, target = _multilabel_stat_scores_format(
preds, target, self.num_labels, self.threshold, self.ignore_index
)
correct, total = _multilabel_exact_scores_update(preds, target, self.num_labels, self.multidim_average)
correct, total = _multilabel_exact_match_update(preds, target, self.num_labels, self.multidim_average)
if self.multidim_average == "samplewise":
self.correct.append(correct)
self.total = total
Expand All @@ -160,5 +229,36 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.total += total

def compute(self) -> Tensor:
correct = dim_zero_cat(self.correct)
return _multilabel_exact_scores_compute(correct, self.total)
correct = dim_zero_cat(self.correct) if isinstance(self.correct, Tensor) else self.correct
return _exact_match_reduce(correct, self.total)


class ExactMatch:
r"""Computes Exact match (also known as subset accuracy). Exact Match is a stricter version of accuracy where
all labels have to match exactly for the sample to be correctly classified.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'multiclass'`` or ``multilabel``. See the documentation of
:mod:`MulticlassExactMatch` and :mod:`MultilabelExactMatch` for the specific details of
each argument influence and examples.
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""

def __new__(
cls,
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args))
if task == "multiclass":
assert isinstance(num_classes, int)
return MulticlassExactMatch(num_classes, **kwargs)
if task == "multilabel":
assert isinstance(num_labels, int)
return MultilabelExactMatch(num_labels, threshold, **kwargs)
raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}")
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.cohen_kappa import cohen_kappa
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix
from torchmetrics.functional.classification.dice import dice, dice_score
from torchmetrics.functional.classification.exact_match import exact_match
from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score
from torchmetrics.functional.classification.hamming import hamming_distance
from torchmetrics.functional.classification.hinge import hinge_loss
Expand Down Expand Up @@ -114,6 +115,7 @@
"dice_score",
"dice",
"error_relative_global_dimensionless_synthesis",
"exact_match",
"explained_variance",
"extended_edit_distance",
"f1_score",
Expand Down
6 changes: 5 additions & 1 deletion src/torchmetrics/functional/classification/__init__.py
Expand Up @@ -46,7 +46,11 @@
multilabel_confusion_matrix,
)
from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401
from torchmetrics.functional.classification.exact_match import multilabel_exact_match # noqa: F401
from torchmetrics.functional.classification.exact_match import ( # noqa: F401
exact_match,
multiclass_exact_match,
multilabel_exact_match,
)
from torchmetrics.functional.classification.f_beta import ( # noqa: F401
binary_f1_score,
binary_fbeta_score,
Expand Down
87 changes: 76 additions & 11 deletions src/torchmetrics/functional/classification/exact_match.py
Expand Up @@ -18,14 +18,57 @@
from typing_extensions import Literal

from torchmetrics.functional.classification.stat_scores import (
_multiclass_stat_scores_arg_validation,
_multiclass_stat_scores_format,
_multiclass_stat_scores_tensor_validation,
_multilabel_stat_scores_arg_validation,
_multilabel_stat_scores_format,
_multilabel_stat_scores_tensor_validation,
)
from torchmetrics.utilities.compute import _safe_divide


def _multilabel_exact_scores_update(
def _exact_match_reduce(
correct: Tensor,
total: Tensor,
) -> Tensor:
"""Final reduction for exact match."""
return _safe_divide(correct, total)


def _multiclass_exact_match_update(
preds: Tensor,
target: Tensor,
multidim_average: Literal["global", "samplewise"] = "global",
) -> Tuple[Tensor, Tensor]:
"""Computes the statistics."""
correct = (preds == target).sum(1) == preds.shape[1]
correct = correct if multidim_average == "samplewise" else correct.sum()
total = torch.tensor(preds.shape[0] if multidim_average == "global" else 1, device=correct.device)
return correct, total


def multiclass_exact_match(
preds: Tensor,
target: Tensor,
num_classes: int,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor:
top_k, average, = (
1,
None,
Borda marked this conversation as resolved.
Show resolved Hide resolved
)
if validate_args:
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
_multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
preds, target = _multiclass_stat_scores_format(preds, target, num_classes, top_k)
correct, total = _multiclass_exact_match_update(preds, target, multidim_average)
return _exact_match_reduce(correct, total)


def _multilabel_exact_match_update(
preds: Tensor, target: Tensor, num_labels: int, multidim_average: Literal["global", "samplewise"] = "global"
) -> Tuple[Tensor, Tensor]:
"""Computes the statistics."""
Expand All @@ -38,14 +81,6 @@ def _multilabel_exact_scores_update(
return correct, total


def _multilabel_exact_scores_compute(
correct: Tensor,
total: Tensor,
) -> Tensor:
"""Final reduction for exact match."""
return _safe_divide(correct, total)


def multilabel_exact_match(
preds: Tensor,
target: Tensor,
Expand Down Expand Up @@ -129,5 +164,35 @@ def multilabel_exact_match(
_multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
_multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index)
preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index)
correct, total = _multilabel_exact_scores_update(preds, target, num_labels, multidim_average)
return _multilabel_exact_scores_compute(correct, total)
correct, total = _multilabel_exact_match_update(preds, target, num_labels, multidim_average)
return _exact_match_reduce(correct, total)


def exact_match(
preds: Tensor,
target: Tensor,
task: Literal["multiclass", "multilabel"],
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
threshold: float = 0.5,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor:
r"""Computes Exact match (also known as subset accuracy). Exact Match is a stricter version of accuracy where
all classes/labels have to match exactly for the sample to be correctly classified.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'multiclass'`` or ``'multilabel'``. See the documentation of
:func:`multiclass_exact_match` and :func:`multilabel_exact_match` for the specific details of
each argument influence and examples.
"""
Borda marked this conversation as resolved.
Show resolved Hide resolved
if task == "multiclass":
assert num_classes is not None
return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args)
if task == "multilalbe":
assert num_labels is not None
return multilabel_exact_match(
preds, target, num_labels, threshold, multidim_average, ignore_index, validate_args
)
raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}")