From 2ce0efb9ac48ab53a017932bd9bb815497aac985 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 31 Jan 2023 19:48:05 +0100 Subject: [PATCH] Fixed `bincount` on XLA (#1471) --- CHANGELOG.md | 3 +++ src/torchmetrics/utilities/data.py | 4 ++-- src/torchmetrics/utilities/imports.py | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6b4834b741..f1951c1b9a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471)) + + - diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index b5b72f6a9e1..214ee5a954b 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -16,7 +16,7 @@ import torch from torch import Tensor -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE METRIC_EPS = 1e-6 @@ -220,7 +220,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """ if minlength is None: minlength = len(torch.unique(x)) - if torch.are_deterministic_algorithms_enabled() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: + if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 3f62f053d4d..cee6e058821 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -42,3 +42,4 @@ _FAST_BSS_EVAL_AVAILABLE: bool = package_available("fast_bss_eval") _MATPLOTLIB_AVAILABLE: bool = package_available("matplotlib") _MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing") +_XLA_AVAILABLE: bool = package_available("torch_xla")