Skip to content

Commit 394f35c

Browse files
update
1 parent ce3d1d8 commit 394f35c

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

returnn/torch/engine.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torch.utils.data import DataLoader
2121
from torch import autocast
2222
from torch.cuda import amp
23-
from torch.utils.tensorboard import SummaryWriter
2423
import numpy as np
2524

2625
import returnn
@@ -132,7 +131,12 @@ def __init__(self, config: Config):
132131
self._reset_dev_memory_caches = config.bool("reset_dev_memory_caches", False)
133132
self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
134133
self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)
135-
self._tensorboard_writer = SummaryWriter()
134+
135+
if config.bool("use_tensorboard", False):
136+
from torch.utils.tensorboard import SummaryWriter
137+
self._tensorboard_writer = SummaryWriter()
138+
else:
139+
self._tensorboard_writer = None
136140

137141
default_float_dtype = config.value("default_float_dtype", None)
138142
if default_float_dtype is not None:
@@ -257,7 +261,8 @@ def train(self):
257261
self.init_train_epoch()
258262
self.train_epoch()
259263

260-
self._tensorboard_writer.close()
264+
if self._tensorboard_writer:
265+
self._tensorboard_writer.close()
261266

262267
print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3)
263268

@@ -485,9 +490,10 @@ def train_epoch(self):
485490
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
486491
log_memory_usage_device=self._device if self._log_memory_usage else None,
487492
)
488-
# write losses/errors to tensorboard
489-
for key, val in eval_info.items():
490-
self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
493+
if self._tensorboard_writer:
494+
# write losses/errors to tensorboard
495+
for key, val in eval_info.items():
496+
self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)
491497

492498
if self._stop_on_nonfinite_train_score:
493499
if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()):
@@ -672,13 +678,12 @@ def eval_model(self, *, skip_already_evaluated: bool = False):
672678
start_elapsed=step_end_time - eval_start_time,
673679
log_memory_usage_device=self._device if self._log_memory_usage else None,
674680
)
675-
# write losses/errors to tensorboard
676-
for key, val in eval_info.items():
677-
self._tensorboard_writer.add_scalar(
678-
f"{dataset_name}/{key}",
679-
val,
680-
global_step=self.global_train_step
681-
)
681+
if self._tensorboard_writer:
682+
# write losses/errors to tensorboard
683+
for key, val in eval_info.items():
684+
self._tensorboard_writer.add_scalar(
685+
f"{dataset_name}/{key}", val, global_step=self.global_train_step
686+
)
682687

683688
step_idx += 1
684689

0 commit comments

Comments
 (0)