diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 413138597..d6d8733f1 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -7,7 +7,7 @@ LR Scheduler Arguments -- **lr_decay_style**: typing.Literal['constant', 'linear', 'cosine', 'exponential'] +- **lr_decay_style**: typing.Literal['constant', 'linear', 'cosine', 'exponential', 'infinite_cosine', 'infinite_inv_sqrt'] Default = linear @@ -31,6 +31,14 @@ LR Scheduler Arguments +- **constant_lr**: float + + Default = 0.0 + + Constant learning rate when using infinite cosine or infinite inv sqrt decay styles. + + + - **warmup**: float Default = 0.01 @@ -39,6 +47,30 @@ LR Scheduler Arguments +- **cooldown_iters_perc**: float + + Default = 0.0 + + Percentage of total iterations to cooldown for. + + + +- **constant_iters_perc**: float + + Default = 0.0 + + Percentage of total iterations to keep the learning rate constant for. + + + +- **timescale**: float + + Default = 1.0 + + Timescale for the steepness of the inverse square root cooldown. + + + - **override_lr_scheduler**: bool Default = False diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 9db951aa0..0b9d0c128 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -34,6 +34,10 @@ def __init__( decay_style, last_iter, min_lr=0.0, + constant_lr=0.0, + constant_iters=None, + cooldown_iters=None, + timescale=None, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False, use_mup=False, @@ -43,9 +47,13 @@ def __init__( self.optimizer = optimizer self.start_lr = start_lr self.min_lr = min_lr + self.constant_lr = constant_lr self.warmup_iter = warmup_iter self.num_iters = last_iter self.end_iter = total_iters + self.constant_iters = constant_iters + self.cooldown_iters = cooldown_iters + self.timescale = timescale assert self.end_iter > 0 self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler @@ -84,6 +92,34 @@ def get_lr(self): # exp(-0.693) = 1/2 end_iter = self.end_iter - self.warmup_iter lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter) + elif self.decay_style == "infinite_cosine" or self.decay_style == "infinite_inv_sqrt": + if num_iters_ <= self.cooldown_iters: + if self.decay_style == "infinite_cosine": + lr = self.constant_lr + ( + (self.start_lr-self.constant_lr) + / 2.0 + * (math.cos(math.pi * num_iters_ / self.cooldown_iters) + 1) + ) + else: + def inv_f(t): + return (1/math.sqrt(1+(self.timescale*t))) - 1 + lr = self.start_lr + ( + (self.constant_lr - self.start_lr) + / inv_f(1) + * (inv_f(num_iters_ / self.cooldown_iters)) + ) + return lr + else: + num_iters_ = num_iters_ - self.cooldown_iters + if num_iters_ <= self.constant_iters: + # Stay constant for constant_iters + lr = self.constant_lr + else: + # Go from constant iters to min LR using exponential decay in remaining iters + end_iter_ = self.end_iter - self.warmup_iter - self.cooldown_iters - self.constant_iters + num_iters_ = num_iters_ - self.constant_iters + exp_factor = -math.log(self.min_lr/self.constant_lr) / end_iter_ + lr = self.constant_lr * math.exp(-1* exp_factor * num_iters_) else: lr = self.start_lr return max(lr, self.min_lr) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index dd51c7778..6963c53d1 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -539,7 +539,7 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): LR Scheduler Arguments """ - lr_decay_style: Literal["constant", "linear", "cosine", "exponential"] = "linear" + lr_decay_style: Literal["constant", "linear", "cosine", "exponential", "infinite_cosine", "infinite_inv_sqrt"] = "linear" """ Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'. """ @@ -554,11 +554,31 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): Minimum value for learning rate. The scheduler clips values below this threshold. """ + constant_lr: float = 0.0 + """ + Constant learning rate when using infinite cosine or infinite inv sqrt decay styles. + """ + warmup: float = 0.01 """ Percentage of total iterations to warmup on (.01 = 1 percent of all training iters). """ + cooldown_iters_perc: float = 0.0 + """ + Percentage of total iterations to cooldown for. + """ + + constant_iters_perc: float = 0.0 + """ + Percentage of total iterations to keep the learning rate constant for. + """ + + timescale: float = 1.0 + """ + Timescale for the steepness of the inverse square root cooldown. + """ + override_lr_scheduler: bool = False """ Reset the values of the scheduler (learning rate,warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset. diff --git a/megatron/training.py b/megatron/training.py index fc3d9e129..42afead6f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -717,6 +717,8 @@ def get_learning_rate_scheduler(optimizer, neox_args): num_iters = max(1, num_iters) init_step = 0 warmup_iter = neox_args.warmup * num_iters + constant_iters = neox_args.constant_iters_perc * num_iters + cooldown_iters = neox_args.cooldown_iters_perc * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=neox_args.lr, @@ -725,6 +727,10 @@ def get_learning_rate_scheduler(optimizer, neox_args): decay_style=neox_args.lr_decay_style, last_iter=init_step, min_lr=neox_args.min_lr, + constant_lr=neox_args.constant_lr, + constant_iters=constant_iters, + cooldown_iters=cooldown_iters, + timescale=neox_args.timescale, use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler, override_lr_scheduler=neox_args.override_lr_scheduler, use_mup=neox_args.use_mup,