Skip to content

Commit

Permalink
Fix grad norm
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jan 31, 2025
1 parent 8b9dcd7 commit 34d59f4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@

if is_torch_xla_available():
import torch_xla.core.xla_model as xm
from torch_xla.distributed.parallel_loader import MpDeviceLoader
else:
xm = None

Expand Down Expand Up @@ -226,8 +227,6 @@ def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optiona
if use_mp_device_loader and self.state.mp_plugin.pipeline_parallel_size == 1:
data_loader = MpDeviceLoader(data_loader, self.device)
return data_loader
# TODO: fix that.
# return super().prepare_data_loader(data_loader, device_placement=device_placement)

def _prepare_optimizer_for_mp(self, optimizer: torch.optim.Optimizer, device_placement=None):
cpu_parameters_to_xla = collections.ChainMap(*self._model_cpu_parameters_to_xla.values())
Expand Down Expand Up @@ -537,7 +536,8 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2, postpone_clipping_t
"prepared by the NeuronAccelerator."
)
self._optimizers[0].prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
else:
return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type)

def _custom_save_state(
self,
Expand Down

0 comments on commit 34d59f4

Please sign in to comment.