perf: optimize DeepSeek-V4#13
Merged
Blaizzy merged 43 commits intoBlaizzy:pc/add-deepseekv4flash-modelfrom Apr 26, 2026
Merged
Conversation
2abd684 to
9b9f012
Compare
- Add n_rows bounds guard to prevent out-of-bounds writes when grid is not a multiple of threadgroup size (256) - #if HC == 4 fast path: float4 vectorized pre/post sigmoid, tree-reduce max for comb softmax, dot-product row sum, metal::fast::recip/exp throughout, aligned 128-bit stores for all outputs - Sinkhorn passes use float4 row vectors — column sums reduce to 4 vector additions, row sums to a single dot(row, 1) instruction - Scalar fallback (#else): replace 1/x with metal::fast::recip and a*b+c with fma for all mix/base linear combinations
MLX injects template params as constexpr variables, not #define macros. The preprocessor sees HC as undefined (expands to 0), so #if HC == 4 always evaluated false and the scalar fallback ran unconditionally, silently defeating all float4 optimizations. Drop the #if/#else/#endif entirely since HC is always 4 for this model.
- Fix initial softmax eps: remove spurious eps from recip denominator to match Python reference (mx.softmax) and old kernel behavior - Guard Metal kernel against hc_mult != 4 by falling back to scalar ops - Add per-query split sparse attention for prefill (L > 1): each query attends only to its own top-k compressed keys instead of all L*topk flattened keys, reducing attention scores from O(L*(T+L*topk)) to O(L*(T+topk)) and fixing cross-position key leakage - Generation (L==1) keeps SDPA path for Flash Attention speed - Document omitted Hadamard rotation and FP4 query simulation gaps
- Stack wo_a.0..7 into single wo_a by concatenating along output dim - Add embed/head scales/biases to top_remap for 4-bit quant models - Include biases suffix in expert stacking loop (alongside weight/scales)
- Handle already-stacked expert weights (experts.w1.weight → switch_mlp) - Drop expert .biases keys (QuantizedSwitchLinear doesn't use them) - Remove hardcoded per-layer expert quant defaults from __post_init__ so checkpoint's actual group_size flows through to nn.quantize
AutoTokenizer.from_pretrained fails for unknown model types (e.g. deepseek_v4) when config.json has rope_scaling — transformers' config standardization accesses max_position_embeddings on a generic PreTrainedConfig. Fall back to PreTrainedTokenizerFast directly.
Single-pass online softmax over local window + index-gathered compressed KV. Eliminates ~1GB intermediates per sparse layer during prefill by avoiding the broadcast+gather of [B,L,topk,D] and all score/exp tensors. - 256 threads (4 per head × 64 heads), SIMD shuffle dot product reduction - Masked local KV with early skip, sparse KV gathered by index on the fly - FlashAttention-style running max/sum, attention sink in denominator - Falls back to pure MLX split attention on CPU / non-Metal
- Use constant address space for topk_idxs pointer (matches MLX int input) - Use remove_reference_t<decltype(*op)> for output cast (decltype(*op) resolves to a reference type, can't functional-cast to it)
… pointer to device
…put write Three correctness bugs: 1. Address space: all pointer declarations in kernel used `const device auto*` which fails when MLX places small tensors in `constant` space. Changed all to `auto` so the compiler deduces the correct address space from context. 2. Grid/gid: MLX grid = total thread count, not number of threadgroups. Was dispatching grid=(B*L) with threadgroup=(H*4), giving only 1 thread per (b,l) pair instead of H*4. Fixed: grid=(B*L*H*4) so each (b,l) gets its own threadgroup of H*4 threads. Changed gid from thread_position_in_grid to threadgroup_position_in_grid to get the (b,l) index. 3. Output write: static_cast<remove_reference_t<decltype(*op)>>() silently fails for address-space-qualified types in Metal. Added store_elem header helper that deduces the base element type T via `device T&` to strip the device qualifier, then assigns T(v) for correct float->dtype conversion. Verified: float32 max diff < 3e-7, bfloat16 max diff < 3e-3 vs CPU reference.
generate_step does prompt[None], so a (1, L) server input becomes (1, 1, L). embed_tokens then produces (1, 1, L, D), causing h.shape[2] in the hc_mult broadcast target to read L instead of D, crashing with shape mismatch. Flatten inputs to (B, L) at the top of __call__ before embedding.
…ctly During decode (L=1), C (compressed tokens accumulated so far) is always <= index_topk=512, so the indexer would select all tokens anyway. Running it wastes matmuls + argpartition, and the subsequent expand→take_along_axis →concat path creates intermediate tensors versus simply using pooled[:, None]. Restrict the indexer call to prefill (L > 1) only. Decode now takes the pooled[:, None] path directly, matching SDPA's efficient single-query path.
…diates Replaces the split-rotate-concat path in _apply_partial_rope with a single Metal kernel dispatch. Each threadgroup (one SIMD group of 32 threads) handles one (batch, head, token) triplet: threads stride-copy the nope prefix, then each lane rotates one interleaved pair of the RoPE suffix in f32 and writes back in the original dtype. For prefill (L=512) this eliminates the intermediate nope/pe split tensors and concatenation across 183 calls per forward pass, reducing bandwidth from ~6 passes over the head tensor to 2 (read + write). For decode (L=1) the shapes are tiny so the benefit is lower dispatch overhead. Verified: float32 inputs agree with reference to <5e-7; bfloat16 round-trip (forward then inverse) recovers input to within 2 ULPs; bf16 vs reference diffs are purely arithmetic-order (f32 intermediates vs bf16 intermediates).
- Q head RMS norm: fuse 5-op sequence into a single Metal kernel using 32-thread SIMD group reduction (simd_sum), avoiding 3 intermediate allocations per forward pass - _overlap_transform: replace mx.full(large)+scatter with a small mx.full((B,1,R,head_dim))+concatenate; eliminates one large allocation and two masked writes - HC fn weights: store fn as bfloat16 instead of float32, halving weight-read bandwidth for the 183 MB/token HC GEMV; norm kept in float32 for correctness; narrowed cast_predicate to only exclude base/scale, not fn
Buffer the raw hidden state x instead of the projected kv/gate in the DeepseekV4Cache compressor and indexer state dicts. Compressor.__call__ now calls wkv/wgate only once per ratio decode steps rather than every step, saving (ratio-1)/ratio of those GEMV reads: ratio=4 layers (~21): saves 3/4 × 2 × 4096×1024 bf16 per step ≈ 263 MB ratio=128 layers (~21): saves ~99% × 2 × 4096×512 bf16 per step ≈ 131 MB Total ~394 MB/step of weight-read traffic eliminated, ≈0.5 ms at 800 GB/s. Math is identical to the old path because wkv/wgate are linear — batching four single-token projections gives the same result as projecting each token individually and concatenating. Verified: max_diff=0.0 vs old kv/gate buffering path across 8 decode steps.
1. HC collapse Metal kernel (_make_hc_collapse_kernel): Fuses (pre[..., None] * x.astype(f32)).sum(axis=2).astype(dtype) into a single 32-thread-per-(b,l) pass, eliminating ~3 intermediate tensor allocations. Used in HyperConnection.collapse and HyperHead (87 calls/step). 2. HC expand Metal kernel (_make_hc_expand_kernel): Fuses post*block_out + matmul(comb, residual) into one pass, eliminating the 64 KB f32 matmul intermediate MLX must materialize before the add. 32 threads per (b, l, h) triple (86 calls/step). Together the two kernels reduce the HC portion of the compute graph from ~774 MLX ops to ~172, saving ~600 graph nodes of Python overhead per step. 3. MoEGate bf16 GEMV: Gate weight is stored as bf16 via cast_predicate but was being upcast to f32 before the matmul, forcing a 2 MB f32 materialization per layer. Change to (flat @ weight.T).astype(float32) → saves ~86 MB/step. Relative error vs f32 GEMV: 0.9% at trained weight scale, negligible for top-k routing. Total expected savings: ~0.4-0.7 ms/token (graph overhead + bandwidth).
…E gate bf16" This reverts commit a8a2925.
…it__ method accordingly.
…_attention functions (30 -> 32 tok/s)
Blaizzy
reviewed
Apr 25, 2026
Blaizzy
reviewed
Apr 25, 2026
Blaizzy
reviewed
Apr 25, 2026
Blaizzy
reviewed
Apr 25, 2026
Blaizzy
reviewed
Apr 25, 2026
Owner
Blaizzy
left a comment
There was a problem hiding this comment.
Thanks a lot! Left a few comments
…n bf16, cast_predicate, tokenizer fallback; remove FP4 comment - Revert n_rows bounds guard + grid rounding in hc_split_sinkhorn - Remove fused sparse attn Metal kernel; inline gather + _split_sparse_attention in V4Attention - Revert HC fn bfloat16 in HyperConnection and HyperHead (no measurable difference) - Revert cast_predicate: remove .attn_hc/.ffn_hc/.hc_head base/scale exclusions - Remove FP4 query simulation comment from Indexer - Revert tokenizer AttributeError fallback: fixed upstream in transformers#45643
Per review feedback: remove loops, maximize Metal bandwidth. - Replace 4 individual element casts per row with single float4 vector loads from mix/base (both float32, 16-byte aligned) - Unroll the for(i<4) comb-row init loop: 4 explicit float4 load/mul/exp/recip statements the compiler can schedule in parallel - Unroll the for(i<4) Sinkhorn row-norm inner loop: 4 explicit multiply statements per iteration instead of a serial loop Verified: max diff vs Python fallback < 6e-8 (float32 rounding only). All 4 DeepSeek V4 unit tests pass.
Blaizzy
reviewed
Apr 26, 2026
Comment on lines
1644
to
1646
| @@ -1645,9 +1696,6 @@ | |||
| return not ( | |||
| "attn_sink" in k | |||
Gather indexed pooled KV, flatten into full_kv, and use a single scaled_dot_product_attention call (Flash Attention backend) instead of the two-pass _split_sparse_attention + log-sum-exp merge. Matches the maintainer's approach: one SDPA call over [local_kv | pooled] instead of separate local + sparse attention computations.
Restores maintainer's V4Attention compress path (select_all optimization, mask trimming, pooled_mask/pooled_bias handling), SwitchGLU score multiplication inside the layer, and HyperConnection float32 matmul. Retains approved changes: vectorized sinkhorn kernel, overlap_transform concatenate, and q-norm Metal kernel.
Blaizzy
reviewed
Apr 26, 2026
Owner
There was a problem hiding this comment.
Let’s remove the readme changes for now
Neither was wired up or called. q-norm kernel replaced by mx.fast.rms_norm; _split_sparse_attention superseded by SDPA path.
- Extract non-hash MoE gate scoring into @mx.compile _expert_select, fusing sqrt-softplus + bias + argpartition + gather + normalize + scale into a single cached graph (matches deepseek_v32.py pattern) - Guard compress path with pooled.shape[1] > 0 to skip indexer checks, mask handling, and empty concat on most decode steps
machiabeli
added a commit
to machiabeli/mlx-lm-1
that referenced
this pull request
Apr 26, 2026
… wo_a Three integrations from sibling PRs and shared community work: 1. Fused partial-RoPE Metal kernel — adapted from @0xClandestine's optimization PR (Blaizzy#13). Collapses the scalar-Python rotation chain (~5 graph ops per rope call) into a single Metal kernel per (b, h, l) work item, with one SIMD-group lane per interleaved pair. ~600 fewer dispatches per token on the decode path at L=1, where DeepseekV4 invokes rope ~3x per attention layer (q_pe, k_pe, inverse on attention output). Env-var escape hatch (MLX_LM_DISABLE_PARTIAL_ROPE_KERNEL=1) lets benchmarks A/B kernel ON vs OFF without monkey-patching. 2. Separate self.rope and self.compress_rope instances — main Q/K always rotate with rope_theta; compressed-pool RoPE uses compress_rope_theta. Same intent as @Blaizzy's b78ccb1 fix on ml-explore#1192 ported to HEAD's 3-arg DeepseekV4RoPE signature. Fixes the periodic CJK token drops reported by @Shinka-Man on ml-explore#1192. 3. Batched grouped quantized wo_a — adapted from @Blaizzy's pc/add-deepseekv4flash-model branch. Replaces the 8-dispatch per-group Python loop in _grouped_output_projection with a single mx.quantized_matmul call by treating the group dim as a broadcast batch dim. Same numerical result, fewer dispatches. Co-authored-by: clandestine.eth <[email protected]> Co-authored-by: Prince Canuma <[email protected]> Reported-by: Shinkaman <[email protected]>
6e6fea7 to
c55bc53
Compare
Add @mx.compile _hyper_head_op that fuses RMS-rsqrt + matmul + sigmoid + weighted sum for the output HyperHead (1 call per token). No weight layout changes — inference only, training path unchanged.
Pre-compute and cache per-layer tensors that were redundantly recomputed every decode step (43 layers x every token): - wo_a weight/scales/biases reshapes for grouped output projection - attn_sink and q_norm_weight dtype casts - MoE gate weight transpose + float32 cast
9074e95 to
ea8ea24
Compare
Profiling shows HyperConnection ops are ~52% of decode time (86 HC cycles × 4-5 kernel dispatches each = ~430 dispatches/token). New _hc_mixes compiled function fuses rms_rsqrt + matmul + scale into one dispatch, and caches fn.T to avoid repeated transpose. Reduces ~2 dispatches per HC cycle = ~172 fewer dispatches/token.
Eliminates one kernel dispatch per HC cycle by computing the sinkhorn normalization and collapse weighted sum in the same Metal kernel. Key optimizations: - Parallel sinkhorn: threads 0-3 each own one comb row, column normalization via simd_sum (free SIMD shuffle, no shared memory) - Vectorized collapse: float4 loads process 4 bfloat16 elements at once - Thread layout: first simd group handles sinkhorn in parallel, all 256 threads collaborate on the collapse reduction over D
5366d6d to
9929953
Compare
Replace dot(e, float4(1)) with explicit e.x+e.y+e.z+e.w to avoid unnecessary fmul. Restructure softmax to compute all 4 row maxes before any exp for better instruction-level parallelism. Apply same pattern to both standalone and fused sinkhorn kernels.
9d86f5b to
9a69e7b
Compare
- Replace divergent if(lane<HC) branches with multiplicative active mask so all 32 SIMD lanes execute identical instructions (no serialization) - Use native bfloat4 vector loads (single 64-bit load per 4 elements) instead of 4 scalar bfloat16 loads + manual float4 construction - FMA chains in collapse: fma(p0,x0, fma(p1,x1, fma(p2,x2, p3*x3))) - Compile-time #if (D%4)!=0 eliminates dead scalar tail code
910b120
into
Blaizzy:pc/add-deepseekv4flash-model
2 checks passed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
DeepSeek V4 Optimization Changelog
Targets the DeepSeek-V4/V4-Flash Metal kernels in
mlx_lm/models/deepseek_v4.py.What stuck
_make_hc_split_sinkhorn_kernelto usefloat4SIMD loads/stores,dot()for row sums, and fully unrolled 4x4 Sinkhorn iterations instead of scalar loops._overlap_transformconcatenate — Replacedmx.full+ scatter-assign withmx.concatenateof sliced tensors, avoiding a large intermediate allocation.ds4_q_norm) — Fused RMS norm for query vectors using a 32-thread SIMD group reduction, replacingmx.rsqrt+ elementwise.hc_mult != 4guard — Falls back to pure-MLX Sinkhorn whenhc_multis not 4, since the float4 kernel is hardcoded for 4x4.What was tried and reverted
ds4_fused_sparse_attn) — Online-softmax over local + sparse KV in a single MSL kernel. Multiple iterations fixing address spaces, dispatch, output shapes._split_sparse_attention(MLX fallback) — Two-pass local+sparse attention with log-sum-exp merge,@mx.compile-decorated. Replaced by standard SDPA with gather+concatenate.ds4_partial_rope) — Eliminated split/concat intermediate for partial rotary embeddings.wkv/wgateGEMVs until a full compression window was ready.fnin model dtype — Changedself.fnfromfloat32to default (bf16), splitflat/flat_f32to do matmul in model dtype.scoresparam, used_gather_sorthelper, moved score multiplication outside toDeepseekV4MoE.__call__.select_alloptimization,lengths/pooled_masktracking, mask trimming, andpooled_biashandling.select_allfast-path and mask trim are critical.PreTrainedTokenizerFastfallback — CaughtAttributeErrorfromAutoTokenizerfor custom models withrope_scaling.Final benchmark