Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix in-place modification when autotuning triton Lion update #36

Merged
merged 1 commit into from
Mar 30, 2024

Conversation

yousufmo
Copy link
Contributor

@yousufmo yousufmo commented Mar 19, 2024

Currently the autotuning modifies the parameters in place and so produces incorrect results.

You can verify this via

import torch
import lion_pytorch.triton as lpt
from lion_pytorch.lion_pytorch import update_fn

# try both the triton and torch implementations
for update in [lpt.update_fn, update_fn]:
    param = torch.ones(5, dtype=torch.float32, device="cuda:0")
    grad = torch.full_like(param, fill_value=1e-3, dtype=torch.float32, device="cuda:0")
    exp = torch.zeros_like(param, dtype=torch.float32, device="cuda:0")
    update(p=param, grad=grad, exp_avg=exp, lr=1e-3, wd=0, beta1=0.9, beta2=0.99)
    print(f"Set param to {param}")

Running this script gives different results
a) between the triton and non-triton implementations and
b) between multiple triton runs of the same script

This PR resolves this by using the restore_value feature in triton.autotune which restores the values back after autotuning completes. And thereby avoids the issues of in-place modification.

@yousufmo
Copy link
Contributor Author

@lucidrains - what are your thoughts?

@lucidrains
Copy link
Owner

@yousufmo yea, triton did not have this feature at the time i wrote this repository. willing to accept as long as you also change the pip install to one with a version after this feature was made available

@lucidrains
Copy link
Owner

lucidrains commented Mar 30, 2024

@yousufmo is the restore_value feature documented anywhere? nevermind, i see it

@lucidrains
Copy link
Owner

@yousufmo ok, i'm going to accept and just force version 2.2.0

thank you!

@lucidrains lucidrains merged commit 85b985a into lucidrains:main Mar 30, 2024
@yousufmo
Copy link
Contributor Author

yousufmo commented Apr 2, 2024

@lucidrains - sorry was on vacation! Thanks for merging much appreciated 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants