From feefe18b1a10df6dc15d0b9b0bf2deb01f91c211 Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Mon, 27 Apr 2026 14:46:39 -0500 Subject: [PATCH] fix(deepseek_v4): drop full-attention sharding for MoE-only strategy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #1978 shipped a head-parallel attention sharding strategy for DeepSeek V4 (`_shard_v4_attention_heads` interleaved-per-group split of `wq_b`/`attn_sink`, plus `_AllSumLinear` to reduce wo_b's input across ranks). On the deployable mlx-community 6-bit checkpoint (`mlx-community/DeepSeek-V4-Flash-6bit`) the resulting model loads cleanly and runs to READY but produces saturated-softmax gibberish at inference: `logprob: 0.0` on every chosen token, empty `top_logprobs` even at temp=1, output is uniform-distribution BPE fragments (`'197197197197197197733733Ca733ca…'` on `"The capital of France is"`). The fault is the head-parallel slicing of V4Attention's LoRA-decomposed projections. V4Attention's `_grouped_output_projection` manually reshapes `wo_a.weight/.scales/.biases` based on `n_groups`/`o_lora_rank`/ `head_dim`; the head-parallel `wq_b` slice changes the per-group input dimensionality that reshape sees, so even though the linear layers all have valid shapes the math composing them no longer produces correct activations. Replicated attention sidesteps this entirely. This PR drops the head-parallel attention sharding and ships a MoE-only strategy that: * replicates attention on every rank (no V4Attention math touched); * shards `ffn.shared_experts.{gate,down,up}_proj` and `ffn.switch_mlp.{gate,down,up}_proj` (all-to-sharded / sharded-to-all on the input/output dim); * wraps `layer.ffn` in the existing `ShardedMoEV4` so cross-rank reduction happens via `sum_gradients` on input + `all_sum` on the MoE output. Memory profile on Flash (158B-total / 13B-active, 4-bit / mxfp4 experts): | section | full sharding | MoE-only | | ---------------- | ------------- | ---------- | | attention/rank | ~15 GB | ~30 GB | | MoE/rank | ~65 GB | ~65 GB | | total/rank @ N=2 | ~80 GB | ~95 GB | Both fit on 128 GB nodes. MoE-only also drops one `all_sum` collective per layer (no wo_b reduction) — single-stream decode is dispatch-bound on the cluster, so this matters. Removes ~125 lines of helper code (`_shard_quantized_rows`, `_AllSumLinear`, `_shard_v4_attention_heads`); adds ~50 lines of strategy. `ShardedMoEV4` retained — it's the cross-rank reduction wrapper the new strategy uses. Bench (mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, single 13B-active forward per decode step, warmup excluded): | concurrency | aggregate tok/s | per-request tok/s | | ----------- | --------------- | ----------------- | | 1 | 34.2 | 34.2 | | 2 | 44.1 | 22.1 | | 4 | 52.8 (mean) | 13.2 | Notes: * Tested with the Blaizzy DSv4 model implementation (`mlx-lm` PR #1192). Verified MoE-only sharding also loads cleanly with rltakashige's variant currently pinned in `pyproject.toml` (`leo/deepseek-v4`), but that combination still emits gibberish — the bug there is in the model implementation itself (Pipeline-mode also fails, so it isn't sharding-related). Tracking that separately; this PR is independent of it. * No existing DSv4 tests in the repo. Issue: --- src/exo/worker/engines/mlx/auto_parallel.py | 142 +++----------------- 1 file changed, 17 insertions(+), 125 deletions(-) diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 14603f4f7f..cd7baebbd7 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -17,7 +17,7 @@ from mlx_lm.models.cache import ArraysCache, KVCache from mlx_lm.models.deepseek_v3 import DeepseekV3MLP from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model -from mlx_lm.models.deepseek_v4 import DeepseekV4MoE, V4Attention +from mlx_lm.models.deepseek_v4 import DeepseekV4MoE from mlx_lm.models.deepseek_v4 import Model as DeepseekV4Model from mlx_lm.models.deepseek_v32 import DeepseekV32MLP from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model @@ -784,127 +784,24 @@ def __call__(self, x: mx.array, input_ids: mx.array) -> mx.array: return y -def _shard_quantized_rows( - q: nn.QuantizedLinear, - head_dim: int, - slicer: Callable[[mx.array, int], mx.array], -) -> None: - weight = q["weight"] - scales = q["scales"] - assert isinstance(weight, mx.array) - assert isinstance(scales, mx.array) - q.weight = slicer(weight, head_dim) - q.scales = slicer(scales, head_dim) - biases = q.get("biases") - if isinstance(biases, mx.array): - q.biases = slicer(biases, head_dim) - - -class _AllSumLinear(nn.Module): - """Wraps an unsharded wo_b that takes a head-sharded partial wo_a output. - - Flow per rank: - 1. all_sum the incoming partial wo_a output (summed across the head - input shards → full wo_a_out on every rank) - 2. apply the unsharded wo_b → full hidden on every rank - - One collective per layer on the smaller of (n_groups * o_lora_rank) vs - hidden. wo_b compute is replicated, but at decode B=1 it's only ~30M FLOPs - per layer and 61 extra all_gathers/token cost more than running wo_b on - every rank. - """ - - def __init__(self, inner: nn.Module, group: mx.distributed.Group): - super().__init__() - self.inner = inner - self._group = group - - def __call__(self, x: mx.array) -> mx.array: - x = mx.distributed.all_sum(x, group=self._group) - return cast(Callable[[mx.array], mx.array], self.inner)(x) - - -def _shard_v4_attention_heads( - attn: V4Attention, - world_size: int, - rank: int, -) -> None: - """Interleaved-per-group head sharding for V4Attention. - - V4 uses a grouped low-rank output projection: `_grouped_output_projection` - reshapes the flat `n_heads * head_dim` dim into `(o_groups, heads_per_group, - head_dim)`, so group g owns heads `[g * heads_per_group : (g+1) * heads_per_group]`. - - A naive contiguous `shard_linear("all-to-sharded")` on wq_b puts whole - original groups on each rank — the per-rank "group g" ends up containing - heads that don't belong to original group g. That breaks the wo_a grouped - weight mapping. We instead slice heads interleaved-by-group: each rank - owns `heads_per_group / N` heads *from every original group*, kept in - group-major order so SDPA → reshape → wo_a preserves the group mapping. - - Affects `wq_b.weight` / `wq_b.bias`, `attn_sink`. wo_a is sharded via a - normal input-dim block split (the default axis-(-1) behavior of - shard_inplace), which now correctly aligns with the interleaved head - layout because the last dim of out after reshape is `heads_per_group/N * - head_dim` per group. +class DeepseekV4ShardingStrategy(TensorParallelShardingStrategy): + """Sharding for DeepSeek V4 Flash / Pro — MoE-only, attention replicated. + + DSv4's V4Attention uses a LoRA-decomposed Q/output projection plus a + ``_grouped_output_projection`` that manually reshapes + ``wo_a.weight/.scales/.biases`` — head-parallel weight slicing of + ``wq_b`` makes that manual reshape see half the per-group input dim, + producing arithmetically incorrect activations. To keep model math + intact we replicate attention on every rank and shard only the MoE + block. Cross-rank reduction for the MoE happens via ``ShardedMoEV4``, + which all-sums the MoE output. + + Memory footprint at 4-bit on a 158B-total / 13B-active model: + - Attention is ~30 GB across 43 layers (replicated on every rank). + - MoE bulk is ~130 GB; sharded across N ranks ⇒ ~130 / N GB / rank. + - Total at N=2 ⇒ ~95 GB / rank — comfortable on 128 GB nodes. """ - n_heads: int = attn.n_heads - head_dim: int = attn.head_dim - o_groups: int = attn.n_groups - assert n_heads % o_groups == 0, "n_heads must be divisible by o_groups" - heads_per_group = n_heads // o_groups - assert heads_per_group % world_size == 0, ( - f"heads_per_group ({heads_per_group}) must be divisible by world_size " - f"({world_size}) for interleaved per-group head sharding" - ) - hpg_per_rank = heads_per_group // world_size - start = rank * hpg_per_rank - end = start + hpg_per_rank - - def _slice_head_major_flat(arr: mx.array, stride: int) -> mx.array: - """Slice arr on axis 0 where the flat 0-axis is (o_groups * - heads_per_group * stride), returning a fresh contiguous allocation - so the full unsharded array can be freed. Without the contiguous - copy the slice is a view and the original weight stays resident — - OOM on large V4. Quantized packed weights don't round-trip through - numpy so we use mx.contiguous directly.""" - rest = arr.shape[1:] - reshaped = arr.reshape(o_groups, heads_per_group, stride, *rest) - sliced = reshaped[:, start:end].reshape(o_groups * hpg_per_rank * stride, *rest) - detached = mx.contiguous(sliced) - mx.eval(detached) - return detached - - wq_b: nn.Module = attn.wq_b - if isinstance(wq_b, nn.QuantizedLinear): - # Packed weight: (n_heads*head_dim, q_lora_rank/el_per_int). - # scales/biases: (n_heads*head_dim, q_lora_rank/group_size). - # Slice axis 0 interleaved-by-group with head_dim stride. - _shard_quantized_rows(wq_b, head_dim, _slice_head_major_flat) - else: - dense = wq_b - assert isinstance(dense, nn.Linear) - w = dense.weight - q_lora_rank = w.shape[-1] - w_sharded = _slice_head_major_flat(w, head_dim) - has_bias = "bias" in dense - new_wq_b = nn.Linear(q_lora_rank, w_sharded.shape[0], bias=has_bias) - new_wq_b.weight = w_sharded - if has_bias: - b = dense.bias - assert b is not None - new_wq_b.bias = _slice_head_major_flat(b[:, None], head_dim).reshape(-1) - attn.wq_b = new_wq_b - - sink = attn.attn_sink - reshaped = sink.reshape(o_groups, heads_per_group)[:, start:end].reshape(-1) - detached_sink = mx.contiguous(reshaped) - mx.eval(detached_sink) - attn.attn_sink = detached_sink - attn.n_heads = o_groups * hpg_per_rank - -class DeepseekV4ShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, @@ -915,11 +812,6 @@ def shard_model( for i, layer in enumerate(model.layers): mx.eval(layer.parameters()) - # Head-parallel attention with interleaved-per-group sharding. - _shard_v4_attention_heads(layer.attn, self.N, self.group.rank()) - self.sharded_to_all_linear_in_place(layer.attn.wo_a) - layer.attn.wo_b = _AllSumLinear(layer.attn.wo_b, self.group) # type: ignore[assignment] - ffn = layer.ffn if getattr(ffn, "shared_experts", None) is not None: self.all_to_sharded_linear_in_place(ffn.shared_experts.gate_proj)