Skip to content

Add tensorboard to torch engine #1704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e1618e5
add tensorboard to torch engine
robin-p-schmitt Mar 31, 2025
8950739
cleanup
robin-p-schmitt Mar 31, 2025
ce3d1d8
cleanup
robin-p-schmitt Mar 31, 2025
394f35c
update
robin-p-schmitt Apr 2, 2025
4f82abd
update
robin-p-schmitt Apr 2, 2025
a2d4234
CI fix transformers version, fix #1706
albertz Apr 3, 2025
03a7526
CI fix transformers version on older Python
albertz Apr 3, 2025
36abfc6
Vocabulary, parse utf8
albertz Apr 4, 2025
a11ee1a
Vocabulary, better error exception on duplicates
albertz Apr 4, 2025
34e3448
LibriSpeechCorpus, orth_post_process, small fix type
albertz Apr 4, 2025
b95b743
LmDataset, support orth_post_process
albertz Apr 4, 2025
f6e07d8
Vocabulary, line-based vocab, use as-is, even with duplicates
albertz Apr 4, 2025
7a1d54e
Vocabulary, small fix warning, cleanup
albertz Apr 4, 2025
cd1cbcc
PT preload_from_files without filename, allow prefix
albertz Apr 4, 2025
048bc87
Torch, fix _print_process handling of undef complete_frac (#1711)
dorian-K Apr 16, 2025
027061f
PostprocessingDataset, implement get_all_tags
albertz Apr 16, 2025
e568db2
SimpleHDFWriter, sanity checks (#1713)
albertz Apr 16, 2025
9dd16e3
RF RunCtx train_flag per func (#1714)
albertz Apr 17, 2025
6e0542a
DummyDataset, fix dtypes
albertz Apr 18, 2025
31f4423
PostprocessingDataset, fix pickling
albertz Apr 18, 2025
dc3667a
is_running_on_cluster: also check SLURM_JOB_ID
albertz Apr 18, 2025
a9b12e7
cf: catch also OSError
albertz Apr 18, 2025
2c99a76
MetaDataset, better err msg
albertz Apr 20, 2025
b72475a
VariableDataset, always_same_tags option, support get_all_tags
albertz Apr 20, 2025
aa9f932
dump-dataset, faster and more generic serialization
albertz Apr 22, 2025
622e4ef
FileCache: hold lock and refresh mtime during cleanup (#1709)
NeoLegends Apr 22, 2025
3a84cb1
Torch engine, epoch_start, epoch_end func for training/eval
albertz Apr 23, 2025
850a2ee
`DistributeFilesDataset`: allow specifying files via list file (#1717)
NeoLegends Apr 23, 2025
6779deb
SentencePieces, handle sis Path
albertz Apr 23, 2025
8588507
Torch, handle available_for_inference in extern_data
albertz Apr 25, 2025
2860715
Make SimpleHDFWriter a context manager
NeoLegends Apr 29, 2025
352af14
SimpleHDFWriter: small comment fix
NeoLegends Apr 29, 2025
639f800
RF scatter_mean
albertz Apr 30, 2025
8cd49a1
update better_exchook
albertz May 8, 2025
f157fa9
get_complete_frac fix when num_seqs is None (#1722)
albertz May 8, 2025
1112cab
tooling: change formatter to ruff (#1719)
NeoLegends May 8, 2025
1a5d0db
update
robin-p-schmitt May 13, 2025
a77e5f4
update
robin-p-schmitt May 13, 2025
5110779
Merge branch 'master' into robin-tensorboard
robin-p-schmitt May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def __init__(self, config: Config):
self._forward_auto_split_batch_on_oom = config.bool("forward_auto_split_batch_on_oom", False)
self._stop_on_nonfinite_train_score = config.bool("stop_on_nonfinite_train_score", True)

if config.bool("use_tensorboard", False):
from torch.utils.tensorboard import SummaryWriter

self._tensorboard_writer = SummaryWriter()
else:
self._tensorboard_writer = None

default_float_dtype = config.value("default_float_dtype", None)
if default_float_dtype is not None:
assert isinstance(default_float_dtype, str)
Expand Down Expand Up @@ -257,6 +264,9 @@ def train(self):
self.init_train_epoch()
self.train_epoch()

if self._tensorboard_writer:
self._tensorboard_writer.close()

print(f"Finished training at epoch {self.epoch}, global train step {self.global_train_step}", file=log.v3)

def init_train_epoch(self):
Expand Down Expand Up @@ -506,6 +516,10 @@ def train_epoch(self):
batch_size_info=_get_batch_size_info(extern_data) if self._log_batch_size else None,
log_memory_usage_device=self._device if self._log_memory_usage else None,
)
if self._tensorboard_writer:
# write losses/errors to tensorboard
for key, val in eval_info.items():
self._tensorboard_writer.add_scalar(f"train/{key}", val, global_step=self.global_train_step)

if self._stop_on_nonfinite_train_score:
if any(np.isinf(v) or np.isnan(v) for v in accumulated_losses_dict.values()):
Expand Down Expand Up @@ -695,12 +709,18 @@ def eval_model(self, *, skip_already_evaluated: bool = False):
start_elapsed=step_end_time - eval_start_time,
log_memory_usage_device=self._device if self._log_memory_usage else None,
)

step_idx += 1

assert step_idx > 0, f"No data in dataset {dataset_name!r}."
accumulated_losses_dict = accumulated_losses_dict / accumulated_inv_norm_factors_dict
accumulated_losses_dict = self._maybe_extend_losses_info(accumulated_losses_dict)

if self._tensorboard_writer:
# write losses/errors to tensorboard
for key, val in accumulated_losses_dict.items():
self._tensorboard_writer.add_scalar(f"{dataset_name}/{key}", val, global_step=self.epoch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we should use self.global_train_step here instead.

Suggested change
self._tensorboard_writer.add_scalar(f"{dataset_name}/{key}", val, global_step=self.epoch)
self._tensorboard_writer.add_scalar(f"{dataset_name}/{key}", val, global_step=self.global_train_step)

We did it like that for TF. Also to have it comparable to the train scores. But I was also not so sure about that.

Other opinions? @NeoLegends? @Atticus1806? @JackTemaki? @vieting? @michelwi?


self.learning_rate_control.set_epoch_error(
self.epoch, {f"{dataset_name}_loss_{k}": v for k, v in accumulated_losses_dict.items()}
)
Expand Down