Skip to content

[CD] Fix slim-wheel cuda_nvrtc import problem #145614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025

Conversation

pytorchbot
Copy link
Collaborator

Similar fix as: #144816

Fixes: #145580

Found during testing of #138340

Please note both nvrtc and nvjitlink exist for cuda 11.8, 12.4 and 12.6 hence we can safely remove if statement. Preloading can apply to all supporting cuda versions.

CUDA 11.8 path:

(.venv) root@b4ffe5c8ac8c:/pytorch/.ci/pytorch/smoke_test# ls /.venv/lib/python3.12/site-packages/torch/lib/../../nvidia/cuda_nvrtc/lib
__init__.py  __pycache__  libnvrtc-builtins.so.11.8  libnvrtc-builtins.so.12.4  libnvrtc.so.11.2  libnvrtc.so.12
(.venv) root@b4ffe5c8ac8c:/pytorch/.ci/pytorch/smoke_test# ls /.venv/lib/python3.12/site-packages/torch/lib/../../nvidia/nvjitlink/lib
__init__.py  __pycache__  libnvJitLink.so.12

Test with rc 2.6 and CUDA 11.8:

python cudnn_test.py
2.6.0+cu118
---------------------------------------------SDPA-Flash---------------------------------------------
ALL GOOD
---------------------------------------------SDPA-CuDNN---------------------------------------------
ALL GOOD

Thank you @nWEIdia for discovering this issue

cc @seemethere @malfet @osalpekar

Similar fix as: #144816

Fixes: #145580

Found during testing of #138340

Please note both nvrtc and nvjitlink exist for cuda 11.8, 12.4 and 12.6 hence we can safely remove if statement. Preloading can apply to all supporting cuda versions.

CUDA 11.8 path:
```
(.venv) root@b4ffe5c8ac8c:/pytorch/.ci/pytorch/smoke_test# ls /.venv/lib/python3.12/site-packages/torch/lib/../../nvidia/cuda_nvrtc/lib
__init__.py  __pycache__  libnvrtc-builtins.so.11.8  libnvrtc-builtins.so.12.4  libnvrtc.so.11.2  libnvrtc.so.12
(.venv) root@b4ffe5c8ac8c:/pytorch/.ci/pytorch/smoke_test# ls /.venv/lib/python3.12/site-packages/torch/lib/../../nvidia/nvjitlink/lib
__init__.py  __pycache__  libnvJitLink.so.12
```

Test with rc 2.6 and CUDA 11.8:
```
python cudnn_test.py
2.6.0+cu118
---------------------------------------------SDPA-Flash---------------------------------------------
ALL GOOD
---------------------------------------------SDPA-CuDNN---------------------------------------------
ALL GOOD
```

Thank you @nWEIdia for discovering this issue

Pull Request resolved: #145582
Approved by: https://github.com/nWEIdia, https://github.com/eqy, https://github.com/kit1980, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
(cherry picked from commit 9752c7c)
Copy link

pytorch-bot bot commented Jan 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145614

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 9 Pending

