Skip to content

Commit

Permalink
Return info / formatting on Classification docs (#1433)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: SkafteNicki <skaftenicki@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people committed Jan 28, 2023
1 parent 3f93c72 commit 8d17f0b
Show file tree
Hide file tree
Showing 30 changed files with 1,141 additions and 1,080 deletions.
2 changes: 1 addition & 1 deletion docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
.. _Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR): https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173
.. _Tweedie Deviance Score: https://en.wikipedia.org/wiki/Tweedie_distribution#The_Tweedie_deviance
.. _Permutation Invariant Training of Deep Models: https://ieeexplore.ieee.org/document/7952154
.. _Computes the Top-label Calibration Error: https://arxiv.org/pdf/1909.10155.pdf
.. _Top-label Calibration Error: https://arxiv.org/pdf/1909.10155.pdf
.. _Gradient Computation of Image: https://en.wikipedia.org/wiki/Image_gradient
.. _R2 Score_Coefficient Determination: https://en.wikipedia.org/wiki/Coefficient_of_determination
.. _Rank of element tensor: https://github.com/scipy/scipy/blob/v1.6.2/scipy/stats/stats.py#L4140-L4303
Expand Down
127 changes: 62 additions & 65 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ class BinaryAccuracy(BinaryStatScores):
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
Accepts the following input tensors:
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
[0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally,
we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (int tensor): ``(N, ...)``
- ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating
point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid
per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average`
argument.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``ba`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, the metric returns a scalar value.
If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
value per sample.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
Expand All @@ -63,10 +66,6 @@ class BinaryAccuracy(BinaryStatScores):
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Returns:
If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average``
is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample.
Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryAccuracy
Expand Down Expand Up @@ -160,15 +159,27 @@ class MulticlassAccuracy(MulticlassStatScores):
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
Accepts the following input tensors:
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``.
If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert
probabilities/logits into an int tensor.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mca`` (:class:`~torch.Tensor`): A tensor with the accuracy score whose returned shape depends on the
``average`` and ``multidim_average`` arguments:
- If ``multidim_average`` is set to ``global``:
- If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
- If ``average=None/'none'``, the shape will be ``(C,)``
- ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point
we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into
an int tensor.
- ``target`` (int tensor): ``(N, ...)``
- If ``multidim_average`` is set to ``samplewise``:
The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average`
argument.
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
Args:
num_classes: Integer specifing the number of classes
Expand All @@ -195,19 +206,6 @@ class MulticlassAccuracy(MulticlassStatScores):
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Returns:
The returned shape depends on the ``average`` and ``multidim_average`` arguments:
- If ``multidim_average`` is set to ``global``:
- If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
- If ``average=None/'none'``, the shape will be ``(C,)``
- If ``multidim_average`` is set to ``samplewise``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassAccuracy
Expand All @@ -216,8 +214,8 @@ class MulticlassAccuracy(MulticlassStatScores):
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> metric(preds, target)
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
>>> mca(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Example (preds is float tensor):
Expand All @@ -230,8 +228,8 @@ class MulticlassAccuracy(MulticlassStatScores):
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> metric(preds, target)
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
>>> mca(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Example (multidim tensors):
Expand All @@ -241,8 +239,8 @@ class MulticlassAccuracy(MulticlassStatScores):
>>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.2778])
>>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
>>> mca = MulticlassAccuracy(num_classes=3, multidim_average='samplewise', average=None)
>>> mca(preds, target)
tensor([[1.0000, 0.0000, 0.5000],
[0.0000, 0.3333, 0.5000]])
"""
Expand Down Expand Up @@ -313,15 +311,27 @@ class MultilabelAccuracy(MultilabelStatScores):
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
Accepts the following input tensors:
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
[0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally,
we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (int tensor): ``(N, C, ...)``
- ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, C, ...)``. If preds is a floating
point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per
element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``
The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average`
argument.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mla`` (:class:`~torch.Tensor`): A tensor with the accuracy score whose returned shape depends on the
``average`` and ``multidim_average`` arguments:
- If ``multidim_average`` is set to ``global``:
- If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
- If ``average=None/'none'``, the shape will be ``(C,)``
- If ``multidim_average`` is set to ``samplewise``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
Args:
num_labels: Integer specifing the number of labels
Expand All @@ -346,19 +356,6 @@ class MultilabelAccuracy(MultilabelStatScores):
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Returns:
The returned shape depends on the ``average`` and ``multidim_average`` arguments:
- If ``multidim_average`` is set to ``global``:
- If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor
- If ``average=None/'none'``, the shape will be ``(C,)``
- If ``multidim_average`` is set to ``samplewise``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelAccuracy
Expand All @@ -367,8 +364,8 @@ class MultilabelAccuracy(MultilabelStatScores):
>>> metric = MultilabelAccuracy(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> metric = MultilabelAccuracy(num_labels=3, average=None)
>>> metric(preds, target)
>>> mla = MultilabelAccuracy(num_labels=3, average=None)
>>> mla(preds, target)
tensor([1.0000, 0.5000, 0.5000])
Example (preds is float tensor):
Expand All @@ -378,8 +375,8 @@ class MultilabelAccuracy(MultilabelStatScores):
>>> metric = MultilabelAccuracy(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> metric = MultilabelAccuracy(num_labels=3, average=None)
>>> metric(preds, target)
>>> mla = MultilabelAccuracy(num_labels=3, average=None)
>>> mla(preds, target)
tensor([1.0000, 0.5000, 0.5000])
Example (multidim tensors):
Expand All @@ -391,11 +388,11 @@ class MultilabelAccuracy(MultilabelStatScores):
... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
... ]
... )
>>> metric = MultilabelAccuracy(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
>>> mla = MultilabelAccuracy(num_labels=3, multidim_average='samplewise')
>>> mla(preds, target)
tensor([0.3333, 0.1667])
>>> metric = MultilabelAccuracy(num_labels=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
>>> mla = MultilabelAccuracy(num_labels=3, multidim_average='samplewise', average=None)
>>> mla(preds, target)
tensor([[0.5000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000]])
"""
Expand Down

0 comments on commit 8d17f0b

Please sign in to comment.