Skip to content

Commit f7511b2

Browse files
committed
rename log_intevals to train_log_freq
1 parent 6514de1 commit f7511b2

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

config/default_config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ data_loader_rng_seed: ???
151151
run_id: ???
152152

153153
# The period to log in the training loop (in number of batch steps)
154-
log_intervals:
154+
train_log_freq:
155155
terminal: 10
156156
metrics: 20
157157
checkpoint: 250

src/weathergen/run_train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,13 @@ def inference_from_args(argl: list[str]):
7070

7171
cf.run_history += [(args.from_run_id, cf.istep)]
7272

73+
<<<<<<< HEAD
7374
trainer = Trainer(cf.log_intervals)
7475
trainer.inference(cf, devices, args.from_run_id, args.epoch)
76+
=======
77+
trainer = Trainer(cf.train_log_freq)
78+
trainer.inference(cf, args.from_run_id, args.epoch)
79+
>>>>>>> acf76bd (rename `log_intevals` to `train_log_freq`)
7580

7681

7782
####################################################################################################
@@ -138,7 +143,7 @@ def train_continue_from_args(argl: list[str]):
138143
# track history of run to ensure traceability of results
139144
cf.run_history += [(args.from_run_id, cf.istep)]
140145

141-
trainer = Trainer(cf.log_intervals)
146+
trainer = Trainer(cf.train_log_freq)
142147
trainer.run(cf, devices, args.from_run_id, args.epoch)
143148

144149

@@ -184,7 +189,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):
184189
if cf.with_flash_attention:
185190
assert cf.with_mixed_precision
186191

187-
trainer = Trainer(cf.log_intervals)
192+
trainer = Trainer(cf.train_log_freq)
188193

189194
try:
190195
trainer.run(cf, devices)

src/weathergen/train/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@
5454

5555

5656
class Trainer(TrainerBase):
57-
def __init__(self, log_intervals: Config):
57+
def __init__(self, train_log_freq: Config):
5858
TrainerBase.__init__(self)
5959

60-
self.log_intervals = log_intervals
60+
self.train_log_freq = train_log_freq
6161

6262
def init(self, cf: Config, devices):
6363
self.cf = OmegaConf.merge(
@@ -589,11 +589,11 @@ def train(self, epoch):
589589
self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item()
590590

591591
self._log_terminal(bidx, epoch, TRAIN)
592-
if bidx % self.log_intervals.metrics == 0:
592+
if bidx % self.train_log_freq.metrics == 0:
593593
self._log(TRAIN)
594594

595595
# save model checkpoint (with designation _latest)
596-
if bidx % self.log_intervals.checkpoint == 0 and bidx > 0:
596+
if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0:
597597
self.save_model(-1)
598598

599599
self.cf.istep += cf.batch_size_per_gpu
@@ -943,7 +943,7 @@ def _log(self, stage: Stage):
943943
self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], []
944944

945945
def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
946-
print_freq = self.log_intervals.terminal
946+
print_freq = self.train_log_freq.terminal
947947
if bidx % print_freq == 0 and bidx > 0 or stage == VAL:
948948
# compute from last iteration
949949
avg_loss, losses_all, _ = self._prepare_losses_for_logging()

0 commit comments

Comments
 (0)