@@ -177,8 +177,8 @@ def validate_and_set_default_args(args):
177177 ]
178178
179179
180- def setup_training (args ):
181- """Parse args, load dataset, and load model trainer ."""
180+ def setup_training_model (args ):
181+ """Parse args, load dataset, and build model with criterion ."""
182182 if not torch .cuda .is_available ():
183183 raise NotImplementedError ("Training on CPU is not supported" )
184184 torch .cuda .set_device (args .device_id )
@@ -205,21 +205,12 @@ def setup_training(args):
205205 f"| num. model params: \
206206 { sum (p .numel () for p in model .parameters ())} "
207207 )
208+ return task , model , criterion
208209
209- # Build trainer
210- if args .fp16 :
211- trainer = FP16Trainer (args , task , model , criterion )
212- else :
213- if torch .cuda .get_device_capability (0 )[0 ] >= 7 :
214- print ("| NOTICE: your device may support faster training with --fp16" )
215- trainer = Trainer (args , task , model , criterion )
216- print (f"| training on { args .distributed_world_size } GPUs" )
217- print (
218- f"| max tokens per GPU = { args .max_tokens } and \
219- max sentences per GPU = { args .max_sentences } " ,
220- flush = True ,
221- )
222210
211+ def setup_training_state (args , trainer , task ):
212+ """Set up the directory for saving checkpoints.
213+ Load pretrained model if specified."""
223214 os .makedirs (args .save_dir , exist_ok = True )
224215
225216 # If --restore-file is already present under --save-dir, use that one
@@ -273,6 +264,21 @@ def setup_training(args):
273264 dataset_split = args .valid_subset ,
274265 )
275266 print (f"| extra_state: { extra_state } " )
267+ return extra_state
268+
269+
270+ def build_trainer (args , task , model , criterion , trainer_class , ** kwargs ):
271+ """ Build trainer with provided trainer_class, and set up training state.
272+ """
273+ trainer = trainer_class (args , task , model , criterion , ** kwargs )
274+
275+ print (f"| training on { args .distributed_world_size } GPUs" )
276+ print (
277+ f"| max tokens per GPU = { args .max_tokens } and \
278+ max sentences per GPU = { args .max_sentences } " ,
279+ flush = True ,
280+ )
281+ extra_state = setup_training_state (args , trainer , task )
276282
277283 epoch_itr = data .EpochBatchIterator (
278284 dataset = task .dataset (args .train_subset ),
@@ -290,6 +296,26 @@ def setup_training(args):
290296 epoch_itr .load_state_dict (
291297 {"epoch" : epoch , "iterations_in_epoch" : extra_state ["batch_offset" ]}
292298 )
299+ return trainer , extra_state , epoch_itr
300+
301+
302+ def setup_training (args ):
303+ """ Perform several steps:
304+ - build model using provided criterion and task
305+ - load data
306+ - build trainer, and set up training state
307+ """
308+ task , model , criterion = setup_training_model (args )
309+ if args .fp16 :
310+ trainer_class = FP16Trainer
311+ else :
312+ if torch .cuda .get_device_capability (0 )[0 ] >= 7 :
313+ print ("| NOTICE: your device may support faster training with --fp16" )
314+ trainer_class = Trainer
315+
316+ trainer , extra_state , epoch_itr = build_trainer (
317+ args , task , model , criterion , trainer_class
318+ )
293319
294320 return extra_state , trainer , task , epoch_itr
295321
@@ -337,13 +363,14 @@ def single_process_main(args):
337363 trainer = trainer ,
338364 task = task ,
339365 epoch_itr = epoch_itr ,
366+ update_params = True ,
340367 )
341368
342369 for _ in train_iterator :
343370 pass
344371
345372
346- def train (args , extra_state , trainer , task , epoch_itr ):
373+ def train (args , extra_state , trainer , task , epoch_itr , ** train_step_kwargs ):
347374 # offset for current epoch (may be different from checkpoint offset)
348375 starting_offset = extra_state ["batch_offset" ]
349376
@@ -377,10 +404,12 @@ def train(args, extra_state, trainer, task, epoch_itr):
377404 for i , sample in enumerate (progress , start = starting_offset ):
378405 if i < num_batches - 1 and (i + 1 ) % update_freq > 0 :
379406 # buffer updates according to --update-freq
380- trainer .train_step (sample , update_params = False )
407+ train_step_kwargs ["update_params" ] = False
408+ trainer .train_step (sample , ** train_step_kwargs )
381409 continue
382410 else :
383- log_output = trainer .train_step (sample , update_params = True )
411+ train_step_kwargs ["update_params" ] = True
412+ log_output = trainer .train_step (sample , ** train_step_kwargs )
384413
385414 if do_prune :
386415 apply_prune_masks (prune_masks , trainer )
0 commit comments