From 34d59f44c27a4e6ab153ab48491e95808d27e5a7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 31 Jan 2025 18:14:27 +0100 Subject: [PATCH] Fix grad norm --- optimum/neuron/accelerate/accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 193dffdf3..aae04fe7d 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -76,6 +76,7 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm + from torch_xla.distributed.parallel_loader import MpDeviceLoader else: xm = None @@ -226,8 +227,6 @@ def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optiona if use_mp_device_loader and self.state.mp_plugin.pipeline_parallel_size == 1: data_loader = MpDeviceLoader(data_loader, self.device) return data_loader - # TODO: fix that. - # return super().prepare_data_loader(data_loader, device_placement=device_placement) def _prepare_optimizer_for_mp(self, optimizer: torch.optim.Optimizer, device_placement=None): cpu_parameters_to_xla = collections.ChainMap(*self._model_cpu_parameters_to_xla.values()) @@ -537,7 +536,8 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2, postpone_clipping_t "prepared by the NeuronAccelerator." ) self._optimizers[0].prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type) - return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type) + else: + return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type) def _custom_save_state( self,