Skip to content

Commit 43eb49b

Browse files
authored
Sgrasse/develop/issue 898 checkpoint freq conf (#905)
* add new/changed parameters in default_config * implement backward compatibility * remove `train_log.log_interval` from default config * use new configuration arguments in Trainer * fix: wrong variable name * ruffed * Rework method structure * fix bug * rename `log_intevals` to `train_log_freq` * fix integration tests * fix forgot renaming * fix rebasing artifact
1 parent 5586d79 commit 43eb49b

File tree

5 files changed

+56
-24
lines changed

5 files changed

+56
-24
lines changed

config/default_config.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ desc: ""
150150
data_loader_rng_seed: ???
151151
run_id: ???
152152

153-
# Parameters for logging/printing in the training loop
154-
train_log:
155-
# The period to log metrics (in number of batch steps)
156-
log_interval: 20
153+
# The period to log in the training loop (in number of batch steps)
154+
train_log_freq:
155+
terminal: 10
156+
metrics: 20
157+
checkpoint: 250

integration_tests/small1_test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,17 @@ def evaluate_results(run_id):
100100
logger.info("run evaluation")
101101
cfg = omegaconf.OmegaConf.create(
102102
{
103-
"verbose": True,
104-
"image_format": "png",
105-
"dpi_val": 300,
106-
"summary_plots": True,
107-
"summary_dir": "./plots/",
108-
"print_summary": True,
109-
"evaluation": {"metrics": ["rmse", "l1", "mse"]},
103+
"global_plotting_options": {
104+
"image_format": "png",
105+
"dpi_val": 300,
106+
},
107+
"evaluation": {
108+
"metrics": ["rmse", "l1", "mse"],
109+
"verbose": True,
110+
"summary_plots": True,
111+
"summary_dir": "./plots/",
112+
"print_summary": True,
113+
},
110114
"run_ids": {
111115
run_id: { # would be nice if this could be done with option
112116
"streams": {

packages/common/src/weathergen/common/config.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
8686
with fname.open() as f:
8787
json_str = f.read()
8888

89-
return OmegaConf.create(json.loads(json_str))
89+
config = OmegaConf.create(json.loads(json_str))
90+
91+
return _check_logging(config)
9092

9193

9294
def _get_model_config_file_name(run_id: str, epoch: int | None):
@@ -99,6 +101,32 @@ def _get_model_config_file_name(run_id: str, epoch: int | None):
99101
return f"model_{run_id}{epoch_str}.json"
100102

101103

104+
def _apply_fixes(config: Config) -> Config:
105+
"""
106+
Apply fixes to maintain a best effort backward combatibility.
107+
108+
This method should act as a central hook to implement config backward
109+
compatibility fixes. This is needed to run inference/continuing from
110+
"outdatet" run configurations. The fixes in this function should be
111+
eventually removed.
112+
"""
113+
config = _check_logging(config)
114+
return config
115+
116+
117+
def _check_logging(config: Config) -> Config:
118+
"""
119+
Apply fixes to log frequency config.
120+
"""
121+
config = config.copy()
122+
if config.get("train_log_freq") is None: # TODO remove this for next version
123+
config.train_log_freq = OmegaConf.construct(
124+
{"checkpoint": 250, "terminal": 10, "metrics": config.train_log.log_interval}
125+
)
126+
127+
return config
128+
129+
102130
def load_config(
103131
private_home: Path | None,
104132
from_run_id: str | None,

src/weathergen/run_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def inference_from_args(argl: list[str]):
7070

7171
cf.run_history += [(args.from_run_id, cf.istep)]
7272

73-
trainer = Trainer()
73+
trainer = Trainer(cf.train_log_freq)
7474
trainer.inference(cf, devices, args.from_run_id, args.epoch)
7575

7676

@@ -138,7 +138,7 @@ def train_continue_from_args(argl: list[str]):
138138
# track history of run to ensure traceability of results
139139
cf.run_history += [(args.from_run_id, cf.istep)]
140140

141-
trainer = Trainer()
141+
trainer = Trainer(cf.train_log_freq)
142142
trainer.run(cf, devices, args.from_run_id, args.epoch)
143143

144144

@@ -184,7 +184,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):
184184
if cf.with_flash_attention:
185185
assert cf.with_mixed_precision
186186

187-
trainer = Trainer(checkpoint_freq=250, print_freq=10)
187+
trainer = Trainer(cf.train_log_freq)
188188

189189
try:
190190
trainer.run(cf, devices)

src/weathergen/train/trainer.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,10 @@
5454

5555

5656
class Trainer(TrainerBase):
57-
def __init__(self, checkpoint_freq=250, print_freq=10):
57+
def __init__(self, train_log_freq: Config):
5858
TrainerBase.__init__(self)
5959

60-
self.checkpoint_freq = checkpoint_freq
61-
self.print_freq = print_freq
60+
self.train_log_freq = train_log_freq
6261

6362
def init(self, cf: Config, devices):
6463
self.cf = OmegaConf.merge(
@@ -533,7 +532,6 @@ def train(self, epoch):
533532
cf = self.cf
534533
self.model.train()
535534
# torch.autograd.set_detect_anomaly(True)
536-
log_interval = self.cf.train_log.log_interval
537535

538536
dataset_iter = iter(self.data_loader)
539537

@@ -591,11 +589,11 @@ def train(self, epoch):
591589
self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item()
592590

593591
self._log_terminal(bidx, epoch, TRAIN)
594-
if bidx % log_interval == 0:
592+
if bidx % self.train_log_freq.metrics == 0:
595593
self._log(TRAIN)
596594

597-
# model checkpoint
598-
if bidx % self.checkpoint_freq == 0 and bidx > 0:
595+
# save model checkpoint (with designation _latest)
596+
if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0:
599597
self.save_model(-1)
600598

601599
self.cf.istep += cf.batch_size_per_gpu
@@ -945,7 +943,8 @@ def _log(self, stage: Stage):
945943
self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], []
946944

947945
def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
948-
if bidx % self.print_freq == 0 and bidx > 0 or stage == VAL:
946+
print_freq = self.train_log_freq.terminal
947+
if bidx % print_freq == 0 and bidx > 0 or stage == VAL:
949948
# compute from last iteration
950949
avg_loss, losses_all, _ = self._prepare_losses_for_logging()
951950

@@ -975,7 +974,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
975974
self.cf.istep,
976975
avg_loss.nanmean().item(),
977976
self.lr_scheduler.get_lr(),
978-
(self.print_freq * self.cf.batch_size_per_gpu) / dt,
977+
(print_freq * self.cf.batch_size_per_gpu) / dt,
979978
),
980979
)
981980
logger.info("\t")

0 commit comments

Comments
 (0)