Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions neural_lam/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Third-party
import torch


class WarmupCosineAnnealingLR(torch.optim.lr_scheduler.LRScheduler):
def __init__(
self,
optimizer,
warmup_steps=1000,
annealing_steps=100000,
max_factor=1.0,
min_factor=0.001,
):
self.warmup_steps = warmup_steps
self.annealing_steps = annealing_steps

# TODO generalize this to support multiple parameter groups
assert (
len(optimizer.param_groups) == 1
), "WarmupCosineAnnealingLR only supports training with one parameter group"
[param_group] = optimizer.param_groups
initial_learning_rate = param_group["lr"]

self.warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=min_factor,
end_factor=max_factor,
total_iters=warmup_steps,
)

self.annealing_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=annealing_steps,
eta_min=min_factor * initial_learning_rate,
)

super().__init__(optimizer)

def get_lr(self):
if self._step_count <= self.warmup_steps:
return self.warmup_scheduler.get_last_lr()
elif self._step_count <= self.warmup_steps + self.annealing_steps:
return self.annealing_scheduler.get_last_lr()

return True

def step(self):
if self._step_count == 0:
pass
elif self._step_count <= self.warmup_steps:
self.warmup_scheduler.step()
elif self._step_count <= self.warmup_steps + self.annealing_steps:
self.annealing_scheduler.step()
self._step_count += 1
1 change: 1 addition & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
)

return opt

@property
Expand Down
7 changes: 7 additions & 0 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def main(input_args=None):
default=200,
help="upper epoch limit (default: 200)",
)
parser.add_argument(
"--steps",
type=int,
default=-1,
help="upper step limit (default: None)",
)
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size (default: 4)"
)
Expand Down Expand Up @@ -308,6 +314,7 @@ def main(input_args=None):
)
trainer = pl.Trainer(
max_epochs=args.epochs,
max_steps=args.steps,
deterministic=True,
strategy="ddp",
accelerator=device_name,
Expand Down
58 changes: 58 additions & 0 deletions tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Third-party
import numpy as np
import pytest
import torch

# First-party
from neural_lam import lr_scheduler


@pytest.fixture
def model():
return torch.nn.Linear(1, 1)


@pytest.fixture
def optimizer(model):
return torch.optim.Adam(model.parameters())


def test_warmup_cosine_annealing_produces_expected_schedule(optimizer):
min_factor = 0.01
max_factor = 1
warmup_steps = 10
annealing_steps = 10
initial_lr = optimizer.param_groups[0]["lr"]

scheduler = lr_scheduler.WarmupCosineAnnealingLR(
optimizer,
min_factor=min_factor,
max_factor=max_factor,
annealing_steps=annealing_steps,
warmup_steps=warmup_steps,
)

lrs = []
for _ in range(25):
lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()

expected_warmup_lr = np.linspace(
min_factor * initial_lr,
max_factor * initial_lr,
warmup_steps,
endpoint=False,
)
warmup_lr = lrs[:warmup_steps]
assert np.allclose(warmup_lr, expected_warmup_lr)

annealing_lr = lrs[warmup_steps : warmup_steps + annealing_steps]

# Formula for the cosine annealing
expected_annealing_lr = min_factor * initial_lr + 0.5 * (
max_factor * initial_lr - min_factor * initial_lr
) * (1 + np.cos(np.pi * np.arange(annealing_steps) / annealing_steps))
assert np.allclose(annealing_lr, expected_annealing_lr)

end_lr = np.array(lrs[warmup_steps + annealing_steps:])
assert all(end_lr == min_factor * initial_lr)