Skip to content

fix(deepseek_v4): drop full-attention sharding for MoE-only strategy#1996

Open
adurham wants to merge 1 commit intoexo-explore:mainfrom
adurham:dsv4-moe-only-sharding
Open

fix(deepseek_v4): drop full-attention sharding for MoE-only strategy#1996
adurham wants to merge 1 commit intoexo-explore:mainfrom
adurham:dsv4-moe-only-sharding

Conversation

@adurham
Copy link
Copy Markdown
Contributor

@adurham adurham commented Apr 27, 2026

Summary

PR #1978 shipped a head-parallel attention sharding strategy for DeepSeek V4 (_shard_v4_attention_heads interleaved-per-group split of wq_b/attn_sink, plus _AllSumLinear to reduce wo_b's input across ranks). On the deployable mlx-community/DeepSeek-V4-Flash-6bit checkpoint the resulting model loads cleanly and reaches RunnerReady, but produces saturated-softmax gibberish at inference:

  • logprob: 0.0 on every chosen token
  • top_logprobs: [] even at temperature=1.0, top_p=1.0
  • Output is uniform-distribution BPE fragments — e.g. '197197197197197197733733Ca733ca…' on prompt "The capital of France is"

The fault is the head-parallel slicing of V4Attention's LoRA-decomposed projections. V4Attention._grouped_output_projection manually reshapes wo_a.weight/.scales/.biases based on n_groups/o_lora_rank/head_dim; the head-parallel wq_b slice changes the per-group input dimensionality that reshape sees, so even though the linear layers all have valid shapes the math composing them no longer produces correct activations. Replicated attention sidesteps this entirely.

What this PR does

Drops the head-parallel attention sharding and ships a MoE-only strategy:

  • Replicates attention on every rank (no V4Attention math touched)
  • Shards ffn.shared_experts.{gate,down,up}_proj and ffn.switch_mlp.{gate,down,up}_proj (all-to-sharded / sharded-to-all on the input/output dim)
  • Wraps layer.ffn in the existing ShardedMoEV4 so cross-rank reduction happens via sum_gradients on input + all_sum on the MoE output

Removes ~125 lines of helper code (_shard_quantized_rows, _AllSumLinear, _shard_v4_attention_heads); adds ~50 lines of strategy. ShardedMoEV4 is retained — it's the cross-rank reduction wrapper the new strategy uses.

Memory profile (Flash, 158B-total / 13B-active, 4-bit / mxfp4 experts)

