Skip to content
9 changes: 5 additions & 4 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ desc: ""
data_loader_rng_seed: ???
run_id: ???

# Parameters for logging/printing in the training loop
train_log:
# The period to log metrics (in number of batch steps)
log_interval: 20
# The period to log in the training loop (in number of batch steps)
train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250
18 changes: 11 additions & 7 deletions integration_tests/small1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,17 @@ def evaluate_results(run_id):
logger.info("run evaluation")
cfg = omegaconf.OmegaConf.create(
{
"verbose": True,
"image_format": "png",
"dpi_val": 300,
"summary_plots": True,
"summary_dir": "./plots/",
"print_summary": True,
"evaluation": {"metrics": ["rmse", "l1", "mse"]},
"global_plotting_options": {
"image_format": "png",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you. clearly we are not running the test suite like we should?

"dpi_val": 300,
},
"evaluation": {
"metrics": ["rmse", "l1", "mse"],
"verbose": True,
"summary_plots": True,
"summary_dir": "./plots/",
"print_summary": True,
},
"run_ids": {
run_id: { # would be nice if this could be done with option
"streams": {
Expand Down
30 changes: 29 additions & 1 deletion packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
with fname.open() as f:
json_str = f.read()

return OmegaConf.create(json.loads(json_str))
config = OmegaConf.create(json.loads(json_str))

return _check_logging(config)


def _get_model_config_file_name(run_id: str, epoch: int | None):
Expand All @@ -99,6 +101,32 @@ def _get_model_config_file_name(run_id: str, epoch: int | None):
return f"model_{run_id}{epoch_str}.json"


def _apply_fixes(config: Config) -> Config:
"""
Apply fixes to maintain a best effort backward combatibility.

This method should act as a central hook to implement config backward
compatibility fixes. This is needed to run inference/continuing from
"outdatet" run configurations. The fixes in this function should be
eventually removed.
"""
config = _check_logging(config)
return config


def _check_logging(config: Config) -> Config:
"""
Apply fixes to log frequency config.
"""
config = config.copy()
if config.get("train_log_freq") is None: # TODO remove this for next version
config.train_log_freq = OmegaConf.construct(
{"checkpoint": 250, "terminal": 10, "metrics": config.train_log.log_interval}
)

return config


def load_config(
private_home: Path | None,
from_run_id: str | None,
Expand Down
6 changes: 3 additions & 3 deletions src/weathergen/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def inference_from_args(argl: list[str]):

cf.run_history += [(args.from_run_id, cf.istep)]

trainer = Trainer()
trainer = Trainer(cf.train_log_freq)
trainer.inference(cf, devices, args.from_run_id, args.epoch)


Expand Down Expand Up @@ -138,7 +138,7 @@ def train_continue_from_args(argl: list[str]):
# track history of run to ensure traceability of results
cf.run_history += [(args.from_run_id, cf.istep)]

trainer = Trainer()
trainer = Trainer(cf.train_log_freq)
trainer.run(cf, devices, args.from_run_id, args.epoch)


Expand Down Expand Up @@ -184,7 +184,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):
if cf.with_flash_attention:
assert cf.with_mixed_precision

trainer = Trainer(checkpoint_freq=250, print_freq=10)
trainer = Trainer(cf.train_log_freq)

try:
trainer.run(cf, devices)
Expand Down
17 changes: 8 additions & 9 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@


class Trainer(TrainerBase):
def __init__(self, checkpoint_freq=250, print_freq=10):
def __init__(self, train_log_freq: Config):
TrainerBase.__init__(self)

self.checkpoint_freq = checkpoint_freq
self.print_freq = print_freq
self.train_log_freq = train_log_freq

def init(self, cf: Config, devices):
self.cf = OmegaConf.merge(
Expand Down Expand Up @@ -533,7 +532,6 @@ def train(self, epoch):
cf = self.cf
self.model.train()
# torch.autograd.set_detect_anomaly(True)
log_interval = self.cf.train_log.log_interval

dataset_iter = iter(self.data_loader)

Expand Down Expand Up @@ -591,11 +589,11 @@ def train(self, epoch):
self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item()

self._log_terminal(bidx, epoch, TRAIN)
if bidx % log_interval == 0:
if bidx % self.train_log_freq.metrics == 0:
self._log(TRAIN)

# model checkpoint
if bidx % self.checkpoint_freq == 0 and bidx > 0:
# save model checkpoint (with designation _latest)
if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0:
self.save_model(-1)

self.cf.istep += cf.batch_size_per_gpu
Expand Down Expand Up @@ -945,7 +943,8 @@ def _log(self, stage: Stage):
self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], []

def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
if bidx % self.print_freq == 0 and bidx > 0 or stage == VAL:
print_freq = self.train_log_freq.terminal
if bidx % print_freq == 0 and bidx > 0 or stage == VAL:
# compute from last iteration
avg_loss, losses_all, _ = self._prepare_losses_for_logging()

Expand Down Expand Up @@ -975,7 +974,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
self.cf.istep,
avg_loss.nanmean().item(),
self.lr_scheduler.get_lr(),
(self.print_freq * self.cf.batch_size_per_gpu) / dt,
(print_freq * self.cf.batch_size_per_gpu) / dt,
),
)
logger.info("\t")
Expand Down
Loading