Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[from_pretrained] Make from_pretrained fast again #27709

Merged
merged 21 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 63 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ def is_local_dist_rank_0():
if is_peft_available():
from .utils import find_adapter_config_file

TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
"trunc_normal_": nn.init.trunc_normal_,
"constant_": nn.init.constant_,
"xavier_uniform_": nn.init.xavier_uniform_,
"xavier_normal_": nn.init.xavier_normal_,
"kaiming_uniform_": nn.init.kaiming_uniform_,
"kaiming_normal_": nn.init.kaiming_normal_,
"uniform": nn.init.uniform,
"normal": nn.init.normal,
"xavier_uniform": nn.init.xavier_uniform,
"xavier_normal": nn.init.xavier_normal,
"kaiming_uniform": nn.init.kaiming_uniform,
"kaiming_normal": nn.init.kaiming_normal,
}

@contextmanager
def no_init_weights(_enable=True):
Expand All @@ -164,12 +180,54 @@ def no_init_weights(_enable=True):
"""
global _init_weights
old_init_weights = _init_weights

if _enable:
_init_weights = False

def _skip_init(*args, **kwargs):
pass

torch.nn.init.uniform_ = _skip_init
torch.nn.init.normal_ = _skip_init
torch.nn.init.trunc_normal_ = _skip_init
torch.nn.init.constant_ = _skip_init
torch.nn.init.xavier_uniform_ = _skip_init
torch.nn.init.xavier_normal_ = _skip_init
torch.nn.init.kaiming_uniform_ = _skip_init
torch.nn.init.kaiming_normal_ = _skip_init
torch.nn.init.uniform = _skip_init
torch.nn.init.normal = _skip_init
torch.nn.init.xavier_uniform = _skip_init
torch.nn.init.xavier_normal = _skip_init
torch.nn.init.kaiming_uniform = _skip_init
torch.nn.init.kaiming_normal = _skip_init

# # Save the original initialization functions
# for name, init_func in TORCH_INIT_FUNCTIONS.items():
# setattr(torch.nn.init, name, _skip_init)
try:
yield
finally:
_init_weights = old_init_weights
if _enable:
torch.nn.init.uniform_ = TORCH_INIT_FUNCTIONS["uniform"]
torch.nn.init.normal_ = TORCH_INIT_FUNCTIONS["normal_"]
torch.nn.init.trunc_normal_ = TORCH_INIT_FUNCTIONS["trunc_normal_"]
torch.nn.init.constant_ = TORCH_INIT_FUNCTIONS["constant_"]
torch.nn.init.xavier_uniform_ = TORCH_INIT_FUNCTIONS["xavier_uniform_"]
torch.nn.init.xavier_normal_ = TORCH_INIT_FUNCTIONS["xavier_normal_"]
torch.nn.init.kaiming_uniform_ = TORCH_INIT_FUNCTIONS["kaiming_uniform_"]
torch.nn.init.kaiming_normal_ = TORCH_INIT_FUNCTIONS["kaiming_normal_"]
torch.nn.init.uniform = TORCH_INIT_FUNCTIONS["uniform"]
torch.nn.init.normal = TORCH_INIT_FUNCTIONS["normal"]
torch.nn.init.xavier_uniform = TORCH_INIT_FUNCTIONS["xavier_uniform"]
torch.nn.init.xavier_normal = TORCH_INIT_FUNCTIONS["xavier_normal"]
torch.nn.init.kaiming_uniform = TORCH_INIT_FUNCTIONS["kaiming_uniform"]
torch.nn.init.kaiming_normal = TORCH_INIT_FUNCTIONS["kaiming_normal"]

# # Restore the original initialization functions
# for name, init_func in TORCH_INIT_FUNCTIONS.items():
# setattr(torch.nn.init, name, init_func)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved


def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
Expand Down Expand Up @@ -1505,7 +1563,10 @@ def get_output_embeddings(self) -> nn.Module:

def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class.
Initialize the weights. This method should be overridden by derived class and is
the only initialization method that will be called when loading a checkpoint
using `from_pretrained`. Any attempt to initialize outside of this function
will be useless as the torch.nn.init function are all replaced with skip.
"""
pass

Expand Down Expand Up @@ -3413,6 +3474,7 @@ def from_pretrained(
)

with ContextManagers(init_contexts):
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)

# make sure we use the model's config since the __init__ call might have copied it
Expand Down
57 changes: 56 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
AutoModelForCausalLM,
AutoModelForSequenceClassification,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
logging,
)
Expand Down Expand Up @@ -85,7 +86,9 @@
is_torch_fx_available,
)
from transformers.utils.generic import ModelOutput

from transformers import set_seed
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers

if is_accelerate_available():
from accelerate.utils import compute_module_sizes
Expand Down Expand Up @@ -427,6 +430,58 @@ class CopyClass(model_class):
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_fast_init_context_manager(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel):
config_class = PretrainedConfig

def __init__(self, config=PretrainedConfig()):
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(config)
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
self.linear = nn.Linear(10, 10, bias=True)
self.embedding = nn.Embedding(10, 10)
self.std = 1

def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
if module.bias is not None:
module.bias.data.normal_(mean=0.0, std=self.std)

# 2. Make sure a linear layer's reset params is properly skipped:
with ContextManagers([no_init_weights(True)]):
no_init_instance = MyClass()

torch.testing.assert_allclose(no_init_instance.linear.bias, torch.zeros(10), rtol=1e-4, atol=1e-4)

set_seed(0)
expected_bias = torch.tensor(
([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475])
)
init_instance = MyClass()
torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4)

set_seed(0)
torch.testing.assert_allclose(
init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5))
)

# 3. Make sure weights that are not present use init_weight_ and get expected values
with tempfile.TemporaryDirectory() as tmpdirname:
state_dict = init_instance.state_dict()
del state_dict["linear.weight"]

init_instance.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
set_seed(0)
model_fast_init = MyClass.from_pretrained(tmpdirname)

set_seed(0)
model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False)

for key in model_fast_init.state_dict().keys():
max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]))
self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical")

def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
Expand Down