Skip to content

Commit

Permalink
Add LogCosh Error (#1316)
Browse files Browse the repository at this point in the history
* Add LogCosh Error #1315

* Update chlog

* Fix doc examples

* Fix a typo in math

* mypy

* mypy

* Specify device for a new tensor

* .
  • Loading branch information
stancld committed Nov 8, 2022
1 parent a3dc40c commit 920fe0f
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271))


- Added `LogCoshError` to regression package ([#1316](https://github.com/Lightning-AI/metrics/pull/1316))


### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Expand Up @@ -95,3 +95,4 @@
.. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric
.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
22 changes: 22 additions & 0 deletions docs/source/regression/log_cosh_error.rst
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Log Cosh Error
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Regression

.. include:: ../links.rst

##############
Log Cosh Error
##############

Module Interface
________________

.. autoclass:: torchmetrics.LogCoshError
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.log_cosh_error
:noindex:
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Expand Up @@ -59,6 +59,7 @@
ExplainedVariance,
KendallRankCorrCoef,
KLDivergence,
LogCoshError,
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredError,
Expand Down Expand Up @@ -132,6 +133,7 @@
"JaccardIndex",
"KendallRankCorrCoef",
"KLDivergence",
"LogCoshError",
"MatchErrorRate",
"MatthewsCorrCoef",
"MaxMetric",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Expand Up @@ -51,6 +51,7 @@
from torchmetrics.functional.regression.explained_variance import explained_variance
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef
from torchmetrics.functional.regression.kl_divergence import kl_divergence
from torchmetrics.functional.regression.log_cosh import log_cosh_error
from torchmetrics.functional.regression.log_mse import mean_squared_log_error
from torchmetrics.functional.regression.mae import mean_absolute_error
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error
Expand Down Expand Up @@ -115,6 +116,7 @@
"jaccard_index",
"kendall_rank_corrcoef",
"kl_divergence",
"log_cosh_error",
"match_error_rate",
"matthews_corrcoef",
"mean_absolute_error",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/regression/__init__.py
Expand Up @@ -16,6 +16,7 @@
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef # noqa: F401
from torchmetrics.functional.regression.kl_divergence import kl_divergence # noqa: F401
from torchmetrics.functional.regression.log_cosh import log_cosh_error # noqa: F401
from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error # noqa: F401
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/regression/kendall.py
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.data import _bincount, dim_zero_cat
from torchmetrics.utilities.enums import EnumStr
Expand Down Expand Up @@ -265,7 +265,7 @@ def _kendall_corrcoef_update(
"""
# Data checking
_check_same_shape(preds, target)
_check_data_shape_for_corr_coef(preds, target, num_outputs)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

if num_outputs == 1:
preds = preds.unsqueeze(1)
Expand Down
92 changes: 92 additions & 0 deletions src/torchmetrics/functional/regression/log_cosh.py
@@ -0,0 +1,92 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape


def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
if preds.ndim == 2:
return preds, target
return preds.unsqueeze(1), target.unsqueeze(1)


def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute LogCosh error.
Checks for same shape of input tensors.
Args:
preds: Predicted tensor
target: Ground truth tensor
Return:
Sum of LogCosh error over examples, and total number of examples
"""
_check_same_shape(preds, target)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

preds, target = _unsqueeze_tensors(preds, target)
diff = preds - target
sum_log_cosh_error = torch.log((torch.exp(diff) + torch.exp(-diff)) / 2).sum(0).squeeze()
n_obs = torch.tensor(target.shape[0], device=preds.device)
return sum_log_cosh_error, n_obs


def _log_cosh_error_compute(sum_log_cosh_error: Tensor, n_obs: Tensor) -> Tensor:
"""Computes Mean Squared Error.
Args:
sum_squared_error: Sum of LogCosh errors over all observations
n_obs: Number of predictions or observations
"""
return (sum_log_cosh_error / n_obs).squeeze()


def log_cosh_error(preds: Tensor, target: Tensor) -> Tensor:
r"""Compute the `LogCosh Error`_.
.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
Args:
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
Return:
Tensor with LogCosh error
Example (single output regression)::
>>> from torchmetrics.functional import log_cosh_error
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> log_cosh_error(preds, target)
tensor(0.3523)
Example (multi output regression)::
>>> from torchmetrics.functional import log_cosh_error
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
>>> log_cosh_error(preds, target)
tensor([0.9176, 0.4277, 0.2194])
"""
sum_log_cosh_error, n_obs = _log_cosh_error_update(
preds, target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
)
return _log_cosh_error_compute(sum_log_cosh_error, n_obs)
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/regression/pearson.py
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape


Expand Down Expand Up @@ -45,7 +45,7 @@ def _pearson_corrcoef_update(
"""
# Data checking
_check_same_shape(preds, target)
_check_data_shape_for_corr_coef(preds, target, num_outputs)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

n_obs = preds.shape[0]
mx_new = (n_prior * mean_x + preds.mean(0) * n_obs) / (n_prior + n_obs)
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/regression/spearman.py
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape


Expand Down Expand Up @@ -68,7 +68,7 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -
"Expected `preds` and `target` both to be floating point tensors, but got {pred.dtype} and {target.dtype}"
)
_check_same_shape(preds, target)
_check_data_shape_for_corr_coef(preds, target, num_outputs)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

return preds, target

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/regression/utils.py
Expand Up @@ -14,7 +14,7 @@
from torch import Tensor


def _check_data_shape_for_corr_coef(preds: Tensor, target: Tensor, num_outputs: int) -> None:
def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs: int) -> None:
"""Check that predictions and target have the correct shape, else raise error."""
if preds.ndim > 2 or target.ndim > 2:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/regression/__init__.py
Expand Up @@ -16,6 +16,7 @@
from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401
from torchmetrics.regression.kendall import KendallRankCorrCoef # noqa: F401
from torchmetrics.regression.kl_divergence import KLDivergence # noqa: F401
from torchmetrics.regression.log_cosh import LogCoshError # noqa: F401
from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mape import MeanAbsolutePercentageError # noqa: F401
Expand Down
83 changes: 83 additions & 0 deletions src/torchmetrics/regression/log_cosh.py
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
from torch import Tensor

from torchmetrics.functional.regression.log_cosh import _log_cosh_error_compute, _log_cosh_error_update
from torchmetrics.metric import Metric


class LogCoshError(Metric):
r"""Compute the `LogCosh Error`_.
.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
Args:
num_outputs: Number of outputs in multioutput setting
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example (single output regression)::
>>> from torchmetrics import LogCoshError
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> log_cosh_error = LogCoshError()
>>> log_cosh_error(preds, target)
tensor(0.3523)
Example (multi output regression)::
>>> from torchmetrics import LogCoshError
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
>>> log_cosh_error = LogCoshError(num_outputs=3)
>>> log_cosh_error(preds, target)
tensor([0.9176, 0.4277, 0.2194])
"""

is_differentiable = True
higher_is_better = False
full_state_update = False
sum_log_cosh_error: Tensor
total: Tensor

def __init__(self, num_outputs: int = 1, **kwargs: Any) -> None:
super().__init__(**kwargs)

if not isinstance(num_outputs, int) and num_outputs < 1:
raise ValueError(f"Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}")
self.num_outputs = num_outputs
self.add_state("sum_log_cosh_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.
Args:
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
Raises:
ValueError:
If ``preds`` or ``target`` has multiple outputs when ``num_outputs=1``
"""
sum_log_cosh_error, n_obs = _log_cosh_error_update(preds, target, self.num_outputs)
self.sum_log_cosh_error += sum_log_cosh_error
self.total += n_obs

def compute(self) -> Tensor:
"""Compute LogCosh error over state."""
return _log_cosh_error_compute(self.sum_log_cosh_error, self.total)

0 comments on commit 920fe0f

Please sign in to comment.