fix(deepseek_v4): drop full-attention sharding for MoE-only strategy#1996
Open
adurham wants to merge 1 commit intoexo-explore:mainfrom
Open
fix(deepseek_v4): drop full-attention sharding for MoE-only strategy#1996adurham wants to merge 1 commit intoexo-explore:mainfrom
adurham wants to merge 1 commit intoexo-explore:mainfrom
Conversation
PR exo-explore#1978 shipped a head-parallel attention sharding strategy for DeepSeek V4 (`_shard_v4_attention_heads` interleaved-per-group split of `wq_b`/`attn_sink`, plus `_AllSumLinear` to reduce wo_b's input across ranks). On the deployable mlx-community 6-bit checkpoint (`mlx-community/DeepSeek-V4-Flash-6bit`) the resulting model loads cleanly and runs to READY but produces saturated-softmax gibberish at inference: `logprob: 0.0` on every chosen token, empty `top_logprobs` even at temp=1, output is uniform-distribution BPE fragments (`'197197197197197197733733Ca733ca…'` on `"The capital of France is"`). The fault is the head-parallel slicing of V4Attention's LoRA-decomposed projections. V4Attention's `_grouped_output_projection` manually reshapes `wo_a.weight/.scales/.biases` based on `n_groups`/`o_lora_rank`/ `head_dim`; the head-parallel `wq_b` slice changes the per-group input dimensionality that reshape sees, so even though the linear layers all have valid shapes the math composing them no longer produces correct activations. Replicated attention sidesteps this entirely. This PR drops the head-parallel attention sharding and ships a MoE-only strategy that: * replicates attention on every rank (no V4Attention math touched); * shards `ffn.shared_experts.{gate,down,up}_proj` and `ffn.switch_mlp.{gate,down,up}_proj` (all-to-sharded / sharded-to-all on the input/output dim); * wraps `layer.ffn` in the existing `ShardedMoEV4` so cross-rank reduction happens via `sum_gradients` on input + `all_sum` on the MoE output. Memory profile on Flash (158B-total / 13B-active, 4-bit / mxfp4 experts): | section | full sharding | MoE-only | | ---------------- | ------------- | ---------- | | attention/rank | ~15 GB | ~30 GB | | MoE/rank | ~65 GB | ~65 GB | | total/rank @ N=2 | ~80 GB | ~95 GB | Both fit on 128 GB nodes. MoE-only also drops one `all_sum` collective per layer (no wo_b reduction) — single-stream decode is dispatch-bound on the cluster, so this matters. Removes ~125 lines of helper code (`_shard_quantized_rows`, `_AllSumLinear`, `_shard_v4_attention_heads`); adds ~50 lines of strategy. `ShardedMoEV4` retained — it's the cross-rank reduction wrapper the new strategy uses. Bench (mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, single 13B-active forward per decode step, warmup excluded): | concurrency | aggregate tok/s | per-request tok/s | | ----------- | --------------- | ----------------- | | 1 | 34.2 | 34.2 | | 2 | 44.1 | 22.1 | | 4 | 52.8 (mean) | 13.2 | Notes: * Tested with the Blaizzy DSv4 model implementation (`mlx-lm` PR exo-explore#1192). Verified MoE-only sharding also loads cleanly with rltakashige's variant currently pinned in `pyproject.toml` (`leo/deepseek-v4`), but that combination still emits gibberish — the bug there is in the model implementation itself (Pipeline-mode also fails, so it isn't sharding-related). Tracking that separately; this PR is independent of it. * No existing DSv4 tests in the repo. Issue: <link to repro thread / PR comment once filed>
adurham
pushed a commit
to adurham/exo
that referenced
this pull request
Apr 27, 2026
4 tasks
adurham
pushed a commit
to adurham/exo
that referenced
this pull request
Apr 27, 2026
…s mlx-lm pin - docs/upstream-prs.md: Open PRs 10 → 11 with exo-explore#1999 entry. Bench evidence (+1.2% c=1 / +1.1% c=2) noted in status. - docs/fork-notes.md: - Bump adurham/mlx-lm pin info (was stale at 65655ce; now 8dfce1f) and document the dsv4-perf-bisect@034a42a active pin + the rlt-deepseek_v4 graft on :main from today's pivot test. - Add the `_quantize` predicate fix as a tracked fork patch (filed as mlx-lm PR exo-explore#1216). - Add a "DeepSeek V4 sharding + fused MoE" subsection under Open optimization projects with PR exo-explore#1996 + exo-explore#1999 references and the bench numbers from today's session.
Collaborator
|
mlx-community/DeepSeek-V4-Flash-6bit --> This is a model sanitised in a different way to our implementation at the moment, so it will not work (only the provided model cards do). I'll eventually resolve the two branches, but I believe our current implementation is considerably better at the moment in terms of performance and stability. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
PR #1978 shipped a head-parallel attention sharding strategy for DeepSeek V4 (
_shard_v4_attention_headsinterleaved-per-group split ofwq_b/attn_sink, plus_AllSumLinearto reducewo_b's input across ranks). On the deployablemlx-community/DeepSeek-V4-Flash-6bitcheckpoint the resulting model loads cleanly and reachesRunnerReady, but produces saturated-softmax gibberish at inference:logprob: 0.0on every chosen tokentop_logprobs: []even attemperature=1.0, top_p=1.0'197197197197197197733733Ca733ca…'on prompt"The capital of France is"The fault is the head-parallel slicing of V4Attention's LoRA-decomposed projections.
V4Attention._grouped_output_projectionmanually reshapeswo_a.weight/.scales/.biasesbased onn_groups/o_lora_rank/head_dim; the head-parallelwq_bslice changes the per-group input dimensionality that reshape sees, so even though the linear layers all have valid shapes the math composing them no longer produces correct activations. Replicated attention sidesteps this entirely.What this PR does
Drops the head-parallel attention sharding and ships a MoE-only strategy:
V4Attentionmath touched)ffn.shared_experts.{gate,down,up}_projandffn.switch_mlp.{gate,down,up}_proj(all-to-sharded / sharded-to-all on the input/output dim)layer.ffnin the existingShardedMoEV4so cross-rank reduction happens viasum_gradientson input +all_sumon the MoE outputRemoves ~125 lines of helper code (
_shard_quantized_rows,_AllSumLinear,_shard_v4_attention_heads); adds ~50 lines of strategy.ShardedMoEV4is retained — it's the cross-rank reduction wrapper the new strategy uses.Memory profile (Flash, 158B-total / 13B-active, 4-bit / mxfp4 experts)
Both fit on 128 GB nodes. MoE-only also drops one
all_sumcollective per layer (nowo_breduction) — single-stream decode is dispatch-bound on the cluster, so this matters.Bench
mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, warmup excluded:Single-stream is rock-solid:
tail_ratio=1.00, ±0.0 across iterations.Caveats / scope
mlx-lmPR #1192). Verified MoE-only sharding also loads cleanly withrltakashige/mlx-lm:leo/deepseek-v4(the variant currently pinned inpyproject.toml), but that combination still emits gibberish on the same checkpoint — the bug there is in the model implementation itself (Pipeline-mode also fails, so it isn't sharding-related). I'll file that separately as an issue once I've narrowed it down. This PR is independent of that — strictly a sharding cleanup that's correct against any DSv4 implementation whoseDeepseekV4MoEexposesshared_experts/switch_mlpand is wrapped byShardedMoEV4.Test plan
uv run ruff check src/exo/worker/engines/mlx/auto_parallel.py— cleanuv run basedpyright src/exo/worker/engines/mlx/auto_parallel.py— 0 errors"The capital of France is"→ coherent answer with reasoning trace🤖 Generated with Claude Code