From 439f6a7d2190ea631db8965860821151e65abdd9 Mon Sep 17 00:00:00 2001 From: zhangir-azerbayev Date: Sat, 21 Oct 2023 01:19:44 -0600 Subject: [PATCH] warmup_iter argument --- megatron/neox_arguments/neox_args.py | 7 ++++++- megatron/training.py | 10 +++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index e427b2551..aeac958a2 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -451,11 +451,16 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): Minimum value for learning rate. The scheduler clips values below this threshold. """ - warmup: float = 0.01 + warmup: float = None """ Percentage of total iterations to warmup on (.01 = 1 percent of all training iters). """ + warmup_iter: int = None + """ + Number of warmup iterations + """ + override_lr_scheduler: bool = False """ Reset the values of the scheduler (learning rate,warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset. diff --git a/megatron/training.py b/megatron/training.py index 548f81cb0..e920725f8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -601,9 +601,17 @@ def get_learning_rate_scheduler(optimizer, neox_args): num_iters = neox_args.lr_decay_iters else: num_iters = neox_args.train_iters + num_iters = max(1, num_iters) + + assert not (neox_args.warmup_iter and neox_args.warmup) + if neox_args.warmup: + warmup_iter = neox_args.warmup*num_iters + else: + warmup_iter = neox_args.warmup_iter + init_step = 0 - warmup_iter = neox_args.warmup * num_iters + lr_scheduler = AnnealingLR( optimizer, start_lr=neox_args.lr,