Skip to content

Commit

Permalink
Fixed bincount on XLA (#1471)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Jan 31, 2023
1 parent 012c6a4 commit 2ce0efb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -40,6 +40,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471))


-


Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/data.py
Expand Up @@ -16,7 +16,7 @@
import torch
from torch import Tensor

from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE

METRIC_EPS = 1e-6

Expand Down Expand Up @@ -220,7 +220,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""
if minlength is None:
minlength = len(torch.unique(x))
if torch.are_deterministic_algorithms_enabled() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
for i in range(minlength):
output[i] = (x == i).sum()
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Expand Up @@ -42,3 +42,4 @@
_FAST_BSS_EVAL_AVAILABLE: bool = package_available("fast_bss_eval")
_MATPLOTLIB_AVAILABLE: bool = package_available("matplotlib")
_MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing")
_XLA_AVAILABLE: bool = package_available("torch_xla")

0 comments on commit 2ce0efb

Please sign in to comment.