diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 77f4bcd84..b257800e4 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -451,6 +451,11 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): Minimum value for learning rate. The scheduler clips values below this threshold. """ + warmup_ratio: float = None + """ + Proportion of steps to warm up for + """ + warmup_iters: int = None """ Number of warmup iterations diff --git a/megatron/training.py b/megatron/training.py index 779b15839..a5bfdb7b7 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -606,10 +606,16 @@ def get_learning_rate_scheduler(optimizer, neox_args): init_step = 0 + assert not (neox_args.warmup_ratio and neox_args.warmup_iters) + if neox_args.warmup_ratio: + warmup_iters = neox_args.warmup_ratio * num_iters + else: + warmup_iters = neox_args.warmup_iters + lr_scheduler = AnnealingLR( optimizer, start_lr=neox_args.lr, - warmup_iter=neox_args.warmup_iters, + warmup_iter=warmup_iters, total_iters=num_iters, decay_style=neox_args.lr_decay_style, last_iter=init_step,