From 260efcfb043b01a465fa54d91ae90d10e31a07ba Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 8 Oct 2022 17:41:37 +0200 Subject: [PATCH 01/38] wip --- CHANGELOG.md | 2 +- src/torchmetrics/functional/__init__.py | 2 + .../functional/regression/__init__.py | 1 + .../functional/regression/kendall.py | 152 ++++++++++++++++++ src/torchmetrics/regression/kendall.py | 0 5 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 src/torchmetrics/functional/regression/kendall.py create mode 100644 src/torchmetrics/regression/kendall.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 29792ba357b..c88f1b7b164 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `TotalVariation` to image package ([#978](https://github.com/Lightning-AI/metrics/pull/978)) - +- Added `KendallRankCorrCoef` to regression package ([]()) ### Changed diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 44245a2d6cd..28d59762ea0 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -55,6 +55,7 @@ from torchmetrics.functional.regression.concordance import concordance_corrcoef from torchmetrics.functional.regression.cosine_similarity import cosine_similarity from torchmetrics.functional.regression.explained_variance import explained_variance +from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef from torchmetrics.functional.regression.kl_divergence import kl_divergence from torchmetrics.functional.regression.log_mse import mean_squared_log_error from torchmetrics.functional.regression.mae import mean_absolute_error @@ -120,6 +121,7 @@ "hinge_loss", "image_gradients", "jaccard_index", + "kendall_rank_corrcoef", "kl_divergence", "label_ranking_average_precision", "label_ranking_loss", diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index 43dead1a32f..852db0bf196 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.functional.regression.concordance import concordance_corrcoef # noqa: F401 from torchmetrics.functional.regression.cosine_similarity import cosine_similarity # noqa: F401 from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401 +from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef # noqa: F401 from torchmetrics.functional.regression.kl_divergence import kl_divergence # noqa: F401 from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401 from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401 diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py new file mode 100644 index 00000000000..6a7e00588fc --- /dev/null +++ b/src/torchmetrics/functional/regression/kendall.py @@ -0,0 +1,152 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.data import _bincount + + +def _sort_on_first_sequence(x: Tensor, y: Tensor, stable: bool) -> Tuple[Tensor, Tensor]: + x, perm = x.sort(stable=stable) + for i in range(x.shape[0]): + y[i] = y[i][perm[i]] + return x, y + + +def _convert_sequence_to_dense_rank(x: Tensor) -> Tensor: + _ones = torch.ones(x.shape[0], 1, dtype=torch.int32, device=x.device) + return torch.cat([_ones, (x[:, :1] != x[:, -1:]).int()], dim=1).cumsum(1) + + +def _count_discordant_pairs(preds: Tensor, target: Tensor) -> Tensor: + """Count a total number of discordant pairs in given sequences.""" + pass + + +def _count_rank_ties(x: Tensor) -> Tensor: + """Count a total number of ties in a given sequence.""" + ties = _bincount(x) + ties = ties[ties > 1] + return (ties * (ties - 1) // 2).sum() + + +def _kendall_corrcoef_update( + preds: Tensor, + target: Tensor, + discordant_pairs: Tensor, + total_pairs: Tensor, + preds_ties: Tensor, + target_ties: Tensor, + joint_ties: Tensor, + variant: Literal["a", "b", "c"], + num_outputs: int, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Update variables required to compute Kendall rank correlation coefficient. + + Check for the same shape of input tensors + + Args: + preds: Ordered sequence of data + target: Ordered sequence of data + + Raises: + RuntimeError: If ``preds`` and ``target`` do not have the same shape + """ + # Data checking + _check_same_shape(preds, target) + if preds.ndim > 2 or target.ndim > 2: + raise ValueError( + f"Expected both predictions and target to be either 1- or 2-dimensional tensors," + f" but got {target.ndim} and {preds.ndim}." + ) + if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): + raise ValueError( + f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" + f" and {preds.ndim}." + ) + if num_outputs == 1: + preds = preds.unsqueeze(0) + target = target.unsqueeze(0) + + # Sort on target and convert it to dense rank + target, preds = _sort_on_first_sequence(target, preds, stable=False) + target = _convert_sequence_to_dense_rank(target) + + # Sort on preds and convert it to dense rank + preds, target = _sort_on_first_sequence(preds, target, stable=True) + preds = _convert_sequence_to_dense_rank(preds) + + discordant_pairs += _count_discordant_pairs(preds, target) + + +def _kendall_corrcoef_compute( + concordant_pairs: Tensor, + discordant_pairs: Tensor, + total_pairs: Tensor, + variant: Literal["a", "b", "c"] = "b", +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + pass + + +def kendall_rank_corrcoef( + preds: Tensor, target: Tensor, variant: Literal["a", "b", "c"] = "b" +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Compute Kendall rank correlation coefficient, commonly also known as Kendall's tau. + + Args: + preds: Ordered sequence of data + target: Ordered sequence of data + variant: Indication of which variant of test to be used + + Return: + Correlation tau statistic + + Raises: + ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` + + Example (single output regression): + >>> from torchmetrics.functional import kendall_corrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> kendal_rank_corrcoef(preds, target) + + Example (multi output regression): + >>> from torchmetrics.functional import kendall_corrcoef + >>> target = torch.tensor([[3, -0.5], [2, 7]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> kendal_rank_corrcoef(preds, target) + """ + if variant not in ["a", "b", "c"]: + raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") + d = preds.shape[1] if preds.ndim == 2 else 1 + _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device) + concordant_pairs, discordant_pairs, total_pairs = _temp.clone(), _temp.clone(), _temp.clone() + if variant == "b": + preds_tied_values, target_tied_values = _temp.clone(), _temp.clone() + else: + preds_tied_values = target_tied_values = None + + _kendall_corrcoef_update( + preds, + target, + concordant_pairs, + discordant_pairs, + total_pairs, + num_outputs=1 if preds.ndim == 1 else preds.shape[-1], + ) diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py new file mode 100644 index 00000000000..e69de29bb2d From 724e11a037b6b4fb6a4039c8aaaee19e3db26580 Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 15 Oct 2022 17:11:24 +0200 Subject: [PATCH 02/38] WIP: Overhaul --- .../functional/regression/kendall.py | 109 ++++++++++++------ 1 file changed, 75 insertions(+), 34 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 6a7e00588fc..8f078b40bfb 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from torch import Tensor @@ -22,41 +22,63 @@ from torchmetrics.utilities.data import _bincount -def _sort_on_first_sequence(x: Tensor, y: Tensor, stable: bool) -> Tuple[Tensor, Tensor]: - x, perm = x.sort(stable=stable) +def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + x, perm = x.sort(stable=False) for i in range(x.shape[0]): y[i] = y[i][perm[i]] return x, y -def _convert_sequence_to_dense_rank(x: Tensor) -> Tensor: - _ones = torch.ones(x.shape[0], 1, dtype=torch.int32, device=x.device) - return torch.cat([_ones, (x[:, :1] != x[:, -1:]).int()], dim=1).cumsum(1) +def _count_concordant_pairs(preds: Tensor, target: Tensor) -> Tensor: + """Count a total number of concordant pairs in given sequences.""" + + def _concordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: + return torch.logical_and(x[i] < x[(i + 1) :], y[i] < y[(i + 1) :]).sum(0).unsqueeze(0) + + return torch.cat([_concordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) def _count_discordant_pairs(preds: Tensor, target: Tensor) -> Tensor: """Count a total number of discordant pairs in given sequences.""" - pass + + def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: + return ( + torch.logical_or( + torch.logical_and(x[i] > x[(i + 1) :], y[i] < y[(i + 1) :]), + torch.logical_and(x[i] < x[(i + 1) :], y[i] > y[(i + 1) :]), + ) + .sum(0) + .unsqueeze(0) + ) + + return torch.cat([_discordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) + + +def _convert_sequence_to_dense_rank(x: Tensor) -> Tensor: + _ones = torch.zeros(1, x.shape[1], dtype=torch.int32, device=x.device) + return torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0).cumsum(0) -def _count_rank_ties(x: Tensor) -> Tensor: - """Count a total number of ties in a given sequence.""" - ties = _bincount(x) - ties = ties[ties > 1] - return (ties * (ties - 1) // 2).sum() +def _get_ties(x: Tensor) -> Tensor: + ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) + for dim in range(x.shape[1]): + n_ties = _bincount(x[:, dim]) + n_ties = n_ties[n_ties > 1] + ties[dim] = (n_ties * (n_ties - 1) // 2).sum() + return ties def _kendall_corrcoef_update( preds: Tensor, target: Tensor, + concordant_pairs: Tensor, discordant_pairs: Tensor, - total_pairs: Tensor, - preds_ties: Tensor, - target_ties: Tensor, - joint_ties: Tensor, - variant: Literal["a", "b", "c"], + preds_ties: Optional[Tensor], + target_ties: Optional[Tensor], + total: Optional[Tensor], num_outputs: int, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + variant: Literal["a", "b", "c"] = "b", +) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. Check for the same shape of input tensors @@ -85,23 +107,38 @@ def _kendall_corrcoef_update( target = target.unsqueeze(0) # Sort on target and convert it to dense rank - target, preds = _sort_on_first_sequence(target, preds, stable=False) - target = _convert_sequence_to_dense_rank(target) - - # Sort on preds and convert it to dense rank - preds, target = _sort_on_first_sequence(preds, target, stable=True) - preds = _convert_sequence_to_dense_rank(preds) + preds, target = _sort_on_first_sequence(preds, target) + preds, target = preds.T, target.T # [num_outputs, seq_len] + concordant_pairs += _count_concordant_pairs(preds, target) discordant_pairs += _count_discordant_pairs(preds, target) + if variant == "b": + preds = _convert_sequence_to_dense_rank(preds) + target = _convert_sequence_to_dense_rank(target) + preds_ties += _get_ties(preds) + target_ties += _get_ties(target) + total += preds.shape[0] + + return concordant_pairs, discordant_pairs, preds_ties, target_ties + def _kendall_corrcoef_compute( concordant_pairs: Tensor, discordant_pairs: Tensor, - total_pairs: Tensor, - variant: Literal["a", "b", "c"] = "b", + preds_ties: Optional[Tensor], + target_ties: Optional[Tensor], + total: Optional[Tensor], + variant: Literal["a", "b", "c"] = "a", ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - pass + con_min_dis_pairs = concordant_pairs - discordant_pairs + + if variant == "a": + return con_min_dis_pairs / (concordant_pairs + discordant_pairs) + if variant == "b": + combinations = total * (total - 1) // 2 + denominator = (combinations - preds_ties) * (combinations - target_ties) + return con_min_dis_pairs / torch.sqrt(denominator) def kendall_rank_corrcoef( @@ -121,13 +158,13 @@ def kendall_rank_corrcoef( ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` Example (single output regression): - >>> from torchmetrics.functional import kendall_corrcoef + >>> from torchmetrics.functional.regression import kendal_rank_corrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> kendal_rank_corrcoef(preds, target) Example (multi output regression): - >>> from torchmetrics.functional import kendall_corrcoef + >>> from torchmetrics.functional.regression import kendal_rank_corrcoef >>> target = torch.tensor([[3, -0.5], [2, 7]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> kendal_rank_corrcoef(preds, target) @@ -136,17 +173,21 @@ def kendall_rank_corrcoef( raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") d = preds.shape[1] if preds.ndim == 2 else 1 _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device) - concordant_pairs, discordant_pairs, total_pairs = _temp.clone(), _temp.clone(), _temp.clone() + concordant_pairs, discordant_pairs = _temp.clone(), _temp.clone() if variant == "b": - preds_tied_values, target_tied_values = _temp.clone(), _temp.clone() + preds_ties, target_ties, total = _temp.clone(), _temp.clone(), _temp.clone() else: - preds_tied_values = target_tied_values = None + preds_ties = target_ties = total = None - _kendall_corrcoef_update( + concordant_pairs, discordant_pairs, preds_ties, target_ties = _kendall_corrcoef_update( preds, target, concordant_pairs, discordant_pairs, - total_pairs, + preds_ties, + target_ties, + total, num_outputs=1 if preds.ndim == 1 else preds.shape[-1], + variant=variant, ) + return _kendall_corrcoef_compute(concordant_pairs, discordant_pairs, preds_ties, target_ties, total, variant) From d564fa618073bbe8a6fe0b098acf4a1cdb3fe47b Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 15 Oct 2022 17:39:14 +0200 Subject: [PATCH 03/38] Do some polishing --- .../functional/regression/kendall.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 8f078b40bfb..7513cebcbda 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -23,6 +23,7 @@ def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: + """Sort sequences in an ascent order according to the sequence ``x``.""" x, perm = x.sort(stable=False) for i in range(x.shape[0]): y[i] = y[i][perm[i]] @@ -55,11 +56,13 @@ def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: def _convert_sequence_to_dense_rank(x: Tensor) -> Tensor: + """Convert a sequence to the rank tensor.""" _ones = torch.zeros(1, x.shape[1], dtype=torch.int32, device=x.device) return torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0).cumsum(0) def _get_ties(x: Tensor) -> Tensor: + """Get number of ties in a given sequence.""" ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) for dim in range(x.shape[1]): n_ties = _bincount(x[:, dim]) @@ -78,15 +81,9 @@ def _kendall_corrcoef_update( total: Optional[Tensor], num_outputs: int, variant: Literal["a", "b", "c"] = "b", -) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: +) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. - Check for the same shape of input tensors - - Args: - preds: Ordered sequence of data - target: Ordered sequence of data - Raises: RuntimeError: If ``preds`` and ``target`` do not have the same shape """ @@ -120,7 +117,7 @@ def _kendall_corrcoef_update( target_ties += _get_ties(target) total += preds.shape[0] - return concordant_pairs, discordant_pairs, preds_ties, target_ties + return concordant_pairs, discordant_pairs, preds_ties, target_ties, total def _kendall_corrcoef_compute( @@ -129,16 +126,21 @@ def _kendall_corrcoef_compute( preds_ties: Optional[Tensor], target_ties: Optional[Tensor], total: Optional[Tensor], - variant: Literal["a", "b", "c"] = "a", + variant: Literal["a", "b", "c"], ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Compute the value of Kendall rank correlation coefficient given pre-computed state variables.""" con_min_dis_pairs = concordant_pairs - discordant_pairs if variant == "a": - return con_min_dis_pairs / (concordant_pairs + discordant_pairs) - if variant == "b": - combinations = total * (total - 1) // 2 - denominator = (combinations - preds_ties) * (combinations - target_ties) - return con_min_dis_pairs / torch.sqrt(denominator) + tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) + elif variant == "b": + total_combinations = total * (total - 1) // 2 + denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) + tau = con_min_dis_pairs / torch.sqrt(denominator) + else: + tau = 2 * con_min_dis_pairs / (total**2) + + return tau.clamp(-1, 1) def kendall_rank_corrcoef( @@ -149,7 +151,7 @@ def kendall_rank_corrcoef( Args: preds: Ordered sequence of data target: Ordered sequence of data - variant: Indication of which variant of test to be used + variant: Indication of which variant of Kendall's tau to be used Return: Correlation tau statistic @@ -159,19 +161,21 @@ def kendall_rank_corrcoef( Example (single output regression): >>> from torchmetrics.functional.regression import kendal_rank_corrcoef - >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> target = torch.tensor([3, -0.5, 2, 1]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> kendal_rank_corrcoef(preds, target) + tensor([0.3333]) Example (multi output regression): >>> from torchmetrics.functional.regression import kendal_rank_corrcoef - >>> target = torch.tensor([[3, -0.5], [2, 7]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> kendal_rank_corrcoef(preds, target) + tensor([ 1., -1.]) """ if variant not in ["a", "b", "c"]: raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") - d = preds.shape[1] if preds.ndim == 2 else 1 + d = preds.shape[0] if preds.ndim == 2 else 1 _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device) concordant_pairs, discordant_pairs = _temp.clone(), _temp.clone() if variant == "b": @@ -179,7 +183,7 @@ def kendall_rank_corrcoef( else: preds_ties = target_ties = total = None - concordant_pairs, discordant_pairs, preds_ties, target_ties = _kendall_corrcoef_update( + concordant_pairs, discordant_pairs, preds_ties, target_ties, total = _kendall_corrcoef_update( preds, target, concordant_pairs, From 60184649319baf8a79c27c2865f49ac3717395e6 Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 15 Oct 2022 19:44:22 +0200 Subject: [PATCH 04/38] WIP: Add class metric + another refactor --- src/torchmetrics/__init__.py | 2 + .../functional/regression/kendall.py | 109 +++++++++--------- src/torchmetrics/regression/__init__.py | 1 + src/torchmetrics/regression/kendall.py | 100 ++++++++++++++++ 4 files changed, 157 insertions(+), 55 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 8aa046a9938..12264137594 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -64,6 +64,7 @@ ConcordanceCorrCoef, CosineSimilarity, ExplainedVariance, + KendallRankCorrCoef, KLDivergence, MeanAbsoluteError, MeanAbsolutePercentageError, @@ -141,6 +142,7 @@ "HammingDistance", "HingeLoss", "JaccardIndex", + "KendallRankCorrCoef", "KLDivergence", "LabelRankingAveragePrecision", "LabelRankingLoss", diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 7513cebcbda..bd9c4020714 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -24,6 +24,8 @@ def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" + # We need to clone `y` tensor not to change it in memory + y = torch.clone(y) x, perm = x.sort(stable=False) for i in range(x.shape[0]): y[i] = y[i][perm[i]] @@ -71,17 +73,40 @@ def _get_ties(x: Tensor) -> Tensor: return ties -def _kendall_corrcoef_update( - preds: Tensor, - target: Tensor, - concordant_pairs: Tensor, - discordant_pairs: Tensor, - preds_ties: Optional[Tensor], - target_ties: Optional[Tensor], - total: Optional[Tensor], - num_outputs: int, - variant: Literal["a", "b", "c"] = "b", +def _dim_one_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: + """Concatenation along the one dimension.""" + x = x if isinstance(x, (list, tuple)) else [x] + if not x: # empty list + raise ValueError("No samples to concatenate") + return torch.cat(x, dim=1) + + +def _get_metric_metadata( + preds: Tensor, target: Tensor, variant: Literal["a", "b", "c"] ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + """Obtain statistics to calculate metric value.""" + # Sort on target and convert it to dense rank + preds, target = _sort_on_first_sequence(preds, target) + preds, target = preds.T, target.T + + concordant_pairs = _count_concordant_pairs(preds, target) + discordant_pairs = _count_discordant_pairs(preds, target) + + if variant == "b": + preds = _convert_sequence_to_dense_rank(preds) + target = _convert_sequence_to_dense_rank(target) + preds_ties = _get_ties(preds) + target_ties = _get_ties(target) + n_total = preds.shape[0] + else: + preds_ties = target_ties = n_total = None + + return concordant_pairs, discordant_pairs, preds_ties, target_ties, n_total + + +def _kendall_corrcoef_update( + preds: Tensor, target: Tensor, concat_preds: List[Tensor], concat_target: List[Tensor], num_outputs: int +) -> Tuple[List[Tensor], List[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. Raises: @@ -103,42 +128,29 @@ def _kendall_corrcoef_update( preds = preds.unsqueeze(0) target = target.unsqueeze(0) - # Sort on target and convert it to dense rank - preds, target = _sort_on_first_sequence(preds, target) - preds, target = preds.T, target.T # [num_outputs, seq_len] - - concordant_pairs += _count_concordant_pairs(preds, target) - discordant_pairs += _count_discordant_pairs(preds, target) + concat_preds.append(preds) + concat_target.append(target) - if variant == "b": - preds = _convert_sequence_to_dense_rank(preds) - target = _convert_sequence_to_dense_rank(target) - preds_ties += _get_ties(preds) - target_ties += _get_ties(target) - total += preds.shape[0] - - return concordant_pairs, discordant_pairs, preds_ties, target_ties, total + return concat_preds, concat_target def _kendall_corrcoef_compute( - concordant_pairs: Tensor, - discordant_pairs: Tensor, - preds_ties: Optional[Tensor], - target_ties: Optional[Tensor], - total: Optional[Tensor], + preds: Tensor, + target: Tensor, variant: Literal["a", "b", "c"], ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Compute the value of Kendall rank correlation coefficient given pre-computed state variables.""" + concordant_pairs, discordant_pairs, preds_ties, target_ties, n_total = _get_metric_metadata(preds, target, variant) con_min_dis_pairs = concordant_pairs - discordant_pairs if variant == "a": tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) elif variant == "b": - total_combinations = total * (total - 1) // 2 + total_combinations = n_total * (n_total - 1) // 2 denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) tau = con_min_dis_pairs / torch.sqrt(denominator) else: - tau = 2 * con_min_dis_pairs / (total**2) + tau = 2 * con_min_dis_pairs / (n_total**2) return tau.clamp(-1, 1) @@ -160,38 +172,25 @@ def kendall_rank_corrcoef( ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` Example (single output regression): - >>> from torchmetrics.functional.regression import kendal_rank_corrcoef + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef >>> target = torch.tensor([3, -0.5, 2, 1]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) - >>> kendal_rank_corrcoef(preds, target) + >>> kendall_rank_corrcoef(preds, target) tensor([0.3333]) Example (multi output regression): - >>> from torchmetrics.functional.regression import kendal_rank_corrcoef + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) - >>> kendal_rank_corrcoef(preds, target) + >>> kendall_rank_corrcoef(preds, target) tensor([ 1., -1.]) """ if variant not in ["a", "b", "c"]: raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") - d = preds.shape[0] if preds.ndim == 2 else 1 - _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device) - concordant_pairs, discordant_pairs = _temp.clone(), _temp.clone() - if variant == "b": - preds_ties, target_ties, total = _temp.clone(), _temp.clone(), _temp.clone() - else: - preds_ties = target_ties = total = None - - concordant_pairs, discordant_pairs, preds_ties, target_ties, total = _kendall_corrcoef_update( - preds, - target, - concordant_pairs, - discordant_pairs, - preds_ties, - target_ties, - total, - num_outputs=1 if preds.ndim == 1 else preds.shape[-1], - variant=variant, + concat_preds, concat_target = [], [] + + concat_preds, concat_target = _kendall_corrcoef_update( + preds, target, concat_preds, concat_target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1] ) - return _kendall_corrcoef_compute(concordant_pairs, discordant_pairs, preds_ties, target_ties, total, variant) + + return _kendall_corrcoef_compute(_dim_one_cat(concat_preds), _dim_one_cat(concat_target), variant) diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index 71ded88a5db..eecb56a452f 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.regression.concordance import ConcordanceCorrCoef # noqa: F401 from torchmetrics.regression.cosine_similarity import CosineSimilarity # noqa: F401 from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401 +from torchmetrics.regression.kendall import KendallRankCorrCoef # noqa: F401 from torchmetrics.regression.kl_divergence import KLDivergence # noqa: F401 from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401 from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401 diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index e69de29bb2d..d624fe52275 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -0,0 +1,100 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.kendall import _dim_one_cat, _kendall_corrcoef_compute, _kendall_corrcoef_update +from torchmetrics.metric import Metric + + +class KendallRankCorrCoef(Metric): + r"""Computes `Kendal Rank Correlation Coefficient`_: + + Where :math:`y` is a tensor of target values, and :math:`x` is a tensor of predictions. + + Forward accepts + + - ``preds``: Ordered sequence of data + - ``target``: Ordered sequence of data + + Args: + variant: Indication of which variant of Kendall's tau to be used + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (single output regression): + >>> from torchmetrics.regression import KendallRankCorrCoef + >>> target = torch.tensor([3, -0.5, 2, 1]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> kendall = KendallRankCorrCoef() + >>> kendall(preds, target) + tensor([0.3333]) + + Example (multi output regression): + kendall + >>> target = torch.tensor([[3, -0.5], [2, 1]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> kendall = KendallRankCorrCoef() + >>> kendall(preds, target) + tensor([ 1., -1.]) + """ + + is_differentiable = True + higher_is_better = None + full_state_update = True + preds: List[Tensor] + target: List[Tensor] + + def __init__( + self, + variant: Literal["a", "b", "c"] = "b", + num_outputs: int = 1, + **kwargs: Any, + ): + super().__init__(**kwargs) + if variant not in ["a", "b", "c"]: + raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") + self.variant = variant + if not isinstance(num_outputs, int) and num_outputs < 1: + raise ValueError("Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}") + self.num_outputs = num_outputs + + self.add_state("preds", [], dist_reduce_fx="cat") + self.add_state("target", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update variables required to compute Kendall rank correlation coefficient. + + Args: + preds: Ordered sequence of data + target: Ordered sequence of data + """ + self.preds, self.target = _kendall_corrcoef_update( + preds, + target, + self.preds, + self.target, + num_outputs=self.num_outputs, + ) + + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Compute Kendall rank correlation coefficient, commonly also known as Kendall's tau.""" + preds = _dim_one_cat(self.preds) + target = _dim_one_cat(self.target) + + return _kendall_corrcoef_compute(preds, target, self.variant) From 22628b1ccab38086871959f536c89a23d49d30a9 Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 15 Oct 2022 20:11:52 +0200 Subject: [PATCH 05/38] Add sorting for the other sequence to calcualte ties properly --- src/torchmetrics/functional/regression/kendall.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index bd9c4020714..55ddd1f0e8a 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -26,7 +26,7 @@ def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" # We need to clone `y` tensor not to change it in memory y = torch.clone(y) - x, perm = x.sort(stable=False) + x, perm = x.sort() for i in range(x.shape[0]): y[i] = y[i][perm[i]] return x, y @@ -57,8 +57,11 @@ def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: return torch.cat([_discordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) -def _convert_sequence_to_dense_rank(x: Tensor) -> Tensor: +def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor: """Convert a sequence to the rank tensor.""" + # Sort if a sequence has not been sorted before + if sort: + x = x.sort(dim=0).values _ones = torch.zeros(1, x.shape[1], dtype=torch.int32, device=x.device) return torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0).cumsum(0) @@ -92,14 +95,13 @@ def _get_metric_metadata( concordant_pairs = _count_concordant_pairs(preds, target) discordant_pairs = _count_discordant_pairs(preds, target) + preds_ties = target_ties = n_total = None if variant == "b": preds = _convert_sequence_to_dense_rank(preds) - target = _convert_sequence_to_dense_rank(target) + target = _convert_sequence_to_dense_rank(target, sort=True) preds_ties = _get_ties(preds) target_ties = _get_ties(target) n_total = preds.shape[0] - else: - preds_ties = target_ties = n_total = None return concordant_pairs, discordant_pairs, preds_ties, target_ties, n_total From 302577b018492b2c2cc116369325d0a66e88163c Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 15 Oct 2022 23:29:35 +0200 Subject: [PATCH 06/38] Add hypothesis testing + links to docs --- docs/source/links.rst | 1 + .../regression/kendall_rank_corr_coef.rst | 22 +++ .../functional/regression/kendall.py | 173 +++++++++++++++--- src/torchmetrics/regression/kendall.py | 41 ++++- 4 files changed, 205 insertions(+), 32 deletions(-) create mode 100644 docs/source/regression/kendall_rank_corr_coef.rst diff --git a/docs/source/links.rst b/docs/source/links.rst index 2cb9be7918f..1b0273b8556 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -93,3 +93,4 @@ .. _AB divergence: https://pdfs.semanticscholar.org/744b/1166de34cb099100f151f3b1459f141ae25b.pdf .. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf .. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric +.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient diff --git a/docs/source/regression/kendall_rank_corr_coef.rst b/docs/source/regression/kendall_rank_corr_coef.rst new file mode 100644 index 00000000000..fa03151d0e3 --- /dev/null +++ b/docs/source/regression/kendall_rank_corr_coef.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Kendall Rank Correlation Coefficient + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Regression + +.. include:: ../links.rst + +####################### +Kendal Rank Corr. Coef. +####################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.KendallRankCorrCoeff + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.kendall.kendall_rank_corrcoef + :noindex: diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 55ddd1f0e8a..2d4d008c4dd 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -20,6 +20,27 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import EnumStr + + +class _TestAlternative(EnumStr): + TWO_SIDED = "two-sided" + LESS = "less" + GREATER = "greater" + + @classmethod + def from_str(cls, value: str) -> Optional["EnumStr"]: + """ + Raises: + ValueError: + If required test alternativeis not among the supported options. + """ + _allowed_im = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] + + enum_key = super().from_str(value.replace("-", "_")) + if enum_key is not None and enum_key in _allowed_im: + return enum_key + raise ValueError(f"Invalid test alternative. Expected one of {_allowed_im}, but got {enum_key}.") def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: @@ -66,14 +87,19 @@ def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor: return torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0).cumsum(0) -def _get_ties(x: Tensor) -> Tensor: - """Get number of ties in a given sequence.""" +def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Get number of ties and staistics for p-value calculation for a given sequence.""" ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) + ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) + ties_p2 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) for dim in range(x.shape[1]): n_ties = _bincount(x[:, dim]) n_ties = n_ties[n_ties > 1] ties[dim] = (n_ties * (n_ties - 1) // 2).sum() - return ties + ties_p1[dim] = (n_ties * (n_ties - 1.0) * (n_ties - 2)).sum() + ties_p2[dim] = (n_ties * (n_ties - 1.0) * (2 * n_ties + 5)).sum() + + return ties, ties_p1, ties_p2 def _dim_one_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: @@ -86,7 +112,7 @@ def _dim_one_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: def _get_metric_metadata( preds: Tensor, target: Tensor, variant: Literal["a", "b", "c"] -) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: +) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor]: """Obtain statistics to calculate metric value.""" # Sort on target and convert it to dense rank preds, target = _sort_on_first_sequence(preds, target) @@ -95,19 +121,66 @@ def _get_metric_metadata( concordant_pairs = _count_concordant_pairs(preds, target) discordant_pairs = _count_discordant_pairs(preds, target) - preds_ties = target_ties = n_total = None + n_total = torch.tensor(preds.shape[0], device=preds.device) + preds_ties = target_ties = None + preds_ties_p1 = preds_ties_p2 = target_ties_p1 = target_ties_p2 = None if variant == "b": preds = _convert_sequence_to_dense_rank(preds) target = _convert_sequence_to_dense_rank(target, sort=True) - preds_ties = _get_ties(preds) - target_ties = _get_ties(target) - n_total = preds.shape[0] + preds_ties, preds_ties_p1, preds_ties_p2 = _get_ties(preds) + target_ties, target_ties_p1, target_ties_p2 = _get_ties(target) + + return ( + concordant_pairs, + discordant_pairs, + preds_ties, + preds_ties_p1, + preds_ties_p2, + target_ties, + target_ties_p1, + target_ties_p2, + n_total, + ) + + +def _calculate_p_value( + con_min_dis_pairs: Tensor, + n_total: Tensor, + preds_ties: Optional[Tensor], + preds_ties_p1: Optional[Tensor], + preds_ties_p2: Optional[Tensor], + target_ties: Optional[Tensor], + target_ties_p1: Optional[Tensor], + target_ties_p2: Optional[Tensor], + variant: Literal["a", "b", "c"], + alternative: _TestAlternative, +) -> Tensor: + normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) + + t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5) + if variant == "a": + t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2) + else: + m = n_total * (n_total - 1) + t_value_denominator = (t_value_denominator_base - preds_ties_p2 - target_ties_p2) / 18 + t_value_denominator += (2 * preds_ties * target_ties) / m + t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) + t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator) - return concordant_pairs, discordant_pairs, preds_ties, target_ties, n_total + if alternative in [_TestAlternative.TWO_SIDED, _TestAlternative.GREATER]: + t_value *= -1 + p_value = normal_dist.cdf(t_value) + if alternative == _TestAlternative.TWO_SIDED: + p_value *= 2 + return p_value def _kendall_corrcoef_update( - preds: Tensor, target: Tensor, concat_preds: List[Tensor], concat_target: List[Tensor], num_outputs: int + preds: Tensor, + target: Tensor, + concat_preds: List[Tensor] = [], + concat_target: List[Tensor] = [], + num_outputs: int = 1, ) -> Tuple[List[Tensor], List[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. @@ -140,35 +213,78 @@ def _kendall_corrcoef_compute( preds: Tensor, target: Tensor, variant: Literal["a", "b", "c"], -) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Compute the value of Kendall rank correlation coefficient given pre-computed state variables.""" - concordant_pairs, discordant_pairs, preds_ties, target_ties, n_total = _get_metric_metadata(preds, target, variant) + alternative: Optional[_TestAlternative] = None, +) -> Tuple[Tensor, Optional[Tensor]]: + """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test.""" + ( + concordant_pairs, + discordant_pairs, + preds_ties, + preds_ties_p1, + preds_ties_p2, + target_ties, + target_ties_p1, + target_ties_p2, + n_total, + ) = _get_metric_metadata(preds, target, variant) con_min_dis_pairs = concordant_pairs - discordant_pairs if variant == "a": tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) elif variant == "b": - total_combinations = n_total * (n_total - 1) // 2 + total_combinations: Tensor = n_total * (n_total - 1) // 2 denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) tau = con_min_dis_pairs / torch.sqrt(denominator) else: tau = 2 * con_min_dis_pairs / (n_total**2) - return tau.clamp(-1, 1) + p_value = ( + _calculate_p_value( + con_min_dis_pairs, + n_total, + preds_ties, + preds_ties_p1, + preds_ties_p2, + target_ties, + target_ties_p1, + target_ties_p2, + variant, + alternative, + ) + if alternative + else None + ) + + # Squeeze tensor if num_outputs=1 + if tau.shape[0] == 1: + tau = tau.squeeze() + p_value = p_value.squeeze() if p_value is not None else None + + return tau.clamp(-1, 1), p_value def kendall_rank_corrcoef( - preds: Tensor, target: Tensor, variant: Literal["a", "b", "c"] = "b" + preds: Tensor, + target: Tensor, + variant: Literal["a", "b", "c"] = "b", + t_test: bool = False, + alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Compute Kendall rank correlation coefficient, commonly also known as Kendall's tau. + """Computes `Kendall Rank Correlation Coefficient`_. Args: preds: Ordered sequence of data target: Ordered sequence of data variant: Indication of which variant of Kendall's tau to be used + t_test: Indication whether to run t-test + alternative: Alternative hypothesis for for t-test. Possible values: + - 'two-sided': the rank correlation is nonzero + - 'less': the rank correlation is negative (less than zero) + - 'greater': the rank correlation is positive (greater than zero) Return: Correlation tau statistic + (Optional) p-value of corresponding statistical test (asymptotic) Raises: ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` @@ -178,7 +294,7 @@ def kendall_rank_corrcoef( >>> target = torch.tensor([3, -0.5, 2, 1]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> kendall_rank_corrcoef(preds, target) - tensor([0.3333]) + tensor(0.3333) Example (multi output regression): >>> from torchmetrics.functional.regression import kendall_rank_corrcoef @@ -188,11 +304,20 @@ def kendall_rank_corrcoef( tensor([ 1., -1.]) """ if variant not in ["a", "b", "c"]: - raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") - concat_preds, concat_target = [], [] - - concat_preds, concat_target = _kendall_corrcoef_update( - preds, target, concat_preds, concat_target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1] + raise ValueError(f"Argument `variant` is expected to be one of `['a', 'b', 'c']`, but got {variant!r}.") + if not isinstance(t_test, bool): + raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") + if not t_test: + alternative = None + if t_test and not alternative: + raise ValueError("Alternative must be specified when `t_test=True`.") + _alternative = _TestAlternative.from_str(alternative) + + _preds, _target = _kendall_corrcoef_update( + preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1] ) + tau, p_value = _kendall_corrcoef_compute(_dim_one_cat(_preds), _dim_one_cat(_target), variant, _alternative) - return _kendall_corrcoef_compute(_dim_one_cat(concat_preds), _dim_one_cat(concat_target), variant) + if p_value is not None: + return tau, p_value + return tau diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index d624fe52275..5f348b924f8 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -12,18 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import Tensor from typing_extensions import Literal -from torchmetrics.functional.regression.kendall import _dim_one_cat, _kendall_corrcoef_compute, _kendall_corrcoef_update +from torchmetrics.functional.regression.kendall import ( + _dim_one_cat, + _kendall_corrcoef_compute, + _kendall_corrcoef_update, + _TestAlternative, +) from torchmetrics.metric import Metric class KendallRankCorrCoef(Metric): - r"""Computes `Kendal Rank Correlation Coefficient`_: + r"""Computes `Kendall Rank Correlation Coefficient`_: Where :math:`y` is a tensor of target values, and :math:`x` is a tensor of predictions. @@ -34,6 +39,11 @@ class KendallRankCorrCoef(Metric): Args: variant: Indication of which variant of Kendall's tau to be used + t_test: Indication whether to run t-test + alternative: Alternative hypothesis for for t-test. Possible values: + - 'two-sided': the rank correlation is nonzero + - 'less': the rank correlation is negative (less than zero) + - 'greater': the rank correlation is positive (greater than zero) num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -43,7 +53,7 @@ class KendallRankCorrCoef(Metric): >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> kendall = KendallRankCorrCoef() >>> kendall(preds, target) - tensor([0.3333]) + tensor(0.3333) Example (multi output regression): kendall @@ -63,6 +73,8 @@ class KendallRankCorrCoef(Metric): def __init__( self, variant: Literal["a", "b", "c"] = "b", + t_test: bool = False, + alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", num_outputs: int = 1, **kwargs: Any, ): @@ -70,8 +82,13 @@ def __init__( if variant not in ["a", "b", "c"]: raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") self.variant = variant - if not isinstance(num_outputs, int) and num_outputs < 1: - raise ValueError("Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}") + if not isinstance(t_test, bool): + raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") + if not t_test: + alternative = None + if t_test and not alternative: + raise ValueError("Alternative must be specified when `t_test=True`.") + self.alternative = _TestAlternative.from_str(alternative) self.num_outputs = num_outputs self.add_state("preds", [], dist_reduce_fx="cat") @@ -93,8 +110,16 @@ def update(self, preds: Tensor, target: Tensor) -> None: ) def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Compute Kendall rank correlation coefficient, commonly also known as Kendall's tau.""" + """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test. + + Return: + Correlation tau statistic + (Optional) p-value of corresponding statistical test (asymptotic) + """ preds = _dim_one_cat(self.preds) target = _dim_one_cat(self.target) + tau, p_value = _kendall_corrcoef_compute(preds, target, self.variant, self.alternative) - return _kendall_corrcoef_compute(preds, target, self.variant) + if p_value is not None: + return tau, p_value + return tau From 5976f5c10f71c12d2bf6108f64d2d6cadb819fc8 Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 15 Oct 2022 23:30:26 +0200 Subject: [PATCH 07/38] chlog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39653683baa..a87daa3aaf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `TotalVariation` to image package ([#978](https://github.com/Lightning-AI/metrics/pull/978)) -- Added `KendallRankCorrCoef` to regression package ([]()) + + +- Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271)) + ### Changed From f625f15855deaf2be6343124cf4b6e87c94db910 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 00:37:49 +0200 Subject: [PATCH 08/38] WIP: Add some tests --- .../regression/kendall_rank_corr_coef.rst | 2 +- .../functional/regression/kendall.py | 48 +++++---- src/torchmetrics/regression/kendall.py | 6 +- tests/unittests/regression/test_kendall.py | 99 +++++++++++++++++++ 4 files changed, 129 insertions(+), 26 deletions(-) create mode 100644 tests/unittests/regression/test_kendall.py diff --git a/docs/source/regression/kendall_rank_corr_coef.rst b/docs/source/regression/kendall_rank_corr_coef.rst index fa03151d0e3..955ae2065d3 100644 --- a/docs/source/regression/kendall_rank_corr_coef.rst +++ b/docs/source/regression/kendall_rank_corr_coef.rst @@ -12,7 +12,7 @@ Kendal Rank Corr. Coef. Module Interface ________________ -.. autoclass:: torchmetrics.KendallRankCorrCoeff +.. autoclass:: torchmetrics.KendallRankCorrCoef :noindex: Functional Interface diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 2d4d008c4dd..3b04261ec1a 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -23,6 +23,15 @@ from torchmetrics.utilities.enums import EnumStr +def _dim_one_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: + """Concatenation along the one dimension.""" + x = x if isinstance(x, (list, tuple)) else [x] + x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] + if not x: # empty list + raise ValueError("No samples to concatenate") + return torch.cat(x, dim=1) + + class _TestAlternative(EnumStr): TWO_SIDED = "two-sided" LESS = "less" @@ -35,12 +44,12 @@ def from_str(cls, value: str) -> Optional["EnumStr"]: ValueError: If required test alternativeis not among the supported options. """ - _allowed_im = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] + _allowed_alternative = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] enum_key = super().from_str(value.replace("-", "_")) - if enum_key is not None and enum_key in _allowed_im: + if enum_key is not None and enum_key in _allowed_alternative: return enum_key - raise ValueError(f"Invalid test alternative. Expected one of {_allowed_im}, but got {enum_key}.") + raise ValueError(f"Invalid test alternative. Expected one of {_allowed_alternative}, but got {enum_key}.") def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: @@ -102,17 +111,19 @@ def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: return ties, ties_p1, ties_p2 -def _dim_one_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: - """Concatenation along the one dimension.""" - x = x if isinstance(x, (list, tuple)) else [x] - if not x: # empty list - raise ValueError("No samples to concatenate") - return torch.cat(x, dim=1) - - def _get_metric_metadata( preds: Tensor, target: Tensor, variant: Literal["a", "b", "c"] -) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor]: +) -> Tuple[ + Tensor, + Tensor, + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Optional[Tensor], + Tensor, +]: """Obtain statistics to calculate metric value.""" # Sort on target and convert it to dense rank preds, target = _sort_on_first_sequence(preds, target) @@ -162,7 +173,7 @@ def _calculate_p_value( t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2) else: m = n_total * (n_total - 1) - t_value_denominator = (t_value_denominator_base - preds_ties_p2 - target_ties_p2) / 18 + t_value_denominator: Tensor = (t_value_denominator_base - preds_ties_p2 - target_ties_p2) / 18 t_value_denominator += (2 * preds_ties * target_ties) / m t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator) @@ -194,10 +205,10 @@ def _kendall_corrcoef_update( f"Expected both predictions and target to be either 1- or 2-dimensional tensors," f" but got {target.ndim} and {preds.ndim}." ) - if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): + if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[1]): raise ValueError( f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" - f" and {preds.ndim}." + f" and {preds.shape[1]}." ) if num_outputs == 1: preds = preds.unsqueeze(0) @@ -307,15 +318,12 @@ def kendall_rank_corrcoef( raise ValueError(f"Argument `variant` is expected to be one of `['a', 'b', 'c']`, but got {variant!r}.") if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") - if not t_test: - alternative = None - if t_test and not alternative: - raise ValueError("Alternative must be specified when `t_test=True`.") - _alternative = _TestAlternative.from_str(alternative) + _alternative = _TestAlternative.from_str(alternative) if t_test else None _preds, _target = _kendall_corrcoef_update( preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1] ) + print(_preds[0].shape) tau, p_value = _kendall_corrcoef_compute(_dim_one_cat(_preds), _dim_one_cat(_target), variant, _alternative) if p_value is not None: diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 5f348b924f8..e0918380519 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -84,11 +84,7 @@ def __init__( self.variant = variant if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") - if not t_test: - alternative = None - if t_test and not alternative: - raise ValueError("Alternative must be specified when `t_test=True`.") - self.alternative = _TestAlternative.from_str(alternative) + self.alternative = _TestAlternative.from_str(alternative) if t_test else None self.num_outputs = num_outputs self.add_state("preds", [], dist_reduce_fx="cat") diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py new file mode 100644 index 00000000000..963a234f987 --- /dev/null +++ b/tests/unittests/regression/test_kendall.py @@ -0,0 +1,99 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from scipy.stats import kendalltau + +from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef +from torchmetrics.regression.kendall import KendallRankCorrCoef +from unittests.helpers import seed_all +from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +_single_inputs1 = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE)) +_single_inputs2 = Input(preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randn(NUM_BATCHES, BATCH_SIZE)) +_single_inputs3 = Input( + preds=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE)), target=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE)) +) +_multi_inputs1 = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM) +) +_multi_inputs2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM) +) +_multi_inputs3 = Input( + preds=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), +) + + +def _sk_metric(preds, target, alternative="two-sided"): + _alternative = alternative or "two-sided" + if preds.ndim == 2: + out = [ + kendalltau(p.numpy(), t.numpy(), method="asymptotic", alternative=_alternative) + for p, t in zip(preds, target) + ] + tau = torch.cat([torch.tensor(o[0]).unsqueeze(0) for o in out]) + p_value = torch.cat([torch.tensor(o[1]).unsqueeze(0) for o in out]) + if alternative is not None: + return tau, p_value + return tau + + tau, p_value = kendalltau(preds.numpy(), target.numpy(), method="asymptotic", alternative=_alternative) + + if alternative is not None: + return torch.tensor(tau), torch.tensor(p_value) + return torch.tensor(tau) + + +@pytest.mark.parametrize( + "preds, target, alternative", + [ + (_single_inputs1.preds, _single_inputs1.target, None), + (_single_inputs2.preds, _single_inputs2.target, "less"), + (_single_inputs3.preds, _single_inputs3.target, "greater"), + (_multi_inputs1.preds, _multi_inputs1.target, None), + (_multi_inputs2.preds, _multi_inputs2.target, "two-sided"), + (_multi_inputs3.preds, _multi_inputs3.target, "greater"), + ], +) +class TestKendallRankCorrCoef(MetricTester): + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [False]) + def test_kendall_rank_corrcoef(self, preds, target, alternative, ddp, dist_sync_on_step): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + t_test = True if alternative is not None else False + _sk_kendall_tau = partial(_sk_metric, alternative=alternative) + + self.run_class_metric_test( + ddp, + preds, + target, + KendallRankCorrCoef, + _sk_kendall_tau, + dist_sync_on_step, + metric_args={"t_test": t_test, "alternative": alternative, "num_outputs": num_outputs}, + ) + + def test_kendall_rank_corrcoef_functional(self, preds, target, alternative): + t_test = True if alternative is not None else False + metric_args = {"t_test": t_test, "alternative": alternative} + _sk_kendall_tau = partial(_sk_metric, alternative=alternative) + self.run_functional_metric_test(preds, target, kendall_rank_corrcoef, _sk_kendall_tau, metric_args=metric_args) From fcfe90ac93376f3b1c2f5941f3e7f933401f2773 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:01:02 +0200 Subject: [PATCH 09/38] Fix dimension handling --- .../functional/regression/kendall.py | 26 +++++++------------ src/torchmetrics/regression/kendall.py | 15 ++++++----- tests/unittests/regression/test_kendall.py | 11 +++++++- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 3b04261ec1a..3b2f4369e36 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -19,19 +19,10 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.data import _bincount, dim_zero_cat from torchmetrics.utilities.enums import EnumStr -def _dim_one_cat(x: Union[Tensor, List[Tensor]]) -> Tensor: - """Concatenation along the one dimension.""" - x = x if isinstance(x, (list, tuple)) else [x] - x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x] - if not x: # empty list - raise ValueError("No samples to concatenate") - return torch.cat(x, dim=1) - - class _TestAlternative(EnumStr): TWO_SIDED = "two-sided" LESS = "less" @@ -56,10 +47,11 @@ def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" # We need to clone `y` tensor not to change it in memory y = torch.clone(y) + x, y = x.T, y.T x, perm = x.sort() for i in range(x.shape[0]): y[i] = y[i][perm[i]] - return x, y + return x.T, y.T def _count_concordant_pairs(preds: Tensor, target: Tensor) -> Tensor: @@ -127,10 +119,10 @@ def _get_metric_metadata( """Obtain statistics to calculate metric value.""" # Sort on target and convert it to dense rank preds, target = _sort_on_first_sequence(preds, target) - preds, target = preds.T, target.T concordant_pairs = _count_concordant_pairs(preds, target) discordant_pairs = _count_discordant_pairs(preds, target) + # preds, target = preds.T, target.T n_total = torch.tensor(preds.shape[0], device=preds.device) preds_ties = target_ties = None @@ -140,7 +132,6 @@ def _get_metric_metadata( target = _convert_sequence_to_dense_rank(target, sort=True) preds_ties, preds_ties_p1, preds_ties_p2 = _get_ties(preds) target_ties, target_ties_p1, target_ties_p2 = _get_ties(target) - return ( concordant_pairs, discordant_pairs, @@ -178,6 +169,8 @@ def _calculate_p_value( t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator) + if alternative == _TestAlternative.TWO_SIDED: + t_value = torch.abs(t_value) if alternative in [_TestAlternative.TWO_SIDED, _TestAlternative.GREATER]: t_value *= -1 p_value = normal_dist.cdf(t_value) @@ -211,8 +204,8 @@ def _kendall_corrcoef_update( f" and {preds.shape[1]}." ) if num_outputs == 1: - preds = preds.unsqueeze(0) - target = target.unsqueeze(0) + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) concat_preds.append(preds) concat_target.append(target) @@ -323,8 +316,7 @@ def kendall_rank_corrcoef( _preds, _target = _kendall_corrcoef_update( preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1] ) - print(_preds[0].shape) - tau, p_value = _kendall_corrcoef_compute(_dim_one_cat(_preds), _dim_one_cat(_target), variant, _alternative) + tau, p_value = _kendall_corrcoef_compute(dim_zero_cat(_preds), dim_zero_cat(_target), variant, _alternative) if p_value is not None: return tau, p_value diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index e0918380519..1142b35f264 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -14,17 +14,16 @@ from typing import Any, List, Optional, Tuple, Union -import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.regression.kendall import ( - _dim_one_cat, _kendall_corrcoef_compute, _kendall_corrcoef_update, _TestAlternative, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat class KendallRankCorrCoef(Metric): @@ -48,6 +47,7 @@ class KendallRankCorrCoef(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (single output regression): + >>> import torch >>> from torchmetrics.regression import KendallRankCorrCoef >>> target = torch.tensor([3, -0.5, 2, 1]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) @@ -56,15 +56,16 @@ class KendallRankCorrCoef(Metric): tensor(0.3333) Example (multi output regression): - kendall + >>> import torch + >>> from torchmetrics.regression import KendallRankCorrCoef >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) - >>> kendall = KendallRankCorrCoef() + >>> kendall = KendallRankCorrCoef(num_outputs=2) >>> kendall(preds, target) tensor([ 1., -1.]) """ - is_differentiable = True + is_differentiable = False higher_is_better = None full_state_update = True preds: List[Tensor] @@ -112,8 +113,8 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: Correlation tau statistic (Optional) p-value of corresponding statistical test (asymptotic) """ - preds = _dim_one_cat(self.preds) - target = _dim_one_cat(self.target) + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) tau, p_value = _kendall_corrcoef_compute(preds, target, self.variant, self.alternative) if p_value is not None: diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index 963a234f987..fa87934535c 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -48,7 +48,7 @@ def _sk_metric(preds, target, alternative="two-sided"): if preds.ndim == 2: out = [ kendalltau(p.numpy(), t.numpy(), method="asymptotic", alternative=_alternative) - for p, t in zip(preds, target) + for p, t in zip(preds.T, target.T) ] tau = torch.cat([torch.tensor(o[0]).unsqueeze(0) for o in out]) p_value = torch.cat([torch.tensor(o[1]).unsqueeze(0) for o in out]) @@ -97,3 +97,12 @@ def test_kendall_rank_corrcoef_functional(self, preds, target, alternative): metric_args = {"t_test": t_test, "alternative": alternative} _sk_kendall_tau = partial(_sk_metric, alternative=alternative) self.run_functional_metric_test(preds, target, kendall_rank_corrcoef, _sk_kendall_tau, metric_args=metric_args) + + def test_kendall_rank_corrcoef_differentiability(self, preds, target, alternative): + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=partial(KendallRankCorrCoef, num_outputs=num_outputs), + metric_functional=kendall_rank_corrcoef, + ) From 87fe7239946eda52ecebbc1cde9137034939233b Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:09:07 +0200 Subject: [PATCH 10/38] Fix doctest --- docs/source/regression/kendall_rank_corr_coef.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/regression/kendall_rank_corr_coef.rst b/docs/source/regression/kendall_rank_corr_coef.rst index 955ae2065d3..66ccc774193 100644 --- a/docs/source/regression/kendall_rank_corr_coef.rst +++ b/docs/source/regression/kendall_rank_corr_coef.rst @@ -18,5 +18,5 @@ ________________ Functional Interface ____________________ -.. autofunction:: torchmetrics.functional.kendall.kendall_rank_corrcoef +.. autofunction:: torchmetrics.functional.kendall_rank_corrcoef :noindex: From 12e2500efb35056fb8fca097204c3da078ffca69 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:14:56 +0200 Subject: [PATCH 11/38] Add some docs + handle some mypy errors --- src/torchmetrics/functional/regression/kendall.py | 8 ++++++-- src/torchmetrics/regression/kendall.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 3b2f4369e36..dbaca274412 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -165,8 +165,8 @@ def _calculate_p_value( else: m = n_total * (n_total - 1) t_value_denominator: Tensor = (t_value_denominator_base - preds_ties_p2 - target_ties_p2) / 18 - t_value_denominator += (2 * preds_ties * target_ties) / m - t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) + t_value_denominator += (2 * preds_ties * target_ties) / m # typing: ignore (is Tensor) + t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) # typing: ignore (is Tensor) t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator) if alternative == _TestAlternative.TWO_SIDED: @@ -292,6 +292,8 @@ def kendall_rank_corrcoef( Raises: ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` + ValueError: If ``t_test`` is not of a type bool + ValueError: If ``t_test=True`` and ``alternative=None`` Example (single output regression): >>> from torchmetrics.functional.regression import kendall_rank_corrcoef @@ -311,6 +313,8 @@ def kendall_rank_corrcoef( raise ValueError(f"Argument `variant` is expected to be one of `['a', 'b', 'c']`, but got {variant!r}.") if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") + if t_test and alternative is None: + raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") _alternative = _TestAlternative.from_str(alternative) if t_test else None _preds, _target = _kendall_corrcoef_update( diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 1142b35f264..d47f16a771f 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -46,6 +46,11 @@ class KendallRankCorrCoef(Metric): num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Raises: + ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` + ValueError: If ``t_test`` is not of a type bool + ValueError: If ``t_test=True`` and ``alternative=None`` + Example (single output regression): >>> import torch >>> from torchmetrics.regression import KendallRankCorrCoef @@ -85,6 +90,8 @@ def __init__( self.variant = variant if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") + if t_test and alternative is None: + raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") self.alternative = _TestAlternative.from_str(alternative) if t_test else None self.num_outputs = num_outputs From c773f8c7e70896f8c2216e31b893dbdba18f4731 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:28:32 +0200 Subject: [PATCH 12/38] Update docs + fix some mypy errors --- docs/source/links.rst | 1 + .../functional/regression/kendall.py | 44 +++++++++++++++---- src/torchmetrics/regression/kendall.py | 8 +++- 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index 1b0273b8556..b305ac49fd8 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -94,3 +94,4 @@ .. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf .. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric .. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient +.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303 diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index dbaca274412..eaa5e4c286b 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -24,6 +24,8 @@ class _TestAlternative(EnumStr): + """Enumerate for test altenative options.""" + TWO_SIDED = "two-sided" LESS = "less" GREATER = "greater" @@ -45,7 +47,7 @@ def from_str(cls, value: str) -> Optional["EnumStr"]: def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" - # We need to clone `y` tensor not to change it in memory + # We need to clone `y` tensor not to change an object in memory y = torch.clone(y) x, y = x.T, y.T x, perm = x.sort() @@ -89,7 +91,7 @@ def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor: def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """Get number of ties and staistics for p-value calculation for a given sequence.""" + """Get a total number of ties and staistics for p-value calculation for a given sequence.""" ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) ties_p2 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device) @@ -117,12 +119,10 @@ def _get_metric_metadata( Tensor, ]: """Obtain statistics to calculate metric value.""" - # Sort on target and convert it to dense rank preds, target = _sort_on_first_sequence(preds, target) concordant_pairs = _count_concordant_pairs(preds, target) discordant_pairs = _count_discordant_pairs(preds, target) - # preds, target = preds.T, target.T n_total = torch.tensor(preds.shape[0], device=preds.device) preds_ties = target_ties = None @@ -155,7 +155,7 @@ def _calculate_p_value( target_ties_p1: Optional[Tensor], target_ties_p2: Optional[Tensor], variant: Literal["a", "b", "c"], - alternative: _TestAlternative, + alternative: Optional[_TestAlternative], ) -> Tensor: normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) @@ -165,8 +165,8 @@ def _calculate_p_value( else: m = n_total * (n_total - 1) t_value_denominator: Tensor = (t_value_denominator_base - preds_ties_p2 - target_ties_p2) / 18 - t_value_denominator += (2 * preds_ties * target_ties) / m # typing: ignore (is Tensor) - t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) # typing: ignore (is Tensor) + t_value_denominator += (2 * preds_ties * target_ties) / m # type: ignore (is Tensor) + t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) # type: ignore (is Tensor) t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator) if alternative == _TestAlternative.TWO_SIDED: @@ -188,6 +188,13 @@ def _kendall_corrcoef_update( ) -> Tuple[List[Tensor], List[Tensor]]: """Update variables required to compute Kendall rank correlation coefficient. + Args: + preds: Ordered sequence of data + target: Ordered sequence of data + concat_preds: List of batches of preds sequence to be concatenated + concat_target: List of batches of target sequence to be concatenated + num_outputs: Number of outputs in multioutput setting + Raises: RuntimeError: If ``preds`` and ``target`` do not have the same shape """ @@ -219,7 +226,18 @@ def _kendall_corrcoef_compute( variant: Literal["a", "b", "c"], alternative: Optional[_TestAlternative] = None, ) -> Tuple[Tensor, Optional[Tensor]]: - """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test.""" + """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test. + + Args: + Args: + preds: Ordered sequence of data + target: Ordered sequence of data + variant: Indication of which variant of Kendall's tau to be used + alternative: Alternative hypothesis for for t-test. Possible values: + - 'two-sided': the rank correlation is nonzero + - 'less': the rank correlation is negative (less than zero) + - 'greater': the rank correlation is positive (greater than zero) + """ ( concordant_pairs, discordant_pairs, @@ -274,7 +292,15 @@ def kendall_rank_corrcoef( t_test: bool = False, alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided", ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Computes `Kendall Rank Correlation Coefficient`_. + r"""Computes `Kendall Rank Correlation Coefficient`_. + + .. math: + tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})} + + tau_c = 2 * \frac{C - D}{n ** 2 * \frac{m - 1}{m}} + + where :math:`C` is represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents + a total number of ties. Definition according to `The Treatment of Ties in Ranking Problems`_. Args: preds: Ordered sequence of data diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index d47f16a771f..91da27515a9 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -29,7 +29,13 @@ class KendallRankCorrCoef(Metric): r"""Computes `Kendall Rank Correlation Coefficient`_: - Where :math:`y` is a tensor of target values, and :math:`x` is a tensor of predictions. + .. math: + tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})} + + tau_c = 2 * \frac{C - D}{n ** 2 * \frac{m - 1}{m}} + + where :math:`C` is represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents + a total number of ties. Definition according to `The Treatment of Ties in Ranking Problems`_. Forward accepts From 90694112f6b3dc2e86dae94058a455172c0b4afb Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:32:04 +0200 Subject: [PATCH 13/38] Re-run mypy From aa609323b3a93eeee9773716adffd198c2070c8d Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:42:51 +0200 Subject: [PATCH 14/38] Refactor: Separate calculate tau method --- .../functional/regression/kendall.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index eaa5e4c286b..4fc865016a0 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -145,6 +145,28 @@ def _get_metric_metadata( ) +def _calculate_tau( + concordant_pairs: Tensor, + discordant_pairs: Tensor, + con_min_dis_pairs: Tensor, + n_total: Tensor, + preds_ties: Optional[Tensor], + target_ties: Optional[Tensor], + variant: Literal["a", "b", "c"], +) -> Tensor: + """""" + if variant == "a": + tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) + elif variant == "b": + total_combinations: Tensor = n_total * (n_total - 1) // 2 + denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) # type: ignore (is Tensor) + tau = con_min_dis_pairs / torch.sqrt(denominator) + else: + tau = 2 * con_min_dis_pairs / (n_total**2) + + return tau + + def _calculate_p_value( con_min_dis_pairs: Tensor, n_total: Tensor, @@ -157,6 +179,7 @@ def _calculate_p_value( variant: Literal["a", "b", "c"], alternative: Optional[_TestAlternative], ) -> Tensor: + """""" normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5) @@ -251,15 +274,9 @@ def _kendall_corrcoef_compute( ) = _get_metric_metadata(preds, target, variant) con_min_dis_pairs = concordant_pairs - discordant_pairs - if variant == "a": - tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) - elif variant == "b": - total_combinations: Tensor = n_total * (n_total - 1) // 2 - denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) - tau = con_min_dis_pairs / torch.sqrt(denominator) - else: - tau = 2 * con_min_dis_pairs / (n_total**2) - + tau = _calculate_tau( + concordant_pairs, discordant_pairs, con_min_dis_pairs, n_total, preds_ties, target_ties, variant + ) p_value = ( _calculate_p_value( con_min_dis_pairs, From ead166b609976b15621c1e551b91517c9db152ab Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:55:45 +0200 Subject: [PATCH 15/38] Refactor + Add variant c * Refactor: Internally, represent metric variant as enumerate * Add variant 'c' * Update docs --- .../functional/regression/kendall.py | 66 +++++++++++++------ src/torchmetrics/regression/kendall.py | 7 +- tests/unittests/regression/test_kendall.py | 23 ++++--- 3 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 4fc865016a0..4cb895f564c 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -23,6 +23,28 @@ from torchmetrics.utilities.enums import EnumStr +class _MetricVariant(EnumStr): + """Enumerate for metric variants.""" + + A = "a" + B = "B" + C = "c" + + @classmethod + def from_str(cls, value: str) -> Optional["EnumStr"]: + """ + Raises: + ValueError: + If required metric variant is not among the supported options. + """ + _allowed_variants = [im.lower() for im in _MetricVariant._member_names_] + + enum_key = super().from_str(value) + if enum_key is not None and enum_key in _allowed_variants: + return enum_key + raise ValueError(f"Invalid metric variant. Expected one of {_allowed_variants}, but got {enum_key}.") + + class _TestAlternative(EnumStr): """Enumerate for test altenative options.""" @@ -35,14 +57,14 @@ def from_str(cls, value: str) -> Optional["EnumStr"]: """ Raises: ValueError: - If required test alternativeis not among the supported options. + If required test alternative is not among the supported options. """ - _allowed_alternative = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] + _allowed_alternatives = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] enum_key = super().from_str(value.replace("-", "_")) - if enum_key is not None and enum_key in _allowed_alternative: + if enum_key is not None and enum_key in _allowed_alternatives: return enum_key - raise ValueError(f"Invalid test alternative. Expected one of {_allowed_alternative}, but got {enum_key}.") + raise ValueError(f"Invalid test alternative. Expected one of {_allowed_alternatives}, but got {enum_key}.") def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: @@ -106,7 +128,7 @@ def _get_ties(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: def _get_metric_metadata( - preds: Tensor, target: Tensor, variant: Literal["a", "b", "c"] + preds: Tensor, target: Tensor, variant: _MetricVariant ) -> Tuple[ Tensor, Tensor, @@ -127,7 +149,7 @@ def _get_metric_metadata( n_total = torch.tensor(preds.shape[0], device=preds.device) preds_ties = target_ties = None preds_ties_p1 = preds_ties_p2 = target_ties_p1 = target_ties_p2 = None - if variant == "b": + if variant != _MetricVariant.A: preds = _convert_sequence_to_dense_rank(preds) target = _convert_sequence_to_dense_rank(target, sort=True) preds_ties, preds_ties_p1, preds_ties_p2 = _get_ties(preds) @@ -146,23 +168,28 @@ def _get_metric_metadata( def _calculate_tau( + preds: Tensor, + target: Tensor, concordant_pairs: Tensor, discordant_pairs: Tensor, con_min_dis_pairs: Tensor, n_total: Tensor, preds_ties: Optional[Tensor], target_ties: Optional[Tensor], - variant: Literal["a", "b", "c"], + variant: _MetricVariant, ) -> Tensor: - """""" - if variant == "a": + """Calculate Kendall's tau from metric metadata.""" + if variant == _MetricVariant.A: tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) - elif variant == "b": + elif variant == _MetricVariant.B: total_combinations: Tensor = n_total * (n_total - 1) // 2 denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) # type: ignore (is Tensor) tau = con_min_dis_pairs / torch.sqrt(denominator) else: - tau = 2 * con_min_dis_pairs / (n_total**2) + preds_unique = torch.tensor([len(p.unique()) for p in preds.T]) + target_unique = torch.tensor([len(t.unique()) for t in target.T]) + min_classes = torch.minimum(preds_unique, target_unique) + tau = 2 * con_min_dis_pairs / ((min_classes - 1) / min_classes * n_total**2) return tau @@ -176,14 +203,14 @@ def _calculate_p_value( target_ties: Optional[Tensor], target_ties_p1: Optional[Tensor], target_ties_p2: Optional[Tensor], - variant: Literal["a", "b", "c"], + variant: _MetricVariant, alternative: Optional[_TestAlternative], ) -> Tensor: - """""" + """Calculate p-value for Kendall's tau from metric metadata.""" normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5) - if variant == "a": + if variant == _MetricVariant.A: t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2) else: m = n_total * (n_total - 1) @@ -246,7 +273,7 @@ def _kendall_corrcoef_update( def _kendall_corrcoef_compute( preds: Tensor, target: Tensor, - variant: Literal["a", "b", "c"], + variant: _MetricVariant, alternative: Optional[_TestAlternative] = None, ) -> Tuple[Tensor, Optional[Tensor]]: """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test. @@ -275,7 +302,7 @@ def _kendall_corrcoef_compute( con_min_dis_pairs = concordant_pairs - discordant_pairs tau = _calculate_tau( - concordant_pairs, discordant_pairs, con_min_dis_pairs, n_total, preds_ties, target_ties, variant + preds, target, concordant_pairs, discordant_pairs, con_min_dis_pairs, n_total, preds_ties, target_ties, variant ) p_value = ( _calculate_p_value( @@ -334,7 +361,6 @@ def kendall_rank_corrcoef( (Optional) p-value of corresponding statistical test (asymptotic) Raises: - ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` ValueError: If ``t_test`` is not of a type bool ValueError: If ``t_test=True`` and ``alternative=None`` @@ -352,18 +378,18 @@ def kendall_rank_corrcoef( >>> kendall_rank_corrcoef(preds, target) tensor([ 1., -1.]) """ - if variant not in ["a", "b", "c"]: - raise ValueError(f"Argument `variant` is expected to be one of `['a', 'b', 'c']`, but got {variant!r}.") if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") if t_test and alternative is None: raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") + + _variant = _MetricVariant.from_str(variant) _alternative = _TestAlternative.from_str(alternative) if t_test else None _preds, _target = _kendall_corrcoef_update( preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1] ) - tau, p_value = _kendall_corrcoef_compute(dim_zero_cat(_preds), dim_zero_cat(_target), variant, _alternative) + tau, p_value = _kendall_corrcoef_compute(dim_zero_cat(_preds), dim_zero_cat(_target), _variant, _alternative) if p_value is not None: return tau, p_value diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 91da27515a9..3a23b6f3544 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -20,6 +20,7 @@ from torchmetrics.functional.regression.kendall import ( _kendall_corrcoef_compute, _kendall_corrcoef_update, + _MetricVariant, _TestAlternative, ) from torchmetrics.metric import Metric @@ -53,7 +54,6 @@ class KendallRankCorrCoef(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: - ValueError: If ``variant`` is not from ``['a', 'b', 'c']`` ValueError: If ``t_test`` is not of a type bool ValueError: If ``t_test=True`` and ``alternative=None`` @@ -91,13 +91,12 @@ def __init__( **kwargs: Any, ): super().__init__(**kwargs) - if variant not in ["a", "b", "c"]: - raise ValueError(f"Argument `variant` is expected to be one of ['a', 'b', 'c'], but got {variant!r}.") - self.variant = variant if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") if t_test and alternative is None: raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") + + self.variant = _MetricVariant.from_str(variant) self.alternative = _TestAlternative.from_str(alternative) if t_test else None self.num_outputs = num_outputs diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index fa87934535c..2f7940303af 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -43,11 +43,11 @@ ) -def _sk_metric(preds, target, alternative="two-sided"): +def _sk_metric(preds, target, alternative, variant): _alternative = alternative or "two-sided" if preds.ndim == 2: out = [ - kendalltau(p.numpy(), t.numpy(), method="asymptotic", alternative=_alternative) + kendalltau(p.numpy(), t.numpy(), method="asymptotic", alternative=_alternative, variant=variant) for p, t in zip(preds.T, target.T) ] tau = torch.cat([torch.tensor(o[0]).unsqueeze(0) for o in out]) @@ -56,7 +56,9 @@ def _sk_metric(preds, target, alternative="two-sided"): return tau, p_value return tau - tau, p_value = kendalltau(preds.numpy(), target.numpy(), method="asymptotic", alternative=_alternative) + tau, p_value = kendalltau( + preds.numpy(), target.numpy(), method="asymptotic", alternative=_alternative, variant=variant + ) if alternative is not None: return torch.tensor(tau), torch.tensor(p_value) @@ -74,13 +76,14 @@ def _sk_metric(preds, target, alternative="two-sided"): (_multi_inputs3.preds, _multi_inputs3.target, "greater"), ], ) +@pytest.mark.parametrize("variant", ["b", "c"]) class TestKendallRankCorrCoef(MetricTester): @pytest.mark.parametrize("ddp", [False]) @pytest.mark.parametrize("dist_sync_on_step", [False]) - def test_kendall_rank_corrcoef(self, preds, target, alternative, ddp, dist_sync_on_step): + def test_kendall_rank_corrcoef(self, preds, target, alternative, variant, ddp, dist_sync_on_step): num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 t_test = True if alternative is not None else False - _sk_kendall_tau = partial(_sk_metric, alternative=alternative) + _sk_kendall_tau = partial(_sk_metric, alternative=alternative, variant=variant) self.run_class_metric_test( ddp, @@ -89,16 +92,16 @@ def test_kendall_rank_corrcoef(self, preds, target, alternative, ddp, dist_sync_ KendallRankCorrCoef, _sk_kendall_tau, dist_sync_on_step, - metric_args={"t_test": t_test, "alternative": alternative, "num_outputs": num_outputs}, + metric_args={"t_test": t_test, "alternative": alternative, "variant": variant, "num_outputs": num_outputs}, ) - def test_kendall_rank_corrcoef_functional(self, preds, target, alternative): + def test_kendall_rank_corrcoef_functional(self, preds, target, alternative, variant): t_test = True if alternative is not None else False - metric_args = {"t_test": t_test, "alternative": alternative} - _sk_kendall_tau = partial(_sk_metric, alternative=alternative) + metric_args = {"t_test": t_test, "alternative": alternative, "variant": variant} + _sk_kendall_tau = partial(_sk_metric, alternative=alternative, variant=variant) self.run_functional_metric_test(preds, target, kendall_rank_corrcoef, _sk_kendall_tau, metric_args=metric_args) - def test_kendall_rank_corrcoef_differentiability(self, preds, target, alternative): + def test_kendall_rank_corrcoef_differentiability(self, preds, target, alternative, variant): num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 self.run_differentiability_test( preds=preds, From ec7ec6a7cc59bf41d703200f0df267760786b7de Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:56:49 +0200 Subject: [PATCH 16/38] Fix a typo --- src/torchmetrics/functional/regression/kendall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 4cb895f564c..7a40b4fb480 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -27,7 +27,7 @@ class _MetricVariant(EnumStr): """Enumerate for metric variants.""" A = "a" - B = "B" + B = "b" C = "c" @classmethod From 6c63a4ab8e149cd837e3885faa263e1eb9369c8b Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 13:59:06 +0200 Subject: [PATCH 17/38] Add testing for ddp=True and dist_sync_on_step=True options --- tests/unittests/regression/test_kendall.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index 2f7940303af..348d5a2cbed 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -78,8 +78,8 @@ def _sk_metric(preds, target, alternative, variant): ) @pytest.mark.parametrize("variant", ["b", "c"]) class TestKendallRankCorrCoef(MetricTester): - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [False]) + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_kendall_rank_corrcoef(self, preds, target, alternative, variant, ddp, dist_sync_on_step): num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 t_test = True if alternative is not None else False From 8053a481585f0fae2bab90d32798705763ca5c72 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 14:02:47 +0200 Subject: [PATCH 18/38] Fix doctest --- src/torchmetrics/functional/regression/kendall.py | 2 +- src/torchmetrics/regression/kendall.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 7a40b4fb480..2f362e874b9 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -376,7 +376,7 @@ def kendall_rank_corrcoef( >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> kendall_rank_corrcoef(preds, target) - tensor([ 1., -1.]) + tensor([ 1., 1.]) """ if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 3a23b6f3544..9571f0329fe 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -73,7 +73,7 @@ class KendallRankCorrCoef(Metric): >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> kendall = KendallRankCorrCoef(num_outputs=2) >>> kendall(preds, target) - tensor([ 1., -1.]) + tensor([ 1., 1.]) """ is_differentiable = False From 98a8033c1c4804a67213674462591858d437cacb Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 14:06:11 +0200 Subject: [PATCH 19/38] . --- src/torchmetrics/functional/regression/kendall.py | 2 +- src/torchmetrics/regression/kendall.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 2f362e874b9..4632b79eb69 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -376,7 +376,7 @@ def kendall_rank_corrcoef( >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> kendall_rank_corrcoef(preds, target) - tensor([ 1., 1.]) + tensor([1., 1.]) """ if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 9571f0329fe..46c4e0b2449 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -73,7 +73,7 @@ class KendallRankCorrCoef(Metric): >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> kendall = KendallRankCorrCoef(num_outputs=2) >>> kendall(preds, target) - tensor([ 1., 1.]) + tensor([1., 1.]) """ is_differentiable = False From 3c74ebbe70c2a706603e79f4efec211b86745084 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 14:34:58 +0200 Subject: [PATCH 20/38] Fix tests for scipy<1.8.0 --- tests/unittests/regression/test_kendall.py | 23 +++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index 348d5a2cbed..3bd6c9379e2 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import operator from collections import namedtuple from functools import partial @@ -20,9 +21,12 @@ from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef from torchmetrics.regression.kendall import KendallRankCorrCoef +from torchmetrics.utilities.imports import _compare_version from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, MetricTester +_SCIPY_GREATER_EQUAL_1_8 = _compare_version("scipy", operator.ge, "1.8.0") + seed_all(42) Input = namedtuple("Input", ["preds", "target"]) @@ -44,10 +48,12 @@ def _sk_metric(preds, target, alternative, variant): - _alternative = alternative or "two-sided" + metric_args = {} + if _SCIPY_GREATER_EQUAL_1_8: + metric_args = {"alternative": alternative or "two-sided"} # scipy cannot accept `None` if preds.ndim == 2: out = [ - kendalltau(p.numpy(), t.numpy(), method="asymptotic", alternative=_alternative, variant=variant) + kendalltau(p.numpy(), t.numpy(), method="asymptotic", variant=variant, **metric_args) for p, t in zip(preds.T, target.T) ] tau = torch.cat([torch.tensor(o[0]).unsqueeze(0) for o in out]) @@ -56,9 +62,7 @@ def _sk_metric(preds, target, alternative, variant): return tau, p_value return tau - tau, p_value = kendalltau( - preds.numpy(), target.numpy(), method="asymptotic", alternative=_alternative, variant=variant - ) + tau, p_value = kendalltau(preds.numpy(), target.numpy(), method="asymptotic", variant=variant, **metric_args) if alternative is not None: return torch.tensor(tau), torch.tensor(p_value) @@ -84,6 +88,7 @@ def test_kendall_rank_corrcoef(self, preds, target, alternative, variant, ddp, d num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 t_test = True if alternative is not None else False _sk_kendall_tau = partial(_sk_metric, alternative=alternative, variant=variant) + alternative = _adjust_alternative_to_scipy(alternative) self.run_class_metric_test( ddp, @@ -97,6 +102,7 @@ def test_kendall_rank_corrcoef(self, preds, target, alternative, variant, ddp, d def test_kendall_rank_corrcoef_functional(self, preds, target, alternative, variant): t_test = True if alternative is not None else False + alternative = _adjust_alternative_to_scipy(alternative) metric_args = {"t_test": t_test, "alternative": alternative, "variant": variant} _sk_kendall_tau = partial(_sk_metric, alternative=alternative, variant=variant) self.run_functional_metric_test(preds, target, kendall_rank_corrcoef, _sk_kendall_tau, metric_args=metric_args) @@ -109,3 +115,10 @@ def test_kendall_rank_corrcoef_differentiability(self, preds, target, alternativ metric_module=partial(KendallRankCorrCoef, num_outputs=num_outputs), metric_functional=kendall_rank_corrcoef, ) + + +def _adjust_alternative_to_scipy(alternative): + """Scipy<1.8.0 supports only two-sided hypothesis testing.""" + if alternative is not None and not _compare_version("scipy", operator.ge, "1.8.0"): + alternative = "two-sided" + return alternative From f6a8d72d4b937217289e7209ebdd7b177f110477 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 14:56:06 +0200 Subject: [PATCH 21/38] Fix some stuff in docs --- src/torchmetrics/functional/regression/kendall.py | 6 +++--- src/torchmetrics/regression/kendall.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 4632b79eb69..b93faa25c33 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -338,8 +338,8 @@ def kendall_rank_corrcoef( ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Computes `Kendall Rank Correlation Coefficient`_. - .. math: - tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})} + .. math:: + tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}} tau_c = 2 * \frac{C - D}{n ** 2 * \frac{m - 1}{m}} @@ -351,7 +351,7 @@ def kendall_rank_corrcoef( target: Ordered sequence of data variant: Indication of which variant of Kendall's tau to be used t_test: Indication whether to run t-test - alternative: Alternative hypothesis for for t-test. Possible values: + alternative: Alternative hypothesis for t-test. Possible values: - 'two-sided': the rank correlation is nonzero - 'less': the rank correlation is negative (less than zero) - 'greater': the rank correlation is positive (greater than zero) diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 46c4e0b2449..594146af636 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -30,8 +30,8 @@ class KendallRankCorrCoef(Metric): r"""Computes `Kendall Rank Correlation Coefficient`_: - .. math: - tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})} + .. math:: + tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}} tau_c = 2 * \frac{C - D}{n ** 2 * \frac{m - 1}{m}} @@ -46,7 +46,7 @@ class KendallRankCorrCoef(Metric): Args: variant: Indication of which variant of Kendall's tau to be used t_test: Indication whether to run t-test - alternative: Alternative hypothesis for for t-test. Possible values: + alternative: Alternative hypothesis for t-test. Possible values: - 'two-sided': the rank correlation is nonzero - 'less': the rank correlation is negative (less than zero) - 'greater': the rank correlation is positive (greater than zero) From 66f6f46478a053bf0046c1bed75b0389c855dcc2 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 15:23:51 +0200 Subject: [PATCH 22/38] Add missing device placement --- src/torchmetrics/functional/regression/kendall.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index b93faa25c33..624c7c599b2 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -186,8 +186,8 @@ def _calculate_tau( denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) # type: ignore (is Tensor) tau = con_min_dis_pairs / torch.sqrt(denominator) else: - preds_unique = torch.tensor([len(p.unique()) for p in preds.T]) - target_unique = torch.tensor([len(t.unique()) for t in target.T]) + preds_unique = torch.tensor([len(p.unique()) for p in preds.T], dtype=preds.dtype, device=preds.device) + target_unique = torch.tensor([len(t.unique()) for t in target.T], dtype=target.dtype, device=target.device) min_classes = torch.minimum(preds_unique, target_unique) tau = 2 * con_min_dis_pairs / ((min_classes - 1) / min_classes * n_total**2) From b50f3fb585e42f6bd2af9ce5356e3cc000f66cee Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 18:18:46 +0200 Subject: [PATCH 23/38] Fix some nits + add t-test examples to docs --- .../functional/regression/kendall.py | 79 +++++++++++++------ src/torchmetrics/regression/kendall.py | 30 +++++-- 2 files changed, 78 insertions(+), 31 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 624c7c599b2..228f5a23b32 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -78,28 +78,30 @@ def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: return x.T, y.T +def _concordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: + """Count a total number of concordant pairs in a single sequence.""" + return torch.logical_and(x[i] < x[(i + 1) :], y[i] < y[(i + 1) :]).sum(0).unsqueeze(0) + + def _count_concordant_pairs(preds: Tensor, target: Tensor) -> Tensor: """Count a total number of concordant pairs in given sequences.""" + return torch.cat([_concordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) - def _concordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: - return torch.logical_and(x[i] < x[(i + 1) :], y[i] < y[(i + 1) :]).sum(0).unsqueeze(0) - return torch.cat([_concordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) +def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: + """Count a total number of discordant pairs in a single sequences.""" + return ( + torch.logical_or( + torch.logical_and(x[i] > x[(i + 1) :], y[i] < y[(i + 1) :]), + torch.logical_and(x[i] < x[(i + 1) :], y[i] > y[(i + 1) :]), + ) + .sum(0) + .unsqueeze(0) + ) def _count_discordant_pairs(preds: Tensor, target: Tensor) -> Tensor: """Count a total number of discordant pairs in given sequences.""" - - def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor: - return ( - torch.logical_or( - torch.logical_and(x[i] > x[(i + 1) :], y[i] < y[(i + 1) :]), - torch.logical_and(x[i] < x[(i + 1) :], y[i] > y[(i + 1) :]), - ) - .sum(0) - .unsqueeze(0) - ) - return torch.cat([_discordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0) @@ -194,6 +196,21 @@ def _calculate_tau( return tau +def _get_p_value_for_t_value_from_dist(t_value: Tensor) -> Tensor: + """Obtain p-value for a given Tensor of t-values. Handle ``nan`` which cannot be passed into torch + distributions. + + When t-value is ``nan``, a resulted p-value should be alson ``nan``. + """ + device = t_value + normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device)) + + is_nan = t_value.isnan() + t_value = t_value.nan_to_num() + p_value = normal_dist.cdf(t_value) + return p_value.where(~is_nan, torch.tensor(float("nan"), dtype=p_value.dtype, device=p_value.device)) + + def _calculate_p_value( con_min_dis_pairs: Tensor, n_total: Tensor, @@ -207,8 +224,6 @@ def _calculate_p_value( alternative: Optional[_TestAlternative], ) -> Tensor: """Calculate p-value for Kendall's tau from metric metadata.""" - normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) - t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5) if variant == _MetricVariant.A: t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2) @@ -223,7 +238,7 @@ def _calculate_p_value( t_value = torch.abs(t_value) if alternative in [_TestAlternative.TWO_SIDED, _TestAlternative.GREATER]: t_value *= -1 - p_value = normal_dist.cdf(t_value) + p_value = _get_p_value_for_t_value_from_dist(t_value) if alternative == _TestAlternative.TWO_SIDED: p_value *= 2 return p_value @@ -239,8 +254,8 @@ def _kendall_corrcoef_update( """Update variables required to compute Kendall rank correlation coefficient. Args: - preds: Ordered sequence of data - target: Ordered sequence of data + preds: Sequence of data + target: Sequence of data concat_preds: List of batches of preds sequence to be concatenated concat_target: List of batches of target sequence to be concatenated num_outputs: Number of outputs in multioutput setting @@ -280,8 +295,8 @@ def _kendall_corrcoef_compute( Args: Args: - preds: Ordered sequence of data - target: Ordered sequence of data + preds: Sequence of data + target: Sequence of data variant: Indication of which variant of Kendall's tau to be used alternative: Alternative hypothesis for for t-test. Possible values: - 'two-sided': the rank correlation is nonzero @@ -347,8 +362,8 @@ def kendall_rank_corrcoef( a total number of ties. Definition according to `The Treatment of Ties in Ranking Problems`_. Args: - preds: Ordered sequence of data - target: Ordered sequence of data + preds: Sequence of data + target: Sequence of data variant: Indication of which variant of Kendall's tau to be used t_test: Indication whether to run t-test alternative: Alternative hypothesis for t-test. Possible values: @@ -366,17 +381,31 @@ def kendall_rank_corrcoef( Example (single output regression): >>> from torchmetrics.functional.regression import kendall_rank_corrcoef - >>> target = torch.tensor([3, -0.5, 2, 1]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) >>> kendall_rank_corrcoef(preds, target) tensor(0.3333) Example (multi output regression): >>> from torchmetrics.functional.regression import kendall_rank_corrcoef - >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> kendall_rank_corrcoef(preds, target) tensor([1., 1.]) + + Example (single output regression with t-test) + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) + >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided') + (tensor(0.3333), tensor(0.4969))) + + Example (multi output regression with t-test): + >>> from torchmetrics.functional.regression import kendall_rank_corrcoef + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) + >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided') + (tensor([1., 1.]), tensor([nan, nan])) """ if not isinstance(t_test, bool): raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.") diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 594146af636..9a89a7b8e9b 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -40,8 +40,8 @@ class KendallRankCorrCoef(Metric): Forward accepts - - ``preds``: Ordered sequence of data - - ``target``: Ordered sequence of data + - ``preds``: Sequence of data + - ``target``: Sequence of data Args: variant: Indication of which variant of Kendall's tau to be used @@ -60,8 +60,8 @@ class KendallRankCorrCoef(Metric): Example (single output regression): >>> import torch >>> from torchmetrics.regression import KendallRankCorrCoef - >>> target = torch.tensor([3, -0.5, 2, 1]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) >>> kendall = KendallRankCorrCoef() >>> kendall(preds, target) tensor(0.3333) @@ -69,11 +69,29 @@ class KendallRankCorrCoef(Metric): Example (multi output regression): >>> import torch >>> from torchmetrics.regression import KendallRankCorrCoef - >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> kendall = KendallRankCorrCoef(num_outputs=2) >>> kendall(preds, target) tensor([1., 1.]) + + Example (single output regression with t-test): + >>> import torch + >>> from torchmetrics.regression import KendallRankCorrCoef + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> target = torch.tensor([3, -0.5, 2, 1]) + >>> kendall = KendallRankCorrCoef() + >>> kendall(preds, target) + (tensor(0.3333), tensor(0.4969))) + + Example (multi output regression with t-test): + >>> import torch + >>> from torchmetrics.regression import KendallRankCorrCoef + >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) + >>> target = torch.tensor([[3, -0.5], [2, 1]]) + >>> kendall = KendallRankCorrCoef(num_outputs=2) + >>> kendall(preds, target) + (tensor([1., 1.]), tensor([nan, nan])) """ is_differentiable = False @@ -107,8 +125,8 @@ def update(self, preds: Tensor, target: Tensor) -> None: """Update variables required to compute Kendall rank correlation coefficient. Args: - preds: Ordered sequence of data - target: Ordered sequence of data + preds: Sequence of data + target: Sequence of data """ self.preds, self.target = _kendall_corrcoef_update( preds, From 9eda894f819a201cf0f170e23a09fd3710f8ba4b Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 18:26:48 +0200 Subject: [PATCH 24/38] Refactor: Move _check_data_shape_for_corr_coef to utils file --- .../functional/regression/kendall.py | 13 ++------- .../functional/regression/pearson.py | 12 ++------ .../functional/regression/spearman.py | 12 ++------ .../functional/regression/utils.py | 28 +++++++++++++++++++ 4 files changed, 35 insertions(+), 30 deletions(-) create mode 100644 src/torchmetrics/functional/regression/utils.py diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 228f5a23b32..2f8659416e7 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -18,6 +18,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.data import _bincount, dim_zero_cat from torchmetrics.utilities.enums import EnumStr @@ -265,16 +266,8 @@ def _kendall_corrcoef_update( """ # Data checking _check_same_shape(preds, target) - if preds.ndim > 2 or target.ndim > 2: - raise ValueError( - f"Expected both predictions and target to be either 1- or 2-dimensional tensors," - f" but got {target.ndim} and {preds.ndim}." - ) - if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[1]): - raise ValueError( - f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" - f" and {preds.shape[1]}." - ) + _check_data_shape_for_corr_coef(preds, target, num_outputs) + if num_outputs == 1: preds = preds.unsqueeze(1) target = target.unsqueeze(1) diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 5273bbb8388..3ac970c28d3 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -16,6 +16,7 @@ import torch from torch import Tensor +from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef from torchmetrics.utilities.checks import _check_same_shape @@ -44,16 +45,7 @@ def _pearson_corrcoef_update( """ # Data checking _check_same_shape(preds, target) - if preds.ndim > 2 or target.ndim > 2: - raise ValueError( - f"Expected both predictions and target to be either 1- or 2-dimensional tensors," - f" but got {target.ndim} and {preds.ndim}." - ) - if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): - raise ValueError( - f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" - f" and {preds.ndim}." - ) + _check_data_shape_for_corr_coef(preds, target, num_outputs) n_obs = preds.shape[0] mx_new = (n_prior * mean_x + preds.mean(0) * n_obs) / (n_prior + n_obs) diff --git a/src/torchmetrics/functional/regression/spearman.py b/src/torchmetrics/functional/regression/spearman.py index 2530d12039b..e4e07d39ade 100644 --- a/src/torchmetrics/functional/regression/spearman.py +++ b/src/torchmetrics/functional/regression/spearman.py @@ -16,6 +16,7 @@ import torch from torch import Tensor +from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef from torchmetrics.utilities.checks import _check_same_shape @@ -68,16 +69,7 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) - f" Got preds: {preds.dtype} and target: {target.dtype}." ) _check_same_shape(preds, target) - if preds.ndim > 2 or target.ndim > 2: - raise ValueError( - f"Expected both predictions and target to be either 1- or 2-dimensional tensors," - f" but got {target.ndim} and {preds.ndim}." - ) - if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[-1]): - raise ValueError( - f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" - f" and {preds.ndim}." - ) + _check_data_shape_for_corr_coef(preds, target, num_outputs) return preds, target diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py new file mode 100644 index 00000000000..82e542ae500 --- /dev/null +++ b/src/torchmetrics/functional/regression/utils.py @@ -0,0 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch import Tensor + + +def _check_data_shape_for_corr_coef(preds: Tensor, target: Tensor, num_outputs: int) -> Tensor: + """Check that predictions and target have the correct shape, else raise error.""" + if preds.ndim > 2 or target.ndim > 2: + raise ValueError( + f"Expected both predictions and target to be either 1- or 2-dimensional tensors," + f" but got {target.ndim} and {preds.ndim}." + ) + if (num_outputs == 1 and preds.ndim != 1) or (num_outputs > 1 and num_outputs != preds.shape[1]): + raise ValueError( + f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" + f" and {preds.shape[1]}." + ) From c5553d556c810cf554240dd2e69c41a2b337c58f Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 18:30:06 +0200 Subject: [PATCH 25/38] Fix a typo --- src/torchmetrics/functional/regression/kendall.py | 2 +- src/torchmetrics/regression/kendall.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 2f8659416e7..e509231dfe7 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -391,7 +391,7 @@ def kendall_rank_corrcoef( >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> target = torch.tensor([3, -0.5, 2, 1]) >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided') - (tensor(0.3333), tensor(0.4969))) + (tensor(0.3333), tensor(0.4969)) Example (multi output regression with t-test): >>> from torchmetrics.functional.regression import kendall_rank_corrcoef diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 9a89a7b8e9b..e8a6d9e8bba 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -82,7 +82,7 @@ class KendallRankCorrCoef(Metric): >>> target = torch.tensor([3, -0.5, 2, 1]) >>> kendall = KendallRankCorrCoef() >>> kendall(preds, target) - (tensor(0.3333), tensor(0.4969))) + (tensor(0.3333), tensor(0.4969)) Example (multi output regression with t-test): >>> import torch From 2470205e979408c125fc3dbbd40929d45c65dcbf Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 18:47:00 +0200 Subject: [PATCH 26/38] Fix docs example for class metric --- src/torchmetrics/regression/kendall.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index e8a6d9e8bba..e8e0d4c260c 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -80,7 +80,7 @@ class KendallRankCorrCoef(Metric): >>> from torchmetrics.regression import KendallRankCorrCoef >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> target = torch.tensor([3, -0.5, 2, 1]) - >>> kendall = KendallRankCorrCoef() + >>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided') >>> kendall(preds, target) (tensor(0.3333), tensor(0.4969)) @@ -89,7 +89,7 @@ class KendallRankCorrCoef(Metric): >>> from torchmetrics.regression import KendallRankCorrCoef >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> target = torch.tensor([[3, -0.5], [2, 1]]) - >>> kendall = KendallRankCorrCoef(num_outputs=2) + >>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided', num_outputs=2) >>> kendall(preds, target) (tensor([1., 1.]), tensor([nan, nan])) """ From bb81c433e595d77f827be31a7aae1b76f3a05a78 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 20:03:06 +0200 Subject: [PATCH 27/38] Try to fix docs --- .../functional/regression/kendall.py | 17 ++++++++++++++--- src/torchmetrics/regression/kendall.py | 17 ++++++++++++++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index e509231dfe7..16b05c05cb2 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -346,13 +346,24 @@ def kendall_rank_corrcoef( ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Computes `Kendall Rank Correlation Coefficient`_. + .. math:: + tau_a = \frac{C - D}{C + D} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs. + .. math:: tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}} - tau_c = 2 * \frac{C - D}{n ** 2 * \frac{m - 1}{m}} + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents + a total number of ties. + + .. math:: + tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a number + of observations and :math:`m` is a ``min`` of uniques values in ``preds`` and ``target`` sequence. - where :math:`C` is represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents - a total number of ties. Definition according to `The Treatment of Ties in Ranking Problems`_. + Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_. Args: preds: Sequence of data diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index e8e0d4c260c..402d41ca519 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -30,13 +30,24 @@ class KendallRankCorrCoef(Metric): r"""Computes `Kendall Rank Correlation Coefficient`_: + .. math:: + tau_a = \frac{C - D}{C + D} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs. + .. math:: tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}} - tau_c = 2 * \frac{C - D}{n ** 2 * \frac{m - 1}{m}} + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents + a total number of ties. + + .. math:: + tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}} + + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a number + of observations and :math:`m` is a ``min`` of uniques values in ``preds`` and ``target`` sequence. - where :math:`C` is represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents - a total number of ties. Definition according to `The Treatment of Ties in Ranking Problems`_. + Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_. Forward accepts From 1103bbfe35d69ad08e1ea1565d2a9d8ef9e93624 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 20:03:57 +0200 Subject: [PATCH 28/38] . --- src/torchmetrics/functional/regression/kendall.py | 4 ++-- src/torchmetrics/regression/kendall.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 16b05c05cb2..0828f839747 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -360,8 +360,8 @@ def kendall_rank_corrcoef( .. math:: tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}} - where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a number - of observations and :math:`m` is a ``min`` of uniques values in ``preds`` and ``target`` sequence. + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a total number + of observations and :math:`m` is a ``min`` of unique values in ``preds`` and ``target`` sequence. Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_. diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 402d41ca519..5f5cf880f11 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -44,8 +44,8 @@ class KendallRankCorrCoef(Metric): .. math:: tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}} - where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a number - of observations and :math:`m` is a ``min`` of uniques values in ``preds`` and ``target`` sequence. + where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a total number + of observations and :math:`m` is a ``min`` of unique values in ``preds`` and ``target`` sequence. Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_. From 87d2d9420110272dfe427598c2790aca19e38015 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 20:12:46 +0200 Subject: [PATCH 29/38] Use different link to check if works --- docs/source/links.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index b305ac49fd8..7e5f3a415af 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -94,4 +94,4 @@ .. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf .. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric .. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient -.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303 +.. _The Treatment of Ties in Ranking Problems: https://doi.org/10.1093/biomet/33.3.239 From 9b845a57b69afd1a832ab8deeed9f86d146a1ca8 Mon Sep 17 00:00:00 2001 From: stancld Date: Sun, 16 Oct 2022 20:25:54 +0200 Subject: [PATCH 30/38] sphinx: Ignore jstor link for linkcheck as cannot be accesed from python --- docs/source/conf.py | 4 ++++ docs/source/links.rst | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b0c002c025a..2d0968d395a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -385,3 +385,7 @@ def find_source(): """ coverage_skip_undoc_in_source = True + +linkcheck_ignore = [ + "https://www.jstor.org/stable/2332303" # jstor cannot be accessed from python, but link work fine in a local doc +] diff --git a/docs/source/links.rst b/docs/source/links.rst index 7e5f3a415af..b305ac49fd8 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -94,4 +94,4 @@ .. _Rényi divergence: https://static.renyi.hu/renyi_cikkek/1961_on_measures_of_entropy_and_information.pdf .. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric .. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient -.. _The Treatment of Ties in Ranking Problems: https://doi.org/10.1093/biomet/33.3.239 +.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303 From 64d427f794265841f95f134d91f34039d2a85e57 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 24 Oct 2022 12:34:53 +0200 Subject: [PATCH 31/38] Apply suggestions from code review Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/functional/regression/kendall.py | 4 ++-- src/torchmetrics/regression/kendall.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 0828f839747..b66f3143975 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -366,8 +366,8 @@ def kendall_rank_corrcoef( Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_. Args: - preds: Sequence of data - target: Sequence of data + preds: Sequence of data of either shape ``(N,)`` or ``(N,d)`` + target: Sequence of data of either shape ``(N,)`` or ``(N,d)`` variant: Indication of which variant of Kendall's tau to be used t_test: Indication whether to run t-test alternative: Alternative hypothesis for t-test. Possible values: diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 5f5cf880f11..819544a1560 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -51,8 +51,8 @@ class KendallRankCorrCoef(Metric): Forward accepts - - ``preds``: Sequence of data - - ``target``: Sequence of data + - ``preds`` (float tensor): Sequence of data of either shape ``(N,)`` or ``(N,d)`` + - ``target`` (float tensor): Sequence of data of either shape ``(N,)`` or ``(N,d)`` Args: variant: Indication of which variant of Kendall's tau to be used @@ -136,8 +136,8 @@ def update(self, preds: Tensor, target: Tensor) -> None: """Update variables required to compute Kendall rank correlation coefficient. Args: - preds: Sequence of data - target: Sequence of data + preds: Sequence of data of either shape ``(N,)`` or ``(N,d)`` + target: Sequence of data of either shape ``(N,)`` or ``(N,d)`` """ self.preds, self.target = _kendall_corrcoef_update( preds, From 330bf4a88b5b845d3e72f48c9082fdb5e9a6905c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 27 Oct 2022 14:10:39 +0200 Subject: [PATCH 32/38] try something --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 896b5723eb7..4fcff9a66df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ module = [ "torchmetrics.functional.image.uqi", "torchmetrics.functional.regression.cosine_similarity", "torchmetrics.functional.regression.explained_variance", + "torchmetrics.functional.regression.kendall.py", "torchmetrics.functional.regression.kl_divergence", "torchmetrics.functional.regression.r2", "torchmetrics.functional.regression.wmape", @@ -100,6 +101,7 @@ module = [ "torchmetrics.image.tv", "torchmetrics.image.uqi", "torchmetrics.metric", + "torchmetrics.regression.kendall.py", "torchmetrics.regression.kl_divergence", "torchmetrics.regression.log_mse", "torchmetrics.regression.mae", From 3ffe505f1c53b874bb958a808d86d3d8fa07ef82 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 27 Oct 2022 14:15:40 +0200 Subject: [PATCH 33/38] try again --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4fcff9a66df..5cefee90e74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ module = [ "torchmetrics.functional.regression.kendall.py", "torchmetrics.functional.regression.kl_divergence", "torchmetrics.functional.regression.r2", + "torchmetrics.functional.regression.utils.py", "torchmetrics.functional.regression.wmape", "torchmetrics.functional.retrieval.r_precision", "torchmetrics.functional.text.squad", From 558dc10e854105e3e5344222eee95a9dba5f12de Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 31 Oct 2022 13:24:06 +0100 Subject: [PATCH 34/38] remove mypy ignore --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5cefee90e74..896b5723eb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,10 +87,8 @@ module = [ "torchmetrics.functional.image.uqi", "torchmetrics.functional.regression.cosine_similarity", "torchmetrics.functional.regression.explained_variance", - "torchmetrics.functional.regression.kendall.py", "torchmetrics.functional.regression.kl_divergence", "torchmetrics.functional.regression.r2", - "torchmetrics.functional.regression.utils.py", "torchmetrics.functional.regression.wmape", "torchmetrics.functional.retrieval.r_precision", "torchmetrics.functional.text.squad", @@ -102,7 +100,6 @@ module = [ "torchmetrics.image.tv", "torchmetrics.image.uqi", "torchmetrics.metric", - "torchmetrics.regression.kendall.py", "torchmetrics.regression.kl_divergence", "torchmetrics.regression.log_mse", "torchmetrics.regression.mae", From fefe50529c6e75169f3b69b84b2c7900b1495f70 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 31 Oct 2022 13:35:30 +0100 Subject: [PATCH 35/38] remove comments --- src/torchmetrics/functional/regression/kendall.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index b66f3143975..6359f105d11 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -186,7 +186,7 @@ def _calculate_tau( tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) elif variant == _MetricVariant.B: total_combinations: Tensor = n_total * (n_total - 1) // 2 - denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) # type: ignore (is Tensor) + denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) # type: ignore tau = con_min_dis_pairs / torch.sqrt(denominator) else: preds_unique = torch.tensor([len(p.unique()) for p in preds.T], dtype=preds.dtype, device=preds.device) @@ -231,8 +231,8 @@ def _calculate_p_value( else: m = n_total * (n_total - 1) t_value_denominator: Tensor = (t_value_denominator_base - preds_ties_p2 - target_ties_p2) / 18 - t_value_denominator += (2 * preds_ties * target_ties) / m # type: ignore (is Tensor) - t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) # type: ignore (is Tensor) + t_value_denominator += (2 * preds_ties * target_ties) / m # type: ignore + t_value_denominator += preds_ties_p1 * target_ties_p1 / (9 * m * (n_total - 2)) # type: ignore t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator) if alternative == _TestAlternative.TWO_SIDED: From 1a710a21c2c498d89bcebe34f591a715f198c5cb Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 31 Oct 2022 13:43:43 +0100 Subject: [PATCH 36/38] fix some typing --- src/torchmetrics/functional/regression/kendall.py | 2 +- src/torchmetrics/functional/regression/utils.py | 2 +- src/torchmetrics/regression/kendall.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 6359f105d11..6b45406c293 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -186,7 +186,7 @@ def _calculate_tau( tau = con_min_dis_pairs / (concordant_pairs + discordant_pairs) elif variant == _MetricVariant.B: total_combinations: Tensor = n_total * (n_total - 1) // 2 - denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) # type: ignore + denominator = (total_combinations - preds_ties) * (total_combinations - target_ties) tau = con_min_dis_pairs / torch.sqrt(denominator) else: preds_unique = torch.tensor([len(p.unique()) for p in preds.T], dtype=preds.dtype, device=preds.device) diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py index 82e542ae500..b9c80987ddc 100644 --- a/src/torchmetrics/functional/regression/utils.py +++ b/src/torchmetrics/functional/regression/utils.py @@ -14,7 +14,7 @@ from torch import Tensor -def _check_data_shape_for_corr_coef(preds: Tensor, target: Tensor, num_outputs: int) -> Tensor: +def _check_data_shape_for_corr_coef(preds: Tensor, target: Tensor, num_outputs: int) -> None: """Check that predictions and target have the correct shape, else raise error.""" if preds.ndim > 2 or target.ndim > 2: raise ValueError( diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 819544a1560..9ff4070a23f 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -132,7 +132,7 @@ def __init__( self.add_state("preds", [], dist_reduce_fx="cat") self.add_state("target", [], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update variables required to compute Kendall rank correlation coefficient. Args: From 04e420c089fc0446aef94cf059f7af610c1f9797 Mon Sep 17 00:00:00 2001 From: stancld Date: Mon, 31 Oct 2022 19:59:39 +0100 Subject: [PATCH 37/38] Fix mypy issues --- src/torchmetrics/functional/regression/kendall.py | 5 ++--- src/torchmetrics/regression/kendall.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 6b45406c293..d0cd8d73f52 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import List, Optional, Tuple, Union import torch @@ -32,7 +31,7 @@ class _MetricVariant(EnumStr): C = "c" @classmethod - def from_str(cls, value: str) -> Optional["EnumStr"]: + def from_str(cls, value: Literal["a", "b", "c"]) -> "_MetricVariant": """ Raises: ValueError: @@ -54,7 +53,7 @@ class _TestAlternative(EnumStr): GREATER = "greater" @classmethod - def from_str(cls, value: str) -> Optional["EnumStr"]: + def from_str(cls, value: Literal["two-sided", "less", "greater"]) -> "_TestAlternative": """ Raises: ValueError: diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 9ff4070a23f..819544a1560 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -132,7 +132,7 @@ def __init__( self.add_state("preds", [], dist_reduce_fx="cat") self.add_state("target", [], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + def update(self, preds: Tensor, target: Tensor) -> None: """Update variables required to compute Kendall rank correlation coefficient. Args: From 2fae0f9b40b96a5c522e02b973c2c8d5345ec5b0 Mon Sep 17 00:00:00 2001 From: stancld Date: Mon, 31 Oct 2022 20:12:04 +0100 Subject: [PATCH 38/38] Nudge mypy a bit --- src/torchmetrics/functional/regression/kendall.py | 10 +++++----- src/torchmetrics/regression/kendall.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index d0cd8d73f52..698cb54ec73 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -31,7 +31,7 @@ class _MetricVariant(EnumStr): C = "c" @classmethod - def from_str(cls, value: Literal["a", "b", "c"]) -> "_MetricVariant": + def from_str(cls, value: Literal["a", "b", "c"]) -> "_MetricVariant": # type: ignore[override] """ Raises: ValueError: @@ -41,7 +41,7 @@ def from_str(cls, value: Literal["a", "b", "c"]) -> "_MetricVariant": enum_key = super().from_str(value) if enum_key is not None and enum_key in _allowed_variants: - return enum_key + return enum_key # type: ignore[return-value] # use override raise ValueError(f"Invalid metric variant. Expected one of {_allowed_variants}, but got {enum_key}.") @@ -53,7 +53,7 @@ class _TestAlternative(EnumStr): GREATER = "greater" @classmethod - def from_str(cls, value: Literal["two-sided", "less", "greater"]) -> "_TestAlternative": + def from_str(cls, value: Literal["two-sided", "less", "greater"]) -> "_TestAlternative": # type: ignore[override] """ Raises: ValueError: @@ -63,7 +63,7 @@ def from_str(cls, value: Literal["two-sided", "less", "greater"]) -> "_TestAlter enum_key = super().from_str(value.replace("-", "_")) if enum_key is not None and enum_key in _allowed_alternatives: - return enum_key + return enum_key # type: ignore[return-value] # use override raise ValueError(f"Invalid test alternative. Expected one of {_allowed_alternatives}, but got {enum_key}.") @@ -416,7 +416,7 @@ def kendall_rank_corrcoef( raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") _variant = _MetricVariant.from_str(variant) - _alternative = _TestAlternative.from_str(alternative) if t_test else None + _alternative = _TestAlternative.from_str(alternative) if t_test and alternative else None _preds, _target = _kendall_corrcoef_update( preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1] diff --git a/src/torchmetrics/regression/kendall.py b/src/torchmetrics/regression/kendall.py index 819544a1560..43abfbbac83 100644 --- a/src/torchmetrics/regression/kendall.py +++ b/src/torchmetrics/regression/kendall.py @@ -126,7 +126,7 @@ def __init__( raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.") self.variant = _MetricVariant.from_str(variant) - self.alternative = _TestAlternative.from_str(alternative) if t_test else None + self.alternative = _TestAlternative.from_str(alternative) if t_test and alternative else None self.num_outputs = num_outputs self.add_state("preds", [], dist_reduce_fx="cat")