section full sharding (PR #1978) MoE-only (this PR)
attention / rank ~15 GB ~30 GB
MoE / rank ~65 GB ~65 GB
total / rank @ N=2 ~80 GB ~95 GB

Both fit on 128 GB nodes. MoE-only also drops one all_sum collective per layer (no wo_b reduction) — single-stream decode is dispatch-bound on the cluster, so this matters.

Bench

mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, warmup excluded:

concurrency aggregate tok/s per-request tok/s
1 34.2 34.2
2 44.1 22.1
4 52.8 (mean) 13.2

Single-stream is rock-solid: tail_ratio=1.00, ±0.0 across iterations.

Caveats / scope

  • Tested with the Blaizzy DSv4 model implementation (mlx-lm PR #1192). Verified MoE-only sharding also loads cleanly with rltakashige/mlx-lm:leo/deepseek-v4 (the variant currently pinned in pyproject.toml), but that combination still emits gibberish on the same checkpoint — the bug there is in the model implementation itself (Pipeline-mode also fails, so it isn't sharding-related). I'll file that separately as an issue once I've narrowed it down. This PR is independent of that — strictly a sharding cleanup that's correct against any DSv4 implementation whose DeepseekV4MoE exposes shared_experts/switch_mlp and is wrapped by ShardedMoEV4.
  • No existing DSv4 tests in the repo.

Test plan

  • uv run ruff check src/exo/worker/engines/mlx/auto_parallel.py — clean
  • uv run basedpyright src/exo/worker/engines/mlx/auto_parallel.py — 0 errors
  • Smoke test on 2× M4 Max RDMA cluster: "The capital of France is" → coherent answer with reasoning trace
  • Bench at concurrency 1/2/4 (numbers above)

🤖 Generated with Claude Code

PR exo-explore#1978 shipped a head-parallel attention sharding strategy for
DeepSeek V4 (`_shard_v4_attention_heads` interleaved-per-group split of
`wq_b`/`attn_sink`, plus `_AllSumLinear` to reduce wo_b's input across
ranks). On the deployable mlx-community 6-bit checkpoint
(`mlx-community/DeepSeek-V4-Flash-6bit`) the resulting model loads
cleanly and runs to READY but produces saturated-softmax gibberish at
inference: `logprob: 0.0` on every chosen token, empty `top_logprobs`
even at temp=1, output is uniform-distribution BPE fragments
(`'197197197197197197733733Ca733ca…'` on `"The capital of France is"`).

The fault is the head-parallel slicing of V4Attention's LoRA-decomposed
projections. V4Attention's `_grouped_output_projection` manually
reshapes `wo_a.weight/.scales/.biases` based on `n_groups`/`o_lora_rank`/
`head_dim`; the head-parallel `wq_b` slice changes the per-group input
dimensionality that reshape sees, so even though the linear layers all
have valid shapes the math composing them no longer produces correct
activations. Replicated attention sidesteps this entirely.

This PR drops the head-parallel attention sharding and ships a
MoE-only strategy that:

  * replicates attention on every rank (no V4Attention math touched);
  * shards `ffn.shared_experts.{gate,down,up}_proj` and
    `ffn.switch_mlp.{gate,down,up}_proj` (all-to-sharded /
    sharded-to-all on the input/output dim);
  * wraps `layer.ffn` in the existing `ShardedMoEV4` so cross-rank
    reduction happens via `sum_gradients` on input + `all_sum` on the
    MoE output.

Memory profile on Flash (158B-total / 13B-active, 4-bit / mxfp4 experts):

| section          | full sharding | MoE-only   |
| ---------------- | ------------- | ---------- |
| attention/rank   | ~15 GB        | ~30 GB     |
| MoE/rank         | ~65 GB        | ~65 GB     |
| total/rank @ N=2 | ~80 GB        | ~95 GB     |

Both fit on 128 GB nodes. MoE-only also drops one `all_sum` collective
per layer (no wo_b reduction) — single-stream decode is dispatch-bound
on the cluster, so this matters.

Removes ~125 lines of helper code (`_shard_quantized_rows`,
`_AllSumLinear`, `_shard_v4_attention_heads`); adds ~50 lines of
strategy. `ShardedMoEV4` retained — it's the cross-rank reduction
wrapper the new strategy uses.

Bench (mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt
RDMA, MlxJaccl backend, single 13B-active forward per decode step,
warmup excluded):

| concurrency | aggregate tok/s | per-request tok/s |
| ----------- | --------------- | ----------------- |
| 1           | 34.2            | 34.2              |
| 2           | 44.1            | 22.1              |
| 4           | 52.8 (mean)     | 13.2              |

Notes:

  * Tested with the Blaizzy DSv4 model implementation (`mlx-lm`
    PR exo-explore#1192). Verified MoE-only sharding also loads cleanly with
    rltakashige's variant currently pinned in `pyproject.toml`
    (`leo/deepseek-v4`), but that combination still emits gibberish —
    the bug there is in the model implementation itself
    (Pipeline-mode also fails, so it isn't sharding-related). Tracking
    that separately; this PR is independent of it.
  * No existing DSv4 tests in the repo.

Issue: <link to repro thread / PR comment once filed>
adurham pushed a commit to adurham/exo that referenced this pull request Apr 27, 2026
adurham pushed a commit to adurham/exo that referenced this pull request Apr 27, 2026
…s mlx-lm pin

- docs/upstream-prs.md: Open PRs 10 → 11 with exo-explore#1999 entry. Bench
  evidence (+1.2% c=1 / +1.1% c=2) noted in status.
- docs/fork-notes.md:
  - Bump adurham/mlx-lm pin info (was stale at 65655ce; now
    8dfce1f) and document the dsv4-perf-bisect@034a42a active pin
    + the rlt-deepseek_v4 graft on :main from today's pivot test.
  - Add the `_quantize` predicate fix as a tracked fork patch
    (filed as mlx-lm PR exo-explore#1216).
  - Add a "DeepSeek V4 sharding + fused MoE" subsection under
    Open optimization projects with PR exo-explore#1996 + exo-explore#1999 references and
    the bench numbers from today's session.
@rltakashige
Copy link
Copy Markdown
Collaborator

mlx-community/DeepSeek-V4-Flash-6bit --> This is a model sanitised in a different way to our implementation at the moment, so it will not work (only the provided model cards do).

I'll eventually resolve the two branches, but I believe our current implementation is considerably better at the moment in terms of performance and stability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants