Skip to content

Add tensorboard to torch engine #1704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ def __init__(self, config: Config):
self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)

if config.bool("use_tensorboard", False):
from torch.utils.tensorboard import SummaryWriter

self._tensorboard_writer = SummaryWriter()
else:
self._tensorboard_writer = None

default_float_dtype = config.value("default_float_dtype", None)
if default_float_dtype is not None:
assert isinstance(default_float_dtype, str)
Expand Down Expand Up @@ -255,6 +262,9 @@ def train(self):
self.init_train_epoch()
self.train_epoch()

if self._tensorboard_writer:
self._tensorboard_writer.close()

print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3)

def init_train_epoch(self):
Expand Down Expand Up @@ -481,6 +491,10 @@ def train_epoch(self):
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
log_memory_usage_device=self._device if self._log_memory_usage else None,
)
if self._tensorboard_writer:
# write losses/errors to tensorboard
for key, val in eval_info.items():
self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)

if self._stop_on_nonfinite_train_score:
if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()):
Expand Down Expand Up @@ -665,6 +679,13 @@ def eval_model(self, *, skip_already_evaluated: bool = False):
start_elapsed=step_end_time - eval_start_time,
log_memory_usage_device=self._device if self._log_memory_usage else None,
)
if self._tensorboard_writer:
# write losses/errors to tensorboard
for key, val in eval_info.items():
self._tensorboard_writer.add_scalar(
f"{dataset_name}/{key}", val, global_step=self.global_train_step
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_train_step always stays the same here for one specific epoch. Isn't that a problem?
Maybe you want to do it for the accumulated_losses_dict instead below?

)

step_idx += 1

assert step_idx > 0, f"No data in dataset {dataset_name!r}."
Expand Down
Loading