Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NPU backend support for val and inference #2109

Merged
merged 2 commits into from
Oct 19, 2024

Conversation

MengqingCao
Copy link
Contributor

I am a user of NPU. When I used TIMM recently, I found that it does not support NPU natively. It's pleasure to see that someone has made some contributions on leveraging NPU to TIMM #2102. But it currently only offers the feature of using NPU during training. This PR extends NPU support to the validate and inference entries, thus addressing this limitation.

Specify the device as "npu", then you can use NPU as accelerator during inferencing and validating.

It is tested on:

  • model: tiny_vit_21m_512
  • dataset: the val subset of ImageNet-1K

Validate Scripts

python validate.py ../open_clip/data/ImageNet-1000/val/ --device npu --model ./model_ckpts/tiny_vit_21m_512 --batch-size 64 --pretrained

ScreenShot

It shows the validation results on val subset of ImageNet-1K are as following:

top-1 acc top-5 acc
86.040% 97.750%

image-20240314164130874

Inference Scripts

python inference.py ./data/ --device npu --batch-size 64 --model ./model_ckpts/tiny_vit_21m_512 --label-type detail --topk 5

ScreenShot

image-20240314171146196

results

Here offers some results of predicting the top-5 classification results by inferencing on tiny_vit_21m_512. Everything goes well on npu.

image-20240314170101799

@MengqingCao MengqingCao changed the title add npu support for val and inference Add NPU backend support for val and inference Mar 14, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@MengqingCao
Copy link
Contributor Author

cc @rwightman

@rwightman
Copy link
Collaborator

@MengqingCao see #2138 ... I need a better design to centralize device specific acccelerator module loading, etc instead of spreading it out across many files, it's not a sustainable approach.

Also, another challenge here is I don't have easy access to many potential accelerators so definitely need help testing as I can't realistically run my normal CI or tests with them as I do across my github and local CI right now...

@MengqingCao
Copy link
Contributor Author

@MengqingCao see #2138 ... I need a better design to centralize device specific acccelerator module loading, etc instead of spreading it out across many files, it's not a sustainable approach.

Also, another challenge here is I don't have easy access to many potential accelerators so definitely need help testing as I can't realistically run my normal CI or tests with them as I do across my github and local CI right now...

Good day! @rwightman, thanks for your reply.

for your first concern, I agree that importing the device specific modules in many files is not a smart way to enable the devices. I was inspired by the way of centralizing device related modules loading in train.py#L415 that we could do a autoloading when the whole lib is initing. Because the way in train.py#L415 also needs to do redundant processing in many files.

My initial idea was to load the device accelerator module via a specific environment variable (e.g. TIMM_DEVICE_EXT). This variable is set in timm/init.py by reading the configuration infos in a specific file (e.g., a json file), and then the module is preloaded according to this variable, so that device-related modules import can be activated within the entire TIMM library, instead of having to import them separately everywhere. But I think the device-specific hardcoding has to be modified.

For your second concern, making a mechine with Ascend NPU available to community is on my to-do list, so that we could ensure that the correctness of the code could be verified and maintained.

Let me know if you have any ideas or confusion!

@MengqingCao
Copy link
Contributor Author

Hi, @rwightman. I have just committed the code implementation of the above solution, please review it, thx!

@MengqingCao
Copy link
Contributor Author

Hi, @rwightman I'm sorry for bothering you. Could you help reviewing the latest code in this PR? Thanks in adavance!

@rwightman
Copy link
Collaborator

@MengqingCao I don't really have any way to test this so don't want to have support for other hardware like this touching as many files. Same thing for Intel and other hardware that requires extra imports, etc. PyTorch 2.4 should have a mechanism for auto-importing device dependencies so I'll probably wait for that ....

@MengqingCao
Copy link
Contributor Author

@MengqingCao I don't really have any way to test this so don't want to have support for other hardware like this touching as many files. Same thing for Intel and other hardware that requires extra imports, etc. PyTorch 2.4 should have a mechanism for auto-importing device dependencies so I'll probably wait for that ....

Thanks a lot for your reply! I‘m applying a NPU machine for CI, thus you can attach NPU for testing. The latest code also avoid touching too many files. However, as far as I know, the auto-importing maybe postponed to PyTorch 2.5
So if you don't mind being a little late, maybe we could wait for PyTorch supportting auto-importing device dependencies

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
@MengqingCao
Copy link
Contributor Author

MengqingCao commented Oct 16, 2024

@rwightman Good day! I'm happy to tell you that PyTorch has supported autoloading device-related dependencies through pytorch/pytorch#127074. This feature will be included in torch 2.5.0. The latest commit is tested on torch 2.5 dev version, and everything goes well on Ascend NPU.

Plz review the code and if these changes are acceptable, maybe we could merge it as soon as PyTorch 2.5 is released?

@rwightman
Copy link
Collaborator

@MengqingCao thanks, this is looking better, two issues flagged above, will merge once 2.25 is out

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
@MengqingCao
Copy link
Contributor Author

@MengqingCao thanks, this is looking better, two issues flagged above, will merge once 2.25 is out

Thanks!

@rwightman
Copy link
Collaborator

@MengqingCao I tried these additions on pytorch 2.5 and 2.4 to ensure nothing broke in normal use. Seems fine.

I noticed there were some other possible errors where different devices might not be supported so I did a bit of cleanup on #2308 ... I think that would be needed for the grad scaler & amp to work fully with NPU?

I don't have an NPU, could you confirm that your changes here + my new ones on that branch work well?

@rwightman
Copy link
Collaborator

@MengqingCao I merged the contents of this branch into the device_amp_cleanup mentioned in comment above. It'd be great if you could try the combination before I merge.

@MengqingCao
Copy link
Contributor Author

@rwightman The cleanup you did is necessary for NPU and makes the code cleaner. I have tested #2308 on my NPU device and everything works fine. Thanks!

@rwightman rwightman merged commit 81b59fa into huggingface:main Oct 19, 2024
22 checks passed
@rwightman
Copy link
Collaborator

@MengqingCao all merged, I'll tweet about the torch autoload support and this addition for NPU in a day or so, and will look at OpenCLIP merge and test tomorrow or Monday. Feel free to let other Ascend users know this should work now.

@MengqingCao
Copy link
Contributor Author

Thanks! I'm excited to announce this good news to more TIMM & Ascend users :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants