Skip to content

Commit

Permalink
ruff: fix couple of issues in src (#1586)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Mar 4, 2023
1 parent 7c885d0 commit 7f01332
Show file tree
Hide file tree
Showing 10 changed files with 14 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, groups: torch.Tensor
def compute(
self,
) -> Dict[str, torch.Tensor]:
"""Computes tp, fp, tn and fn rates based on inputs passed in to ``update`` previously."""
"""Compute tp, fp, tn and fn rates based on inputs passed in to ``update`` previously."""
results = torch.stack((self.tp, self.fp, self.tn, self.fn), dim=1)

return {f"group_{i}": group / group.sum() for i, group in enumerate(results)}
Expand Down Expand Up @@ -247,7 +247,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, groups: Optional[tor
def compute(
self,
) -> Dict[str, torch.Tensor]:
"""Computes fairness criteria based on inputs passed in to ``update`` previously."""
"""Compute fairness criteria based on inputs passed in to ``update`` previously."""
if self.task == "demographic_parity":
return _compute_binary_demographic_parity(self.tp, self.fp, self.tn, self.fn)

Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _binary_groups_stat_scores(
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Computes the true/false positives and true/false negatives rates for binary classification by group.
"""Compute the true/false positives and true/false negatives rates for binary classification by group.
Related to `Type I and Type II errors`_.
"""
Expand Down Expand Up @@ -113,7 +113,7 @@ def binary_groups_stat_rates(
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Dict[str, torch.Tensor]:
r"""Computes the true/false positives and true/false negatives rates for binary classification by group.
r"""Compute the true/false positives and true/false negatives rates for binary classification by group.
Related to `Type I and Type II errors`_.
Expand Down Expand Up @@ -331,7 +331,7 @@ def binary_fairness(
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Dict[str, torch.Tensor]:
r"""This function is a simple wrapper to get the task specific versions of the metric.
r"""Compute either `Demographic parity`_ and `Equal opportunity`_ ratio for binary classification problems.
This is done by setting the ``task`` argument to either ``'demographic_parity'``, ``'equal_opportunity'``
or ``all``. See the documentation of :func:`_compute_binary_demographic_parity`
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _reflection_pad_2d(inputs: Tensor, pad: int, outer_pad: int = 0) -> Tensor:


def _uniform_filter(inputs: Tensor, window_size: int) -> Tensor:
"""Applies uniform filtew with a window of a given size over the input image.
"""Apply uniform filter with a window of a given size over the input image.
Args:
inputs: Input image
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/image/rase.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def _rase_update(
preds: Tensor, target: Tensor, window_size: int, rmse_map: Tensor, target_sum: Tensor, total_images: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
"""Calculates the sum of RMSE map values for the batch of examples and update intermediate states.
"""Calculate the sum of RMSE map values for the batch of examples and update intermediate states.
Args:
preds: Deformed image
Expand Down Expand Up @@ -67,7 +67,7 @@ def _rase_compute(rmse_map: Tensor, target_sum: Tensor, total_images: Tensor, wi


def relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: int = 8) -> Tensor:
"""Computes Relative Average Spectral Error (RASE) (RelativeAverageSpectralError_).
"""Compute Relative Average Spectral Error (RASE) (RelativeAverageSpectralError_).
Args:
preds: Deformed image
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/image/rmse_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _rmse_sw_update(
rmse_map: Optional[Tensor],
total_images: Optional[Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
"""Calculates the sum of RMSE values and RMSE map for the batch of examples and update intermediate states.
"""Calculate the sum of RMSE values and RMSE map for the batch of examples and update intermediate states.
Args:
preds: Deformed image
Expand Down Expand Up @@ -89,7 +89,7 @@ def _rmse_sw_update(
def _rmse_sw_compute(
rmse_val_sum: Optional[Tensor], rmse_map: Tensor, total_images: Tensor
) -> Tuple[Optional[Tensor], Tensor]:
"""Computes RMSE from the aggregated RMSE value. Optionally also computes the mean value for RMSE map.
"""Compute RMSE from the aggregated RMSE value. Optionally also computes the mean value for RMSE map.
Args:
rmse_val_sum: Sum of RMSE over all examples
Expand All @@ -109,7 +109,7 @@ def _rmse_sw_compute(
def root_mean_squared_error_using_sliding_window(
preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False
) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]:
"""Computes Root Mean Squared Error (RMSE) using sliding window.
"""Compute Root Mean Squared Error (RMSE) using sliding window.
Args:
preds: Deformed image
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/rase.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
self.target.append(target)

def compute(self) -> Tensor:
"""Computes Relative Average Spectral Error (RASE)."""
"""Compute Relative Average Spectral Error (RASE)."""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return relative_average_spectral_error(preds, target, self.window_size)
2 changes: 1 addition & 1 deletion src/torchmetrics/image/rmse_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
)

def compute(self) -> Optional[Tensor]:
"""Computes Root Mean Squared Error (using sliding window) and potentially return RMSE map."""
"""Compute Root Mean Squared Error (using sliding window) and potentially return RMSE map."""
assert self.rmse_map is not None
rmse, _ = _rmse_sw_compute(self.rmse_val_sum, self.rmse_map, self.total_images)
return rmse
2 changes: 1 addition & 1 deletion src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

@contextmanager
def style_change(*args: Any, **kwargs: Any) -> Generator:
"""Default no-ops decorator if matplotlib is not installed."""
"""No-ops decorator if matplotlib is not installed."""
yield


Expand Down
2 changes: 0 additions & 2 deletions tests/unittests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ class TestMeanError(MetricTester):
def test_mean_error_class(
self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args, ddp
):
# todo: `metric_functional` is unused
self.run_class_metric_test(
ddp=ddp,
preds=preds,
Expand All @@ -186,7 +185,6 @@ def test_mean_error_class(
def test_mean_error_functional(
self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args
):
# todo: `metric_class` is unused
self.run_functional_metric_test(
preds=preds,
target=target,
Expand Down
1 change: 0 additions & 1 deletion tests/unittests/regression/test_r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_r2(self, adjusted, multioutput, preds, target, ref_metric, num_outputs,
)

def test_r2_functional(self, adjusted, multioutput, preds, target, ref_metric, num_outputs):
# todo: `num_outputs` is unused
self.run_functional_metric_test(
preds,
target,
Expand Down

0 comments on commit 7f01332

Please sign in to comment.