From 14ce462b58d424f464d9950398288832e1b3d236 Mon Sep 17 00:00:00 2001 From: Saurabh750 Date: Mon, 2 Jun 2025 22:47:19 +0530 Subject: [PATCH 1/6] muon commit-1 --- torchtune/modules/muon.py | 238 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 torchtune/modules/muon.py diff --git a/torchtune/modules/muon.py b/torchtune/modules/muon.py new file mode 100644 index 0000000000..23f845f6e0 --- /dev/null +++ b/torchtune/modules/muon.py @@ -0,0 +1,238 @@ +import torch +import torch.distributed as dist + + +def zeropower_via_newtonschulz5(G, steps: int): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): + momentum.lerp_(grad, 1 - beta) + update = grad.lerp_(momentum, beta) if nesterov else momentum + if update.ndim == 4: # for the case of conv filters + update = update.view(len(update), -1) + update = zeropower_via_newtonschulz5(update, steps=ns_steps) + update *= max(1, grad.size(-2) / grad.size(-1))**0.5 + return update + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the + advantage that it can be stably run in bfloat16 on the GPU. + + Muon should only be used for hidden weight layers. The input embedding, final output layer, + and any internal gains or biases should be optimized using a standard method such as AdamW. + Hidden convolutional weights can be trained using Muon by viewing them as 2D and then + collapsing their last 3 dimensions. + + Arguments: + lr: The learning rate, in units of spectral norm per update. + weight_decay: The AdamW-style weight decay. + momentum: The momentum. A value of 0.95 here is usually fine. + """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) + params = sorted(params, key=lambda x: x.size(), reverse=True) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * (len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + + +class SingleDeviceMuon(torch.optim.Optimizer): + """ + Muon variant for usage in non-distributed settings. + """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + +def adam_update(grad, buf1, buf2, step, betas, eps): + buf1.lerp_(grad, 1 - betas[0]) + buf2.lerp_(grad.square(), 1 - betas[1]) + buf1c = buf1 / (1 - betas[0]**step) + buf2c = buf2 / (1 - betas[1]**step) + return buf1c / (buf2c.sqrt() + eps) + + +class MuonWithAuxAdam(torch.optim.Optimizer): + """ + Distributed Muon variant that can be used for all parameters in the network, since it runs an + internal AdamW for the parameters that are not compatible with Muon. The user must manually + specify which parameters shall be optimized with Muon and which with Adam by passing in a + list of param_groups with the `use_muon` flag set. + + The point of this class is to allow the user to have a single Opimizer in their code, rather + than having both a Muon and an Adam which each need to be stepped. + + You can see an example usage below: + + https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470 + ``` + hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] + embed_params = [p for n, p in model.named_parameters() if "embed" in n] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + head_params = [model.lm_head.weight] + + from muon import MuonWithAuxAdam + adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] + adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups] + muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True) + param_groups = [*adam_groups, muon_group] + optimizer = MuonWithAuxAdam(param_groups) + ``` + """ + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + if group["use_muon"]: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * (len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + else: + beta1, beta2 = group["betas"] + for p in group["params"]: + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + + +class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): + """ + Non-distributed variant of MuonWithAuxAdam. + """ + def __init__(self, param_groups): + for group in param_groups: + assert "use_muon" in group + if group["use_muon"]: + # defaults + group["lr"] = group.get("lr", 0.02) + group["momentum"] = group.get("momentum", 0.95) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) + else: + # defaults + group["lr"] = group.get("lr", 3e-4) + group["betas"] = group.get("betas", (0.9, 0.95)) + group["eps"] = group.get("eps", 1e-10) + group["weight_decay"] = group.get("weight_decay", 0) + assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) + super().__init__(param_groups, dict()) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + if group["use_muon"]: + for p in group["params"]: + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + else: + beta1, beta2 = group["betas"] + for p in group["params"]: + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) \ No newline at end of file From a706a2fe42e7517862a6e2cc27aa81a368187f4f Mon Sep 17 00:00:00 2001 From: Saurabh750 Date: Sun, 8 Jun 2025 21:38:38 +0530 Subject: [PATCH 2/6] Testing muon on 0.5B qwen --- .../qwen2/0.5B_full_single_device_muon.yaml | 114 ++++++++++++++++++ recipes/full_finetune_single_device.py | 73 ++++++++++- torchtune/modules/__init__.py | 4 + torchtune/modules/muon.py | 15 +++ 4 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 recipes/configs/qwen2/0.5B_full_single_device_muon.yaml diff --git a/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml b/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml new file mode 100644 index 0000000000..843b1a7ed0 --- /dev/null +++ b/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml @@ -0,0 +1,114 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2 0.5B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/torchtune/qwen2_0_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed +seed: null +shuffle: False #True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_0_5b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 3 +epochs: 1 +muon: + enabled: True + _component_: torchtune.modules.SingleDeviceMuon + momentum: 0.95 + lr: 5e-4 #0.02 + weight_decay: 0 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 + +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 16aa0dbb0e..b3c52bf631 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -272,8 +272,8 @@ def setup(self, cfg: DictConfig) -> None: # _setup_optimizer should take in ckpt_dict only if training is resumed from # checkpoint. Transforming the opt state dict is handled by this method - self.optimizer = self._setup_optimizer( - cfg_optimizer=cfg.optimizer, + self.optimizer, self.muon = self._setup_optimizer( + cfg=cfg, opt_state_dict=( ckpt_dict[training.OPT_KEY] if training.OPT_KEY in ckpt_dict else None ), @@ -424,7 +424,8 @@ def _setup_model( return model - def _setup_optimizer( +# TODO: Remove this function + def _setup_optimizer_delete( self, cfg_optimizer: DictConfig, opt_state_dict: Optional[dict[str, Any]] = None, @@ -444,6 +445,66 @@ def _setup_optimizer( optimizer.load_state_dict(opt_state_dict) self._logger.info("Optimizer is initialized.") return optimizer + + def _setup_optimizer(self, cfg, opt_state_dict: Optional[dict[str, Any]] = None,) -> Optimizer: + cfg_optimizer = cfg.optimizer + muon_enabled = cfg.muon.pop("enabled") + cfg_muon = cfg.muon + + if muon_enabled: + if self.optimizer_in_bwd: + # TODO: Modify optimizer_in_bwd for muon + pass + else: + muon_params = [] + non_muon_params = [] + + for name, module in self._model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + full_name = f"{name}.{param_name}" if name else param_name + + if not param.requires_grad: + continue + + # Skip if embedding + if isinstance(module, nn.Embedding) or "embed" in full_name.lower(): + non_muon_params.append(param) + # Skip if scalar (ndim < 2) + elif param.ndim < 2: + non_muon_params.append(param) + # Skip known head layers + elif "lm_head" in full_name.lower(): + non_muon_params.append(param) + else: + muon_params.append(param) + + optimizer = config.instantiate( + cfg_optimizer, params=non_muon_params + ) + muon = config.instantiate( + cfg_muon, params=muon_params + ) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + self._logger.info("Optimizer is initialized.") + return optimizer, muon + + else: + if self.optimizer_in_bwd: + optimizer_cls = _get_component_from_path(cfg_optimizer.pop("_component_")) + optimizer = OptimizerInBackward( + params=self._model.parameters(), + optimizer_cls=optimizer_cls, + **cfg_optimizer, + ) + else: + optimizer = config.instantiate( + cfg_optimizer, params=self._model.parameters() + ) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + self._logger.info("Optimizer is initialized.") + return optimizer, None def _setup_lr_scheduler( self, @@ -555,6 +616,8 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: def train(self) -> None: self.optimizer.zero_grad() + if self.muon: + self.muon.zero_grad() t0 = time.perf_counter() running_loss, num_tokens = 0.0, 0 self._profiler.start() @@ -599,11 +662,15 @@ def train(self) -> None: # This will be a no-op for optim in bwd, but prevents a warning w/ LR Scheduler self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) + if self.muon: + self.muon.step() + self.muon.zero_grad(set_to_none=True) if self.lr_scheduler is not None: self.lr_scheduler.step() self.global_step += 1 + print(f"running_loss: {running_loss} ; num_tokens: {num_tokens}") loss_value = ( running_loss / (num_tokens if not self.optimizer_in_bwd else 1.0) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 2e25d424a1..97a149ea6f 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -36,6 +36,8 @@ from .vq_embeddings import VectorQuantizedEmbeddings from .embedding_utils import resize_token_embeddings # usort: skip +from .muon import Muon, SingleDeviceMuon + __all__ = [ "MultiHeadAttention", "TanhGate", @@ -63,4 +65,6 @@ "classifier_model", "rms_norm", "resize_token_embeddings", + "Muon", + "SingleDeviceMuon" ] diff --git a/torchtune/modules/muon.py b/torchtune/modules/muon.py index 23f845f6e0..1417bbc0d2 100644 --- a/torchtune/modules/muon.py +++ b/torchtune/modules/muon.py @@ -1,3 +1,18 @@ +###################################################### +# +# This code is referred from https://github.com/KellerJordan/Muon repo. +# @misc{jordan2024muon, +# author = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and +# Franz Cesista and Laker Newhouse and Jeremy Bernstein}, +# title = {Muon: An optimizer for hidden layers in neural networks}, +# year = {2024}, +# url = {https://kellerjordan.github.io/posts/muon/} +# } +# Changes have been made wherever necessary. +# +###################################################### + + import torch import torch.distributed as dist From 72f86e738360a39b542895852f3de299640939ea Mon Sep 17 00:00:00 2001 From: Saurabh750 Date: Tue, 17 Jun 2025 22:55:11 +0530 Subject: [PATCH 3/6] Muon shifted to optim, muon checks removed as possible --- recipes/full_finetune_single_device.py | 76 +------- torchtune/modules/__init__.py | 4 - torchtune/modules/muon.py | 253 ------------------------- torchtune/modules/optim.py | 208 ++++++++++++++++++++ torchtune/training/lr_schedulers.py | 9 + 5 files changed, 223 insertions(+), 327 deletions(-) delete mode 100644 torchtune/modules/muon.py diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index b3c52bf631..680e280400 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -272,8 +272,8 @@ def setup(self, cfg: DictConfig) -> None: # _setup_optimizer should take in ckpt_dict only if training is resumed from # checkpoint. Transforming the opt state dict is handled by this method - self.optimizer, self.muon = self._setup_optimizer( - cfg=cfg, + self.optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, opt_state_dict=( ckpt_dict[training.OPT_KEY] if training.OPT_KEY in ckpt_dict else None ), @@ -424,8 +424,7 @@ def _setup_model( return model -# TODO: Remove this function - def _setup_optimizer_delete( + def _setup_optimizer( self, cfg_optimizer: DictConfig, opt_state_dict: Optional[dict[str, Any]] = None, @@ -438,73 +437,15 @@ def _setup_optimizer_delete( **cfg_optimizer, ) else: + optimizer_cls = cfg_optimizer["_component_"] + params = self._model.named_parameters() if 'muon' in optimizer_cls.lower() else self._model.parameters() optimizer = config.instantiate( - cfg_optimizer, params=self._model.parameters() + cfg_optimizer, params=params ) if opt_state_dict: optimizer.load_state_dict(opt_state_dict) self._logger.info("Optimizer is initialized.") return optimizer - - def _setup_optimizer(self, cfg, opt_state_dict: Optional[dict[str, Any]] = None,) -> Optimizer: - cfg_optimizer = cfg.optimizer - muon_enabled = cfg.muon.pop("enabled") - cfg_muon = cfg.muon - - if muon_enabled: - if self.optimizer_in_bwd: - # TODO: Modify optimizer_in_bwd for muon - pass - else: - muon_params = [] - non_muon_params = [] - - for name, module in self._model.named_modules(): - for param_name, param in module.named_parameters(recurse=False): - full_name = f"{name}.{param_name}" if name else param_name - - if not param.requires_grad: - continue - - # Skip if embedding - if isinstance(module, nn.Embedding) or "embed" in full_name.lower(): - non_muon_params.append(param) - # Skip if scalar (ndim < 2) - elif param.ndim < 2: - non_muon_params.append(param) - # Skip known head layers - elif "lm_head" in full_name.lower(): - non_muon_params.append(param) - else: - muon_params.append(param) - - optimizer = config.instantiate( - cfg_optimizer, params=non_muon_params - ) - muon = config.instantiate( - cfg_muon, params=muon_params - ) - if opt_state_dict: - optimizer.load_state_dict(opt_state_dict) - self._logger.info("Optimizer is initialized.") - return optimizer, muon - - else: - if self.optimizer_in_bwd: - optimizer_cls = _get_component_from_path(cfg_optimizer.pop("_component_")) - optimizer = OptimizerInBackward( - params=self._model.parameters(), - optimizer_cls=optimizer_cls, - **cfg_optimizer, - ) - else: - optimizer = config.instantiate( - cfg_optimizer, params=self._model.parameters() - ) - if opt_state_dict: - optimizer.load_state_dict(opt_state_dict) - self._logger.info("Optimizer is initialized.") - return optimizer, None def _setup_lr_scheduler( self, @@ -616,8 +557,6 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: def train(self) -> None: self.optimizer.zero_grad() - if self.muon: - self.muon.zero_grad() t0 = time.perf_counter() running_loss, num_tokens = 0.0, 0 self._profiler.start() @@ -662,9 +601,6 @@ def train(self) -> None: # This will be a no-op for optim in bwd, but prevents a warning w/ LR Scheduler self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) - if self.muon: - self.muon.step() - self.muon.zero_grad(set_to_none=True) if self.lr_scheduler is not None: self.lr_scheduler.step() diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 97a149ea6f..2e25d424a1 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -36,8 +36,6 @@ from .vq_embeddings import VectorQuantizedEmbeddings from .embedding_utils import resize_token_embeddings # usort: skip -from .muon import Muon, SingleDeviceMuon - __all__ = [ "MultiHeadAttention", "TanhGate", @@ -65,6 +63,4 @@ "classifier_model", "rms_norm", "resize_token_embeddings", - "Muon", - "SingleDeviceMuon" ] diff --git a/torchtune/modules/muon.py b/torchtune/modules/muon.py deleted file mode 100644 index 1417bbc0d2..0000000000 --- a/torchtune/modules/muon.py +++ /dev/null @@ -1,253 +0,0 @@ -###################################################### -# -# This code is referred from https://github.com/KellerJordan/Muon repo. -# @misc{jordan2024muon, -# author = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and -# Franz Cesista and Laker Newhouse and Jeremy Bernstein}, -# title = {Muon: An optimizer for hidden layers in neural networks}, -# year = {2024}, -# url = {https://kellerjordan.github.io/posts/muon/} -# } -# Changes have been made wherever necessary. -# -###################################################### - - -import torch -import torch.distributed as dist - - -def zeropower_via_newtonschulz5(G, steps: int): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - if G.size(-2) > G.size(-1): - X = X.mT - return X - - -def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): - momentum.lerp_(grad, 1 - beta) - update = grad.lerp_(momentum, beta) if nesterov else momentum - if update.ndim == 4: # for the case of conv filters - update = update.view(len(update), -1) - update = zeropower_via_newtonschulz5(update, steps=ns_steps) - update *= max(1, grad.size(-2) / grad.size(-1))**0.5 - return update - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the - advantage that it can be stably run in bfloat16 on the GPU. - - Muon should only be used for hidden weight layers. The input embedding, final output layer, - and any internal gains or biases should be optimized using a standard method such as AdamW. - Hidden convolutional weights can be trained using Muon by viewing them as 2D and then - collapsing their last 3 dimensions. - - Arguments: - lr: The learning rate, in units of spectral norm per update. - weight_decay: The AdamW-style weight decay. - momentum: The momentum. A value of 0.95 here is usually fine. - """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) - params = sorted(params, key=lambda x: x.size(), reverse=True) - super().__init__(params, defaults) - - @torch.no_grad() - def step(self): - for group in self.param_groups: - params = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * (len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: - if base_i + dist.get_rank() < len(params): - p = params[base_i + dist.get_rank()] - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) - - -class SingleDeviceMuon(torch.optim.Optimizer): - """ - Muon variant for usage in non-distributed settings. - """ - def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - super().__init__(params, defaults) - - @torch.no_grad() - def step(self): - for group in self.param_groups: - for p in group["params"]: - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - - -def adam_update(grad, buf1, buf2, step, betas, eps): - buf1.lerp_(grad, 1 - betas[0]) - buf2.lerp_(grad.square(), 1 - betas[1]) - buf1c = buf1 / (1 - betas[0]**step) - buf2c = buf2 / (1 - betas[1]**step) - return buf1c / (buf2c.sqrt() + eps) - - -class MuonWithAuxAdam(torch.optim.Optimizer): - """ - Distributed Muon variant that can be used for all parameters in the network, since it runs an - internal AdamW for the parameters that are not compatible with Muon. The user must manually - specify which parameters shall be optimized with Muon and which with Adam by passing in a - list of param_groups with the `use_muon` flag set. - - The point of this class is to allow the user to have a single Opimizer in their code, rather - than having both a Muon and an Adam which each need to be stepped. - - You can see an example usage below: - - https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470 - ``` - hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] - embed_params = [p for n, p in model.named_parameters() if "embed" in n] - scalar_params = [p for p in model.parameters() if p.ndim < 2] - head_params = [model.lm_head.weight] - - from muon import MuonWithAuxAdam - adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)] - adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups] - muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True) - param_groups = [*adam_groups, muon_group] - optimizer = MuonWithAuxAdam(param_groups) - ``` - """ - def __init__(self, param_groups): - for group in param_groups: - assert "use_muon" in group - if group["use_muon"]: - group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True) - # defaults - group["lr"] = group.get("lr", 0.02) - group["momentum"] = group.get("momentum", 0.95) - group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) - else: - # defaults - group["lr"] = group.get("lr", 3e-4) - group["betas"] = group.get("betas", (0.9, 0.95)) - group["eps"] = group.get("eps", 1e-10) - group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) - super().__init__(param_groups, dict()) - - @torch.no_grad() - def step(self): - for group in self.param_groups: - if group["use_muon"]: - params = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * (len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: - if base_i + dist.get_rank() < len(params): - p = params[base_i + dist.get_rank()] - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) - else: - beta1, beta2 = group["betas"] - for p in group["params"]: - state = self.state[p] - if len(state) == 0: - state["exp_avg"] = torch.zeros_like(p) - state["exp_avg_sq"] = torch.zeros_like(p) - state["step"] = 0 - state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - - -class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): - """ - Non-distributed variant of MuonWithAuxAdam. - """ - def __init__(self, param_groups): - for group in param_groups: - assert "use_muon" in group - if group["use_muon"]: - # defaults - group["lr"] = group.get("lr", 0.02) - group["momentum"] = group.get("momentum", 0.95) - group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"]) - else: - # defaults - group["lr"] = group.get("lr", 3e-4) - group["betas"] = group.get("betas", (0.9, 0.95)) - group["eps"] = group.get("eps", 1e-10) - group["weight_decay"] = group.get("weight_decay", 0) - assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"]) - super().__init__(param_groups, dict()) - - @torch.no_grad() - def step(self): - for group in self.param_groups: - if group["use_muon"]: - for p in group["params"]: - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - else: - beta1, beta2 = group["betas"] - for p in group["params"]: - state = self.state[p] - if len(state) == 0: - state["exp_avg"] = torch.zeros_like(p) - state["exp_avg_sq"] = torch.zeros_like(p) - state["step"] = 0 - state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) \ No newline at end of file diff --git a/torchtune/modules/optim.py b/torchtune/modules/optim.py index 4e7d53d45c..231e2b610e 100644 --- a/torchtune/modules/optim.py +++ b/torchtune/modules/optim.py @@ -8,6 +8,7 @@ import torch from torch.optim import Optimizer +import torch.distributed as dist __all__ = ["OptimizerInBackward"] @@ -82,3 +83,210 @@ def load_state_dict(self, state_dict): ) for idx, opt in self._optimizers.items(): opt.load_state_dict(state_dict["optimizers"][str(idx)]) + +###################################################### +# +# This code is referred from https://github.com/KellerJordan/Muon repo. +# @misc{jordan2024muon, +# author = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and +# Franz Cesista and Laker Newhouse and Jeremy Bernstein}, +# title = {Muon: An optimizer for hidden layers in neural networks}, +# year = {2024}, +# url = {https://kellerjordan.github.io/posts/muon/} +# } +# Changes have been made wherever necessary. +# +###################################################### + +def zeropower_via_newtonschulz5(G, steps: int): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): + momentum.lerp_(grad, 1 - beta) + update = grad.lerp_(momentum, beta) if nesterov else momentum + if update.ndim == 4: # for the case of conv filters + update = update.view(len(update), -1) + update = zeropower_via_newtonschulz5(update, steps=ns_steps) + update *= max(1, grad.size(-2) / grad.size(-1))**0.5 + return update + +def adam_update(grad, buf1, buf2, step, betas, eps): + buf1.lerp_(grad, 1 - betas[0]) + buf2.lerp_(grad.square(), 1 - betas[1]) + buf1c = buf1 / (1 - betas[0]**step) + buf2c = buf2 / (1 - betas[1]**step) + return buf1c / (buf2c.sqrt() + eps) + +class SingleDeviceMuonWithAuxAdam(Optimizer): + def __init__( + self, + params, # Pass model.named_parameters() + *, + muon_selector=None, + muon_lr: float = 0.02, + muon_momentum: float = 0.95, + adam_lr: float = 3e-4, + adam_betas=(0.9, 0.95), + adam_eps: float = 1e-10, + weight_decay: float = 0.0, + ): + if muon_selector is None: + muon_selector = lambda name, param: ( + param.requires_grad and + param.ndim >= 2 and # Check if scalar + "embed" not in name.lower() and # Check if embedding layer + "tok" not in name.lower() and # Check if token embeddings + "head" not in name.lower() and # Check if output head + "bias" not in name.lower() # Check if bias term + ) + + named_params = list(params) + + muon_params = [p for n, p in named_params if muon_selector(n, p)] + adam_params = [p for n, p in named_params if not muon_selector(n, p)] + + muon_params.sort(key=lambda p: p.size(), reverse=True) + + super().__init__( + [ + dict(params=muon_params, + lr=muon_lr, + momentum=muon_momentum, + weight_decay=weight_decay, + use_muon=True), + dict(params=adam_params, + lr=adam_lr, + betas=adam_betas, + eps=adam_eps, + weight_decay=weight_decay, + use_muon=False), + ], + defaults={} + ) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + assert "use_muon" in group + if group["use_muon"]: + for p in group["params"]: + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + else: + for p in group["params"]: + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + +class MuonWithAuxAdam(torch.optim.Optimizer): + def __init__( + self, + params, # Pass model.named_parameters() + *, + muon_selector=None, + muon_lr: float = 0.02, + muon_momentum: float = 0.95, + adam_lr: float = 3e-4, + adam_betas=(0.9, 0.95), + adam_eps: float = 1e-10, + weight_decay: float = 0.0, + ): + if muon_selector is None: + muon_selector = lambda name, param: ( + param.requires_grad and + param.ndim >= 2 and # Check if scalar + "embed" not in name.lower() and # Check if embedding layer + "tok" not in name.lower() and # Check if token embeddings + "head" not in name.lower() and # Check if output head + "bias" not in name.lower() # Check if bias term + ) + + named_params = list(params) + + muon_params = [p for n, p in named_params if muon_selector(n, p)] + adam_params = [p for n, p in named_params if not muon_selector(n, p)] + + muon_params.sort(key=lambda p: p.size(), reverse=True) + + super().__init__( + [ + dict(params=muon_params, + lr=muon_lr, + momentum=muon_momentum, + weight_decay=weight_decay, + use_muon=True), + dict(params=adam_params, + lr=adam_lr, + betas=adam_betas, + eps=adam_eps, + weight_decay=weight_decay, + use_muon=False), + ], + defaults={} + ) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + if group["use_muon"]: + params = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * (len(params) % dist.get_world_size()) + for base_i in range(len(params))[::dist.get_world_size()]: + if base_i + dist.get_rank() < len(params): + p = params[base_i + dist.get_rank()] + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) + dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + else: + for p in group["params"]: + state = self.state[p] + if len(state) == 0: + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + state["step"] = 0 + state["step"] += 1 + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], + state["step"], group["betas"], group["eps"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) + p.add_(update, alpha=-group["lr"]) \ No newline at end of file diff --git a/torchtune/training/lr_schedulers.py b/torchtune/training/lr_schedulers.py index 6f431c9f37..f7a2e024a8 100644 --- a/torchtune/training/lr_schedulers.py +++ b/torchtune/training/lr_schedulers.py @@ -10,6 +10,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from torchtune.training.memory import OptimizerInBackwardWrapper +from torchtune.modules.optim import SingleDeviceMuonWithAuxAdam def get_cosine_schedule_with_warmup( @@ -88,6 +89,14 @@ def get_lr( ) # LR Schedulers are the same across all param groups for full_finetune right now + + if isinstance(optimizer, SingleDeviceMuonWithAuxAdam): + for group in param_groups: + lr = group["lr"] + if group['use_muon']: # Returning only Muon learning rate + return lr + return lr + lr = param_groups[0]["lr"] for group in param_groups: if group["lr"] != lr: From 550a9b39bdcb30a11278f284edfa938adeedc087 Mon Sep 17 00:00:00 2001 From: Saurabh750 Date: Tue, 17 Jun 2025 22:59:43 +0530 Subject: [PATCH 4/6] correct config for muon optimizer --- .../qwen2/0.5B_full_single_device_muon.yaml | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml b/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml index 843b1a7ed0..c16388fb99 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml @@ -49,16 +49,14 @@ resume_from_checkpoint: False # Fine-tuning arguments batch_size: 3 epochs: 1 -muon: - enabled: True - _component_: torchtune.modules.SingleDeviceMuon - momentum: 0.95 - lr: 5e-4 #0.02 - weight_decay: 0 optimizer: - _component_: torch.optim.AdamW - fused: True - lr: 2e-5 + _component_: torchtune.modules.optim.SingleDeviceMuonWithAuxAdam + muon_lr: 0.02 + muon_momentum: 0.95 + weight_decay: 0 + adam_lr: 2e-5 + adam_betas: [0.9, 0.95] + adam_eps: 1e-10 loss: _component_: torchtune.modules.loss.LinearCrossEntropyLoss From 942a9a2895e1d31c36a6b68b22a9e4b5815868ba Mon Sep 17 00:00:00 2001 From: Saurabh750 Date: Thu, 19 Jun 2025 22:12:07 +0530 Subject: [PATCH 5/6] Testing with different implementation of Muon --- torchtune/modules/muon.py | 246 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 torchtune/modules/muon.py diff --git a/torchtune/modules/muon.py b/torchtune/modules/muon.py new file mode 100644 index 0000000000..35613030be --- /dev/null +++ b/torchtune/modules/muon.py @@ -0,0 +1,246 @@ +import os +import torch +import torch.distributed as dist +from torch.distributed.tensor import DTensor +from torch.utils._pytree import tree_map, tree_flatten +from typing import Generator +# from utils import to_local, to_dist + +import functools +import gc +import math +import random +import string +from typing import List, Optional, Tuple, Callable, Union + +import numpy as np +import torch +from torch.backends import cudnn, opt_einsum +from torch.utils._pytree import tree_map + +from torch.distributed.tensor import distribute_tensor, DTensor + +def to_dist(x, from_local=False, **meta): + if from_local: + return DTensor.from_local( + x, + device_mesh=meta["device_mesh"], + placements=meta["placements"], + shape=meta["shape"], + stride=meta["stride"], + ) + else: + return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"]) + + +def to_local(x, keep_sharded=False): + if isinstance(x, DTensor): + meta = dict( + device_mesh=x.device_mesh, + placements=x.placements, + shape=x.shape, + stride=x.stride(), + ) + if keep_sharded: + return x.to_local(), meta + else: + return x.full_tensor(), meta + + return x, None + + +def local_op(x, fn, keep_sharded=False): + """ + converts to Tensor, does a thing, then back to Dtensor + """ + x, meta = to_local(x, keep_sharded) + x = fn(x) + if meta is not None: + x = to_dist(x, from_local=keep_sharded, **meta) + return x + +# @torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= (X.norm() + eps) # ensure top singular value <= 1 + if G.size(0) > G.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(0) > G.size(1): + X = X.T + return X + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + # def __init__(self, muon_params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=6, + # adamw_params=None, adamw_lr=3e-4, adamw_betas=(0.95, 0.95), adamw_eps=1e-8, adamw_wd=0): + def __init__(self, params, muon_selector=None, lr=0.02, momentum=0.95, nesterov=True, ns_steps=6, + adamw_lr=3e-4, adamw_betas=[0.95, 0.95], adamw_eps=1e-8, adamw_wd=0): + + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, + adamw_lr_ratio=adamw_lr/lr, adamw_betas=adamw_betas, + adamw_eps=adamw_eps, adamw_wd=adamw_wd) + + if muon_selector is None: + muon_selector = lambda name, param: ( + param.requires_grad and + param.ndim >= 2 and # Check if scalar + "embed" not in name.lower() and # Check if embedding layer + "tok" not in name.lower() and # Check if token embeddings + "head" not in name.lower() and # Check if output head + "bias" not in name.lower() # Check if bias term + ) + + # handle list of params or list of dicts + # if isinstance(muon_params, Generator): + # muon_params = list(muon_params) + # if isinstance(adamw_params, Generator): + # adamw_params = list(adamw_params) + # elif adamw_params is None: + # adamw_params = [] + + named_params = list(params) + + muon_params = [p for n, p in named_params if muon_selector(n, p)] + adamw_params = [p for n, p in named_params if not muon_selector(n, p)] + + super().__init__([*muon_params, *adamw_params], defaults) + + # Sort parameters into those for which we will use Muon, and those for which we will not + # we cant pickle booleans for saving, so we will use 1=True, 0=False + def assign_muon(p): + if p.ndim >= 2 and p.size(0) < 10000: + self.state[p]['use_muon'] = 1 + else: + self.state[p]['use_muon'] = 0 + + if isinstance(muon_params[0], dict): + for group in muon_params: + for p in group['params']: + assign_muon(p) + else: + for p in muon_params: + assign_muon(p) + + def assign_adamw(p): + # Do not use Muon for parameters in adamw_params + self.state[p]['use_muon'] = 0 + + if len(adamw_params) and isinstance(adamw_params[0], dict): + for group in adamw_params: + for p in group['params']: + assign_adamw(p) + else: + for p in adamw_params: + assign_adamw(p) + + if torch.distributed.is_initialized(): + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + else: + self.world_size = 1 + self.rank = 0 + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + momentum = group['momentum'] + for i, p in enumerate(group['params']): + if self.state[p]['use_muon'] == 1: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(momentum).add_(g) + if group['nesterov']: + g = g.add(buf, alpha=momentum) + + meta = None + if isinstance(g, DTensor): + g, meta = to_local(g, keep_sharded=False) + # gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky + g = zeropower_via_newtonschulz5(g, steps=group['ns_steps']) + if meta is not None: + g = to_dist(g, **meta) + g *= max(1, g.size(0)/g.size(1))**0.5 + + g = g.view_as(p.data).type_as(p.data) + p.data.add_(g, alpha=-lr) + else: + # these are all pointwise so we can stay in Dtensor + g = p.grad + if g is None: + continue + state = self.state[p] + if 'step' not in state: + state['step'] = 0 + state['moment1'] = torch.zeros_like(g) + state['moment2'] = torch.zeros_like(g) + state['step'] += 1 + step = state['step'] + buf1 = state['moment1'] + buf2 = state['moment2'] + buf1.lerp_(g, 1-group['adamw_betas'][0]) + buf2.lerp_(g.square(), 1-group['adamw_betas'][1]) + + g = buf1 / (group['adamw_eps'] + buf2.sqrt()) + + bias_correction1 = 1 - group['adamw_betas'][0]**step + bias_correction2 = 1 - group['adamw_betas'][1]**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * group['adamw_wd']) + p.data.add_(g, alpha=-lr/scale) \ No newline at end of file From cf2193e139834a197b886cfb3cd453c33b293e7b Mon Sep 17 00:00:00 2001 From: Saurabh750 Date: Mon, 23 Jun 2025 23:22:27 +0530 Subject: [PATCH 6/6] Stable version of Muon --- torchtune/modules/muon.py | 246 ------------------- torchtune/modules/optim.py | 361 ++++++++++++++-------------- torchtune/training/lr_schedulers.py | 10 +- 3 files changed, 180 insertions(+), 437 deletions(-) delete mode 100644 torchtune/modules/muon.py diff --git a/torchtune/modules/muon.py b/torchtune/modules/muon.py deleted file mode 100644 index 35613030be..0000000000 --- a/torchtune/modules/muon.py +++ /dev/null @@ -1,246 +0,0 @@ -import os -import torch -import torch.distributed as dist -from torch.distributed.tensor import DTensor -from torch.utils._pytree import tree_map, tree_flatten -from typing import Generator -# from utils import to_local, to_dist - -import functools -import gc -import math -import random -import string -from typing import List, Optional, Tuple, Callable, Union - -import numpy as np -import torch -from torch.backends import cudnn, opt_einsum -from torch.utils._pytree import tree_map - -from torch.distributed.tensor import distribute_tensor, DTensor - -def to_dist(x, from_local=False, **meta): - if from_local: - return DTensor.from_local( - x, - device_mesh=meta["device_mesh"], - placements=meta["placements"], - shape=meta["shape"], - stride=meta["stride"], - ) - else: - return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"]) - - -def to_local(x, keep_sharded=False): - if isinstance(x, DTensor): - meta = dict( - device_mesh=x.device_mesh, - placements=x.placements, - shape=x.shape, - stride=x.stride(), - ) - if keep_sharded: - return x.to_local(), meta - else: - return x.full_tensor(), meta - - return x, None - - -def local_op(x, fn, keep_sharded=False): - """ - converts to Tensor, does a thing, then back to Dtensor - """ - x, meta = to_local(x, keep_sharded) - x = fn(x) - if meta is not None: - x = to_dist(x, from_local=keep_sharded, **meta) - return x - -# @torch.compile -def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - X /= (X.norm() + eps) # ensure top singular value <= 1 - if G.size(0) > G.size(1): - X = X.T - for _ in range(steps): - A = X @ X.T - B = b * A + c * A @ A - X = a * X + B @ X - if G.size(0) > G.size(1): - X = X.T - return X - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - muon_params: The parameters to be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - adamw_wd: The weight decay for the internal AdamW. - """ - # def __init__(self, muon_params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=6, - # adamw_params=None, adamw_lr=3e-4, adamw_betas=(0.95, 0.95), adamw_eps=1e-8, adamw_wd=0): - def __init__(self, params, muon_selector=None, lr=0.02, momentum=0.95, nesterov=True, ns_steps=6, - adamw_lr=3e-4, adamw_betas=[0.95, 0.95], adamw_eps=1e-8, adamw_wd=0): - - defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, - adamw_lr_ratio=adamw_lr/lr, adamw_betas=adamw_betas, - adamw_eps=adamw_eps, adamw_wd=adamw_wd) - - if muon_selector is None: - muon_selector = lambda name, param: ( - param.requires_grad and - param.ndim >= 2 and # Check if scalar - "embed" not in name.lower() and # Check if embedding layer - "tok" not in name.lower() and # Check if token embeddings - "head" not in name.lower() and # Check if output head - "bias" not in name.lower() # Check if bias term - ) - - # handle list of params or list of dicts - # if isinstance(muon_params, Generator): - # muon_params = list(muon_params) - # if isinstance(adamw_params, Generator): - # adamw_params = list(adamw_params) - # elif adamw_params is None: - # adamw_params = [] - - named_params = list(params) - - muon_params = [p for n, p in named_params if muon_selector(n, p)] - adamw_params = [p for n, p in named_params if not muon_selector(n, p)] - - super().__init__([*muon_params, *adamw_params], defaults) - - # Sort parameters into those for which we will use Muon, and those for which we will not - # we cant pickle booleans for saving, so we will use 1=True, 0=False - def assign_muon(p): - if p.ndim >= 2 and p.size(0) < 10000: - self.state[p]['use_muon'] = 1 - else: - self.state[p]['use_muon'] = 0 - - if isinstance(muon_params[0], dict): - for group in muon_params: - for p in group['params']: - assign_muon(p) - else: - for p in muon_params: - assign_muon(p) - - def assign_adamw(p): - # Do not use Muon for parameters in adamw_params - self.state[p]['use_muon'] = 0 - - if len(adamw_params) and isinstance(adamw_params[0], dict): - for group in adamw_params: - for p in group['params']: - assign_adamw(p) - else: - for p in adamw_params: - assign_adamw(p) - - if torch.distributed.is_initialized(): - self.world_size = torch.distributed.get_world_size() - self.rank = torch.distributed.get_rank() - else: - self.world_size = 1 - self.rank = 0 - - def step(self, closure=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - lr = group["lr"] - momentum = group['momentum'] - for i, p in enumerate(group['params']): - if self.state[p]['use_muon'] == 1: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - state = self.state[p] - if 'momentum_buffer' not in state: - state['momentum_buffer'] = torch.zeros_like(g) - buf = state['momentum_buffer'] - buf.mul_(momentum).add_(g) - if group['nesterov']: - g = g.add(buf, alpha=momentum) - - meta = None - if isinstance(g, DTensor): - g, meta = to_local(g, keep_sharded=False) - # gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky - g = zeropower_via_newtonschulz5(g, steps=group['ns_steps']) - if meta is not None: - g = to_dist(g, **meta) - g *= max(1, g.size(0)/g.size(1))**0.5 - - g = g.view_as(p.data).type_as(p.data) - p.data.add_(g, alpha=-lr) - else: - # these are all pointwise so we can stay in Dtensor - g = p.grad - if g is None: - continue - state = self.state[p] - if 'step' not in state: - state['step'] = 0 - state['moment1'] = torch.zeros_like(g) - state['moment2'] = torch.zeros_like(g) - state['step'] += 1 - step = state['step'] - buf1 = state['moment1'] - buf2 = state['moment2'] - buf1.lerp_(g, 1-group['adamw_betas'][0]) - buf2.lerp_(g.square(), 1-group['adamw_betas'][1]) - - g = buf1 / (group['adamw_eps'] + buf2.sqrt()) - - bias_correction1 = 1 - group['adamw_betas'][0]**step - bias_correction2 = 1 - group['adamw_betas'][1]**step - scale = bias_correction1 / bias_correction2**0.5 - p.data.mul_(1 - lr * group['adamw_wd']) - p.data.add_(g, alpha=-lr/scale) \ No newline at end of file diff --git a/torchtune/modules/optim.py b/torchtune/modules/optim.py index 231e2b610e..16fc0a1e87 100644 --- a/torchtune/modules/optim.py +++ b/torchtune/modules/optim.py @@ -9,6 +9,7 @@ import torch from torch.optim import Optimizer import torch.distributed as dist +from torch.distributed.tensor import distribute_tensor, DTensor __all__ = ["OptimizerInBackward"] @@ -84,78 +85,37 @@ def load_state_dict(self, state_dict): for idx, opt in self._optimizers.items(): opt.load_state_dict(state_dict["optimizers"][str(idx)]) -###################################################### -# -# This code is referred from https://github.com/KellerJordan/Muon repo. -# @misc{jordan2024muon, -# author = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and -# Franz Cesista and Laker Newhouse and Jeremy Bernstein}, -# title = {Muon: An optimizer for hidden layers in neural networks}, -# year = {2024}, -# url = {https://kellerjordan.github.io/posts/muon/} -# } -# Changes have been made wherever necessary. -# -###################################################### - -def zeropower_via_newtonschulz5(G, steps: int): +class Muon(torch.optim.Optimizer): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + params: The parameters to be optimized. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - # Perform the NS iterations - for _ in range(steps): - A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X - - if G.size(-2) > G.size(-1): - X = X.mT - return X - - -def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): - momentum.lerp_(grad, 1 - beta) - update = grad.lerp_(momentum, beta) if nesterov else momentum - if update.ndim == 4: # for the case of conv filters - update = update.view(len(update), -1) - update = zeropower_via_newtonschulz5(update, steps=ns_steps) - update *= max(1, grad.size(-2) / grad.size(-1))**0.5 - return update - -def adam_update(grad, buf1, buf2, step, betas, eps): - buf1.lerp_(grad, 1 - betas[0]) - buf2.lerp_(grad.square(), 1 - betas[1]) - buf1c = buf1 / (1 - betas[0]**step) - buf2c = buf2 / (1 - betas[1]**step) - return buf1c / (buf2c.sqrt() + eps) - -class SingleDeviceMuonWithAuxAdam(Optimizer): - def __init__( - self, - params, # Pass model.named_parameters() - *, - muon_selector=None, - muon_lr: float = 0.02, - muon_momentum: float = 0.95, - adam_lr: float = 3e-4, - adam_betas=(0.9, 0.95), - adam_eps: float = 1e-10, - weight_decay: float = 0.0, - ): + def __init__(self, params, muon_selector=None, lr=0.02, momentum=0.95, nesterov=True, ns_steps=6, + adamw_lr=3e-4, adamw_betas=[0.95, 0.95], adamw_eps=1e-8, adamw_wd=0): + + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, + adamw_lr_ratio=adamw_lr/lr, adamw_betas=adamw_betas, + adamw_eps=adamw_eps, adamw_wd=adamw_wd) + if muon_selector is None: muon_selector = lambda name, param: ( param.requires_grad and @@ -169,124 +129,159 @@ def __init__( named_params = list(params) muon_params = [p for n, p in named_params if muon_selector(n, p)] - adam_params = [p for n, p in named_params if not muon_selector(n, p)] - - muon_params.sort(key=lambda p: p.size(), reverse=True) - - super().__init__( - [ - dict(params=muon_params, - lr=muon_lr, - momentum=muon_momentum, - weight_decay=weight_decay, - use_muon=True), - dict(params=adam_params, - lr=adam_lr, - betas=adam_betas, - eps=adam_eps, - weight_decay=weight_decay, - use_muon=False), - ], - defaults={} - ) + adamw_params = [p for n, p in named_params if not muon_selector(n, p)] - @torch.no_grad() - def step(self): - for group in self.param_groups: - assert "use_muon" in group - if group["use_muon"]: - for p in group["params"]: - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) + super().__init__([*muon_params, *adamw_params], defaults) + + # Sort parameters into those for which we will use Muon, and those for which we will not + # we cant pickle booleans for saving, so we will use 1=True, 0=False + def assign_muon(p): + if p.ndim >= 2 and p.size(0) < 10000: + self.state[p]['use_muon'] = 1 else: - for p in group["params"]: - state = self.state[p] - if len(state) == 0: - state["exp_avg"] = torch.zeros_like(p) - state["exp_avg_sq"] = torch.zeros_like(p) - state["step"] = 0 - state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - -class MuonWithAuxAdam(torch.optim.Optimizer): - def __init__( - self, - params, # Pass model.named_parameters() - *, - muon_selector=None, - muon_lr: float = 0.02, - muon_momentum: float = 0.95, - adam_lr: float = 3e-4, - adam_betas=(0.9, 0.95), - adam_eps: float = 1e-10, - weight_decay: float = 0.0, - ): - if muon_selector is None: - muon_selector = lambda name, param: ( - param.requires_grad and - param.ndim >= 2 and # Check if scalar - "embed" not in name.lower() and # Check if embedding layer - "tok" not in name.lower() and # Check if token embeddings - "head" not in name.lower() and # Check if output head - "bias" not in name.lower() # Check if bias term + self.state[p]['use_muon'] = 0 + + if isinstance(muon_params[0], dict): + for group in muon_params: + for p in group['params']: + assign_muon(p) + else: + for p in muon_params: + assign_muon(p) + + def assign_adamw(p): + # Do not use Muon for parameters in adamw_params + self.state[p]['use_muon'] = 0 + + if len(adamw_params) and isinstance(adamw_params[0], dict): + for group in adamw_params: + for p in group['params']: + assign_adamw(p) + else: + for p in adamw_params: + assign_adamw(p) + + if torch.distributed.is_initialized(): + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + else: + self.world_size = 1 + self.rank = 0 + + def to_dist(self, x, from_local=False, **meta): + if from_local: + return DTensor.from_local( + x, + device_mesh=meta["device_mesh"], + placements=meta["placements"], + shape=meta["shape"], + stride=meta["stride"], ) + else: + return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"]) - named_params = list(params) - muon_params = [p for n, p in named_params if muon_selector(n, p)] - adam_params = [p for n, p in named_params if not muon_selector(n, p)] - - muon_params.sort(key=lambda p: p.size(), reverse=True) - - super().__init__( - [ - dict(params=muon_params, - lr=muon_lr, - momentum=muon_momentum, - weight_decay=weight_decay, - use_muon=True), - dict(params=adam_params, - lr=adam_lr, - betas=adam_betas, - eps=adam_eps, - weight_decay=weight_decay, - use_muon=False), - ], - defaults={} - ) + def to_local(self, x, keep_sharded=False): + if isinstance(x, DTensor): + meta = dict( + device_mesh=x.device_mesh, + placements=x.placements, + shape=x.shape, + stride=x.stride(), + ) + if keep_sharded: + return x.to_local(), meta + else: + return x.full_tensor(), meta + + return x, None + + def zeropower_via_newtonschulz5(self, G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= (X.norm() + eps) # ensure top singular value <= 1 + if G.size(0) > G.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(0) > G.size(1): + X = X.T + return X + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() - @torch.no_grad() - def step(self): for group in self.param_groups: - if group["use_muon"]: - params = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * (len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: - if base_i + dist.get_rank() < len(params): - p = params[base_i + dist.get_rank()] - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(p) - update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) - else: - for p in group["params"]: + lr = group["lr"] + momentum = group['momentum'] + for i, p in enumerate(group['params']): + if self.state[p]['use_muon'] == 1: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(momentum).add_(g) + if group['nesterov']: + g = g.add(buf, alpha=momentum) + + meta = None + if isinstance(g, DTensor): + g, meta = self.to_local(g, keep_sharded=False) + # gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky + g = self.zeropower_via_newtonschulz5(g, steps=group['ns_steps']) + if meta is not None: + g = self.to_dist(g, **meta) + g *= max(1, g.size(0)/g.size(1))**0.5 + + g = g.view_as(p.data).type_as(p.data) + p.data.add_(g, alpha=-lr) + else: + # these are all pointwise so we can stay in Dtensor + g = p.grad + if g is None: + continue state = self.state[p] - if len(state) == 0: - state["exp_avg"] = torch.zeros_like(p) - state["exp_avg_sq"] = torch.zeros_like(p) - state["step"] = 0 - state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) - p.mul_(1 - group["lr"] * group["weight_decay"]) - p.add_(update, alpha=-group["lr"]) \ No newline at end of file + if 'step' not in state: + state['step'] = 0 + state['moment1'] = torch.zeros_like(g) + state['moment2'] = torch.zeros_like(g) + state['step'] += 1 + step = state['step'] + buf1 = state['moment1'] + buf2 = state['moment2'] + buf1.lerp_(g, 1-group['adamw_betas'][0]) + buf2.lerp_(g.square(), 1-group['adamw_betas'][1]) + + g = buf1 / (group['adamw_eps'] + buf2.sqrt()) + + bias_correction1 = 1 - group['adamw_betas'][0]**step + bias_correction2 = 1 - group['adamw_betas'][1]**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * group['adamw_wd']) + p.data.add_(g, alpha=-lr/scale) \ No newline at end of file diff --git a/torchtune/training/lr_schedulers.py b/torchtune/training/lr_schedulers.py index f7a2e024a8..be8d85debf 100644 --- a/torchtune/training/lr_schedulers.py +++ b/torchtune/training/lr_schedulers.py @@ -10,7 +10,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from torchtune.training.memory import OptimizerInBackwardWrapper -from torchtune.modules.optim import SingleDeviceMuonWithAuxAdam +from torchtune.modules.optim import Muon def get_cosine_schedule_with_warmup( @@ -89,15 +89,9 @@ def get_lr( ) # LR Schedulers are the same across all param groups for full_finetune right now - - if isinstance(optimizer, SingleDeviceMuonWithAuxAdam): - for group in param_groups: - lr = group["lr"] - if group['use_muon']: # Returning only Muon learning rate - return lr - return lr lr = param_groups[0]["lr"] + if isinstance(optimizer, Muon): return lr # return Muon learning rate if Muon optimizer for group in param_groups: if group["lr"] != lr: raise RuntimeError("LR Schedulers are different across all param groups ")