Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogCosh Error #1316

Merged
merged 8 commits into from Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271))


- Added `LogCoshError` to regression package ([#1316](https://github.com/Lightning-AI/metrics/pull/1316))


### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Expand Up @@ -95,3 +95,4 @@
.. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric
.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
22 changes: 22 additions & 0 deletions docs/source/regression/log_cosh_error.rst
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Log Cosh Error
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Regression

.. include:: ../links.rst

##############
Log Cosh Error
##############

Module Interface
________________

.. autoclass:: torchmetrics.LogCoshError
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.log_cosh_error
:noindex:
2 changes: 2 additions & 0 deletions src/torchmetrics/__init__.py
Expand Up @@ -59,6 +59,7 @@
ExplainedVariance,
KendallRankCorrCoef,
KLDivergence,
LogCoshError,
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredError,
Expand Down Expand Up @@ -132,6 +133,7 @@
"JaccardIndex",
"KendallRankCorrCoef",
"KLDivergence",
"LogCoshError",
"MatchErrorRate",
"MatthewsCorrCoef",
"MaxMetric",
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/__init__.py
Expand Up @@ -51,6 +51,7 @@
from torchmetrics.functional.regression.explained_variance import explained_variance
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef
from torchmetrics.functional.regression.kl_divergence import kl_divergence
from torchmetrics.functional.regression.log_cosh import log_cosh_error
from torchmetrics.functional.regression.log_mse import mean_squared_log_error
from torchmetrics.functional.regression.mae import mean_absolute_error
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error
Expand Down Expand Up @@ -115,6 +116,7 @@
"jaccard_index",
"kendall_rank_corrcoef",
"kl_divergence",
"log_cosh_error",
"match_error_rate",
"matthews_corrcoef",
"mean_absolute_error",
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/functional/regression/__init__.py
Expand Up @@ -16,6 +16,7 @@
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef # noqa: F401
from torchmetrics.functional.regression.kl_divergence import kl_divergence # noqa: F401
from torchmetrics.functional.regression.log_cosh import log_cosh_error # noqa: F401
from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error # noqa: F401
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/regression/kendall.py
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.data import _bincount, dim_zero_cat
from torchmetrics.utilities.enums import EnumStr
Expand Down Expand Up @@ -265,7 +265,7 @@ def _kendall_corrcoef_update(
"""
# Data checking
_check_same_shape(preds, target)
_check_data_shape_for_corr_coef(preds, target, num_outputs)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

if num_outputs == 1:
preds = preds.unsqueeze(1)
Expand Down
92 changes: 92 additions & 0 deletions src/torchmetrics/functional/regression/log_cosh.py
@@ -0,0 +1,92 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape


def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
if preds.ndim == 2:
return preds, target
return preds.unsqueeze(1), target.unsqueeze(1)


def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute LogCosh error.

Checks for same shape of input tensors.

Args:
preds: Predicted tensor
target: Ground truth tensor

Return:
Sum of LogCosh error over examples, and total number of examples
"""
_check_same_shape(preds, target)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

preds, target = _unsqueeze_tensors(preds, target)
diff = preds - target
sum_log_cosh_error = torch.log((torch.exp(diff) + torch.exp(-diff)) / 2).sum(0).squeeze()
n_obs = torch.tensor(target.shape[0], device=preds.device)
return sum_log_cosh_error, n_obs


def _log_cosh_error_compute(sum_log_cosh_error: Tensor, n_obs: Tensor) -> Tensor:
"""Computes Mean Squared Error.

Args:
sum_squared_error: Sum of LogCosh errors over all observations
n_obs: Number of predictions or observations
"""
return (sum_log_cosh_error / n_obs).squeeze()


def log_cosh_error(preds: Tensor, target: Tensor) -> Tensor:
r"""Compute the `LogCosh Error`_.

.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)

Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

Args:
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``

Return:
Tensor with LogCosh error

Example (single output regression)::
>>> from torchmetrics.functional import log_cosh_error
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> log_cosh_error(preds, target)
tensor(0.3523)

Example (multi output regression)::
>>> from torchmetrics.functional import log_cosh_error
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
>>> log_cosh_error(preds, target)
tensor([0.9176, 0.4277, 0.2194])
"""
sum_log_cosh_error, n_obs = _log_cosh_error_update(
preds, target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
)
return _log_cosh_error_compute(sum_log_cosh_error, n_obs)
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/regression/pearson.py
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape


Expand Down Expand Up @@ -45,7 +45,7 @@ def _pearson_corrcoef_update(
"""
# Data checking
_check_same_shape(preds, target)
_check_data_shape_for_corr_coef(preds, target, num_outputs)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

n_obs = preds.shape[0]
mx_new = (n_prior * mean_x + preds.mean(0) * n_obs) / (n_prior + n_obs)
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/regression/spearman.py
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape


Expand Down Expand Up @@ -68,7 +68,7 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -
"Expected `preds` and `target` both to be floating point tensors, but got {pred.dtype} and {target.dtype}"
)
_check_same_shape(preds, target)
_check_data_shape_for_corr_coef(preds, target, num_outputs)
_check_data_shape_to_num_outputs(preds, target, num_outputs)

return preds, target

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/regression/utils.py
Expand Up @@ -14,7 +14,7 @@
from torch import Tensor


def _check_data_shape_for_corr_coef(preds: Tensor, target: Tensor, num_outputs: int) -> None:
def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs: int) -> None:
"""Check that predictions and target have the correct shape, else raise error."""
if preds.ndim > 2 or target.ndim > 2:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/regression/__init__.py
Expand Up @@ -16,6 +16,7 @@
from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401
from torchmetrics.regression.kendall import KendallRankCorrCoef # noqa: F401
from torchmetrics.regression.kl_divergence import KLDivergence # noqa: F401
from torchmetrics.regression.log_cosh import LogCoshError # noqa: F401
from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mape import MeanAbsolutePercentageError # noqa: F401
Expand Down
83 changes: 83 additions & 0 deletions src/torchmetrics/regression/log_cosh.py
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
from torch import Tensor

from torchmetrics.functional.regression.log_cosh import _log_cosh_error_compute, _log_cosh_error_update
from torchmetrics.metric import Metric


class LogCoshError(Metric):
r"""Compute the `LogCosh Error`_.

.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)

Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

Args:
num_outputs: Number of outputs in multioutput setting
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example (single output regression)::
>>> from torchmetrics import LogCoshError
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> log_cosh_error = LogCoshError()
>>> log_cosh_error(preds, target)
tensor(0.3523)

Example (multi output regression)::
>>> from torchmetrics import LogCoshError
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
>>> log_cosh_error = LogCoshError(num_outputs=3)
>>> log_cosh_error(preds, target)
tensor([0.9176, 0.4277, 0.2194])
"""

is_differentiable = True
higher_is_better = False
full_state_update = False
sum_log_cosh_error: Tensor
total: Tensor

def __init__(self, num_outputs: int = 1, **kwargs: Any) -> None:
super().__init__(**kwargs)

if not isinstance(num_outputs, int) and num_outputs < 1:
raise ValueError(f"Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}")
self.num_outputs = num_outputs
self.add_state("sum_log_cosh_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets.

Args:
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``

Raises:
ValueError:
If ``preds`` or ``target`` has multiple outputs when ``num_outputs=1``
"""
sum_log_cosh_error, n_obs = _log_cosh_error_update(preds, target, self.num_outputs)
self.sum_log_cosh_error += sum_log_cosh_error
self.total += n_obs

def compute(self) -> Tensor:
"""Compute LogCosh error over state."""
return _log_cosh_error_compute(self.sum_log_cosh_error, self.total)