Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Commit eab1d26

Browse files
xianxlfacebook-github-bot
authored andcommitted
Make setup_train more modular
Summary: Two refactors to make train resuable by adversarial training: 1. Modularize different steps in setup_train, especially allowing passing a trainer class with corresponding kwargs 2. Allow train_step takes additional args Reviewed By: liezl200 Differential Revision: D9994930 fbshipit-source-id: 6cafa6833e316e6adc77f212aa83c030723d81c1
1 parent 0bcfad8 commit eab1d26

File tree

1 file changed

+47
-18
lines changed

1 file changed

+47
-18
lines changed

pytorch_translate/train.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def validate_and_set_default_args(args):
177177
]
178178

179179

180-
def setup_training(args):
181-
"""Parse args, load dataset, and load model trainer."""
180+
def setup_training_model(args):
181+
"""Parse args, load dataset, and build model with criterion."""
182182
if not torch.cuda.is_available():
183183
raise NotImplementedError("Training on CPU is not supported")
184184
torch.cuda.set_device(args.device_id)
@@ -205,21 +205,12 @@ def setup_training(args):
205205
f"| num. model params: \
206206
{sum(p.numel() for p in model.parameters())}"
207207
)
208+
return task, model, criterion
208209

209-
# Build trainer
210-
if args.fp16:
211-
trainer = FP16Trainer(args, task, model, criterion)
212-
else:
213-
if torch.cuda.get_device_capability(0)[0] >= 7:
214-
print("| NOTICE: your device may support faster training with --fp16")
215-
trainer = Trainer(args, task, model, criterion)
216-
print(f"| training on {args.distributed_world_size} GPUs")
217-
print(
218-
f"| max tokens per GPU = {args.max_tokens} and \
219-
max sentences per GPU = {args.max_sentences}",
220-
flush=True,
221-
)
222210

211+
def setup_training_state(args, trainer, task):
212+
"""Set up the directory for saving checkpoints.
213+
Load pretrained model if specified."""
223214
os.makedirs(args.save_dir, exist_ok=True)
224215

225216
# If --restore-file is already present under --save-dir, use that one
@@ -273,6 +264,21 @@ def setup_training(args):
273264
dataset_split=args.valid_subset,
274265
)
275266
print(f"| extra_state: {extra_state}")
267+
return extra_state
268+
269+
270+
def build_trainer(args, task, model, criterion, trainer_class, **kwargs):
271+
""" Build trainer with provided trainer_class, and set up training state.
272+
"""
273+
trainer = trainer_class(args, task, model, criterion, **kwargs)
274+
275+
print(f"| training on {args.distributed_world_size} GPUs")
276+
print(
277+
f"| max tokens per GPU = {args.max_tokens} and \
278+
max sentences per GPU = {args.max_sentences}",
279+
flush=True,
280+
)
281+
extra_state = setup_training_state(args, trainer, task)
276282

277283
epoch_itr = data.EpochBatchIterator(
278284
dataset=task.dataset(args.train_subset),
@@ -290,6 +296,26 @@ def setup_training(args):
290296
epoch_itr.load_state_dict(
291297
{"epoch": epoch, "iterations_in_epoch": extra_state["batch_offset"]}
292298
)
299+
return trainer, extra_state, epoch_itr
300+
301+
302+
def setup_training(args):
303+
""" Perform several steps:
304+
- build model using provided criterion and task
305+
- load data
306+
- build trainer, and set up training state
307+
"""
308+
task, model, criterion = setup_training_model(args)
309+
if args.fp16:
310+
trainer_class = FP16Trainer
311+
else:
312+
if torch.cuda.get_device_capability(0)[0] >= 7:
313+
print("| NOTICE: your device may support faster training with --fp16")
314+
trainer_class = Trainer
315+
316+
trainer, extra_state, epoch_itr = build_trainer(
317+
args, task, model, criterion, trainer_class
318+
)
293319

294320
return extra_state, trainer, task, epoch_itr
295321

@@ -337,13 +363,14 @@ def single_process_main(args):
337363
trainer=trainer,
338364
task=task,
339365
epoch_itr=epoch_itr,
366+
update_params=True,
340367
)
341368

342369
for _ in train_iterator:
343370
pass
344371

345372

346-
def train(args, extra_state, trainer, task, epoch_itr):
373+
def train(args, extra_state, trainer, task, epoch_itr, **train_step_kwargs):
347374
# offset for current epoch (may be different from checkpoint offset)
348375
starting_offset = extra_state["batch_offset"]
349376

@@ -377,10 +404,12 @@ def train(args, extra_state, trainer, task, epoch_itr):
377404
for i, sample in enumerate(progress, start=starting_offset):
378405
if i < num_batches - 1 and (i + 1) % update_freq > 0:
379406
# buffer updates according to --update-freq
380-
trainer.train_step(sample, update_params=False)
407+
train_step_kwargs["update_params"] = False
408+
trainer.train_step(sample, **train_step_kwargs)
381409
continue
382410
else:
383-
log_output = trainer.train_step(sample, update_params=True)
411+
train_step_kwargs["update_params"] = True
412+
log_output = trainer.train_step(sample, **train_step_kwargs)
384413

385414
if do_prune:
386415
apply_prune_masks(prune_masks, trainer)

0 commit comments

Comments
 (0)