generated from benchopt/template_benchmark
-
Notifications
You must be signed in to change notification settings - Fork 1
SOLVER Implement SOAP #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
svaiter
wants to merge
11
commits into
benchopt:main
Choose a base branch
from
svaiter:soap
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
e4ef22a
SOLVER start working on SOAP (using codex)
svaiter 62c5b2c
CLN use the original soap.py file
svaiter 0d8e800
RFC refactor + fix some issues in dataloader
tomMoral eca5541
FIX flake8
tomMoral 2d68a07
Merge branch 'main' into soap
tomMoral 7ba58b9
Remove some hyperparameters
svaiter 6d2d659
Apply suggestions from code review
tomMoral a22b268
Updated solvers/soap.py to modify line 101 and line 105 as per the re…
svaiter 1fefcaa
flake8
svaiter 52c9167
ENH improve parameters
tomMoral 378a4a7
CLN a few last tweaks
tomMoral File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
|
|
||
| # learning rate schedule: stable then decay | ||
| def get_lr(step, num_step, cooldown_frac=0.4): | ||
| x = step / num_step # progress in training | ||
| assert 0 <= x < 1 | ||
| if x < 1 - cooldown_frac: | ||
| return 1.0 | ||
| else: | ||
| return (1 - x) / cooldown_frac | ||
| # return w * 1.0 + (1 - w) * 0.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| import os | ||
| import torch | ||
| import torch.distributed as t_dist | ||
|
|
||
|
|
||
| def get_running_setup(): | ||
|
|
||
| # Use submitit helpers to setup distributed training easily. | ||
| try: | ||
| import submitit | ||
| submitit.helpers.TorchDistributedEnvironment().export() | ||
| ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? | ||
| except (ImportError, RuntimeError): | ||
| ddp = False | ||
| if ddp: | ||
| print("Running in Distributed Data Parallel (DDP) mode") | ||
| rank = int(os.environ["RANK"]) | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| assert torch.cuda.is_available() | ||
| # TorchDistributedEnvironment sets the visible devices to the | ||
| # current rank, so we can use the default device. | ||
| device = torch.device("cuda", 0) | ||
| torch.cuda.set_device(device) | ||
| t_dist.init_process_group(backend="nccl", device_id=device) | ||
| dist = t_dist | ||
| else: | ||
| rank = 0 | ||
| world_size = 1 | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| dist = None | ||
|
|
||
| return dist, rank, world_size, device |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| from benchopt import BaseSolver | ||
|
|
||
| from contextlib import nullcontext | ||
|
|
||
| from tqdm.auto import tqdm | ||
|
|
||
| import torch | ||
|
|
||
| from benchmark_utils.soap import SOAP | ||
| from benchmark_utils.lr_scheduler import get_lr | ||
| from benchmark_utils.running_setup import get_running_setup | ||
|
|
||
|
|
||
| class Solver(BaseSolver): | ||
| name = "SOAP" | ||
|
|
||
| parameters = { | ||
| "learning_rate": [1e-4], | ||
| "weight_decay": [5e-3], | ||
| "num_steps": [6200], | ||
| "batch_size": [64], | ||
| "precondition_frequency": [10], | ||
| "max_precond_dim": [10000], | ||
| "merge_dims": [False], | ||
| "precondition_1d": [False], | ||
| "normalize_grads": [False], | ||
| "correct_bias": [True], | ||
| "slurm_nodes": [1, 2], | ||
tomMoral marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| slurm_params = { | ||
| "slurm_gres": "gpu:4", | ||
| "slurm_ntasks_per_node": 4, | ||
| } | ||
|
|
||
| sampling_strategy = "callback" | ||
|
|
||
| def set_objective(self, train_dataloader, model): | ||
|
|
||
| # Setup distributed training if needed | ||
| self.dist, self.rank, self.world_size, device = get_running_setup() | ||
|
|
||
| model = model.to(device=device) | ||
| model.device = device | ||
| self.train_dataloader = train_dataloader | ||
|
|
||
| self.ctx = ( | ||
| torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) | ||
| if torch.cuda.is_available() | ||
| else nullcontext() | ||
| ) | ||
|
|
||
| self.model = torch.compile(model, dynamic=False, fullgraph=True) | ||
| SOAP.step = torch.compile(torch.no_grad(SOAP.step)) | ||
|
|
||
| def __del__(self): | ||
| if getattr(self, "dist", None) is not None: | ||
| self.dist.destroy_process_group() | ||
|
|
||
| def get_next(self, stop_val): | ||
| return stop_val + 250 | ||
|
|
||
| def warm_up(self): | ||
| self.run_once(stop_val=10) | ||
|
|
||
| def run(self, cb): | ||
| param_dict = { | ||
| pn: p for pn, p in self.model.named_parameters() if p.requires_grad | ||
| } | ||
| decay_params = [p for _, p in param_dict.items() if p.dim() >= 2] | ||
| nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2] | ||
| optim_groups = [ | ||
| {"params": decay_params, "weight_decay": self.weight_decay}, | ||
| {"params": nodecay_params, "weight_decay": 0.0}, | ||
| ] | ||
|
|
||
| self.optimizer = SOAP( | ||
| optim_groups, | ||
| lr=torch.tensor(self.learning_rate), | ||
| betas=(0.95, 0.95), | ||
| precondition_frequency=self.precondition_frequency, | ||
| max_precond_dim=self.max_precond_dim, | ||
| merge_dims=self.merge_dims, | ||
| precondition_1d=self.precondition_1d, | ||
| normalize_grads=self.normalize_grads, | ||
| correct_bias=self.correct_bias, | ||
| ) | ||
|
|
||
| train_loader = self.train_dataloader.get_distributed_data_generator( | ||
| batch_size=self.batch_size, | ||
| world_size=self.world_size, | ||
| rank=self.rank, | ||
| ) | ||
|
|
||
| if self.dist is not None: | ||
| self.dist.barrier() | ||
|
|
||
| step = 0 | ||
| with tqdm(total=self.num_steps, desc="Training") as progress: | ||
| while cb(): | ||
| self.model.train() | ||
| self.optimizer.zero_grad(set_to_none=True) | ||
|
|
||
| step += 1 | ||
| progress.update() | ||
| if step == self.num_steps: | ||
| break | ||
|
|
||
| data = next(train_loader) | ||
| with self.ctx: | ||
| loss, *_ = self.model(*data) | ||
| loss.backward() | ||
| if self.dist is not None: | ||
| for param in self.model.parameters(): | ||
| self.dist.all_reduce( | ||
| param.grad, op=self.dist.ReduceOp.AVG | ||
| ) | ||
|
|
||
| scale_lr = get_lr(step, self.num_steps) | ||
| for param_group in self.optimizer.param_groups: | ||
| param_group["lr"] = torch.tensor( | ||
| self.learning_rate * scale_lr | ||
| ) | ||
|
|
||
| self.optimizer.step() | ||
|
|
||
| def get_result(self): | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| return dict(model=self.model, dist=self.dist) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.