Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA support #1466

Closed
mfatih7 opened this issue Jan 28, 2023 · 13 comments · Fixed by #1471
Closed

XLA support #1466

mfatih7 opened this issue Jan 28, 2023 · 13 comments · Fixed by #1471
Labels
enhancement New feature or request

Comments

@mfatih7
Copy link

mfatih7 commented Jan 28, 2023

Hello

Up to now, I was using torchmetrics in my training scripts running on GPUs.
Now I want to use Google Tensor Processing Units in my work.
For the last few days, I am observing that torchmetrics is not compatible with XLA library.
torchmetrics needs to be lowered for TPU support.

best regards

@mfatih7 mfatih7 added the enhancement New feature or request label Jan 28, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@justusschock
Copy link
Member

Hi @mfatih7 ,
Thanks for the issue. Would you be interested in doing so or at least give us a hint on what's not working on TPU?

@mfatih7
Copy link
Author

mfatih7 commented Jan 29, 2023

Hi @justusschock

I can give information about the situation in the TPU.
I have a PyTorch project in which I use different kinds of models for classification.
In the program flow, I import your module as

from torchmetrics import ConfusionMatrix

I instantiate an instance using

confmat = ConfusionMatrix(task="binary", num_classes=2).to(device)

When the device is GPU it works without any problem.
When the device is TPU confmat causes XLA compilations

Actually, it is a normal situation.
PyTorch XLA team welcomes lowering requests for torch functions.
The functions causing compilations are _unique2 and bincount.
You can check my issue in XLA issue.

Maybe you need to implement your module without these functions.

I can give more information with pleasure.

@mfatih7
Copy link
Author

mfatih7 commented Jan 29, 2023

To prevent XLA compilations temporarily, I am using the simple function below.

def get_confusion_matrix_for_xla(outputs, labels):
    
    confusion_matrix = torch.zeros( 2, 2, dtype=torch.int64, device = params.DEVICE)
    
    confusion_matrix[0,0] = torch.sum( torch.logical_and( labels < 0.5, outputs < 0.5) )
    confusion_matrix[0,1] = torch.sum( torch.logical_and( labels < 0.5, outputs >= 0.5) )
    confusion_matrix[1,0] = torch.sum( torch.logical_and( labels >= 0.5, outputs < 0.5) )
    confusion_matrix[1,1] = torch.sum( torch.logical_and( labels >= 0.5, outputs >= 0.5) )
    
    return confusion_matrix

@SkafteNicki
Copy link
Member

Hi @mfatih7,
Is it possible to reproduce this behaviour in a colab notebook?
While we cannot rewrite our hole framework to support XLA, we could probably implement fallback solutions on specific devices. For example _bincount is a function that is not support by MPS accelerator yet, and therefore we have a fallback solution for that:
https://github.com/Lightning-AI/metrics/blob/5d4ffe01aa09b7108f7e0e4034748bdfd64bf5f9/src/torchmetrics/utilities/data.py#L206-L228

@mfatih7
Copy link
Author

mfatih7 commented Jan 30, 2023

hi @SkafteNicki

I could not get the warnings with code purely written in the COLAB notebook.
But here is a .py file and a COLAB notebook.
You can easily see the warnings.

I can give more support if needed.
More debugging options are available.

Do not forget to select TPU from COLAB settings

@SkafteNicki
Copy link
Member

Hi @mfatih7,
Could you try re-running it with the changes from this branch:
https://github.com/Lightning-AI/metrics/tree/xla_test
and additionally also change the initialization of the metric to be:

confmat = ConfusionMatrix(task="binary", num_classes=2, validate_args=False).to(device)

?

@mfatih7
Copy link
Author

mfatih7 commented Jan 31, 2023

OK but

How can I download this version in my COLAB notebook?
I was using !pip install torchmetrics at the top of my notebooks.

@justusschock
Copy link
Member

Change !pip install torchmetrics to !pip install git+https://github.com/Lightning-AI/metrics@xla_test to install from this branch.

@mfatih7
Copy link
Author

mfatih7 commented Jan 31, 2023

OK

I don't see any recompilations due to torchmetrics now.
Aside from the parameter change in the instantiation, do you also change some parts of the source code?

Will you commit to the main torchmetrics branch?

Do you consider making your library accessible without installation on COLAB?

@SkafteNicki
Copy link
Member

@mfatih7 so I can explain what I did:

  1. By setting validate_args=False you are going to skip an internal check that the input is the right format. The check uses torch.unique which XLA does not suppport.
  2. Secondly, in the branch you used I implemented some logic for bincount such that if XLA is detected then we use a simple for-loop which works for XLA but you should note can be significantly slower if you have a large number of classes.

I think we can include the change, but we are not going to officially support XLA for now

@SkafteNicki SkafteNicki mentioned this issue Jan 31, 2023
4 tasks
@mfatih7
Copy link
Author

mfatih7 commented Jan 31, 2023

Thank you

We can close this issue if you want.

I hope I can hear any updates in the future.

@SkafteNicki
Copy link
Member

Hi @mfatih7,
You are welcome, it will be close when PR #1471 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants