Skip to content
Open
10 changes: 10 additions & 0 deletions benchmark_utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

# learning rate schedule: stable then decay
def get_lr(step, num_step, cooldown_frac=0.4):
x = step / num_step # progress in training
assert 0 <= x < 1
if x < 1 - cooldown_frac:
return 1.0
else:
return (1 - x) / cooldown_frac
# return w * 1.0 + (1 - w) * 0.1
32 changes: 32 additions & 0 deletions benchmark_utils/running_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import torch
import torch.distributed as t_dist


def get_running_setup():

# Use submitit helpers to setup distributed training easily.
try:
import submitit
submitit.helpers.TorchDistributedEnvironment().export()
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
except (ImportError, RuntimeError):
ddp = False
if ddp:
print("Running in Distributed Data Parallel (DDP) mode")
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert torch.cuda.is_available()
# TorchDistributedEnvironment sets the visible devices to the
# current rank, so we can use the default device.
device = torch.device("cuda", 0)
torch.cuda.set_device(device)
t_dist.init_process_group(backend="nccl", device_id=device)
dist = t_dist
else:
rank = 0
world_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
dist = None

return dist, rank, world_size, device
433 changes: 433 additions & 0 deletions benchmark_utils/soap.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion datasets/fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def get_distributed_data_generator(self, batch_size, rank=0, world_size=1):
]
effective_batch_size = batch_size * world_size * self.seq_len
if self.max_tokens is not None:
assert self.max_tokens % effective_batch_size == 0
assert (self.max_tokens % effective_batch_size) == 0, (
"max_tokens must be multiple of effective batch size"
)

# Compute local batch size per process
# We use sequence length 1024 and load the token stream as a flat array
Expand Down
4 changes: 2 additions & 2 deletions objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def set_data(self, train_dataloader, val_dataloader, model):

def evaluate_result(self, model, dist=None):
model.eval()
val_batch_size = 32 * 1024 # 32k tokens per batch
val_batch_size = 64 # Batch of 64 for validation
if dist is not None:
# In distributed mode, we use the distributed data generator
rank, size = dist.get_rank(), dist.get_world_size()
val_loader = self.val_dataloader.get_distributed_data_generator(
batch_size=val_batch_size * size, rank=rank, world_size=size
batch_size=val_batch_size, rank=rank, world_size=size
)
else:
# In non-distributed mode, we use the regular data generator
Expand Down
52 changes: 5 additions & 47 deletions solvers/adam.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,19 @@
from benchopt import BaseSolver

import os
from contextlib import nullcontext

from tqdm.auto import tqdm

import torch
import torch.distributed as dist
from torch.optim import AdamW
from tqdm.auto import tqdm


# learning rate schedule: stable then decay
def get_lr(step, num_step, cooldown_frac=0.4):
x = step / num_step # progress in training
assert 0 <= x < 1
if x < 1 - cooldown_frac:
return 1.0
else:
return (1 - x) / cooldown_frac
# return w * 1.0 + (1 - w) * 0.1
from benchmark_utils.lr_scheduler import get_lr
from benchmark_utils.running_setup import get_running_setup


# The benchmark solvers must be named `Solver` and
# inherit from `BaseSolver` for `benchopt` to work properly.
class Solver(BaseSolver):

# Name to select the solver in the CLI and to display the results.
name = 'Adam'

# List of parameters for the solver. The benchmark will consider
# the cross product for each key in the dictionary.
# All parameters 'p' defined here are available as 'self.p'.
parameters = {
'learning_rate': [1e-3],
'weight_decay': [1e-4],
Expand All @@ -44,37 +27,12 @@ class Solver(BaseSolver):
"slurm_ntasks_per_node": 4,
}

# List of packages needed to run the solver. See the corresponding
# section in objective.py
requirements = []

sampling_strategy = 'callback'

def set_objective(self, train_dataloader, model):

# Use submitit helpers to setup distributed training easily.
try:
import submitit
submitit.helpers.TorchDistributedEnvironment().export()
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
except (ImportError, RuntimeError):
ddp = False
if ddp:
print("Running in Distributed Data Parallel (DDP) mode")
self.rank = int(os.environ["RANK"])
self.world_size = int(os.environ["WORLD_SIZE"])
assert torch.cuda.is_available()
# TorchDistributedEnvironment sets the visible devices to the
# current rank, so we can use the default device.
device = torch.device("cuda", 0)
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
self.dist = dist
else:
self.rank = 0
self.world_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
self.dist = None
# Setup distributed training if needed
self.dist, self.rank, self.world_size, device = get_running_setup()

if self.sin_init:
print("Using sinusoidal initialization")
Expand Down
44 changes: 6 additions & 38 deletions solvers/scionlight.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from benchopt import BaseSolver

import os
from contextlib import nullcontext

import torch
from tqdm.auto import tqdm

