diff --git a/CHANGELOG.md b/CHANGELOG.md index e26b18ccc2b..df6488490b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added option to pass `distributed_available_fn` to metrics to allow checks for custom communication backend for making `dist_sync_fn` actually useful ([#1301](https://github.com/Lightning-AI/metrics/pull/1301)) +- Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271)) + + ### Changed - Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259)) diff --git a/docs/source/conf.py b/docs/source/conf.py index b0c002c025a..2d0968d395a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -385,3 +385,7 @@ def find_source(): """ coverage_skip_undoc_in_source = True + +linkcheck_ignore = [ + "https://www.jstor.org/stable/2332303" # jstor cannot be accessed from python, but link work fine in a local doc +] diff --git a/docs/source/links.rst b/docs/source/links.rst index 2cb9be7918f..b305ac49fd8 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -93,3 +93,5 @@ .. _AB divergence: https://pdfs.semanticscholar.org/744b/1166de34cb099100f151f3b1459f141ae25b.pdf .. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf .. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric +.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient +.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303 diff --git a/docs/source/regression/kendall_rank_corr_coef.rst b/docs/source/regression/kendall_rank_corr_coef.rst new file mode 100644 index 00000000000..66ccc774193 --- /dev/null +++ b/docs/source/regression/kendall_rank_corr_coef.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Kendall Rank Correlation Coefficient + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Regression + +.. include:: ../links.rst + +####################### +Kendal Rank Corr. Coef. +####################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.KendallRankCorrCoef + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.kendall_rank_corrcoef + :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 514b7365ab3..0b5310975a7 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -57,6 +57,7 @@ ConcordanceCorrCoef, CosineSimilarity, ExplainedVariance, + KendallRankCorrCoef, KLDivergence, MeanAbsoluteError, MeanAbsolutePercentageError, @@ -129,6 +130,7 @@ "HammingDistance", "HingeLoss", "JaccardIndex", + "KendallRankCorrCoef", "KLDivergence", "MatchErrorRate", "MatthewsCorrCoef", diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index ca83c1c5f75..233731d06d8 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -49,6 +49,7 @@ from torchmetrics.functional.regression.concordance import concordance_corrcoef from torchmetrics.functional.regression.cosine_similarity import cosine_similarity from torchmetrics.functional.regression.explained_variance import explained_variance +from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef from torchmetrics.functional.regression.kl_divergence import kl_divergence from torchmetrics.functional.regression.log_mse import mean_squared_log_error from torchmetrics.functional.regression.mae import mean_absolute_error @@ -112,6 +113,7 @@ "hinge_loss", "image_gradients", "jaccard_index", + "kendall_rank_corrcoef", "kl_divergence", "match_error_rate", "matthews_corrcoef", diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index 43dead1a32f..852db0bf196 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.functional.regression.concordance import concordance_corrcoef # noqa: F401 from torchmetrics.functional.regression.cosine_similarity import cosine_similarity # noqa: F401 from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401 +from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef # noqa: F401 from torchmetrics.functional.regression.kl_divergence import kl_divergence # noqa: F401 from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401 from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401 diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py new file mode 100644 index 00000000000..698cb54ec73 --- /dev/null +++ b/src/torchmetrics/functional/regression/kendall.py @@ -0,0 +1,428 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.data import _bincount, dim_zero_cat +from torchmetrics.utilities.enums import EnumStr + + +class _MetricVariant(EnumStr): + """Enumerate for metric variants.""" + + A = "a" + B = "b" + C = "c" + + @classmethod + def from_str(cls, value: Literal["a", "b", "c"]) -> "_MetricVariant": # type: ignore[override] + """ + Raises: + ValueError: + If required metric variant is not among the supported options. + """ + _allowed_variants = [im.lower() for im in _MetricVariant._member_names_] + + enum_key = super().from_str(value) + if enum_key is not None and enum_key in _allowed_variants: + return enum_key # type: ignore[return-value] # use override + raise ValueError(f"Invalid metric variant. Expected one of {_allowed_variants}, but got {enum_key}.") + + +class _TestAlternative(EnumStr): + """Enumerate for test altenative options.""" + + TWO_SIDED = "two-sided" + LESS = "less" + GREATER = "greater" + + @classmethod + def from_str(cls, value: Literal["two-sided", "less", "greater"]) -> "_TestAlternative": # type: ignore[override] + """ + Raises: + ValueError: + If required test alternative is not among the supported options. + """ + _allowed_alternatives = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] + + enum_key = super().from_str(value.replace("-", "_")) + if enum_key is not None and enum_key in _allowed_alternatives: + return enum_key # type: ignore[return-value] # use override + raise ValueError(f"Invalid test alternative. Expected one of {_allowed_alternatives}, but got {enum_key}.") + + +def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + """Sort sequences in an ascent order according to the sequence ``x``.""" + # We need to clone `y` tensor not to change an object in memory + y = torch.clone(y) + x, y = x.T, y.T + x, perm = x.sort() + for i in range(x.shape[0]): + y[i] = y[i][perm[i]] + return x.T, y.T + + +def _concordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: + """Count a total number of concordant pairs in a single sequence.""" + return torch.logical_and(x[i] < x[(i + 1) :], y[i] < y[(i + 1) :]).sum(0).unsqueeze(0) + + +def _count_concordant_pairs(preds: Tensor, target: Tensor) -> Tensor: + """Count a total number of concordant pairs in given sequences.""" + return torch.cat([_concordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) + + +def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: + """Count a total number of discordant pairs in a single sequences.""" + return ( + torch.logical_or( + torch.logical_and(x[i] > x[(i + 1) :], y[i] < y[(i + 1) :]), + torch.logical_and(x[i] < x[(i + 1) :], y[i] > y[(i + 1) :]), + ) + .sum(0) + .unsqueeze(0) + ) + + +def _count_discordant_pairs(preds: Tensor, target: Tensor) -> Tensor: + """Count a total number of discordant pairs in given sequences.""" + return torch.cat([_discordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) + + +def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor: + """Convert a sequence to the rank tensor.""" + # Sort if a sequence has not been sorted before + if sort: + x = x.sort(dim=0).values + _ones = torch.zeros(1, x.shape[1], dtype=torch.int32, device=x.device) + return torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0).cumsum(0) + + +def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Get a total number of ties and staistics for p-value calculation for a given sequence.""" + ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) + ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) + ties_p2 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) + for dim in range(x.shape[1]): + n_ties = _bincount(x[:, dim]) + n_ties = n_ties[n_ties > 1] + ties[dim] = (n_ties * (n_ties - 1) // 2).sum() + ties_p1[dim] = (n_ties * (n_ties - 1.0) * (n_ties - 2)).sum() + ties_p2[dim] = (n_ties * (n_ties - 1.0) * (2 * n_ties + 5)).sum() + + return ties, ties_p1, ties_p2 + + +def _get_metric_metadata( + preds: Tensor, target: Tensor, variant: _MetricVariant +) -> Tuple[ + Tensor, + Tensor, + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Tensor, +]: + """Obtain statistics to calculate metric value.""" + preds, target = _sort_on_first_sequence(preds, target) + + concordant_pairs = _count_concordant_pairs(preds, target) + discordant_pairs = _count_discordant_pairs(preds, target) + + n_total = torch.tensor(preds.shape[0], device=preds.device) + preds_ties = target_ties = None + preds_ties_p1 = preds_ties_p2 = target_ties_p1 = target_ties_p2 = None + if variant != _MetricVariant.A: + preds = _convert_sequence_to_dense_rank(preds) + target = _convert_sequence_to_dense_rank(target, sort=True) + preds_ties, preds_ties_p1, preds_ties_p2 = _get_ties(preds) + target_ties, target_ties_p1, target_ties_p2 = _get_ties(target) + return ( + concordant_pairs, + discordant_pairs, + preds_ties, + preds_ties_p1, + preds_ties_p2, + target_ties, + target_ties_p1, + target_ties_p2, + n_total, + ) + + +def _calculate_tau( + preds: Tensor, + target: Tensor, + concordant_pairs: Tensor, + discordant_pairs: Tensor, + con_min_dis_pairs: Tensor, + n_total: Tensor, + preds_ties: Optional[Tensor], + target_ties: Optional[Tensor], + variant: _MetricVariant, +) -> Tensor: + """Calculate Kendall's tau from metric metadata.""" + if variant == _MetricVariant.A: + tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) + elif variant == _MetricVariant.B: + total_combinations: Tensor = n_total * (n_total - 1) // 2 + denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) + tau = con_min_dis_pairs / torch.sqrt(denominator) + else: + preds_unique = torch.tensor([len(p.unique()) for p in preds.T], dtype=preds.dtype, device=preds.device) + target_unique = torch.tensor([len(t.unique()) for t in target.T], dtype=target.dtype, device=target.device) + min_classes = torch.minimum(preds_unique, target_unique) + tau = 2 * con_min_dis_pairs / ((min_classes - 1) / min_classes * n_total**2) + + return tau + + +def _get_p_value_for_t_value_from_dist(t_value: Tensor) -> Tensor: + """Obtain p-value for a given Tensor of t-values. Handle ``nan`` which cannot be passed into torch + distributions. + + When t-value is ``nan``, a resulted p-value should be alson ``nan``. + """ + device = t_value + normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device)) + + is_nan = t_value.isnan() + t_value = t_value.nan_to_num() + p_value = normal_dist.cdf(t_value) + return p_value.where(~is_nan, torch.tensor(float("nan"), dtype=p_value.dtype, device=p_value.device)) + + +def _calculate_p_value( + con_min_dis_pairs: Tensor, + n_total: Tensor, + preds_ties: Optional[Tensor], + preds_ties_p1: Optional[Tensor], + preds_ties_p2: Optional[Tensor], + target_ties: Optional[Tensor], + target_ties_p1: Optional[Tensor], + target_ties_p2: Optional[Tensor], + variant: _MetricVariant, + alternative: Optional[_TestAlternative], +) -> Tensor: + """Calculate p-value for Kendall's tau from metric metadata.""" + t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5) + if variant == _MetricVariant.A: + t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2) + else: + m = n_total * (n_total - 1) + t_value_denominator: Tensor = (t_value_denominator_base - preds_ties_p2 - target_ties_p2) / 18 + t_value_denominator += (2 * preds_ties * target_ties) / m # type: ignore + t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) # type: ignore + t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator) + + if alternative == _TestAlternative.TWO_SIDED: + t_value = torch.abs(t_value) + if alternative in [_TestAlternative.TWO_SIDED, _TestAlternative.GREATER]: + t_value *= -1 + p_value = _get_p_value_for_t_value_from_dist(t_value) + if alternative == _TestAlternative.TWO_SIDED: + p_value *= 2 + return p_value + + +def _kendall_corrcoef_update( + preds: Tensor, + target: Tensor, + concat_preds: List[Tensor] = [], + concat_target: List[Tensor] = [], + num_outputs: int = 1, +) -> Tuple[List[Tensor], List[Tensor]]: + """Update variables required to compute Kendall rank correlation coefficient. + + Args: + preds: Sequence of data + target: Sequence of data + concat_preds: List of batches of preds sequence to be concatenated + concat_target: List of batches of target sequence to be concatenated + num_outputs: Number of outputs in multioutput setting + + Raises: + RuntimeError: If ``preds`` and ``target`` do not have the same shape + """ + # Data checking + _check_same_shape(preds, target) + _check_data_shape_for_corr_coef(preds, target, num_outputs) + + if num_outputs == 1: + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) + + concat_preds.append(preds) + concat_target.append(target) + + return concat_preds, concat_target + + +def _kendall_corrcoef_compute( + preds: Tensor, + target: Tensor, + variant: _MetricVariant, + alternative: Optional[_TestAlternative] = None, +) -> Tuple[Tensor, Optional[Tensor]]: + """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test. + + Args: + Args: + preds: Sequence of data + target: Sequence of data + variant: Indication of which variant of Kendall's tau to be used + alternative: Alternative hypothesis for for t-test. Possible values: + - 'two-sided': the rank correlation is nonzero + - 'less': the rank correlation is negative (less than zero) + - 'greater': the rank correlation is positive (greater than zero) + """ + ( + concordant_pairs, + discordant_pairs, + preds_ties, + preds_ties_p1, + preds_ties_p2, + target_ties, + target_ties_p1, + target_ties_p2, + n_total, + ) = _get_metric_metadata(preds, target, variant) + con_min_dis_pairs = concordant_pairs - discordant_pairs + + tau = _calculate_tau( + preds, target, concordant_pairs, discordant_pairs, con_min_dis_pairs, n_total, preds_ties, target_ties, variant + ) + p_value = ( + _calculate_p_value( + con_min_dis_pairs, + n_total, + preds_ties, + preds_ties_p1, + preds_ties_p2, + target_ties, + target_ties_p1, + target_ties_p2, + variant, + alternative, + ) + if alternative + else None + ) + + # Squeeze tensor if num_outputs=1 + if tau.shape[0] == 1: + tau = tau.squeeze() + p_value = p_value.squeeze() if p_value is not None else None + + return tau.clamp(-1, 1), p_value + + +def kendall_rank_corrcoef( + preds: Tensor, + target: Tensor, + variant: Literal["a", "b", "c"] = "b", + t_test: bool = False, + alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + r"""Computes `Kendall Rank Correlation Coefficient`_. + + .. math:: + tau_a = \frac{C - D}{C + D} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs. + + .. math:: + tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents + a total number of ties. + + .. math:: + tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a total number + of observations and :math:`m` is a ``min`` of unique values in ``preds`` and ``target`` sequence. + + Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_. + + Args: + preds: Sequence of data of either shape ``(N,)`` or ``(N,d)`` + target: Sequence of data of either shape ``(N,)`` or ``(N,d)`` + variant: Indication of which variant of Kendall's tau to be used + t_test: Indication whether to run t-test + alternative: Alternative hypothesis for t-test. Possible values: + - 'two-sided': the rank correlation is nonzero + - 'less': the rank correlation is negative (less than zero) + - 'greater': the rank correlation is positive (greater than zero) + + Return: + Correlation tau statistic + (Optional) p-value of corresponding statistical test (asymptotic) + + Raises: + ValueError: If ``t_test`` is not of a type bool + ValueError: If ``t_test=True`` and ``alternative=None`` + + Example (single output regression): + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) + >>> kendall_rank_corrcoef(preds, target) + tensor(0.3333) + + Example (multi output regression): + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) + >>> kendall_rank_corrcoef(preds, target) + tensor([1., 1.]) + + Example (single output regression with t-test) + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) + >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided') + (tensor(0.3333), tensor(0.4969)) + + Example (multi output regression with t-test): + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) + >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided') + (tensor([1., 1.]), tensor([nan, nan])) + """ + if not isinstance(t_test, bool): + raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") + if t_test and alternative is None: + raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") + + _variant = _MetricVariant.from_str(variant) + _alternative = _TestAlternative.from_str(alternative) if t_test and alternative else None + + _preds, _target = _kendall_corrcoef_update( + preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1] + ) + tau, p_value = _kendall_corrcoef_compute(dim_zero_cat(_preds), dim_zero_cat(_target), _variant, _alternative) + + if p_value is not None: + return tau, p_value + return tau diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 5273bbb8388..3ac970c28d3 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -16,6 +16,7 @@ import torch from torch import Tensor +from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef from torchmetrics.utilities.checks import _check_same_shape @@ -44,16 +45,7 @@ def _pearson_corrcoef_update( """ # Data checking _check_same_shape(preds, target) - if preds.ndim > 2 or target.ndim > 2: - raise ValueError( - f"Expected both predictions and target to be either 1- or 2-dimensional tensors," - f" but got {target.ndim} and {preds.ndim}." - ) - if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): - raise ValueError( - f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" - f" and {preds.ndim}." - ) + _check_data_shape_for_corr_coef(preds, target, num_outputs) n_obs = preds.shape[0] mx_new = (n_prior * mean_x + preds.mean(0) * n_obs) / (n_prior + n_obs) diff --git a/src/torchmetrics/functional/regression/spearman.py b/src/torchmetrics/functional/regression/spearman.py index 61c0fb25548..d20a2d01038 100644 --- a/src/torchmetrics/functional/regression/spearman.py +++ b/src/torchmetrics/functional/regression/spearman.py @@ -16,6 +16,7 @@ import torch from torch import Tensor +from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef from torchmetrics.utilities.checks import _check_same_shape @@ -67,16 +68,7 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) - "Expected `preds` and `target` both to be floating point tensors, but got {pred.dtype} and {target.dtype}" ) _check_same_shape(preds, target) - if preds.ndim > 2 or target.ndim > 2: - raise ValueError( - f"Expected both predictions and target to be either 1- or 2-dimensional tensors," - f" but got {target.ndim} and {preds.ndim}." - ) - if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): - raise ValueError( - f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" - f" and {preds.ndim}." - ) + _check_data_shape_for_corr_coef(preds, target, num_outputs) return preds, target diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py new file mode 100644 index 00000000000..b9c80987ddc --- /dev/null +++ b/src/torchmetrics/functional/regression/utils.py @@ -0,0 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch import Tensor + + +def _check_data_shape_for_corr_coef(preds: Tensor, target: Tensor, num_outputs: int) -> None: + """Check that predictions and target have the correct shape, else raise error.""" + if preds.ndim > 2 or target.ndim > 2: + raise ValueError( + f"Expected both predictions and target to be either 1- or 2-dimensional tensors," + f" but got {target.ndim} and {preds.ndim}." + ) + if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[1]): + raise ValueError( + f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" + f" and {preds.shape[1]}." + ) diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index 71ded88a5db..eecb56a452f 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.regression.concordance import ConcordanceCorrCoef # noqa: F401 from torchmetrics.regression.cosine_similarity import CosineSimilarity # noqa: F401 from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401 +from torchmetrics.regression.kendall import KendallRankCorrCoef # noqa: F401 from torchmetrics.regression.kl_divergence import KLDivergence # noqa: F401 from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401 from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401 diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py new file mode 100644 index 00000000000..43abfbbac83 --- /dev/null +++ b/src/torchmetrics/regression/kendall.py @@ -0,0 +1,163 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.kendall import ( + _kendall_corrcoef_compute, + _kendall_corrcoef_update, + _MetricVariant, + _TestAlternative, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat + + +class KendallRankCorrCoef(Metric): + r"""Computes `Kendall Rank Correlation Coefficient`_: + + .. math:: + tau_a = \frac{C - D}{C + D} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs. + + .. math:: + tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents + a total number of ties. + + .. math:: + tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a total number + of observations and :math:`m` is a ``min`` of unique values in ``preds`` and ``target`` sequence. + + Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_. + + Forward accepts + + - ``preds`` (float tensor): Sequence of data of either shape ``(N,)`` or ``(N,d)`` + - ``target`` (float tensor): Sequence of data of either shape ``(N,)`` or ``(N,d)`` + + Args: + variant: Indication of which variant of Kendall's tau to be used + t_test: Indication whether to run t-test + alternative: Alternative hypothesis for t-test. Possible values: + - 'two-sided': the rank correlation is nonzero + - 'less': the rank correlation is negative (less than zero) + - 'greater': the rank correlation is positive (greater than zero) + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: If ``t_test`` is not of a type bool + ValueError: If ``t_test=True`` and ``alternative=None`` + + Example (single output regression): + >>> import torch + >>> from torchmetrics.regression import KendallRankCorrCoef + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) + >>> kendall = KendallRankCorrCoef() + >>> kendall(preds, target) + tensor(0.3333) + + Example (multi output regression): + >>> import torch + >>> from torchmetrics.regression import KendallRankCorrCoef + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) + >>> kendall = KendallRankCorrCoef(num_outputs=2) + >>> kendall(preds, target) + tensor([1., 1.]) + + Example (single output regression with t-test): + >>> import torch + >>> from torchmetrics.regression import KendallRankCorrCoef + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) + >>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided') + >>> kendall(preds, target) + (tensor(0.3333), tensor(0.4969)) + + Example (multi output regression with t-test): + >>> import torch + >>> from torchmetrics.regression import KendallRankCorrCoef + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) + >>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided', num_outputs=2) + >>> kendall(preds, target) + (tensor([1., 1.]), tensor([nan, nan])) + """ + + is_differentiable = False + higher_is_better = None + full_state_update = True + preds: List[Tensor] + target: List[Tensor] + + def __init__( + self, + variant: Literal["a", "b", "c"] = "b", + t_test: bool = False, + alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", + num_outputs: int = 1, + **kwargs: Any, + ): + super().__init__(**kwargs) + if not isinstance(t_test, bool): + raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") + if t_test and alternative is None: + raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") + + self.variant = _MetricVariant.from_str(variant) + self.alternative = _TestAlternative.from_str(alternative) if t_test and alternative else None + self.num_outputs = num_outputs + + self.add_state("preds", [], dist_reduce_fx="cat") + self.add_state("target", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update variables required to compute Kendall rank correlation coefficient. + + Args: + preds: Sequence of data of either shape ``(N,)`` or ``(N,d)`` + target: Sequence of data of either shape ``(N,)`` or ``(N,d)`` + """ + self.preds, self.target = _kendall_corrcoef_update( + preds, + target, + self.preds, + self.target, + num_outputs=self.num_outputs, + ) + + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test. + + Return: + Correlation tau statistic + (Optional) p-value of corresponding statistical test (asymptotic) + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + tau, p_value = _kendall_corrcoef_compute(preds, target, self.variant, self.alternative) + + if p_value is not None: + return tau, p_value + return tau diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py new file mode 100644 index 00000000000..3bd6c9379e2 --- /dev/null +++ b/tests/unittests/regression/test_kendall.py @@ -0,0 +1,124 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import operator +from collections import namedtuple +from functools import partial + +import pytest +import torch +from scipy.stats import kendalltau + +from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef +from torchmetrics.regression.kendall import KendallRankCorrCoef +from torchmetrics.utilities.imports import _compare_version +from unittests.helpers import seed_all +from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester + +_SCIPY_GREATER_EQUAL_1_8 = _compare_version("scipy", operator.ge, "1.8.0") + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +_single_inputs1 = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE)) +_single_inputs2 = Input(preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randn(NUM_BATCHES, BATCH_SIZE)) +_single_inputs3 = Input( + preds=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE)), target=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE)) +) +_multi_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM) +) +_multi_inputs2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM) +) +_multi_inputs3 = Input( + preds=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) + + +def _sk_metric(preds, target, alternative, variant): + metric_args = {} + if _SCIPY_GREATER_EQUAL_1_8: + metric_args = {"alternative": alternative or "two-sided"} # scipy cannot accept `None` + if preds.ndim == 2: + out = [ + kendalltau(p.numpy(), t.numpy(), method="asymptotic", variant=variant, **metric_args) + for p, t in zip(preds.T, target.T) + ] + tau = torch.cat([torch.tensor(o[0]).unsqueeze(0) for o in out]) + p_value = torch.cat([torch.tensor(o[1]).unsqueeze(0) for o in out]) + if alternative is not None: + return tau, p_value + return tau + + tau, p_value = kendalltau(preds.numpy(), target.numpy(), method="asymptotic", variant=variant, **metric_args) + + if alternative is not None: + return torch.tensor(tau), torch.tensor(p_value) + return torch.tensor(tau) + + +@pytest.mark.parametrize( + "preds, target, alternative", + [ + (_single_inputs1.preds, _single_inputs1.target, None), + (_single_inputs2.preds, _single_inputs2.target, "less"), + (_single_inputs3.preds, _single_inputs3.target, "greater"), + (_multi_inputs1.preds, _multi_inputs1.target, None), + (_multi_inputs2.preds, _multi_inputs2.target, "two-sided"), + (_multi_inputs3.preds, _multi_inputs3.target, "greater"), + ], +) +@pytest.mark.parametrize("variant", ["b", "c"]) +class TestKendallRankCorrCoef(MetricTester): + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_kendall_rank_corrcoef(self, preds, target, alternative, variant, ddp, dist_sync_on_step): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + t_test = True if alternative is not None else False + _sk_kendall_tau = partial(_sk_metric, alternative=alternative, variant=variant) + alternative = _adjust_alternative_to_scipy(alternative) + + self.run_class_metric_test( + ddp, + preds, + target, + KendallRankCorrCoef, + _sk_kendall_tau, + dist_sync_on_step, + metric_args={"t_test": t_test, "alternative": alternative, "variant": variant, "num_outputs": num_outputs}, + ) + + def test_kendall_rank_corrcoef_functional(self, preds, target, alternative, variant): + t_test = True if alternative is not None else False + alternative = _adjust_alternative_to_scipy(alternative) + metric_args = {"t_test": t_test, "alternative": alternative, "variant": variant} + _sk_kendall_tau = partial(_sk_metric, alternative=alternative, variant=variant) + self.run_functional_metric_test(preds, target, kendall_rank_corrcoef, _sk_kendall_tau, metric_args=metric_args) + + def test_kendall_rank_corrcoef_differentiability(self, preds, target, alternative, variant): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=partial(KendallRankCorrCoef, num_outputs=num_outputs), + metric_functional=kendall_rank_corrcoef, + ) + + +def _adjust_alternative_to_scipy(alternative): + """Scipy<1.8.0 supports only two-sided hypothesis testing.""" + if alternative is not None and not _compare_version("scipy", operator.ge, "1.8.0"): + alternative = "two-sided" + return alternative