Skip to content

Commit

Permalink
warmup_iter argument
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangir-azerbayev committed Oct 21, 2023
1 parent e001a04 commit 439f6a7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
7 changes: 6 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,16 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
Minimum value for learning rate. The scheduler clips values below this threshold.
"""

warmup: float = 0.01
warmup: float = None
"""
Percentage of total iterations to warmup on (.01 = 1 percent of all training iters).
"""

warmup_iter: int = None
"""
Number of warmup iterations
"""

override_lr_scheduler: bool = False
"""
Reset the values of the scheduler (learning rate,warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset.
Expand Down
10 changes: 9 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,17 @@ def get_learning_rate_scheduler(optimizer, neox_args):
num_iters = neox_args.lr_decay_iters
else:
num_iters = neox_args.train_iters

num_iters = max(1, num_iters)

assert not (neox_args.warmup_iter and neox_args.warmup)
if neox_args.warmup:
warmup_iter = neox_args.warmup*num_iters
else:
warmup_iter = neox_args.warmup_iter

init_step = 0
warmup_iter = neox_args.warmup * num_iters

lr_scheduler = AnnealingLR(
optimizer,
start_lr=neox_args.lr,
Expand Down

0 comments on commit 439f6a7

Please sign in to comment.