Skip to content

perf: optimize DeepSeek-V4#13

Merged
Blaizzy merged 43 commits intoBlaizzy:pc/add-deepseekv4flash-modelfrom
0xClandestine:perf/optimize-ds4
Apr 26, 2026
Merged

perf: optimize DeepSeek-V4#13
Blaizzy merged 43 commits intoBlaizzy:pc/add-deepseekv4flash-modelfrom
0xClandestine:perf/optimize-ds4

Conversation

@0xClandestine
Copy link
Copy Markdown

@0xClandestine 0xClandestine commented Apr 24, 2026

DeepSeek V4 Optimization Changelog

Targets the DeepSeek-V4/V4-Flash Metal kernels in mlx_lm/models/deepseek_v4.py.

What stuck

Change Impact
Vectorized HC Sinkhorn kernel — Rewrote _make_hc_split_sinkhorn_kernel to use float4 SIMD loads/stores, dot() for row sums, and fully unrolled 4x4 Sinkhorn iterations instead of scalar loops. Measurable prefill speedup; approved by maintainer.
_overlap_transform concatenate — Replaced mx.full + scatter-assign with mx.concatenate of sliced tensors, avoiding a large intermediate allocation. Approved by maintainer ("this is great").
Per-head q-norm Metal kernel (ds4_q_norm) — Fused RMS norm for query vectors using a 32-thread SIMD group reduction, replacing mx.rsqrt + elementwise. Additive, no regression.
hc_mult != 4 guard — Falls back to pure-MLX Sinkhorn when hc_mult is not 4, since the float4 kernel is hardcoded for 4x4. Safety fix.

What was tried and reverted

