Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Correct behavior of _load_best_model for PEFT models #24103

Merged
merged 4 commits into from Jun 8, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 8, 2023

What does this PR do?

Fixes #24096

This PR fixes the bugs related with PEFT models and load_best_model_at_end. It also refactors a bit the current logic to extend it generally to all LoRA models, not only 8-bit base models + LoRA.

Repro script
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import TrainingArguments

dataset = load_dataset("imdb", split="train")

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

args = TrainingArguments(
    max_steps=1,
    save_steps=1,
    eval_steps=1,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    resume_from_checkpoint=True,
    output_dir="test_trainer",
    load_best_model_at_end=True,
)

trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",
    train_dataset=dataset,
    eval_dataset=dataset,
    dataset_text_field="text",
    peft_config=peft_config,
    max_seq_length=128,
    args=args,
)
trainer.train()

cc @sgugger @pacman100

@younesbelkada younesbelkada changed the title [Trainer] Correct behavior of _load_best_model [Trainer] Correct behavior of _load_best_model for PEFT models Jun 8, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 8, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @younesbelkada for simplifying trainer usage with PEFT in terms of saving/loading as this has been a reason for numerous issues 🚀. Left few comments.

@@ -2177,11 +2177,18 @@ def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can also be safetensor ckpts, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe adding best_safe_adapter_model_path should serve the purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, will refactor that a bit

Comment on lines 2233 to 2236
else:
# We can't do pure 8bit training using transformers.
logger.warning("Could not loading a quantized checkpoint.")
has_been_loaded = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be removed now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is needed so that it can be used in the block below for the check, otherwise it will throw an error similar as #24096

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AH sorry I see what you meant, yes will remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proposed something in bf31c5e

- add ST format as well
Comment on lines 2180 to 2181
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.safetensors")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those two should be in constants (like WEIGHTS_NAME) as they are now used several time across the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, just added it!

@younesbelkada younesbelkada merged commit 2200bf7 into huggingface:main Jun 8, 2023
22 checks passed
@younesbelkada younesbelkada deleted the trainer-resume-fix branch June 8, 2023 13:38
sgugger pushed a commit that referenced this pull request Jun 8, 2023
…24103)

* v1

* some refactor

- add ST format as well

* fix

* add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME`
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…uggingface#24103)

* v1

* some refactor

- add ST format as well

* fix

* add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Exception when saving weights from QLORA due to UnboundLocalError
4 participants