diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py new file mode 100644 index 00000000..5b1a6bf5 --- /dev/null +++ b/neural_lam/lr_scheduler.py @@ -0,0 +1,54 @@ +# Third-party +import torch + + +class WarmupCosineAnnealingLR(torch.optim.lr_scheduler.LRScheduler): + def __init__( + self, + optimizer, + warmup_steps=1000, + annealing_steps=100000, + max_factor=1.0, + min_factor=0.001, + ): + self.warmup_steps = warmup_steps + self.annealing_steps = annealing_steps + + # TODO generalize this to support multiple parameter groups + assert ( + len(optimizer.param_groups) == 1 + ), "WarmupCosineAnnealingLR only supports training with one parameter group" + [param_group] = optimizer.param_groups + initial_learning_rate = param_group["lr"] + + self.warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=min_factor, + end_factor=max_factor, + total_iters=warmup_steps, + ) + + self.annealing_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=annealing_steps, + eta_min=min_factor * initial_learning_rate, + ) + + super().__init__(optimizer) + + def get_lr(self): + if self._step_count <= self.warmup_steps: + return self.warmup_scheduler.get_last_lr() + elif self._step_count <= self.warmup_steps + self.annealing_steps: + return self.annealing_scheduler.get_last_lr() + + return True + + def step(self): + if self._step_count == 0: + pass + elif self._step_count <= self.warmup_steps: + self.warmup_scheduler.step() + elif self._step_count <= self.warmup_steps + self.annealing_steps: + self.annealing_scheduler.step() + self._step_count += 1 diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f3769f19..0bec5722 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -193,6 +193,7 @@ def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) + return opt @property diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index e8b402d5..f1b050a5 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -72,6 +72,12 @@ def main(input_args=None): default=200, help="upper epoch limit (default: 200)", ) + parser.add_argument( + "--steps", + type=int, + default=-1, + help="upper step limit (default: None)", + ) parser.add_argument( "--batch_size", type=int, default=4, help="batch size (default: 4)" ) @@ -308,6 +314,7 @@ def main(input_args=None): ) trainer = pl.Trainer( max_epochs=args.epochs, + max_steps=args.steps, deterministic=True, strategy="ddp", accelerator=device_name, diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py new file mode 100644 index 00000000..c179d5ea --- /dev/null +++ b/tests/test_lr_scheduler.py @@ -0,0 +1,58 @@ +# Third-party +import numpy as np +import pytest +import torch + +# First-party +from neural_lam import lr_scheduler + + +@pytest.fixture +def model(): + return torch.nn.Linear(1, 1) + + +@pytest.fixture +def optimizer(model): + return torch.optim.Adam(model.parameters()) + + +def test_warmup_cosine_annealing_produces_expected_schedule(optimizer): + min_factor = 0.01 + max_factor = 1 + warmup_steps = 10 + annealing_steps = 10 + initial_lr = optimizer.param_groups[0]["lr"] + + scheduler = lr_scheduler.WarmupCosineAnnealingLR( + optimizer, + min_factor=min_factor, + max_factor=max_factor, + annealing_steps=annealing_steps, + warmup_steps=warmup_steps, + ) + + lrs = [] + for _ in range(25): + lrs.append(optimizer.param_groups[0]["lr"]) + scheduler.step() + + expected_warmup_lr = np.linspace( + min_factor * initial_lr, + max_factor * initial_lr, + warmup_steps, + endpoint=False, + ) + warmup_lr = lrs[:warmup_steps] + assert np.allclose(warmup_lr, expected_warmup_lr) + + annealing_lr = lrs[warmup_steps : warmup_steps + annealing_steps] + + # Formula for the cosine annealing + expected_annealing_lr = min_factor * initial_lr + 0.5 * ( + max_factor * initial_lr - min_factor * initial_lr + ) * (1 + np.cos(np.pi * np.arange(annealing_steps) / annealing_steps)) + assert np.allclose(annealing_lr, expected_annealing_lr) + + end_lr = np.array(lrs[warmup_steps + annealing_steps:]) + assert all(end_lr == min_factor * initial_lr)