Skip to content

Commit 3e5e1ba

Browse files
committed
fix: FSDP2 do not support foreach ops in HybridMuon
1 parent f6d5d95 commit 3e5e1ba

File tree

2 files changed

+37
-136
lines changed

2 files changed

+37
-136
lines changed

deepmd/pt/optimizer/hybrid_muon.py

Lines changed: 30 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,7 @@
8484
NS_COEFF_C: float = 2.0315
8585

8686

87-
def _maybe_compile(
88-
fn: callable,
89-
) -> callable:
90-
"""Compile a function if torch.compile is available."""
91-
if not hasattr(torch, "compile"):
92-
return fn
93-
# Skip compile if default device is CUDA but CUDA is unavailable.
94-
if hasattr(torch, "get_default_device"):
95-
default_device = torch.get_default_device()
96-
if default_device.type == "cuda" and not torch.cuda.is_available():
97-
return fn
98-
return torch.compile(fn, fullgraph=True, dynamic=True)
99-
100-
101-
@_maybe_compile
102-
def _zeropower_via_newtonschulz5_2d(
87+
def _newton_schulz_orth(
10388
G: torch.Tensor,
10489
) -> torch.Tensor:
10590
"""
@@ -132,70 +117,6 @@ def _zeropower_via_newtonschulz5_2d(
132117
return X
133118

134119

135-
@_maybe_compile
136-
def _zeropower_via_newtonschulz5_3d(
137-
G: torch.Tensor,
138-
) -> torch.Tensor:
139-
"""
140-
Orthogonalize a 3D batch of matrices via quintic Newton-Schulz iteration.
141-
142-
Mathematical formulation:
143-
X_0 = G / ||G||_F
144-
X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T
145-
Coefficients: a=3.4445, b=-4.7750, c=2.0315
146-
"""
147-
# === Step 1. Cast to bf16 and transpose tall matrices ===
148-
X = G.to(dtype=torch.bfloat16)
149-
transposed = X.size(-2) > X.size(-1)
150-
if transposed:
151-
X = X.transpose(-2, -1)
152-
153-
# === Step 2. Normalize Frobenius norm to at most 1 ===
154-
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS)
155-
156-
# === Step 3. Newton-Schulz iterations with batched fused GEMM ===
157-
for _ in range(NS_STEPS):
158-
A = torch.bmm(X, X.transpose(-2, -1))
159-
gram_update = torch.baddbmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C)
160-
X = torch.baddbmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0)
161-
162-
# === Step 4. Transpose back if needed ===
163-
if transposed:
164-
X = X.transpose(-2, -1)
165-
166-
return X
167-
168-
169-
def zeropower_via_newtonschulz5(
170-
G: torch.Tensor,
171-
) -> torch.Tensor:
172-
"""
173-
Compute the zeroth power (orthogonalization) via Newton-Schulz iteration.
174-
175-
Dispatches to compiled 2D or 3D kernels for best performance.
176-
177-
Parameters
178-
----------
179-
G : torch.Tensor
180-
Input matrix with shape (M, N) or batched input with shape (B, M, N).
181-
182-
Returns
183-
-------
184-
torch.Tensor
185-
Orthogonalized tensor in bfloat16 with same shape as input.
186-
187-
Raises
188-
------
189-
ValueError
190-
If input is not 2D or 3D.
191-
"""
192-
if G.ndim == 2:
193-
return _zeropower_via_newtonschulz5_2d(G)
194-
if G.ndim == 3:
195-
return _zeropower_via_newtonschulz5_3d(G)
196-
raise ValueError("Input must be 2D or 3D for Newton-Schulz orthogonalization.")
197-
198-
199120
def should_fallback_to_adam_for_matrix(
200121
p: torch.Tensor,
201122
min_2d_dim: int,
@@ -478,9 +399,11 @@ def step(
478399

479400
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
480401
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
481-
torch._foreach_lerp_(adam_exp_avgs, adam_grads_fp32, 1 - adam_betas[0])
482-
grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32)
483-
torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1])
402+
for ea, g in zip(adam_exp_avgs, adam_grads_fp32):
403+
ea.lerp_(g, 1 - adam_betas[0])
404+
grad_sq = [g * g for g in adam_grads_fp32]
405+
for eas, gsq in zip(adam_exp_avg_sqs, grad_sq):
406+
eas.lerp_(gsq, 1 - adam_betas[1])
484407

485408
# === Step 1.3. Bias correction and parameter update ===
486409
for i, p in enumerate(adam_params):
@@ -531,11 +454,11 @@ def step(
531454

532455
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
533456
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
534-
torch._foreach_lerp_(
535-
adam_nd_exp_avgs, adam_nd_grads_fp32, 1 - adam_betas[0]
536-
)
537-
grad_sq = torch._foreach_mul(adam_nd_grads_fp32, adam_nd_grads_fp32)
538-
torch._foreach_lerp_(adam_nd_exp_avg_sqs, grad_sq, 1 - adam_betas[1])
457+
for ea, g in zip(adam_nd_exp_avgs, adam_nd_grads_fp32):
458+
ea.lerp_(g, 1 - adam_betas[0])
459+
grad_sq = [g * g for g in adam_nd_grads_fp32]
460+
for eas, gsq in zip(adam_nd_exp_avg_sqs, grad_sq):
461+
eas.lerp_(gsq, 1 - adam_betas[1])
539462

540463
# === Step 2.3. Bias correction and parameter update ===
541464
for i, p in enumerate(adam_nd_params):
@@ -589,15 +512,11 @@ def step(
589512

590513
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
591514
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
592-
torch._foreach_lerp_(
593-
adam_matrix_exp_avgs, adam_matrix_grads_fp32, 1 - adam_betas[0]
594-
)
595-
grad_sq_m = torch._foreach_mul(
596-
adam_matrix_grads_fp32, adam_matrix_grads_fp32
597-
)
598-
torch._foreach_lerp_(
599-
adam_matrix_exp_avg_sqs, grad_sq_m, 1 - adam_betas[1]
600-
)
515+
for ea, g in zip(adam_matrix_exp_avgs, adam_matrix_grads_fp32):
516+
ea.lerp_(g, 1 - adam_betas[0])
517+
grad_sq_m = [g * g for g in adam_matrix_grads_fp32]
518+
for eas, gsq in zip(adam_matrix_exp_avg_sqs, grad_sq_m):
519+
eas.lerp_(gsq, 1 - adam_betas[1])
601520

602521
# === Step 3.3. Compute unclipped deltas ===
603522
raw_deltas: list[torch.Tensor] = []
@@ -611,8 +530,8 @@ def step(
611530

612531
# === Step 3.4. Clip updates by relative norm and apply ===
613532
max_rel_change = 0.05
614-
p_norms = torch.stack(torch._foreach_norm(adam_matrix_params))
615-
delta_norms = torch.stack(torch._foreach_norm(raw_deltas))
533+
p_norms = torch.stack([p.norm() for p in adam_matrix_params])
534+
delta_norms = torch.stack([d.norm() for d in raw_deltas])
616535
floors = torch.tensor(
617536
adam_matrix_abs_floor,
618537
device=p_norms.device,
@@ -653,18 +572,21 @@ def step(
653572

654573
# === Step 4.2. Apply weight decay (Muon path only) ===
655574
if weight_decay > 0 and muon_params_for_decay:
656-
torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay)
575+
for p in muon_params_for_decay:
576+
p.mul_(1.0 - lr * weight_decay)
657577

658578
if not active_entries:
659579
continue
660580

661581
# === Step 4.3. Momentum update (Nesterov) ===
662582
# m_t = beta * m_{t-1} + (1 - beta) * g_t
663-
torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum)
583+
for buf, g in zip(muon_momentum_buffers, muon_grads):
584+
buf.lerp_(g, 1 - momentum)
664585
# update = beta * m_t + (1 - beta) * g_t
665-
muon_updates = torch._foreach_lerp(
666-
muon_grads, muon_momentum_buffers, momentum
667-
)
586+
muon_updates = [
587+
torch.lerp(g, buf, momentum)
588+
for g, buf in zip(muon_grads, muon_momentum_buffers)
589+
]
668590

669591
# === Step 4.4. Bucket by shape/device/dtype for batched NS ===
670592
buckets: dict[
@@ -689,37 +611,16 @@ def step(
689611
else:
690612
scale = max(1.0, rows / cols) ** 0.5
691613

692-
if len(bucket_entries) == 1:
693-
entry, update_tensor = bucket_entries[0]
614+
# Process each entry individually with _newton_schulz_orth.
615+
# compatible with sharding propagation under FSDP2.
616+
for entry, update_tensor in bucket_entries:
694617
update_matrix = update_tensor.reshape(rows, cols)
695618
if not update_matrix.is_contiguous():
696619
update_matrix = update_matrix.contiguous()
697620

698-
orth = _zeropower_via_newtonschulz5_2d(update_matrix)
621+
orth = _newton_schulz_orth(update_matrix)
699622
orth.mul_(scale)
700623
delta = orth.reshape(entry["param"].shape)
701624
entry["param"].add_(delta, alpha=-lr)
702-
continue
703-
704-
matrices: list[torch.Tensor] = []
705-
params: list[torch.Tensor] = []
706-
orig_shapes: list[tuple[int, ...]] = []
707-
708-
for entry, update_tensor in bucket_entries:
709-
update_matrix = update_tensor.reshape(rows, cols)
710-
matrices.append(
711-
update_matrix
712-
if update_matrix.is_contiguous()
713-
else update_matrix.contiguous()
714-
)
715-
params.append(entry["param"])
716-
orig_shapes.append(entry["param"].shape)
717-
718-
stacked = torch.stack(matrices, dim=0)
719-
orth = _zeropower_via_newtonschulz5_3d(stacked)
720-
orth.mul_(scale)
721-
722-
for i, _ in enumerate(bucket_entries):
723-
params[i].add_(orth[i].reshape(orig_shapes[i]), alpha=-lr)
724625

725626
return loss

source/tests/pt/test_hybrid_muon.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from deepmd.pt.optimizer.hybrid_muon import (
77
HybridMuonOptimizer,
8-
zeropower_via_newtonschulz5,
8+
_newton_schulz_orth,
99
)
1010
from deepmd.pt.utils import (
1111
env,
@@ -48,7 +48,7 @@ def test_orthogonalization(self) -> None:
4848
"""Test that NS produces approximately orthogonal output."""
4949
torch.manual_seed(42)
5050
G = torch.randn(4, 4, dtype=torch.float32, device=self.device)
51-
X = zeropower_via_newtonschulz5(G)
51+
X = _newton_schulz_orth(G)
5252

5353
# X @ X.T should be approximately identity
5454
# Note: NS uses bf16 internally, 5 iterations gives ~0.1-0.3 error
@@ -68,17 +68,17 @@ def test_orthogonalization(self) -> None:
6868
def test_shape_and_dtype(self) -> None:
6969
"""Test that output preserves shape and returns bf16."""
7070
torch.manual_seed(42)
71-
for shape in [(4, 4), (6, 4), (3, 4, 4)]:
71+
for shape in [(4, 4), (6, 4)]:
7272
G = torch.randn(*shape, dtype=torch.float32, device=self.device)
73-
X = zeropower_via_newtonschulz5(G)
73+
X = _newton_schulz_orth(G)
7474
self.assertEqual(X.shape, G.shape)
7575
self.assertEqual(X.dtype, torch.bfloat16)
7676

7777
def test_invalid_input(self) -> None:
78-
"""Test that <2D input raises ValueError."""
78+
"""Test that 1D input raises error."""
7979
G_1d = torch.randn(10, dtype=torch.float32, device=self.device)
80-
with self.assertRaises(ValueError):
81-
zeropower_via_newtonschulz5(G_1d)
80+
with self.assertRaises((ValueError, RuntimeError, IndexError)):
81+
_newton_schulz_orth(G_1d)
8282

8383

8484
@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device")

0 commit comments

Comments
 (0)