diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index ab07bc2f1..68a19ca4f 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -827,9 +827,6 @@ def _inner_training_loop( self.lr_scheduler = None self._created_lr_scheduler = False - if not delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self.state = TrainerState() self.state.is_hyper_param_search = trial is not None @@ -865,26 +862,28 @@ def _inner_training_loop( # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX - use_accelerator_prepare = True if model is self.model else False + # use_accelerator_prepare = True if model is self.model else False + + self.model = self.accelerator.prepare(self.model) + self.model.train() + + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.optimizer, self.lr_scheduler) - if delay_optimizer_creation: - if use_accelerator_prepare: - self.model = self.accelerator.prepare(self.model) - self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare - if use_accelerator_prepare: - self.model.train() - if hasattr(self.lr_scheduler, "step"): - if self.use_apex: - model = self.accelerator.prepare(self.model) - else: - model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) - else: - # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. - model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.lr_scheduler - ) + # if use_accelerator_prepare: + # self.model.train() + # if hasattr(self.lr_scheduler, "step"): + # if self.use_apex: + # model = self.accelerator.prepare(self.model) + # else: + # model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + # else: + # # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + # model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + # self.model, self.optimizer, self.lr_scheduler + # ) if isinstance(model, NxDPPModel): self.model = model