Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix accelerator prepare during eval only mode #24014

Merged
merged 4 commits into from
Jun 7, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 30 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3157,14 +3157,26 @@ def evaluation_loop(

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

# if eval is called w/o train init deepspeed here
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self.accelerator.prepare(self.model)
self.model_wrapped = self.deepspeed = model

model = self._wrap_model(self.model, training=False, dataloader=dataloader)

if len(self.accelerator._models) == 0 and model is self.model:
model = self.accelerator.prepare(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we only want to do this for DeepSpeed, not all the time. Putting a model in DistributedDataParallel just for evaluation will waste some memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree on the DDP case and hence I didn't update it earlier but as mentioned below we will be missing mixed precision coverage for eval-only mode


if self.is_fsdp_enabled:
self.model = model

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model

# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped

# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
Expand Down Expand Up @@ -3752,14 +3764,26 @@ def prediction_loop(

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

# if eval is called w/o train init deepspeed here
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self.accelerator.prepare(self.model)
self.model_wrapped = self.deepspeed = model

model = self._wrap_model(self.model, training=False, dataloader=dataloader)

if len(self.accelerator._models) == 0 and model is self.model:
model = self.accelerator.prepare(model)

if self.is_fsdp_enabled:
self.model = model

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model

# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped

# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
Expand Down