Change Why reverted
Fused sparse attention Metal kernel (ds4_fused_sparse_attn) — Online-softmax over local + sparse KV in a single MSL kernel. Multiple iterations fixing address spaces, dispatch, output shapes. Hung prompt processing; maintainer benchmarked at 86 tok/s vs 122 tok/s prefill.
_split_sparse_attention (MLX fallback) — Two-pass local+sparse attention with log-sum-exp merge, @mx.compile-decorated. Replaced by standard SDPA with gather+concatenate. Still in file as dead code but unused. SDPA path is faster.
Fused partial RoPE Metal kernel (ds4_partial_rope) — Eliminated split/concat intermediate for partial rotary embeddings. Didn't move the needle on benchmarks. Reverted per maintainer feedback.
Compressor x-buffer decode deferral — Deferred wkv/wgate GEMVs until a full compression window was ready. Broke batching. Reverted per maintainer feedback.
HyperConnection fn in model dtype — Changed self.fn from float32 to default (bf16), split flat/flat_f32 to do matmul in model dtype. Caused prefill regression (78 tok/s down from 120). Reverted.
SwitchGLU refactor — Removed scores param, used _gather_sort helper, moved score multiplication outside to DeepseekV4MoE.__call__. Caused prefill regression. Reverted to maintainer's approach.
Simplified compress path — Stripped select_all optimization, lengths/pooled_mask tracking, mask trimming, and pooled_bias handling. Major prefill regression (78 → 262 tok/s after restoring). The select_all fast-path and mask trim are critical.
Collapse/expand Metal kernels + MoE gate bf16 — Custom kernels for HyperConnection collapse/expand ops. Reverted immediately (broke things).
Tokenizer PreTrainedTokenizerFast fallback — Caught AttributeError from AutoTokenizer for custom models with rope_scaling. Upstream transformers fix (PR #45643) made it unnecessary.

Final benchmark

Prefill:    262.7 tok/s
Generation:  31.4 tok/s
Peak memory: 152.2 GB

@0xClandestine 0xClandestine changed the title Optimize HC sinkhorn Metal kernel with float4 SIMD and bounds guard perf: optimize DeepSeek-V4 Apr 24, 2026
@0xClandestine 0xClandestine force-pushed the perf/optimize-ds4 branch 3 times, most recently from 2abd684 to 9b9f012 Compare April 25, 2026 10:59
0xClandestine and others added 25 commits April 25, 2026 16:43
- Add n_rows bounds guard to prevent out-of-bounds writes when grid
  is not a multiple of threadgroup size (256)
- #if HC == 4 fast path: float4 vectorized pre/post sigmoid, tree-reduce
  max for comb softmax, dot-product row sum, metal::fast::recip/exp
  throughout, aligned 128-bit stores for all outputs
- Sinkhorn passes use float4 row vectors — column sums reduce to 4
  vector additions, row sums to a single dot(row, 1) instruction
- Scalar fallback (#else): replace 1/x with metal::fast::recip and
  a*b+c with fma for all mix/base linear combinations
MLX injects template params as constexpr variables, not #define macros.
The preprocessor sees HC as undefined (expands to 0), so #if HC == 4
always evaluated false and the scalar fallback ran unconditionally,
silently defeating all float4 optimizations.

Drop the #if/#else/#endif entirely since HC is always 4 for this model.
- Fix initial softmax eps: remove spurious eps from recip denominator
  to match Python reference (mx.softmax) and old kernel behavior
- Guard Metal kernel against hc_mult != 4 by falling back to scalar ops
- Add per-query split sparse attention for prefill (L > 1): each query
  attends only to its own top-k compressed keys instead of all L*topk
  flattened keys, reducing attention scores from O(L*(T+L*topk)) to
  O(L*(T+topk)) and fixing cross-position key leakage
- Generation (L==1) keeps SDPA path for Flash Attention speed
- Document omitted Hadamard rotation and FP4 query simulation gaps
- Stack wo_a.0..7 into single wo_a by concatenating along output dim
- Add embed/head scales/biases to top_remap for 4-bit quant models
- Include biases suffix in expert stacking loop (alongside weight/scales)
- Handle already-stacked expert weights (experts.w1.weight → switch_mlp)
- Drop expert .biases keys (QuantizedSwitchLinear doesn't use them)
- Remove hardcoded per-layer expert quant defaults from __post_init__
  so checkpoint's actual group_size flows through to nn.quantize
AutoTokenizer.from_pretrained fails for unknown model types (e.g.
deepseek_v4) when config.json has rope_scaling — transformers'
config standardization accesses max_position_embeddings on a generic
PreTrainedConfig. Fall back to PreTrainedTokenizerFast directly.
Single-pass online softmax over local window + index-gathered compressed
KV. Eliminates ~1GB intermediates per sparse layer during prefill by
avoiding the broadcast+gather of [B,L,topk,D] and all score/exp tensors.

- 256 threads (4 per head × 64 heads), SIMD shuffle dot product reduction
- Masked local KV with early skip, sparse KV gathered by index on the fly
- FlashAttention-style running max/sum, attention sink in denominator
- Falls back to pure MLX split attention on CPU / non-Metal
- Use constant address space for topk_idxs pointer (matches MLX int input)
- Use remove_reference_t<decltype(*op)> for output cast (decltype(*op)
  resolves to a reference type, can't functional-cast to it)
…put write

Three correctness bugs:

1. Address space: all pointer declarations in kernel used `const device auto*`
   which fails when MLX places small tensors in `constant` space. Changed all
   to `auto` so the compiler deduces the correct address space from context.

2. Grid/gid: MLX grid = total thread count, not number of threadgroups.
   Was dispatching grid=(B*L) with threadgroup=(H*4), giving only 1 thread
   per (b,l) pair instead of H*4. Fixed: grid=(B*L*H*4) so each (b,l) gets
   its own threadgroup of H*4 threads. Changed gid from thread_position_in_grid
   to threadgroup_position_in_grid to get the (b,l) index.

3. Output write: static_cast<remove_reference_t<decltype(*op)>>() silently
   fails for address-space-qualified types in Metal. Added store_elem header
   helper that deduces the base element type T via `device T&` to strip the
   device qualifier, then assigns T(v) for correct float->dtype conversion.

Verified: float32 max diff < 3e-7, bfloat16 max diff < 3e-3 vs CPU reference.
generate_step does prompt[None], so a (1, L) server input becomes (1, 1, L).
embed_tokens then produces (1, 1, L, D), causing h.shape[2] in the hc_mult
broadcast target to read L instead of D, crashing with shape mismatch.

Flatten inputs to (B, L) at the top of __call__ before embedding.
…ctly

During decode (L=1), C (compressed tokens accumulated so far) is always
<= index_topk=512, so the indexer would select all tokens anyway. Running
it wastes matmuls + argpartition, and the subsequent expand→take_along_axis
→concat path creates intermediate tensors versus simply using pooled[:, None].

Restrict the indexer call to prefill (L > 1) only. Decode now takes the
pooled[:, None] path directly, matching SDPA's efficient single-query path.
…diates

Replaces the split-rotate-concat path in _apply_partial_rope with a single
Metal kernel dispatch. Each threadgroup (one SIMD group of 32 threads) handles
one (batch, head, token) triplet: threads stride-copy the nope prefix, then
each lane rotates one interleaved pair of the RoPE suffix in f32 and writes
back in the original dtype.

For prefill (L=512) this eliminates the intermediate nope/pe split tensors and
concatenation across 183 calls per forward pass, reducing bandwidth from ~6
passes over the head tensor to 2 (read + write). For decode (L=1) the shapes
are tiny so the benefit is lower dispatch overhead.

Verified: float32 inputs agree with reference to <5e-7; bfloat16 round-trip
(forward then inverse) recovers input to within 2 ULPs; bf16 vs reference
diffs are purely arithmetic-order (f32 intermediates vs bf16 intermediates).
- Q head RMS norm: fuse 5-op sequence into a single Metal kernel using
  32-thread SIMD group reduction (simd_sum), avoiding 3 intermediate
  allocations per forward pass

- _overlap_transform: replace mx.full(large)+scatter with a small
  mx.full((B,1,R,head_dim))+concatenate; eliminates one large
  allocation and two masked writes

- HC fn weights: store fn as bfloat16 instead of float32, halving
  weight-read bandwidth for the 183 MB/token HC GEMV; norm kept in
  float32 for correctness; narrowed cast_predicate to only exclude
  base/scale, not fn
Buffer the raw hidden state x instead of the projected kv/gate in the
DeepseekV4Cache compressor and indexer state dicts.  Compressor.__call__
now calls wkv/wgate only once per ratio decode steps rather than every step,
saving (ratio-1)/ratio of those GEMV reads:

  ratio=4  layers (~21): saves 3/4 × 2 × 4096×1024 bf16 per step ≈ 263 MB
  ratio=128 layers (~21): saves ~99% × 2 × 4096×512 bf16 per step ≈ 131 MB

Total ~394 MB/step of weight-read traffic eliminated, ≈0.5 ms at 800 GB/s.
Math is identical to the old path because wkv/wgate are linear — batching
four single-token projections gives the same result as projecting each
token individually and concatenating.

Verified: max_diff=0.0 vs old kv/gate buffering path across 8 decode steps.
1. HC collapse Metal kernel (_make_hc_collapse_kernel):
   Fuses (pre[..., None] * x.astype(f32)).sum(axis=2).astype(dtype) into a
   single 32-thread-per-(b,l) pass, eliminating ~3 intermediate tensor
   allocations. Used in HyperConnection.collapse and HyperHead (87 calls/step).

2. HC expand Metal kernel (_make_hc_expand_kernel):
   Fuses post*block_out + matmul(comb, residual) into one pass, eliminating
   the 64 KB f32 matmul intermediate MLX must materialize before the add.
   32 threads per (b, l, h) triple (86 calls/step).

   Together the two kernels reduce the HC portion of the compute graph from
   ~774 MLX ops to ~172, saving ~600 graph nodes of Python overhead per step.

3. MoEGate bf16 GEMV:
   Gate weight is stored as bf16 via cast_predicate but was being upcast to
   f32 before the matmul, forcing a 2 MB f32 materialization per layer.
   Change to (flat @ weight.T).astype(float32) → saves ~86 MB/step.
   Relative error vs f32 GEMV: 0.9% at trained weight scale, negligible for
   top-k routing.

Total expected savings: ~0.4-0.7 ms/token (graph overhead + bandwidth).
Comment thread mlx_lm/models/deepseek_v4.py Outdated
Comment thread mlx_lm/models/deepseek_v4.py Outdated
Comment thread mlx_lm/models/deepseek_v4.py Outdated
Comment thread mlx_lm/tokenizer_utils.py Outdated
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Left a few comments

…n bf16, cast_predicate, tokenizer fallback; remove FP4 comment

- Revert n_rows bounds guard + grid rounding in hc_split_sinkhorn
- Remove fused sparse attn Metal kernel; inline gather + _split_sparse_attention in V4Attention
- Revert HC fn bfloat16 in HyperConnection and HyperHead (no measurable difference)
- Revert cast_predicate: remove .attn_hc/.ffn_hc/.hc_head base/scale exclusions
- Remove FP4 query simulation comment from Indexer
- Revert tokenizer AttributeError fallback: fixed upstream in transformers#45643
Per review feedback: remove loops, maximize Metal bandwidth.

- Replace 4 individual element casts per row with single float4 vector
  loads from mix/base (both float32, 16-byte aligned)
- Unroll the for(i<4) comb-row init loop: 4 explicit float4
  load/mul/exp/recip statements the compiler can schedule in parallel
- Unroll the for(i<4) Sinkhorn row-norm inner loop: 4 explicit
  multiply statements per iteration instead of a serial loop

Verified: max diff vs Python fallback < 6e-8 (float32 rounding only).
All 4 DeepSeek V4 unit tests pass.
Comment on lines 1644 to 1646
@@ -1645,9 +1696,6 @@
return not (
"attn_sink" in k
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this deletion

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gather indexed pooled KV, flatten into full_kv, and use a single
scaled_dot_product_attention call (Flash Attention backend) instead of
the two-pass _split_sparse_attention + log-sum-exp merge.

Matches the maintainer's approach: one SDPA call over [local_kv | pooled]
instead of separate local + sparse attention computations.
Restores maintainer's V4Attention compress path (select_all optimization,
mask trimming, pooled_mask/pooled_bias handling), SwitchGLU score
multiplication inside the layer, and HyperConnection float32 matmul.
Retains approved changes: vectorized sinkhorn kernel, overlap_transform
concatenate, and q-norm Metal kernel.
Comment thread README.md
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s remove the readme changes for now

Neither was wired up or called. q-norm kernel replaced by
mx.fast.rms_norm; _split_sparse_attention superseded by SDPA path.
- Extract non-hash MoE gate scoring into @mx.compile _expert_select,
  fusing sqrt-softplus + bias + argpartition + gather + normalize + scale
  into a single cached graph (matches deepseek_v32.py pattern)
- Guard compress path with pooled.shape[1] > 0 to skip indexer checks,
  mask handling, and empty concat on most decode steps
machiabeli added a commit to machiabeli/mlx-lm-1 that referenced this pull request Apr 26, 2026
… wo_a

Three integrations from sibling PRs and shared community work:

1. Fused partial-RoPE Metal kernel — adapted from @0xClandestine's
   optimization PR (Blaizzy#13). Collapses the scalar-Python
   rotation chain (~5 graph ops per rope call) into a single Metal
   kernel per (b, h, l) work item, with one SIMD-group lane per
   interleaved pair. ~600 fewer dispatches per token on the decode
   path at L=1, where DeepseekV4 invokes rope ~3x per attention
   layer (q_pe, k_pe, inverse on attention output).

   Env-var escape hatch (MLX_LM_DISABLE_PARTIAL_ROPE_KERNEL=1) lets
   benchmarks A/B kernel ON vs OFF without monkey-patching.

2. Separate self.rope and self.compress_rope instances — main Q/K
   always rotate with rope_theta; compressed-pool RoPE uses
   compress_rope_theta. Same intent as @Blaizzy's b78ccb1 fix on
   ml-explore#1192 ported to HEAD's 3-arg DeepseekV4RoPE signature. Fixes the
   periodic CJK token drops reported by @Shinka-Man on ml-explore#1192.

3. Batched grouped quantized wo_a — adapted from @Blaizzy's
   pc/add-deepseekv4flash-model branch. Replaces the 8-dispatch
   per-group Python loop in _grouped_output_projection with a
   single mx.quantized_matmul call by treating the group dim as a
   broadcast batch dim. Same numerical result, fewer dispatches.

Co-authored-by: clandestine.eth <[email protected]>
Co-authored-by: Prince Canuma <[email protected]>
Reported-by: Shinkaman <[email protected]>
@0xClandestine 0xClandestine force-pushed the perf/optimize-ds4 branch 2 times, most recently from 6e6fea7 to c55bc53 Compare April 26, 2026 04:17
Add @mx.compile _hyper_head_op that fuses RMS-rsqrt + matmul + sigmoid
+ weighted sum for the output HyperHead (1 call per token). No weight
layout changes — inference only, training path unchanged.
Pre-compute and cache per-layer tensors that were redundantly
recomputed every decode step (43 layers x every token):
- wo_a weight/scales/biases reshapes for grouped output projection
- attn_sink and q_norm_weight dtype casts
- MoE gate weight transpose + float32 cast
Profiling shows HyperConnection ops are ~52% of decode time (86
HC cycles × 4-5 kernel dispatches each = ~430 dispatches/token).

New _hc_mixes compiled function fuses rms_rsqrt + matmul + scale
into one dispatch, and caches fn.T to avoid repeated transpose.
Reduces ~2 dispatches per HC cycle = ~172 fewer dispatches/token.
Eliminates one kernel dispatch per HC cycle by computing the sinkhorn
normalization and collapse weighted sum in the same Metal kernel.

Key optimizations:
- Parallel sinkhorn: threads 0-3 each own one comb row, column
  normalization via simd_sum (free SIMD shuffle, no shared memory)
- Vectorized collapse: float4 loads process 4 bfloat16 elements at once
- Thread layout: first simd group handles sinkhorn in parallel,
  all 256 threads collaborate on the collapse reduction over D
Replace dot(e, float4(1)) with explicit e.x+e.y+e.z+e.w to avoid
unnecessary fmul. Restructure softmax to compute all 4 row maxes
before any exp for better instruction-level parallelism. Apply same
pattern to both standalone and fused sinkhorn kernels.
- Replace divergent if(lane<HC) branches with multiplicative active mask
  so all 32 SIMD lanes execute identical instructions (no serialization)
- Use native bfloat4 vector loads (single 64-bit load per 4 elements)
  instead of 4 scalar bfloat16 loads + manual float4 construction
- FMA chains in collapse: fma(p0,x0, fma(p1,x1, fma(p2,x2, p3*x3)))
- Compile-time #if (D%4)!=0 eliminates dead scalar tail code
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 🚀

@Blaizzy Blaizzy merged commit 910b120 into Blaizzy:pc/add-deepseekv4flash-model Apr 26, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants