Skip to content

Commit

Permalink
Fixed bincount on XLA (#1471)
Browse files Browse the repository at this point in the history
(cherry picked from commit 2ce0efb)
  • Loading branch information
SkafteNicki authored and Borda committed Feb 20, 2023
1 parent 2f8ec7b commit 416b079
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471))


-


Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor

from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE

METRIC_EPS = 1e-6

Expand Down Expand Up @@ -220,7 +220,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""
if minlength is None:
minlength = len(torch.unique(x))
if torch.are_deterministic_algorithms_enabled() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
for i in range(minlength):
output[i] = (x == i).sum()
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool]
_PYSTOI_AVAILABLE: bool = _package_available("pystoi")
_FAST_BSS_EVAL_AVAILABLE: bool = _package_available("fast_bss_eval")
_MULTIPROCESSING_AVAILABLE: bool = _package_available("multiprocessing")
_XLA_AVAILABLE: bool = _package_available("torch_xla")

0 comments on commit 416b079

Please sign in to comment.