Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,45 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
pred_std = self.per_var_std # (d_f,)

return prediction, pred_std

def _sync2skip(self, flag_skip):
# see below:
# https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013
# gathering a tensor across all workers and then reduce it using or
world_size = torch.distributed.get_world_size()
torch.distributed.barrier()
# now gather
result = [torch.zeros_like(flag_skip) for _ in range(world_size)]
torch.distributed.all_gather(result, flag_skip)
any_invalid = torch.sum(torch.stack(result)).bool().item()
return any_invalid

def _check_nan_inf_loss(self, loss, batch_id=None):
mask_nan_inf = torch.logical_or(torch.isnan(loss), ~torch.isfinite(loss))
if torch.any(mask_nan_inf):
# if any is invalid then we must flag this to all DDP processes
flag_skip = torch.ones((), device=loss.device, dtype=torch.bool)
else:
flag_skip = torch.zeros((), device=loss.device, dtype=torch.bool)

# sub-optimal but will do, till they fix it in
# https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013
any_invalid = self._sync2skip(flag_skip)
if any_invalid:
if self.nan_countdown >= 100:
raise RuntimeError(
"Too many NaNs loss iterations encountered, stopping!"
)
logger.warning(
f"NaN loss in batch {batch_id} of epoch {self.current_epoch}, "
f"skipping the whole batch across all workers."
)
self.nan_countdown += 1
else:
# reset counter
self.nan_countdown = 1

return any_invalid

def common_step(self, batch):
"""
Expand Down Expand Up @@ -298,6 +337,11 @@ def training_step(self, batch):
)
) # mean over unrolled times and batch

any_invalid = self._check_nan_inf_loss(batch_loss)
if any_invalid:
# skip this batch altogether on all workers.
return None

log_dict = {"train_loss": batch_loss}
self.log_dict(
log_dict,
Expand Down Expand Up @@ -336,6 +380,11 @@ def validation_step(self, batch, batch_idx):
) # (time_steps-1)
mean_loss = torch.mean(time_step_loss)

any_invalid = self._check_nan_inf_loss(mean_loss, batch_idx)
if any_invalid:
# skip this batch altogether on all workers.
return None

# Log loss per time step forward and mean
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
Expand Down