diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f3769f19..2bb31c81 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -266,6 +266,45 @@ def unroll_prediction(self, init_states, forcing_features, true_states): pred_std = self.per_var_std # (d_f,) return prediction, pred_std + + def _sync2skip(self, flag_skip): + # see below: + # https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013 + # gathering a tensor across all workers and then reduce it using or + world_size = torch.distributed.get_world_size() + torch.distributed.barrier() + # now gather + result = [torch.zeros_like(flag_skip) for _ in range(world_size)] + torch.distributed.all_gather(result, flag_skip) + any_invalid = torch.sum(torch.stack(result)).bool().item() + return any_invalid + + def _check_nan_inf_loss(self, loss, batch_id=None): + mask_nan_inf = torch.logical_or(torch.isnan(loss), ~torch.isfinite(loss)) + if torch.any(mask_nan_inf): + # if any is invalid then we must flag this to all DDP processes + flag_skip = torch.ones((), device=loss.device, dtype=torch.bool) + else: + flag_skip = torch.zeros((), device=loss.device, dtype=torch.bool) + + # sub-optimal but will do, till they fix it in + # https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013 + any_invalid = self._sync2skip(flag_skip) + if any_invalid: + if self.nan_countdown >= 100: + raise RuntimeError( + "Too many NaNs loss iterations encountered, stopping!" + ) + logger.warning( + f"NaN loss in batch {batch_id} of epoch {self.current_epoch}, " + f"skipping the whole batch across all workers." + ) + self.nan_countdown += 1 + else: + # reset counter + self.nan_countdown = 1 + + return any_invalid def common_step(self, batch): """ @@ -298,6 +337,11 @@ def training_step(self, batch): ) ) # mean over unrolled times and batch + any_invalid = self._check_nan_inf_loss(batch_loss) + if any_invalid: + # skip this batch altogether on all workers. + return None + log_dict = {"train_loss": batch_loss} self.log_dict( log_dict, @@ -336,6 +380,11 @@ def validation_step(self, batch, batch_idx): ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) + any_invalid = self._check_nan_inf_loss(mean_loss, batch_idx) + if any_invalid: + # skip this batch altogether on all workers. + return None + # Log loss per time step forward and mean val_log_dict = { f"val_loss_unroll{step}": time_step_loss[step - 1]