Skip to content

Commit

Permalink
Changes the order of creation / preparation
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jan 28, 2025
1 parent 9d38084 commit 0a23be3
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,6 @@ def _inner_training_loop(
self.lr_scheduler = None
self._created_lr_scheduler = False

if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None

Expand Down Expand Up @@ -865,26 +862,28 @@ def _inner_training_loop(
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
# use_accelerator_prepare = True if model is self.model else False

self.model = self.accelerator.prepare(self.model)
self.model.train()

self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.optimizer, self.lr_scheduler)

if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
# if use_accelerator_prepare:
# self.model.train()
# if hasattr(self.lr_scheduler, "step"):
# if self.use_apex:
# model = self.accelerator.prepare(self.model)
# else:
# model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
# else:
# # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
# model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
# self.model, self.optimizer, self.lr_scheduler
# )

if isinstance(model, NxDPPModel):
self.model = model
Expand Down

0 comments on commit 0a23be3

Please sign in to comment.