Skip to content

Commit 2bdd443

Browse files
committed
adding tensorboard support
1 parent cd3ee78 commit 2bdd443

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

timm/utils/summary.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def update_summary(
3535
lr=None,
3636
write_header=False,
3737
log_wandb=False,
38+
tensorboard_writer=False,
3839
):
3940
rowd = OrderedDict(epoch=epoch)
4041
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
@@ -43,6 +44,12 @@ def update_summary(
4344
rowd['lr'] = lr
4445
if log_wandb:
4546
wandb.log(rowd)
47+
if tensorboard_writer:
48+
import torch
49+
for k, v in rowd.items():
50+
if isinstance(v, float):
51+
tensorboard_writer.add_scalar(k, v, epoch)
52+
4653
with open(filename, mode='a') as cf:
4754
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
4855
if write_header: # first iteration (epoch == 1 can't be used)

train.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@
6464
has_functorch = True
6565
except ImportError as e:
6666
has_functorch = False
67-
67+
#test tensorboard install
68+
try:
69+
from torch.utils.tensorboard import SummaryWriter
70+
has_tensorboard = True
71+
except ImportError as e:
72+
has_tensorboard = False
6873
has_compile = hasattr(torch, 'compile')
6974

7075

@@ -347,8 +352,8 @@
347352
help='use the multi-epochs-loader to save time at the beginning of every epoch')
348353
group.add_argument('--log-wandb', action='store_true', default=False,
349354
help='log training and validation metrics to wandb')
350-
351-
355+
group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH',
356+
help='log training and validation metrics to TensorBoard')
352357
def _parse_args():
353358
# Do we have a config file to parse?
354359
args_config, remaining = config_parser.parse_known_args()
@@ -726,6 +731,16 @@ def main():
726731
"You've requested to log metrics to wandb but package not found. "
727732
"Metrics not being logged to wandb, try `pip install wandb`")
728733

734+
if utils.is_primary(args) and args.log_tensorboard:
735+
if has_tensorboard:
736+
writer = SummaryWriter(args.log_tensorboard)
737+
else:
738+
_logger.warning(
739+
"You've requested to log metrics to tensorboard but package not found. "
740+
"Metrics not being logged to tensorboard, try `pip install tensorboard`")
741+
742+
743+
729744
# setup learning rate schedule and starting epoch
730745
updates_per_epoch = len(loader_train)
731746
lr_scheduler, num_epochs = create_scheduler_v2(
@@ -809,6 +824,7 @@ def main():
809824
lr=sum(lrs) / len(lrs),
810825
write_header=best_metric is None,
811826
log_wandb=args.log_wandb and has_wandb,
827+
tensorboard_writer=writer if writer is not None and has_tensorboard else False,
812828
)
813829

814830
if saver is not None:

0 commit comments

Comments
 (0)