|
54 | 54 |
|
55 | 55 |
|
56 | 56 | class Trainer(TrainerBase): |
57 | | - def __init__(self, checkpoint_freq=250, print_freq=10): |
| 57 | + def __init__(self, train_log_freq: Config): |
58 | 58 | TrainerBase.__init__(self) |
59 | 59 |
|
60 | | - self.checkpoint_freq = checkpoint_freq |
61 | | - self.print_freq = print_freq |
| 60 | + self.train_log_freq = train_log_freq |
62 | 61 |
|
63 | 62 | def init(self, cf: Config, devices): |
64 | 63 | self.cf = OmegaConf.merge( |
@@ -533,7 +532,6 @@ def train(self, epoch): |
533 | 532 | cf = self.cf |
534 | 533 | self.model.train() |
535 | 534 | # torch.autograd.set_detect_anomaly(True) |
536 | | - log_interval = self.cf.train_log.log_interval |
537 | 535 |
|
538 | 536 | dataset_iter = iter(self.data_loader) |
539 | 537 |
|
@@ -591,11 +589,11 @@ def train(self, epoch): |
591 | 589 | self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() |
592 | 590 |
|
593 | 591 | self._log_terminal(bidx, epoch, TRAIN) |
594 | | - if bidx % log_interval == 0: |
| 592 | + if bidx % self.train_log_freq.metrics == 0: |
595 | 593 | self._log(TRAIN) |
596 | 594 |
|
597 | | - # model checkpoint |
598 | | - if bidx % self.checkpoint_freq == 0 and bidx > 0: |
| 595 | + # save model checkpoint (with designation _latest) |
| 596 | + if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: |
599 | 597 | self.save_model(-1) |
600 | 598 |
|
601 | 599 | self.cf.istep += cf.batch_size_per_gpu |
@@ -945,7 +943,8 @@ def _log(self, stage: Stage): |
945 | 943 | self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] |
946 | 944 |
|
947 | 945 | def _log_terminal(self, bidx: int, epoch: int, stage: Stage): |
948 | | - if bidx % self.print_freq == 0 and bidx > 0 or stage == VAL: |
| 946 | + print_freq = self.train_log_freq.terminal |
| 947 | + if bidx % print_freq == 0 and bidx > 0 or stage == VAL: |
949 | 948 | # compute from last iteration |
950 | 949 | avg_loss, losses_all, _ = self._prepare_losses_for_logging() |
951 | 950 |
|
@@ -975,7 +974,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): |
975 | 974 | self.cf.istep, |
976 | 975 | avg_loss.nanmean().item(), |
977 | 976 | self.lr_scheduler.get_lr(), |
978 | | - (self.print_freq * self.cf.batch_size_per_gpu) / dt, |
| 977 | + (print_freq * self.cf.batch_size_per_gpu) / dt, |
979 | 978 | ), |
980 | 979 | ) |
981 | 980 | logger.info("\t") |
|
0 commit comments