|
20 | 20 | from torch.utils.data import DataLoader
|
21 | 21 | from torch import autocast
|
22 | 22 | from torch.cuda import amp
|
23 |
| -from torch.utils.tensorboard import SummaryWriter |
24 | 23 | import numpy as np
|
25 | 24 |
|
26 | 25 | import returnn
|
@@ -132,7 +131,12 @@ def __init__(self, config: Config):
|
132 | 131 | self._reset_dev_memory_caches = config.bool("reset_dev_memory_caches", False)
|
133 | 132 | self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
|
134 | 133 | self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)
|
135 |
| - self._tensorboard_writer = SummaryWriter() |
| 134 | + |
| 135 | + if config.bool("use_tensorboard", False): |
| 136 | + from torch.utils.tensorboard import SummaryWriter |
| 137 | + self._tensorboard_writer = SummaryWriter() |
| 138 | + else: |
| 139 | + self._tensorboard_writer = None |
136 | 140 |
|
137 | 141 | default_float_dtype = config.value("default_float_dtype", None)
|
138 | 142 | if default_float_dtype is not None:
|
@@ -257,7 +261,8 @@ def train(self):
|
257 | 261 | self.init_train_epoch()
|
258 | 262 | self.train_epoch()
|
259 | 263 |
|
260 |
| - self._tensorboard_writer.close() |
| 264 | + if self._tensorboard_writer: |
| 265 | + self._tensorboard_writer.close() |
261 | 266 |
|
262 | 267 | print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3)
|
263 | 268 |
|
@@ -485,9 +490,10 @@ def train_epoch(self):
|
485 | 490 | batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
|
486 | 491 | log_memory_usage_device=self._device if self._log_memory_usage else None,
|
487 | 492 | )
|
488 |
| - # write losses/errors to tensorboard |
489 |
| - for key, val in eval_info.items(): |
490 |
| - self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step) |
| 493 | + if self._tensorboard_writer: |
| 494 | + # write losses/errors to tensorboard |
| 495 | + for key, val in eval_info.items(): |
| 496 | + self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step) |
491 | 497 |
|
492 | 498 | if self._stop_on_nonfinite_train_score:
|
493 | 499 | if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()):
|
@@ -672,13 +678,12 @@ def eval_model(self, *, skip_already_evaluated: bool = False):
|
672 | 678 | start_elapsed=step_end_time - eval_start_time,
|
673 | 679 | log_memory_usage_device=self._device if self._log_memory_usage else None,
|
674 | 680 | )
|
675 |
| - # write losses/errors to tensorboard |
676 |
| - for key, val in eval_info.items(): |
677 |
| - self._tensorboard_writer.add_scalar( |
678 |
| - f"{dataset_name}/{key}", |
679 |
| - val, |
680 |
| - global_step=self.global_train_step |
681 |
| - ) |
| 681 | + if self._tensorboard_writer: |
| 682 | + # write losses/errors to tensorboard |
| 683 | + for key, val in eval_info.items(): |
| 684 | + self._tensorboard_writer.add_scalar( |
| 685 | + f"{dataset_name}/{key}", val, global_step=self.global_train_step |
| 686 | + ) |
682 | 687 |
|
683 | 688 | step_idx += 1
|
684 | 689 |
|
|
0 commit comments