From 90d9bb32a34268841906b636b65bf438b7fc4446 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 31 Jan 2023 11:47:17 +0100 Subject: [PATCH 1/3] try --- src/torchmetrics/utilities/data.py | 4 ++-- src/torchmetrics/utilities/imports.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index b5b72f6a9e1..214ee5a954b 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -16,7 +16,7 @@ import torch from torch import Tensor -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _XLA_AVAILABLE METRIC_EPS = 1e-6 @@ -220,7 +220,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """ if minlength is None: minlength = len(torch.unique(x)) - if torch.are_deterministic_algorithms_enabled() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: + if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 3f62f053d4d..cee6e058821 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -42,3 +42,4 @@ _FAST_BSS_EVAL_AVAILABLE: bool = package_available("fast_bss_eval") _MATPLOTLIB_AVAILABLE: bool = package_available("matplotlib") _MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing") +_XLA_AVAILABLE: bool = package_available("torch_xla") From ca66422dcbd2cf311d0dc028a6c7aa4ef27f1b31 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 31 Jan 2023 14:38:43 +0100 Subject: [PATCH 2/3] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b9cdc0b7679..5a479ec8180 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix dtype checking in `PrecisionRecallCurve` for `target` tensor ([#1457](https://github.com/Lightning-AI/metrics/pull/1457)) +- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471)) + + ## [0.11.0] - 2022-11-30 ### Added From 0d897657ba084ea423b5830791c17fd0a6d8900a Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 31 Jan 2023 14:48:47 +0100 Subject: [PATCH 3/3] chlog --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf279a38002..f1951c1b9a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471)) + + - @@ -54,9 +57,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed dtype checking in `PrecisionRecallCurve` for `target` tensor ([#1457](https://github.com/Lightning-AI/metrics/pull/1457)) -- Fix compatibility between XLA in `_bincount` function ([#1471](https://github.com/Lightning-AI/metrics/pull/1471)) - - ## [0.11.0] - 2022-11-30 ### Added