Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Correct behavior of _load_best_model for PEFT models #24103

Merged
merged 4 commits into from
Jun 8, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2177,11 +2177,20 @@ def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.safetensors")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those two should be in constants (like WEIGHTS_NAME) as they are now used several time across the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, just added it!


model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
if (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path)
):
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
else:
has_been_loaded = True
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
Expand All @@ -2207,10 +2216,10 @@ def _load_best_model(self):
self.accelerator, model, self.state.best_model_checkpoint
)
else:
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
if is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys
Expand All @@ -2222,9 +2231,10 @@ def _load_best_model(self):
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
# We can't do pure 8bit training using transformers.
logger.warning("Could not loading a quantized checkpoint.")
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
has_been_loaded = False
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
Expand All @@ -2236,7 +2246,7 @@ def _load_best_model(self):
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled():
if not is_sagemaker_mp_enabled() and has_been_loaded:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
load_result = load_sharded_checkpoint(
Expand Down