|
64 | 64 | has_functorch = True
|
65 | 65 | except ImportError as e:
|
66 | 66 | has_functorch = False
|
67 |
| - |
| 67 | +#test tensorboard install |
| 68 | +try: |
| 69 | + from torch.utils.tensorboard import SummaryWriter |
| 70 | + has_tensorboard = True |
| 71 | +except ImportError as e: |
| 72 | + has_tensorboard = False |
68 | 73 | has_compile = hasattr(torch, 'compile')
|
69 | 74 |
|
70 | 75 |
|
|
347 | 352 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
348 | 353 | group.add_argument('--log-wandb', action='store_true', default=False,
|
349 | 354 | help='log training and validation metrics to wandb')
|
350 |
| - |
351 |
| - |
| 355 | +group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH', |
| 356 | + help='log training and validation metrics to TensorBoard') |
352 | 357 | def _parse_args():
|
353 | 358 | # Do we have a config file to parse?
|
354 | 359 | args_config, remaining = config_parser.parse_known_args()
|
@@ -726,6 +731,16 @@ def main():
|
726 | 731 | "You've requested to log metrics to wandb but package not found. "
|
727 | 732 | "Metrics not being logged to wandb, try `pip install wandb`")
|
728 | 733 |
|
| 734 | + if utils.is_primary(args) and args.log_tensorboard: |
| 735 | + if has_tensorboard: |
| 736 | + writer = SummaryWriter(args.log_tensorboard) |
| 737 | + else: |
| 738 | + _logger.warning( |
| 739 | + "You've requested to log metrics to tensorboard but package not found. " |
| 740 | + "Metrics not being logged to tensorboard, try `pip install tensorboard`") |
| 741 | + |
| 742 | + |
| 743 | + |
729 | 744 | # setup learning rate schedule and starting epoch
|
730 | 745 | updates_per_epoch = len(loader_train)
|
731 | 746 | lr_scheduler, num_epochs = create_scheduler_v2(
|
@@ -809,6 +824,7 @@ def main():
|
809 | 824 | lr=sum(lrs) / len(lrs),
|
810 | 825 | write_header=best_metric is None,
|
811 | 826 | log_wandb=args.log_wandb and has_wandb,
|
| 827 | + tensorboard_writer=writer if writer is not None and has_tensorboard else False, |
812 | 828 | )
|
813 | 829 |
|
814 | 830 | if saver is not None:
|
|
0 commit comments