Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suppress reset_parameters of torch.nn.Linear,Conv2d... inside no_init_weights #18505

Closed
YouJiacheng opened this issue Aug 6, 2022 · 12 comments · Fixed by #27709
Closed

Suppress reset_parameters of torch.nn.Linear,Conv2d... inside no_init_weights #18505

YouJiacheng opened this issue Aug 6, 2022 · 12 comments · Fixed by #27709

Comments

@YouJiacheng
Copy link
Contributor

Feature request

torch.nn.Linear,Conv2d... will call self.reset_parameters() inside their __init__.
I'd like to make reset_parameters be a no-op inside no_init_weights context manager.

Motivation

no_init_weights is used in from_pretrained to speed up loading large models.
However, torch-built-in modules like torch.nn.Linear are heavily used in models of transformers, while its weights initialization cannot be disabled by no_init_weights.
And in the doc string of no_init_weights, it should "globally disable weight initialization".

Your contribution

possible implementation

class SupportsResetParameters(Protocol):
    def reset_parameters(self): ...

@contextmanager
def no_init(module_classes: Iterable[Type[SupportsResetParameters]]):
    saved = {m: vars(m).get('reset_parameters') for m in module_classes}
    def no_op(_): pass
    for m in saved: m.reset_parameters = no_op # Iterable can only be safely iterated through once
    try:
        yield
    finally:
        for m, init in saved.items():
            del m.reset_parameters
            if init is not None:
                m.reset_parameters = init

TORCH_BUILT_IN_MODULES = [nn.Linear, nn.Conv2d, ...]

@contextmanager
def no_init_weights():
    """
    Context manager to globally disable weight initialization to speed up loading large models.
    """
    global _init_weights
    saved = _init_weights
    _init_weights = False
    try:
        with no_init(TORCH_BUILT_IN_MODULES):
            yield
    finally:
        _init_weights = saved
@github-actions
Copy link

github-actions bot commented Sep 5, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@jph00
Copy link

jph00 commented Oct 10, 2023

@younesbelkada @pacman100 Can I suggest reopening this issue - it's pretty important IMO and it's hitting lots of people.

@LysandreJik
Copy link
Member

May be interesting to you @ArthurZucker

RyanJDick added a commit to invoke-ai/InvokeAI that referenced this issue Oct 10, 2023
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [x] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission
      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description

This PR optimizes the time to load models from disk.
In my local testing, SDXL text_encoder_2 models saw the greatest
improvement:
- Before change, load time (disk to cpu): 14 secs
- After change, load time (disk to cpu): 4 secs

See the in-code documentation for an explanation of how this speedup is
achieved.

## Related Tickets & Documents

This change was previously proposed on the HF transformers repo, but did
not get any traction:
huggingface/transformers#18505 (comment)

## QA Instructions, Screenshots, Recordings

I don't expect any adverse effects, but the new context manager is
applied while loading **all** models, so it would make sense to exercise
everything.

## Added/updated tests?

- [x] Yes
- [ ] No
@ArthurZucker
Copy link
Collaborator

Yes! I'll take this one, makes sense 😉

@Chillee
Copy link

Chillee commented Oct 13, 2023

Why not just do

with torch.device('meta'):
     model_init()

?

@YouJiacheng
Copy link
Contributor Author

YouJiacheng commented Oct 14, 2023

Yes, meta device or faketensor is the correct and recommended choice for deferring/skipping initialization.

https://pytorch.org/torchdistx/latest/fake_tensor_and_deferred_init.html

https://huggingface.co/blog/accelerate-large-models

But with torch.device('meta') requires pytorch 2.0 IIUC.

pytorch/pytorch#97951 (comment)

TBH, I didn't know meta device when I posted this issue (2022). I knew meta device this year.

@huggingface huggingface deleted a comment from github-actions bot Nov 7, 2023
@ArthurZucker
Copy link
Collaborator

Oups sorry for the delay! Have not forgotten about this 😉

@jph00
Copy link

jph00 commented Nov 10, 2023

Yes, meta device or faketensor is the correct and recommended choice for deferring/skipping initialization.

meta device doesn't solve the problem, because buffers (e.g sin/cos in llama2) don't get initialized in that case.

@YouJiacheng
Copy link
Contributor Author

YouJiacheng commented Nov 10, 2023

Yes, meta device or faketensor is the correct and recommended choice for deferring/skipping initialization.

meta device doesn't solve the problem, because buffers (e.g sin/cos in llama2) don't get initialized in that case.

You are right, meta device would skip all initializations, builtin or user-defined, parameters or buffers.
But as long as those buffers are not marked as persistent=False, it can be loaded from the checkpoint.

@ArthurZucker
Copy link
Collaborator

We have a lot of buffers with persistant = False. I answered on the other issue but I’ll most probably go about this with skipping initialization for all layers but the ones that are missing in the checkpoints.

@pacman100
Copy link
Contributor

Hello, please also see this comment #26258 (comment)

@ArthurZucker
Copy link
Collaborator

Fixed by #27709 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants