diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index b07be4052..445dfab05 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -134,6 +134,13 @@ def __init__(self, config: Config): self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False) self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True) + if config.bool("use_tensorboard", False): + from torch.utils.tensorboard import SummaryWriter + + self._tensorboard_writer = SummaryWriter() + else: + self._tensorboard_writer = None + default_float_dtype = config.value("default_float_dtype", None) if default_float_dtype is not None: assert isinstance(default_float_dtype, str) @@ -257,6 +264,9 @@ def train(self): self.init_train_epoch() self.train_epoch() + if self._tensorboard_writer: + self._tensorboard_writer.close() + print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3) def init_train_epoch(self): @@ -506,6 +516,10 @@ def train_epoch(self): batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None, log_memory_usage_device=self._device if self._log_memory_usage else None, ) + if self._tensorboard_writer: + # write losses/errors to tensorboard + for key, val in eval_info.items(): + self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step) if self._stop_on_nonfinite_train_score: if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()): @@ -695,12 +709,18 @@ def eval_model(self, *, skip_already_evaluated: bool = False): start_elapsed=step_end_time - eval_start_time, log_memory_usage_device=self._device if self._log_memory_usage else None, ) + step_idx += 1 assert step_idx > 0, f"No data in dataset {dataset_name!r}." accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict accumulated_losses_dict = self._maybe_extend_losses_info(accumulated_losses_dict) + if self._tensorboard_writer: + # write losses/errors to tensorboard + for key, val in accumulated_losses_dict.items(): + self._tensorboard_writer.add_scalar(f"{dataset_name}/{key}", val, global_step=self.epoch) + self.learning_rate_control.set_epoch_error( self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()} )