Skip to content

perf(deepseek_v4): use mx.fast.rms_norm in HyperConnection and HyperHead#16

Open
0xClandestine wants to merge 1 commit intoBlaizzy:pc/add-deepseekv4flash-modelfrom
0xClandestine:phew/optimize-blaizzy-deepseekv4flash
Open

perf(deepseek_v4): use mx.fast.rms_norm in HyperConnection and HyperHead#16
0xClandestine wants to merge 1 commit intoBlaizzy:pc/add-deepseekv4flash-modelfrom
0xClandestine:phew/optimize-blaizzy-deepseekv4flash

Conversation

@0xClandestine
Copy link
Copy Markdown

Summary

Replaces the manual rsqrt(mean(x²) + eps) + scaled matmul pattern in HyperConnection and HyperHead with mx.fast.rms_norm, which dispatches a fused Metal kernel.

  • Removes _rms_rsqrt helper (subsumed)
  • _hc_mixes: (flat @ fn_T) * rsqrtmx.fast.rms_norm(flat, None, eps) @ fn_T
  • _hyper_head_op: same rewrite
  • HyperConnection.compute_weights training path: unified with inference path (both now go through _hc_mixes)
  • HyperHead.__call__ training path: same rewrite

The identity (x @ W) * s = rms_norm(x) @ W holds because s = 1/rms(x) is a per-row scalar, so it distributes over matrix multiplication.

Perf

Measured on M-series (DeepSeek-V4 geometry: HC=4, D=7168, flat_dim=28672):

kernel baseline optimised speedup
_hc_mixes (B=2 L=256) 0.77 ms 0.63 ms 1.22×
_hyper_head (B=2 L=256) 0.72 ms 0.63 ms 1.14×

Verified numerically with phew-mlx equivalence checker (SubstitutionClass.normed_matmul, atol=1e-3).

Replace manual rsqrt(mean(x²)+eps) + scaled matmul with mx.fast.rms_norm
throughout the HyperConnection/HyperHead compute paths:

  - Remove _rms_rsqrt helper (subsumed by _hc_mixes)
  - _hc_mixes: (flat @ fn_T) * rsqrt → rms_norm(flat) @ fn_T
  - _hyper_head_op: same rewrite
  - HyperConnection.compute_weights training path: unified with inference path
  - HyperHead.__call__ training path: same rewrite

mx.fast.rms_norm dispatches a fused Metal kernel; the rewrite is algebraically
exact ((x@W)*s = rms_norm(x)@W for per-row scalar s) and verified ~1.2× faster
on DeepSeek-V4 geometry (HC=4, D=7168, flat_dim=28672).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant