Skip to content

Commit

Permalink
refactor(trainer): adapt to new compute_loss signature
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Jan 7, 2025
1 parent f3a82ee commit 6639b1d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s
self._update_input_specs_in_model_cache_entry(input_specs_for_cache_entry)
return inputs

def compute_loss(self, model, inputs, return_outputs: bool = False):
def compute_loss(self, model, inputs, num_items_in_batch):
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
inputs = self._prepare_inputs(inputs)
loss = model.run_train(**inputs)
else:
loss = super().compute_loss(model, inputs, return_outputs=return_outputs)
loss = super().compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
return loss

def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
Expand Down

0 comments on commit 6639b1d

Please sign in to comment.