-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suppress reset_parameters
of torch.nn.Linear,Conv2d...
inside no_init_weights
#18505
Comments
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@younesbelkada @pacman100 Can I suggest reopening this issue - it's pretty important IMO and it's hitting lots of people. |
May be interesting to you @ArthurZucker |
## What type of PR is this? (check all applicable) - [ ] Refactor - [ ] Feature - [ ] Bug Fix - [x] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you updated all relevant documentation? - [x] Yes - [ ] No ## Description This PR optimizes the time to load models from disk. In my local testing, SDXL text_encoder_2 models saw the greatest improvement: - Before change, load time (disk to cpu): 14 secs - After change, load time (disk to cpu): 4 secs See the in-code documentation for an explanation of how this speedup is achieved. ## Related Tickets & Documents This change was previously proposed on the HF transformers repo, but did not get any traction: huggingface/transformers#18505 (comment) ## QA Instructions, Screenshots, Recordings I don't expect any adverse effects, but the new context manager is applied while loading **all** models, so it would make sense to exercise everything. ## Added/updated tests? - [x] Yes - [ ] No
Yes! I'll take this one, makes sense 😉 |
Why not just do
? |
Yes, meta device or faketensor is the correct and recommended choice for deferring/skipping initialization. https://pytorch.org/torchdistx/latest/fake_tensor_and_deferred_init.html https://huggingface.co/blog/accelerate-large-models But pytorch/pytorch#97951 (comment) TBH, I didn't know meta device when I posted this issue (2022). I knew meta device this year. |
Oups sorry for the delay! Have not forgotten about this 😉 |
meta device doesn't solve the problem, because buffers (e.g sin/cos in llama2) don't get initialized in that case. |
You are right, meta device would skip all initializations, builtin or user-defined, parameters or buffers. |
We have a lot of buffers with persistant = False. I answered on the other issue but I’ll most probably go about this with skipping initialization for all layers but the ones that are missing in the checkpoints. |
Hello, please also see this comment #26258 (comment) |
Fixed by #27709 🤗 |
Feature request
torch.nn.Linear,Conv2d...
will callself.reset_parameters()
inside their__init__
.I'd like to make
reset_parameters
be a no-op insideno_init_weights
context manager.Motivation
no_init_weights
is used infrom_pretrained
to speed up loading large models.However, torch-built-in modules like
torch.nn.Linear
are heavily used in models oftransformers
, while its weights initialization cannot be disabled byno_init_weights
.And in the doc string of
no_init_weights
, it should "globally disable weight initialization".Your contribution
possible implementation
The text was updated successfully, but these errors were encountered: