Skip to content

perf(deepseek_v4): fuse switch_mlp gate_proj + up_proj into single gather_qmm#1999

Open
adurham wants to merge 2 commits intoexo-explore:mainfrom
adurham:dsv4-fused-moe-gate-up
Open

perf(deepseek_v4): fuse switch_mlp gate_proj + up_proj into single gather_qmm#1999
adurham wants to merge 2 commits intoexo-explore:mainfrom
adurham:dsv4-fused-moe-gate-up

Conversation

@adurham
Copy link
Copy Markdown
Contributor

@adurham adurham commented Apr 27, 2026

Stacks on #1996 (MoE-only DSv4 sharding). The diff visible here includes that PR's content; once #1996 merges, this PR's incremental change is just the fused-MoE addition (~140 lines).

Summary

Adds an opt-in _FusedSwitchGLU (with _install_fused_switch_glu) that pre-concatenates a SwitchGLU's gate_proj + up_proj weights along the output-dim axis so a single mx.gather_qmm produces both halves. We split, apply the original self.activation(x_up, x_gate) (preserving custom activations like DSv4's _DSV4SwiGLU(swiglu_limit)), then run down_proj unchanged.

Saves one Metal dispatch per decoder layer per decode token: 43 dispatches × ~100-200 µs each on the deployed 2× M4 Max RDMA cluster.

Bench

mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, warmup excluded:

concurrency baseline fused Δ
1 34.2 34.6 +1.2%
2 45.4 45.9 +1.1%
4 62.2 62.5 +0.5%
6 70.0 70.1 +0.1%
8 75.2 75.2 0%

tail_ratio_mean=1.00 across all concurrency levels, zero errors. Smoke test on \"The capital of France is\" produces identical reasoning trace + answer between unfused and fused paths.

How it's wired

Wired into DeepseekV4ShardingStrategy.shard_model (the strategy from #1996) after tensor-parallel sharding of gate_proj / up_proj, so the concat happens on the post-shard local output dim. Concat order is [up, gate] to match SwitchGLU's call sequence self.activation(x_up, x_gate).

The instance rebind happens via __class__ = _FusedSwitchGLU — preserves the SwitchGLU attributes (activation, sort_threshold, down_proj). The pre-fuse gate_proj / up_proj instances are replaced with empty modules to free their weights — without this they'd live alongside the fused weight and triple the per-layer MoE memory footprint.

Off by default behind EXO_DSV4_FUSED_MOE=1.

Caveats

  • Smaller relative gain than fork's MiniMax fused MoE (which sees ~+6% on the same kind of fusion). DSv4's compressor + indexer + LoRA-decomposed V4Attention add work outside the MoE — MoE is a smaller fraction of decode time, so saving its dispatches has smaller proportional impact. DSv4 also has fewer layers (43 vs 62), 30% fewer dispatches saved.
  • Bench was on the Blaizzy DSv4 mlx-lm variant (PR ml-explore/mlx-lm#1192). Code is interface-compatible with rltakashige's variant (leo/deepseek-v4, the branch upstream's pyproject.toml pins) — same SwitchGLU.__call__(x, indices) signature, same pluggable self.activation — but rlt's variant has a separate inference-quality issue (saturated softmax / gibberish output, reproducible with PR Add DeepSeek V4 Flash/Pro #1978's sharding too) that I haven't been able to fix; couldn't validate decode quality there.

Test plan

  • uv run ruff check src/exo/worker/engines/mlx/auto_parallel.py — clean
  • uv run basedpyright src/exo/worker/engines/mlx/auto_parallel.py — only the deliberate self_any: Any reportAny entries (mirrors the existing fork pattern for bypassing mlx-lm's untyped APIs)
  • Smoke test on cluster: EXO_DSV4_FUSED_MOE=0 and =1 both produce coherent output with identical reasoning traces
  • Bench at concurrency 1/2/4/6/8 (numbers above)

🤖 Generated with Claude Code

Adam Durham added 2 commits April 27, 2026 14:46
PR exo-explore#1978 shipped a head-parallel attention sharding strategy for
DeepSeek V4 (`_shard_v4_attention_heads` interleaved-per-group split of
`wq_b`/`attn_sink`, plus `_AllSumLinear` to reduce wo_b's input across
ranks). On the deployable mlx-community 6-bit checkpoint
(`mlx-community/DeepSeek-V4-Flash-6bit`) the resulting model loads
cleanly and runs to READY but produces saturated-softmax gibberish at
inference: `logprob: 0.0` on every chosen token, empty `top_logprobs`
even at temp=1, output is uniform-distribution BPE fragments
(`'197197197197197197733733Ca733ca…'` on `"The capital of France is"`).

The fault is the head-parallel slicing of V4Attention's LoRA-decomposed
projections. V4Attention's `_grouped_output_projection` manually
reshapes `wo_a.weight/.scales/.biases` based on `n_groups`/`o_lora_rank`/
`head_dim`; the head-parallel `wq_b` slice changes the per-group input
dimensionality that reshape sees, so even though the linear layers all
have valid shapes the math composing them no longer produces correct
activations. Replicated attention sidesteps this entirely.

This PR drops the head-parallel attention sharding and ships a
MoE-only strategy that:

  * replicates attention on every rank (no V4Attention math touched);
  * shards `ffn.shared_experts.{gate,down,up}_proj` and
    `ffn.switch_mlp.{gate,down,up}_proj` (all-to-sharded /
    sharded-to-all on the input/output dim);
  * wraps `layer.ffn` in the existing `ShardedMoEV4` so cross-rank
    reduction happens via `sum_gradients` on input + `all_sum` on the
    MoE output.

Memory profile on Flash (158B-total / 13B-active, 4-bit / mxfp4 experts):

| section          | full sharding | MoE-only   |
| ---------------- | ------------- | ---------- |
| attention/rank   | ~15 GB        | ~30 GB     |
| MoE/rank         | ~65 GB        | ~65 GB     |
| total/rank @ N=2 | ~80 GB        | ~95 GB     |

Both fit on 128 GB nodes. MoE-only also drops one `all_sum` collective
per layer (no wo_b reduction) — single-stream decode is dispatch-bound
on the cluster, so this matters.

Removes ~125 lines of helper code (`_shard_quantized_rows`,
`_AllSumLinear`, `_shard_v4_attention_heads`); adds ~50 lines of
strategy. `ShardedMoEV4` retained — it's the cross-rank reduction
wrapper the new strategy uses.

Bench (mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt
RDMA, MlxJaccl backend, single 13B-active forward per decode step,
warmup excluded):

| concurrency | aggregate tok/s | per-request tok/s |
| ----------- | --------------- | ----------------- |
| 1           | 34.2            | 34.2              |
| 2           | 44.1            | 22.1              |
| 4           | 52.8 (mean)     | 13.2              |

Notes:

  * Tested with the Blaizzy DSv4 model implementation (`mlx-lm`
    PR exo-explore#1192). Verified MoE-only sharding also loads cleanly with
    rltakashige's variant currently pinned in `pyproject.toml`
    (`leo/deepseek-v4`), but that combination still emits gibberish —
    the bug there is in the model implementation itself
    (Pipeline-mode also fails, so it isn't sharding-related). Tracking
    that separately; this PR is independent of it.
  * No existing DSv4 tests in the repo.

Issue: <link to repro thread / PR comment once filed>
…ther_qmm

Adds an opt-in `_FusedSwitchGLU` (with `_install_fused_switch_glu`)
that pre-concatenates a SwitchGLU's `gate_proj` + `up_proj` weights
along the output-dim axis so a single `mx.gather_qmm` produces both
halves. We split, apply the original `self.activation(x_up, x_gate)`
(preserving custom activations like DSv4's `_DSV4SwiGLU(swiglu_limit)`),
then run `down_proj` unchanged.

Saves one Metal dispatch per decoder layer per decode token: 43
dispatches × ~100-200 µs each on the deployed 2× M4 Max RDMA cluster.
Bench-validated **+1.2% c=1, +1.1% c=2** on
`mlx-community/DeepSeek-V4-Flash-6bit` (33.8 → 34.6 tok/s c=1, 45.4 →
45.9 tok/s c=2; 0% regression at c≥6, no quality regression — coherent
output on `"The capital of France is"` smoke prompt and reasoning
trace identical to the unfused path).

Wired into `DeepseekV4ShardingStrategy` after tensor-parallel sharding
of `gate_proj` / `up_proj` so the concat happens on the post-shard
local output dim. Concat order is `[up, gate]` to match SwitchGLU's
call sequence `self.activation(x_up, x_gate)`.

**Off by default** behind `EXO_DSV4_FUSED_MOE=1`. The instance rebind
happens via `__class__ = _FusedSwitchGLU`, preserving the SwitchGLU
attributes (activation, sort_threshold, down_proj). The pre-fuse
`gate_proj` / `up_proj` instances are replaced with empty modules to
free their weights — without this they'd live alongside the fused
weight and triple the per-layer MoE memory footprint.

Caveats:
- Smaller relative gain than fork's MiniMax fused MoE (which sees
  ~+6%) because DSv4's compressor + indexer + LoRA-decomposed
  V4Attention add work outside the MoE — MoE is a smaller fraction of
  decode time, so saving its dispatches has smaller proportional
  impact. DSv4 also has fewer layers (43 vs 62), 30% fewer dispatches
  saved.
- Bench was on the Blaizzy DSv4 mlx-lm variant (`mlx-lm` PR exo-explore#1192).
  Code is interface-compatible with rltakashige's variant
  (`leo/deepseek-v4`, the branch upstream's `pyproject.toml` pins) —
  same `SwitchGLU.__call__(x, indices)` signature, same pluggable
  `self.activation` — but rltakashige's variant has a separate
  inference-quality issue (saturated softmax / gibberish output)
  unrelated to sharding that I haven't been able to fix; couldn't
  validate decode quality there.

Stacks on PR exo-explore#1996 (the MoE-only DSv4 sharding strategy this fused
path slots into).
adurham pushed a commit to adurham/exo that referenced this pull request Apr 27, 2026
…s mlx-lm pin

- docs/upstream-prs.md: Open PRs 10 → 11 with exo-explore#1999 entry. Bench
  evidence (+1.2% c=1 / +1.1% c=2) noted in status.
- docs/fork-notes.md:
  - Bump adurham/mlx-lm pin info (was stale at 65655ce; now
    8dfce1f) and document the dsv4-perf-bisect@034a42a active pin
    + the rlt-deepseek_v4 graft on :main from today's pivot test.
  - Add the `_quantize` predicate fix as a tracked fork patch
    (filed as mlx-lm PR exo-explore#1216).
  - Add a "DeepSeek V4 sharding + fused MoE" subsection under
    Open optimization projects with PR exo-explore#1996 + exo-explore#1999 references and
    the bench numbers from today's session.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant