diff --git a/config/default_config.yml b/config/default_config.yml index 1e9f8be9b..e40fa3b7f 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -150,7 +150,8 @@ desc: "" data_loader_rng_seed: ??? run_id: ??? -# Parameters for logging/printing in the training loop -train_log: - # The period to log metrics (in number of batch steps) - log_interval: 20 \ No newline at end of file +# The period to log in the training loop (in number of batch steps) +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 250 diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index f8a39a0f1..4f975c140 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -100,13 +100,17 @@ def evaluate_results(run_id): logger.info("run evaluation") cfg = omegaconf.OmegaConf.create( { - "verbose": True, - "image_format": "png", - "dpi_val": 300, - "summary_plots": True, - "summary_dir": "./plots/", - "print_summary": True, - "evaluation": {"metrics": ["rmse", "l1", "mse"]}, + "global_plotting_options": { + "image_format": "png", + "dpi_val": 300, + }, + "evaluation": { + "metrics": ["rmse", "l1", "mse"], + "verbose": True, + "summary_plots": True, + "summary_dir": "./plots/", + "print_summary": True, + }, "run_ids": { run_id: { # would be nice if this could be done with option "streams": { diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index b69cce01a..5e0822d28 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -86,7 +86,9 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> with fname.open() as f: json_str = f.read() - return OmegaConf.create(json.loads(json_str)) + config = OmegaConf.create(json.loads(json_str)) + + return _check_logging(config) def _get_model_config_file_name(run_id: str, epoch: int | None): @@ -99,6 +101,32 @@ def _get_model_config_file_name(run_id: str, epoch: int | None): return f"model_{run_id}{epoch_str}.json" +def _apply_fixes(config: Config) -> Config: + """ + Apply fixes to maintain a best effort backward combatibility. + + This method should act as a central hook to implement config backward + compatibility fixes. This is needed to run inference/continuing from + "outdatet" run configurations. The fixes in this function should be + eventually removed. + """ + config = _check_logging(config) + return config + + +def _check_logging(config: Config) -> Config: + """ + Apply fixes to log frequency config. + """ + config = config.copy() + if config.get("train_log_freq") is None: # TODO remove this for next version + config.train_log_freq = OmegaConf.construct( + {"checkpoint": 250, "terminal": 10, "metrics": config.train_log.log_interval} + ) + + return config + + def load_config( private_home: Path | None, from_run_id: str | None, diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 5ccc891b9..eb2cab895 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -70,7 +70,7 @@ def inference_from_args(argl: list[str]): cf.run_history += [(args.from_run_id, cf.istep)] - trainer = Trainer() + trainer = Trainer(cf.train_log_freq) trainer.inference(cf, devices, args.from_run_id, args.epoch) @@ -138,7 +138,7 @@ def train_continue_from_args(argl: list[str]): # track history of run to ensure traceability of results cf.run_history += [(args.from_run_id, cf.istep)] - trainer = Trainer() + trainer = Trainer(cf.train_log_freq) trainer.run(cf, devices, args.from_run_id, args.epoch) @@ -184,7 +184,7 @@ def train_with_args(argl: list[str], stream_dir: str | None): if cf.with_flash_attention: assert cf.with_mixed_precision - trainer = Trainer(checkpoint_freq=250, print_freq=10) + trainer = Trainer(cf.train_log_freq) try: trainer.run(cf, devices) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0fdbd0f6d..36e97aeef 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -54,11 +54,10 @@ class Trainer(TrainerBase): - def __init__(self, checkpoint_freq=250, print_freq=10): + def __init__(self, train_log_freq: Config): TrainerBase.__init__(self) - self.checkpoint_freq = checkpoint_freq - self.print_freq = print_freq + self.train_log_freq = train_log_freq def init(self, cf: Config, devices): self.cf = OmegaConf.merge( @@ -533,7 +532,6 @@ def train(self, epoch): cf = self.cf self.model.train() # torch.autograd.set_detect_anomaly(True) - log_interval = self.cf.train_log.log_interval dataset_iter = iter(self.data_loader) @@ -591,11 +589,11 @@ def train(self, epoch): self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() self._log_terminal(bidx, epoch, TRAIN) - if bidx % log_interval == 0: + if bidx % self.train_log_freq.metrics == 0: self._log(TRAIN) - # model checkpoint - if bidx % self.checkpoint_freq == 0 and bidx > 0: + # save model checkpoint (with designation _latest) + if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: self.save_model(-1) self.cf.istep += cf.batch_size_per_gpu @@ -945,7 +943,8 @@ def _log(self, stage: Stage): self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] def _log_terminal(self, bidx: int, epoch: int, stage: Stage): - if bidx % self.print_freq == 0 and bidx > 0 or stage == VAL: + print_freq = self.train_log_freq.terminal + if bidx % print_freq == 0 and bidx > 0 or stage == VAL: # compute from last iteration avg_loss, losses_all, _ = self._prepare_losses_for_logging() @@ -975,7 +974,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): self.cf.istep, avg_loss.nanmean().item(), self.lr_scheduler.get_lr(), - (self.print_freq * self.cf.batch_size_per_gpu) / dt, + (print_freq * self.cf.batch_size_per_gpu) / dt, ), ) logger.info("\t")