import torch
import torch.distributed as dist
from benchmark_utils.lr_scheduler import get_lr
from benchmark_utils.running_setup import get_running_setup


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -163,17 +163,6 @@ def step(self):
G.mul_(1 - momentum)


# learning rate schedule: stable then decay to 0
def get_lr(step, num_steps, cooldown_frac=0.4):
x = step / num_steps # progress in training
assert 0 <= x < 1
if x < 1 - cooldown_frac:
return 1.0
else:
return (1 - x) / cooldown_frac
# return w * 1.0 + (1 - w) * 0.1


# The benchmark solvers must be named `Solver` and
# inherit from `BaseSolver` for `benchopt` to work properly.
class Solver(BaseSolver):
Expand Down Expand Up @@ -204,30 +193,9 @@ class Solver(BaseSolver):
sampling_strategy = "callback"

def set_objective(self, train_dataloader, model):
# Use submitit helpers to setup distributed training easily.
try:
import submitit

submitit.helpers.TorchDistributedEnvironment().export()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
except (ImportError, RuntimeError):
ddp = False
if ddp:
print("Running in Distributed Data Parallel (DDP) mode")
self.rank = int(os.environ["RANK"])
self.world_size = int(os.environ["WORLD_SIZE"])
assert torch.cuda.is_available()
# TorchDistributedEnvironment sets the visible devices to the
# current rank, so we can use the default device.
device = torch.device("cuda", 0)
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
self.dist = dist
else:
self.rank = 0
self.world_size = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
self.dist = None

# Setup distributed training if needed
self.dist, self.rank, self.world_size, device = get_running_setup()

model = model.to(device=device)
model.device = device # store the device in the model
Expand Down
129 changes: 129 additions & 0 deletions solvers/soap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from benchopt import BaseSolver

from contextlib import nullcontext

from tqdm.auto import tqdm

import torch

from benchmark_utils.soap import SOAP
from benchmark_utils.lr_scheduler import get_lr
from benchmark_utils.running_setup import get_running_setup


class Solver(BaseSolver):
name = "SOAP"

parameters = {
"learning_rate": [1e-4],
"weight_decay": [5e-3],
"num_steps": [6200],
"batch_size": [64],
"precondition_frequency": [10],
"max_precond_dim": [10000],
"merge_dims": [False],
"precondition_1d": [False],
"normalize_grads": [False],
"correct_bias": [True],
"slurm_nodes": [1, 2],
}
slurm_params = {
"slurm_gres": "gpu:4",
"slurm_ntasks_per_node": 4,
}

sampling_strategy = "callback"

def set_objective(self, train_dataloader, model):

# Setup distributed training if needed
self.dist, self.rank, self.world_size, device = get_running_setup()

model = model.to(device=device)
model.device = device
self.train_dataloader = train_dataloader

self.ctx = (
torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
if torch.cuda.is_available()
else nullcontext()
)

self.model = torch.compile(model, dynamic=False, fullgraph=True)
SOAP.step = torch.compile(torch.no_grad(SOAP.step))

def __del__(self):
if getattr(self, "dist", None) is not None:
self.dist.destroy_process_group()

def get_next(self, stop_val):
return stop_val + 250

def warm_up(self):
self.run_once(stop_val=10)

def run(self, cb):
param_dict = {
pn: p for pn, p in self.model.named_parameters() if p.requires_grad
}
decay_params = [p for _, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{"params": decay_params, "weight_decay": self.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]

self.optimizer = SOAP(
optim_groups,
lr=torch.tensor(self.learning_rate),
betas=(0.95, 0.95),
precondition_frequency=self.precondition_frequency,
max_precond_dim=self.max_precond_dim,
merge_dims=self.merge_dims,
precondition_1d=self.precondition_1d,
normalize_grads=self.normalize_grads,
correct_bias=self.correct_bias,
)

train_loader = self.train_dataloader.get_distributed_data_generator(
batch_size=self.batch_size,
world_size=self.world_size,
rank=self.rank,
)

if self.dist is not None:
self.dist.barrier()

step = 0
with tqdm(total=self.num_steps, desc="Training") as progress:
while cb():
self.model.train()
self.optimizer.zero_grad(set_to_none=True)

step += 1
progress.update()
if step == self.num_steps:
break

data = next(train_loader)
with self.ctx:
loss, *_ = self.model(*data)
loss.backward()
if self.dist is not None:
for param in self.model.parameters():
self.dist.all_reduce(
param.grad, op=self.dist.ReduceOp.AVG
)

scale_lr = get_lr(step, self.num_steps)
for param_group in self.optimizer.param_groups:
param_group["lr"] = torch.tensor(
self.learning_rate * scale_lr
)

self.optimizer.step()

def get_result(self):
if torch.cuda.is_available():
torch.cuda.synchronize()
return dict(model=self.model, dist=self.dist)