As of commit 34c3e25 with merge base f7e621c (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet malfet merged commit 3207040 into release/2.6 Jan 24, 2025
111 of 120 checks passed
@nWEIdia
Copy link
Collaborator

nWEIdia commented Jan 24, 2025

@atalman I am noticing that you might have tested cu124 first, and then cu118, please see that your test directory containing both libnvrtc.so.11.2 libnvrtc.so.12
so I went ahead and tested vanilla cu118 binary (standalone, not with cu124), I have an impression that this line may have prevented things from working on cu118 (i.e. cu118 binary seems to still be breaking):

if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps:
return

The above libcudart.so might be too strict, I guess libcudart.so.* existence should be fine?
Please see below what cuda_runtime/lib has for cu118

/usr/local/lib/python3.12/dist-packages/torch# ls ../nvidia/cuda_runtime/lib/
init.py pycache libOpenCL.so.1 libcudart.so.11.0

@nWEIdia
Copy link
Collaborator

nWEIdia commented Jan 24, 2025

Only if /usr/local/lib/python3.12/dist-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so symlink is created
or cu124 installed it.

by default installation, it is only libnvrtc.so.11.2
and the code only checks libnvrtc.so and otherwise return and would not execute the preload

@atalman
Copy link
Contributor

atalman commented Jan 24, 2025

Looks like you are right standalone cu118 is not loading:

---------------------------------------------SDPA-Flash---------------------------------------------
ALL GOOD
---------------------------------------------SDPA-CuDNN---------------------------------------------
Could not load library libnvrtc.so.11.2. Error: libnvrtc.so.11.2: cannot open shared object file: No such file or directory
Could not load library libnvrtc.so. Error: libnvrtc.so: cannot open shared object file: No such file or directory
Could not load library libnvrtc.so.11.2. Error: libnvrtc.so.11.2: cannot open shared object file: No such file or directory
Could not load library libnvrtc.so. Error: libnvrtc.so: cannot open shared object file: No such file or directory
Could not load library libnvrtc.so.11.2. Error: libnvrtc.so.11.2: cannot open shared object file: No such file or directory
Could not load library libnvrtc.so. Error: libnvrtc.so: cannot open shared object file: No such file or directory

However the file is there:

ldd /venv/lib/python3.12/site-packages/torch/lib/../../nvidia/cuda_nvrtc/lib/libnvrtc.so.11.2
	linux-vdso.so.1 (0x00007fff3d1eb000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x000079dff81d0000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x000079dff81cb000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x000079dff81c6000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x000079dff80dd000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x000079dff45ee000)
	/lib64/ld-linux-x86-64.so.2 (0x000079dff81da000)

FYI. the statement: if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps: is not an issue

As per @nWEIdia workaround is: ln -s libnvrtc.so.11.2 libnvrtc.so

@nWEIdia
Copy link
Collaborator

nWEIdia commented Jan 24, 2025

Yeah, not sure why but two workarounds identified so far: (either of them works)

export LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/nvidia/cuda_nvrtc/lib/:$LD_LIBRARY_PATH
ln -s libnvrtc.so.11.2 libnvrtc.so

@nWEIdia
Copy link
Collaborator

nWEIdia commented Jan 24, 2025

I am going to switch the preload order, but need the test case for the first issue.
DO not want to fix one but regress the other.

Would they be incompatible (both want to be preloaded first?)

Update: it seems the libnvjitlink test would just be "python -c 'import torch'" , so if libnvrtc test case works, libnvjitlink test must also have worked fine.

pytorchmergebot pushed a commit that referenced this pull request Jan 24, 2025
There is no libnvjitlink in  CUDA-11.x , so attempts to load it first will abort the execution and prevent the script from preloading nvrtc

Fixes issues reported in #145614 (comment)

Pull Request resolved: #145638
Approved by: https://github.com/atalman, https://github.com/kit1980, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
pytorchbot pushed a commit that referenced this pull request Jan 24, 2025
There is no libnvjitlink in  CUDA-11.x , so attempts to load it first will abort the execution and prevent the script from preloading nvrtc

Fixes issues reported in #145614 (comment)

Pull Request resolved: #145638
Approved by: https://github.com/atalman, https://github.com/kit1980, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
(cherry picked from commit 2a70de7)
malfet pushed a commit that referenced this pull request Jan 24, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
[CUDA] Change slim-wheel libraries load order (#145638)

There is no libnvjitlink in  CUDA-11.x , so attempts to load it first will abort the execution and prevent the script from preloading nvrtc

Fixes issues reported in #145614 (comment)

Pull Request resolved: #145638
Approved by: https://github.com/atalman, https://github.com/kit1980, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
(cherry picked from commit 2a70de7)

Co-authored-by: Wei Wang <weiwan@nvidia.com>
nWEIdia added a commit to nWEIdia/pytorch that referenced this pull request Jan 27, 2025
There is no libnvjitlink in  CUDA-11.x , so attempts to load it first will abort the execution and prevent the script from preloading nvrtc

Fixes issues reported in pytorch#145614 (comment)

Pull Request resolved: pytorch#145638
Approved by: https://github.com/atalman, https://github.com/kit1980, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
@github-actions github-actions bot deleted the cherry-pick-145582-by-pytorch_bot_bot_ branch February 24, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants