8484NS_COEFF_C : float = 2.0315
8585
8686
87- def _maybe_compile (
88- fn : callable ,
89- ) -> callable :
90- """Compile a function if torch.compile is available."""
91- if not hasattr (torch , "compile" ):
92- return fn
93- # Skip compile if default device is CUDA but CUDA is unavailable.
94- if hasattr (torch , "get_default_device" ):
95- default_device = torch .get_default_device ()
96- if default_device .type == "cuda" and not torch .cuda .is_available ():
97- return fn
98- return torch .compile (fn , fullgraph = True , dynamic = True )
99-
100-
101- @_maybe_compile
102- def _zeropower_via_newtonschulz5_2d (
87+ def _newton_schulz_orth (
10388 G : torch .Tensor ,
10489) -> torch .Tensor :
10590 """
@@ -132,70 +117,6 @@ def _zeropower_via_newtonschulz5_2d(
132117 return X
133118
134119
135- @_maybe_compile
136- def _zeropower_via_newtonschulz5_3d (
137- G : torch .Tensor ,
138- ) -> torch .Tensor :
139- """
140- Orthogonalize a 3D batch of matrices via quintic Newton-Schulz iteration.
141-
142- Mathematical formulation:
143- X_0 = G / ||G||_F
144- X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T
145- Coefficients: a=3.4445, b=-4.7750, c=2.0315
146- """
147- # === Step 1. Cast to bf16 and transpose tall matrices ===
148- X = G .to (dtype = torch .bfloat16 )
149- transposed = X .size (- 2 ) > X .size (- 1 )
150- if transposed :
151- X = X .transpose (- 2 , - 1 )
152-
153- # === Step 2. Normalize Frobenius norm to at most 1 ===
154- X = X / X .norm (dim = (- 2 , - 1 ), keepdim = True ).clamp (min = EPS )
155-
156- # === Step 3. Newton-Schulz iterations with batched fused GEMM ===
157- for _ in range (NS_STEPS ):
158- A = torch .bmm (X , X .transpose (- 2 , - 1 ))
159- gram_update = torch .baddbmm (A , A , A , beta = NS_COEFF_B , alpha = NS_COEFF_C )
160- X = torch .baddbmm (X , gram_update , X , beta = NS_COEFF_A , alpha = 1.0 )
161-
162- # === Step 4. Transpose back if needed ===
163- if transposed :
164- X = X .transpose (- 2 , - 1 )
165-
166- return X
167-
168-
169- def zeropower_via_newtonschulz5 (
170- G : torch .Tensor ,
171- ) -> torch .Tensor :
172- """
173- Compute the zeroth power (orthogonalization) via Newton-Schulz iteration.
174-
175- Dispatches to compiled 2D or 3D kernels for best performance.
176-
177- Parameters
178- ----------
179- G : torch.Tensor
180- Input matrix with shape (M, N) or batched input with shape (B, M, N).
181-
182- Returns
183- -------
184- torch.Tensor
185- Orthogonalized tensor in bfloat16 with same shape as input.
186-
187- Raises
188- ------
189- ValueError
190- If input is not 2D or 3D.
191- """
192- if G .ndim == 2 :
193- return _zeropower_via_newtonschulz5_2d (G )
194- if G .ndim == 3 :
195- return _zeropower_via_newtonschulz5_3d (G )
196- raise ValueError ("Input must be 2D or 3D for Newton-Schulz orthogonalization." )
197-
198-
199120def should_fallback_to_adam_for_matrix (
200121 p : torch .Tensor ,
201122 min_2d_dim : int ,
@@ -478,9 +399,11 @@ def step(
478399
479400 # exp_avg = beta1 * exp_avg + (1 - beta1) * grad
480401 # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
481- torch ._foreach_lerp_ (adam_exp_avgs , adam_grads_fp32 , 1 - adam_betas [0 ])
482- grad_sq = torch ._foreach_mul (adam_grads_fp32 , adam_grads_fp32 )
483- torch ._foreach_lerp_ (adam_exp_avg_sqs , grad_sq , 1 - adam_betas [1 ])
402+ for ea , g in zip (adam_exp_avgs , adam_grads_fp32 ):
403+ ea .lerp_ (g , 1 - adam_betas [0 ])
404+ grad_sq = [g * g for g in adam_grads_fp32 ]
405+ for eas , gsq in zip (adam_exp_avg_sqs , grad_sq ):
406+ eas .lerp_ (gsq , 1 - adam_betas [1 ])
484407
485408 # === Step 1.3. Bias correction and parameter update ===
486409 for i , p in enumerate (adam_params ):
@@ -531,11 +454,11 @@ def step(
531454
532455 # exp_avg = beta1 * exp_avg + (1 - beta1) * grad
533456 # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
534- torch . _foreach_lerp_ (
535- adam_nd_exp_avgs , adam_nd_grads_fp32 , 1 - adam_betas [0 ]
536- )
537- grad_sq = torch . _foreach_mul ( adam_nd_grads_fp32 , adam_nd_grads_fp32 )
538- torch . _foreach_lerp_ ( adam_nd_exp_avg_sqs , grad_sq , 1 - adam_betas [1 ])
457+ for ea , g in zip ( adam_nd_exp_avgs , adam_nd_grads_fp32 ):
458+ ea . lerp_ ( g , 1 - adam_betas [0 ])
459+ grad_sq = [ g * g for g in adam_nd_grads_fp32 ]
460+ for eas , gsq in zip ( adam_nd_exp_avg_sqs , grad_sq ):
461+ eas . lerp_ ( gsq , 1 - adam_betas [1 ])
539462
540463 # === Step 2.3. Bias correction and parameter update ===
541464 for i , p in enumerate (adam_nd_params ):
@@ -589,15 +512,11 @@ def step(
589512
590513 # exp_avg = beta1 * exp_avg + (1 - beta1) * grad
591514 # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
592- torch ._foreach_lerp_ (
593- adam_matrix_exp_avgs , adam_matrix_grads_fp32 , 1 - adam_betas [0 ]
594- )
595- grad_sq_m = torch ._foreach_mul (
596- adam_matrix_grads_fp32 , adam_matrix_grads_fp32
597- )
598- torch ._foreach_lerp_ (
599- adam_matrix_exp_avg_sqs , grad_sq_m , 1 - adam_betas [1 ]
600- )
515+ for ea , g in zip (adam_matrix_exp_avgs , adam_matrix_grads_fp32 ):
516+ ea .lerp_ (g , 1 - adam_betas [0 ])
517+ grad_sq_m = [g * g for g in adam_matrix_grads_fp32 ]
518+ for eas , gsq in zip (adam_matrix_exp_avg_sqs , grad_sq_m ):
519+ eas .lerp_ (gsq , 1 - adam_betas [1 ])
601520
602521 # === Step 3.3. Compute unclipped deltas ===
603522 raw_deltas : list [torch .Tensor ] = []
@@ -611,8 +530,8 @@ def step(
611530
612531 # === Step 3.4. Clip updates by relative norm and apply ===
613532 max_rel_change = 0.05
614- p_norms = torch .stack (torch . _foreach_norm ( adam_matrix_params ) )
615- delta_norms = torch .stack (torch . _foreach_norm ( raw_deltas ) )
533+ p_norms = torch .stack ([ p . norm () for p in adam_matrix_params ] )
534+ delta_norms = torch .stack ([ d . norm () for d in raw_deltas ] )
616535 floors = torch .tensor (
617536 adam_matrix_abs_floor ,
618537 device = p_norms .device ,
@@ -653,18 +572,21 @@ def step(
653572
654573 # === Step 4.2. Apply weight decay (Muon path only) ===
655574 if weight_decay > 0 and muon_params_for_decay :
656- torch ._foreach_mul_ (muon_params_for_decay , 1.0 - lr * weight_decay )
575+ for p in muon_params_for_decay :
576+ p .mul_ (1.0 - lr * weight_decay )
657577
658578 if not active_entries :
659579 continue
660580
661581 # === Step 4.3. Momentum update (Nesterov) ===
662582 # m_t = beta * m_{t-1} + (1 - beta) * g_t
663- torch ._foreach_lerp_ (muon_momentum_buffers , muon_grads , 1 - momentum )
583+ for buf , g in zip (muon_momentum_buffers , muon_grads ):
584+ buf .lerp_ (g , 1 - momentum )
664585 # update = beta * m_t + (1 - beta) * g_t
665- muon_updates = torch ._foreach_lerp (
666- muon_grads , muon_momentum_buffers , momentum
667- )
586+ muon_updates = [
587+ torch .lerp (g , buf , momentum )
588+ for g , buf in zip (muon_grads , muon_momentum_buffers )
589+ ]
668590
669591 # === Step 4.4. Bucket by shape/device/dtype for batched NS ===
670592 buckets : dict [
@@ -689,37 +611,16 @@ def step(
689611 else :
690612 scale = max (1.0 , rows / cols ) ** 0.5
691613
692- if len (bucket_entries ) == 1 :
693- entry , update_tensor = bucket_entries [0 ]
614+ # Process each entry individually with _newton_schulz_orth.
615+ # compatible with sharding propagation under FSDP2.
616+ for entry , update_tensor in bucket_entries :
694617 update_matrix = update_tensor .reshape (rows , cols )
695618 if not update_matrix .is_contiguous ():
696619 update_matrix = update_matrix .contiguous ()
697620
698- orth = _zeropower_via_newtonschulz5_2d (update_matrix )
621+ orth = _newton_schulz_orth (update_matrix )
699622 orth .mul_ (scale )
700623 delta = orth .reshape (entry ["param" ].shape )
701624 entry ["param" ].add_ (delta , alpha = - lr )
702- continue
703-
704- matrices : list [torch .Tensor ] = []
705- params : list [torch .Tensor ] = []
706- orig_shapes : list [tuple [int , ...]] = []
707-
708- for entry , update_tensor in bucket_entries :
709- update_matrix = update_tensor .reshape (rows , cols )
710- matrices .append (
711- update_matrix
712- if update_matrix .is_contiguous ()
713- else update_matrix .contiguous ()
714- )
715- params .append (entry ["param" ])
716- orig_shapes .append (entry ["param" ].shape )
717-
718- stacked = torch .stack (matrices , dim = 0 )
719- orth = _zeropower_via_newtonschulz5_3d (stacked )
720- orth .mul_ (scale )
721-
722- for i , _ in enumerate (bucket_entries ):
723- params [i ].add_ (orth [i ].reshape (orig_shapes [i ]), alpha = - lr )
724625
725626 return loss
0 commit comments