perf(deepseek_v4): fuse switch_mlp gate_proj + up_proj into single gather_qmm#1999
Open
adurham wants to merge 2 commits intoexo-explore:mainfrom
Open
perf(deepseek_v4): fuse switch_mlp gate_proj + up_proj into single gather_qmm#1999adurham wants to merge 2 commits intoexo-explore:mainfrom
adurham wants to merge 2 commits intoexo-explore:mainfrom
Conversation
added 2 commits
April 27, 2026 14:46
PR exo-explore#1978 shipped a head-parallel attention sharding strategy for DeepSeek V4 (`_shard_v4_attention_heads` interleaved-per-group split of `wq_b`/`attn_sink`, plus `_AllSumLinear` to reduce wo_b's input across ranks). On the deployable mlx-community 6-bit checkpoint (`mlx-community/DeepSeek-V4-Flash-6bit`) the resulting model loads cleanly and runs to READY but produces saturated-softmax gibberish at inference: `logprob: 0.0` on every chosen token, empty `top_logprobs` even at temp=1, output is uniform-distribution BPE fragments (`'197197197197197197733733Ca733ca…'` on `"The capital of France is"`). The fault is the head-parallel slicing of V4Attention's LoRA-decomposed projections. V4Attention's `_grouped_output_projection` manually reshapes `wo_a.weight/.scales/.biases` based on `n_groups`/`o_lora_rank`/ `head_dim`; the head-parallel `wq_b` slice changes the per-group input dimensionality that reshape sees, so even though the linear layers all have valid shapes the math composing them no longer produces correct activations. Replicated attention sidesteps this entirely. This PR drops the head-parallel attention sharding and ships a MoE-only strategy that: * replicates attention on every rank (no V4Attention math touched); * shards `ffn.shared_experts.{gate,down,up}_proj` and `ffn.switch_mlp.{gate,down,up}_proj` (all-to-sharded / sharded-to-all on the input/output dim); * wraps `layer.ffn` in the existing `ShardedMoEV4` so cross-rank reduction happens via `sum_gradients` on input + `all_sum` on the MoE output. Memory profile on Flash (158B-total / 13B-active, 4-bit / mxfp4 experts): | section | full sharding | MoE-only | | ---------------- | ------------- | ---------- | | attention/rank | ~15 GB | ~30 GB | | MoE/rank | ~65 GB | ~65 GB | | total/rank @ N=2 | ~80 GB | ~95 GB | Both fit on 128 GB nodes. MoE-only also drops one `all_sum` collective per layer (no wo_b reduction) — single-stream decode is dispatch-bound on the cluster, so this matters. Removes ~125 lines of helper code (`_shard_quantized_rows`, `_AllSumLinear`, `_shard_v4_attention_heads`); adds ~50 lines of strategy. `ShardedMoEV4` retained — it's the cross-rank reduction wrapper the new strategy uses. Bench (mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, single 13B-active forward per decode step, warmup excluded): | concurrency | aggregate tok/s | per-request tok/s | | ----------- | --------------- | ----------------- | | 1 | 34.2 | 34.2 | | 2 | 44.1 | 22.1 | | 4 | 52.8 (mean) | 13.2 | Notes: * Tested with the Blaizzy DSv4 model implementation (`mlx-lm` PR exo-explore#1192). Verified MoE-only sharding also loads cleanly with rltakashige's variant currently pinned in `pyproject.toml` (`leo/deepseek-v4`), but that combination still emits gibberish — the bug there is in the model implementation itself (Pipeline-mode also fails, so it isn't sharding-related). Tracking that separately; this PR is independent of it. * No existing DSv4 tests in the repo. Issue: <link to repro thread / PR comment once filed>
…ther_qmm Adds an opt-in `_FusedSwitchGLU` (with `_install_fused_switch_glu`) that pre-concatenates a SwitchGLU's `gate_proj` + `up_proj` weights along the output-dim axis so a single `mx.gather_qmm` produces both halves. We split, apply the original `self.activation(x_up, x_gate)` (preserving custom activations like DSv4's `_DSV4SwiGLU(swiglu_limit)`), then run `down_proj` unchanged. Saves one Metal dispatch per decoder layer per decode token: 43 dispatches × ~100-200 µs each on the deployed 2× M4 Max RDMA cluster. Bench-validated **+1.2% c=1, +1.1% c=2** on `mlx-community/DeepSeek-V4-Flash-6bit` (33.8 → 34.6 tok/s c=1, 45.4 → 45.9 tok/s c=2; 0% regression at c≥6, no quality regression — coherent output on `"The capital of France is"` smoke prompt and reasoning trace identical to the unfused path). Wired into `DeepseekV4ShardingStrategy` after tensor-parallel sharding of `gate_proj` / `up_proj` so the concat happens on the post-shard local output dim. Concat order is `[up, gate]` to match SwitchGLU's call sequence `self.activation(x_up, x_gate)`. **Off by default** behind `EXO_DSV4_FUSED_MOE=1`. The instance rebind happens via `__class__ = _FusedSwitchGLU`, preserving the SwitchGLU attributes (activation, sort_threshold, down_proj). The pre-fuse `gate_proj` / `up_proj` instances are replaced with empty modules to free their weights — without this they'd live alongside the fused weight and triple the per-layer MoE memory footprint. Caveats: - Smaller relative gain than fork's MiniMax fused MoE (which sees ~+6%) because DSv4's compressor + indexer + LoRA-decomposed V4Attention add work outside the MoE — MoE is a smaller fraction of decode time, so saving its dispatches has smaller proportional impact. DSv4 also has fewer layers (43 vs 62), 30% fewer dispatches saved. - Bench was on the Blaizzy DSv4 mlx-lm variant (`mlx-lm` PR exo-explore#1192). Code is interface-compatible with rltakashige's variant (`leo/deepseek-v4`, the branch upstream's `pyproject.toml` pins) — same `SwitchGLU.__call__(x, indices)` signature, same pluggable `self.activation` — but rltakashige's variant has a separate inference-quality issue (saturated softmax / gibberish output) unrelated to sharding that I haven't been able to fix; couldn't validate decode quality there. Stacks on PR exo-explore#1996 (the MoE-only DSv4 sharding strategy this fused path slots into).
adurham
pushed a commit
to adurham/exo
that referenced
this pull request
Apr 27, 2026
…s mlx-lm pin - docs/upstream-prs.md: Open PRs 10 → 11 with exo-explore#1999 entry. Bench evidence (+1.2% c=1 / +1.1% c=2) noted in status. - docs/fork-notes.md: - Bump adurham/mlx-lm pin info (was stale at 65655ce; now 8dfce1f) and document the dsv4-perf-bisect@034a42a active pin + the rlt-deepseek_v4 graft on :main from today's pivot test. - Add the `_quantize` predicate fix as a tracked fork patch (filed as mlx-lm PR exo-explore#1216). - Add a "DeepSeek V4 sharding + fused MoE" subsection under Open optimization projects with PR exo-explore#1996 + exo-explore#1999 references and the bench numbers from today's session.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an opt-in
_FusedSwitchGLU(with_install_fused_switch_glu) that pre-concatenates a SwitchGLU'sgate_proj+up_projweights along the output-dim axis so a singlemx.gather_qmmproduces both halves. We split, apply the originalself.activation(x_up, x_gate)(preserving custom activations like DSv4's_DSV4SwiGLU(swiglu_limit)), then rundown_projunchanged.Saves one Metal dispatch per decoder layer per decode token: 43 dispatches × ~100-200 µs each on the deployed 2× M4 Max RDMA cluster.
Bench
mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, warmup excluded:tail_ratio_mean=1.00across all concurrency levels, zero errors. Smoke test on\"The capital of France is\"produces identical reasoning trace + answer between unfused and fused paths.How it's wired
Wired into
DeepseekV4ShardingStrategy.shard_model(the strategy from #1996) after tensor-parallel sharding ofgate_proj/up_proj, so the concat happens on the post-shard local output dim. Concat order is[up, gate]to match SwitchGLU's call sequenceself.activation(x_up, x_gate).The instance rebind happens via
__class__ = _FusedSwitchGLU— preserves the SwitchGLU attributes (activation,sort_threshold,down_proj). The pre-fusegate_proj/up_projinstances are replaced with empty modules to free their weights — without this they'd live alongside the fused weight and triple the per-layer MoE memory footprint.Off by default behind
EXO_DSV4_FUSED_MOE=1.Caveats
leo/deepseek-v4, the branch upstream'spyproject.tomlpins) — sameSwitchGLU.__call__(x, indices)signature, same pluggableself.activation— but rlt's variant has a separate inference-quality issue (saturated softmax / gibberish output, reproducible with PR Add DeepSeek V4 Flash/Pro #1978's sharding too) that I haven't been able to fix; couldn't validate decode quality there.Test plan
uv run ruff check src/exo/worker/engines/mlx/auto_parallel.py— cleanuv run basedpyright src/exo/worker/engines/mlx/auto_parallel.py— only the deliberateself_any: AnyreportAny entries (mirrors the existing fork pattern for bypassing mlx-lm's untyped APIs)EXO_DSV4_FUSED_MOE=0and=1both produce coherent output with identical reasoning traces🤖 Generated with Claude Code