From feefe18b1a10df6dc15d0b9b0bf2deb01f91c211 Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Mon, 27 Apr 2026 14:46:39 -0500 Subject: [PATCH 1/2] fix(deepseek_v4): drop full-attention sharding for MoE-only strategy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #1978 shipped a head-parallel attention sharding strategy for DeepSeek V4 (`_shard_v4_attention_heads` interleaved-per-group split of `wq_b`/`attn_sink`, plus `_AllSumLinear` to reduce wo_b's input across ranks). On the deployable mlx-community 6-bit checkpoint (`mlx-community/DeepSeek-V4-Flash-6bit`) the resulting model loads cleanly and runs to READY but produces saturated-softmax gibberish at inference: `logprob: 0.0` on every chosen token, empty `top_logprobs` even at temp=1, output is uniform-distribution BPE fragments (`'197197197197197197733733Ca733ca…'` on `"The capital of France is"`). The fault is the head-parallel slicing of V4Attention's LoRA-decomposed projections. V4Attention's `_grouped_output_projection` manually reshapes `wo_a.weight/.scales/.biases` based on `n_groups`/`o_lora_rank`/ `head_dim`; the head-parallel `wq_b` slice changes the per-group input dimensionality that reshape sees, so even though the linear layers all have valid shapes the math composing them no longer produces correct activations. Replicated attention sidesteps this entirely. This PR drops the head-parallel attention sharding and ships a MoE-only strategy that: * replicates attention on every rank (no V4Attention math touched); * shards `ffn.shared_experts.{gate,down,up}_proj` and `ffn.switch_mlp.{gate,down,up}_proj` (all-to-sharded / sharded-to-all on the input/output dim); * wraps `layer.ffn` in the existing `ShardedMoEV4` so cross-rank reduction happens via `sum_gradients` on input + `all_sum` on the MoE output. Memory profile on Flash (158B-total / 13B-active, 4-bit / mxfp4 experts): | section | full sharding | MoE-only | | ---------------- | ------------- | ---------- | | attention/rank | ~15 GB | ~30 GB | | MoE/rank | ~65 GB | ~65 GB | | total/rank @ N=2 | ~80 GB | ~95 GB | Both fit on 128 GB nodes. MoE-only also drops one `all_sum` collective per layer (no wo_b reduction) — single-stream decode is dispatch-bound on the cluster, so this matters. Removes ~125 lines of helper code (`_shard_quantized_rows`, `_AllSumLinear`, `_shard_v4_attention_heads`); adds ~50 lines of strategy. `ShardedMoEV4` retained — it's the cross-rank reduction wrapper the new strategy uses. Bench (mlx-community/DeepSeek-V4-Flash-6bit, 2× M4 Max via Thunderbolt RDMA, MlxJaccl backend, single 13B-active forward per decode step, warmup excluded): | concurrency | aggregate tok/s | per-request tok/s | | ----------- | --------------- | ----------------- | | 1 | 34.2 | 34.2 | | 2 | 44.1 | 22.1 | | 4 | 52.8 (mean) | 13.2 | Notes: * Tested with the Blaizzy DSv4 model implementation (`mlx-lm` PR #1192). Verified MoE-only sharding also loads cleanly with rltakashige's variant currently pinned in `pyproject.toml` (`leo/deepseek-v4`), but that combination still emits gibberish — the bug there is in the model implementation itself (Pipeline-mode also fails, so it isn't sharding-related). Tracking that separately; this PR is independent of it. * No existing DSv4 tests in the repo. Issue: --- src/exo/worker/engines/mlx/auto_parallel.py | 142 +++----------------- 1 file changed, 17 insertions(+), 125 deletions(-) diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 14603f4f7f..cd7baebbd7 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -17,7 +17,7 @@ from mlx_lm.models.cache import ArraysCache, KVCache from mlx_lm.models.deepseek_v3 import DeepseekV3MLP from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model -from mlx_lm.models.deepseek_v4 import DeepseekV4MoE, V4Attention +from mlx_lm.models.deepseek_v4 import DeepseekV4MoE from mlx_lm.models.deepseek_v4 import Model as DeepseekV4Model from mlx_lm.models.deepseek_v32 import DeepseekV32MLP from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model @@ -784,127 +784,24 @@ def __call__(self, x: mx.array, input_ids: mx.array) -> mx.array: return y -def _shard_quantized_rows( - q: nn.QuantizedLinear, - head_dim: int, - slicer: Callable[[mx.array, int], mx.array], -) -> None: - weight = q["weight"] - scales = q["scales"] - assert isinstance(weight, mx.array) - assert isinstance(scales, mx.array) - q.weight = slicer(weight, head_dim) - q.scales = slicer(scales, head_dim) - biases = q.get("biases") - if isinstance(biases, mx.array): - q.biases = slicer(biases, head_dim) - - -class _AllSumLinear(nn.Module): - """Wraps an unsharded wo_b that takes a head-sharded partial wo_a output. - - Flow per rank: - 1. all_sum the incoming partial wo_a output (summed across the head - input shards → full wo_a_out on every rank) - 2. apply the unsharded wo_b → full hidden on every rank - - One collective per layer on the smaller of (n_groups * o_lora_rank) vs - hidden. wo_b compute is replicated, but at decode B=1 it's only ~30M FLOPs - per layer and 61 extra all_gathers/token cost more than running wo_b on - every rank. - """ - - def __init__(self, inner: nn.Module, group: mx.distributed.Group): - super().__init__() - self.inner = inner - self._group = group - - def __call__(self, x: mx.array) -> mx.array: - x = mx.distributed.all_sum(x, group=self._group) - return cast(Callable[[mx.array], mx.array], self.inner)(x) - - -def _shard_v4_attention_heads( - attn: V4Attention, - world_size: int, - rank: int, -) -> None: - """Interleaved-per-group head sharding for V4Attention. - - V4 uses a grouped low-rank output projection: `_grouped_output_projection` - reshapes the flat `n_heads * head_dim` dim into `(o_groups, heads_per_group, - head_dim)`, so group g owns heads `[g * heads_per_group : (g+1) * heads_per_group]`. - - A naive contiguous `shard_linear("all-to-sharded")` on wq_b puts whole - original groups on each rank — the per-rank "group g" ends up containing - heads that don't belong to original group g. That breaks the wo_a grouped - weight mapping. We instead slice heads interleaved-by-group: each rank - owns `heads_per_group / N` heads *from every original group*, kept in - group-major order so SDPA → reshape → wo_a preserves the group mapping. - - Affects `wq_b.weight` / `wq_b.bias`, `attn_sink`. wo_a is sharded via a - normal input-dim block split (the default axis-(-1) behavior of - shard_inplace), which now correctly aligns with the interleaved head - layout because the last dim of out after reshape is `heads_per_group/N * - head_dim` per group. +class DeepseekV4ShardingStrategy(TensorParallelShardingStrategy): + """Sharding for DeepSeek V4 Flash / Pro — MoE-only, attention replicated. + + DSv4's V4Attention uses a LoRA-decomposed Q/output projection plus a + ``_grouped_output_projection`` that manually reshapes + ``wo_a.weight/.scales/.biases`` — head-parallel weight slicing of + ``wq_b`` makes that manual reshape see half the per-group input dim, + producing arithmetically incorrect activations. To keep model math + intact we replicate attention on every rank and shard only the MoE + block. Cross-rank reduction for the MoE happens via ``ShardedMoEV4``, + which all-sums the MoE output. + + Memory footprint at 4-bit on a 158B-total / 13B-active model: + - Attention is ~30 GB across 43 layers (replicated on every rank). + - MoE bulk is ~130 GB; sharded across N ranks ⇒ ~130 / N GB / rank. + - Total at N=2 ⇒ ~95 GB / rank — comfortable on 128 GB nodes. """ - n_heads: int = attn.n_heads - head_dim: int = attn.head_dim - o_groups: int = attn.n_groups - assert n_heads % o_groups == 0, "n_heads must be divisible by o_groups" - heads_per_group = n_heads // o_groups - assert heads_per_group % world_size == 0, ( - f"heads_per_group ({heads_per_group}) must be divisible by world_size " - f"({world_size}) for interleaved per-group head sharding" - ) - hpg_per_rank = heads_per_group // world_size - start = rank * hpg_per_rank - end = start + hpg_per_rank - - def _slice_head_major_flat(arr: mx.array, stride: int) -> mx.array: - """Slice arr on axis 0 where the flat 0-axis is (o_groups * - heads_per_group * stride), returning a fresh contiguous allocation - so the full unsharded array can be freed. Without the contiguous - copy the slice is a view and the original weight stays resident — - OOM on large V4. Quantized packed weights don't round-trip through - numpy so we use mx.contiguous directly.""" - rest = arr.shape[1:] - reshaped = arr.reshape(o_groups, heads_per_group, stride, *rest) - sliced = reshaped[:, start:end].reshape(o_groups * hpg_per_rank * stride, *rest) - detached = mx.contiguous(sliced) - mx.eval(detached) - return detached - - wq_b: nn.Module = attn.wq_b - if isinstance(wq_b, nn.QuantizedLinear): - # Packed weight: (n_heads*head_dim, q_lora_rank/el_per_int). - # scales/biases: (n_heads*head_dim, q_lora_rank/group_size). - # Slice axis 0 interleaved-by-group with head_dim stride. - _shard_quantized_rows(wq_b, head_dim, _slice_head_major_flat) - else: - dense = wq_b - assert isinstance(dense, nn.Linear) - w = dense.weight - q_lora_rank = w.shape[-1] - w_sharded = _slice_head_major_flat(w, head_dim) - has_bias = "bias" in dense - new_wq_b = nn.Linear(q_lora_rank, w_sharded.shape[0], bias=has_bias) - new_wq_b.weight = w_sharded - if has_bias: - b = dense.bias - assert b is not None - new_wq_b.bias = _slice_head_major_flat(b[:, None], head_dim).reshape(-1) - attn.wq_b = new_wq_b - - sink = attn.attn_sink - reshaped = sink.reshape(o_groups, heads_per_group)[:, start:end].reshape(-1) - detached_sink = mx.contiguous(reshaped) - mx.eval(detached_sink) - attn.attn_sink = detached_sink - attn.n_heads = o_groups * hpg_per_rank - -class DeepseekV4ShardingStrategy(TensorParallelShardingStrategy): def shard_model( self, model: nn.Module, @@ -915,11 +812,6 @@ def shard_model( for i, layer in enumerate(model.layers): mx.eval(layer.parameters()) - # Head-parallel attention with interleaved-per-group sharding. - _shard_v4_attention_heads(layer.attn, self.N, self.group.rank()) - self.sharded_to_all_linear_in_place(layer.attn.wo_a) - layer.attn.wo_b = _AllSumLinear(layer.attn.wo_b, self.group) # type: ignore[assignment] - ffn = layer.ffn if getattr(ffn, "shared_experts", None) is not None: self.all_to_sharded_linear_in_place(ffn.shared_experts.gate_proj) From 043caf008547580eb486fb3b846f41cc89158387 Mon Sep 17 00:00:00 2001 From: Adam Durham Date: Mon, 27 Apr 2026 18:22:59 -0500 Subject: [PATCH 2/2] perf(deepseek_v4): fuse switch_mlp gate_proj + up_proj into single gather_qmm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an opt-in `_FusedSwitchGLU` (with `_install_fused_switch_glu`) that pre-concatenates a SwitchGLU's `gate_proj` + `up_proj` weights along the output-dim axis so a single `mx.gather_qmm` produces both halves. We split, apply the original `self.activation(x_up, x_gate)` (preserving custom activations like DSv4's `_DSV4SwiGLU(swiglu_limit)`), then run `down_proj` unchanged. Saves one Metal dispatch per decoder layer per decode token: 43 dispatches × ~100-200 µs each on the deployed 2× M4 Max RDMA cluster. Bench-validated **+1.2% c=1, +1.1% c=2** on `mlx-community/DeepSeek-V4-Flash-6bit` (33.8 → 34.6 tok/s c=1, 45.4 → 45.9 tok/s c=2; 0% regression at c≥6, no quality regression — coherent output on `"The capital of France is"` smoke prompt and reasoning trace identical to the unfused path). Wired into `DeepseekV4ShardingStrategy` after tensor-parallel sharding of `gate_proj` / `up_proj` so the concat happens on the post-shard local output dim. Concat order is `[up, gate]` to match SwitchGLU's call sequence `self.activation(x_up, x_gate)`. **Off by default** behind `EXO_DSV4_FUSED_MOE=1`. The instance rebind happens via `__class__ = _FusedSwitchGLU`, preserving the SwitchGLU attributes (activation, sort_threshold, down_proj). The pre-fuse `gate_proj` / `up_proj` instances are replaced with empty modules to free their weights — without this they'd live alongside the fused weight and triple the per-layer MoE memory footprint. Caveats: - Smaller relative gain than fork's MiniMax fused MoE (which sees ~+6%) because DSv4's compressor + indexer + LoRA-decomposed V4Attention add work outside the MoE — MoE is a smaller fraction of decode time, so saving its dispatches has smaller proportional impact. DSv4 also has fewer layers (43 vs 62), 30% fewer dispatches saved. - Bench was on the Blaizzy DSv4 mlx-lm variant (`mlx-lm` PR #1192). Code is interface-compatible with rltakashige's variant (`leo/deepseek-v4`, the branch upstream's `pyproject.toml` pins) — same `SwitchGLU.__call__(x, indices)` signature, same pluggable `self.activation` — but rltakashige's variant has a separate inference-quality issue (saturated softmax / gibberish output) unrelated to sharding that I haven't been able to fix; couldn't validate decode quality there. Stacks on PR #1996 (the MoE-only DSv4 sharding strategy this fused path slots into). --- src/exo/worker/engines/mlx/auto_parallel.py | 138 +++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index cd7baebbd7..fe8352afb4 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -1,8 +1,9 @@ +import os from abc import ABC, abstractmethod from collections.abc import Callable, Generator from functools import partial from inspect import signature -from typing import TYPE_CHECKING, Literal, Protocol, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import mlx.core as mx import mlx.nn as nn @@ -784,6 +785,134 @@ def __call__(self, x: mx.array, input_ids: mx.array) -> mx.array: return y +# Off by default: opt-in fused MoE gate+up dispatch for DSv4. Saves one +# Metal dispatch per decoder layer per decode token (43 per forward) at +# ~100-200 µs each on M4 Max — bench-validated +1.2% c=1 / +1.1% c=2 on +# `mlx-community/DeepSeek-V4-Flash-6bit` (2× M4 Max RDMA, MlxJaccl). +_DSV4_FUSED_MOE: bool = os.environ.get("EXO_DSV4_FUSED_MOE", "0") == "1" + + +class _FusedSwitchGLU(nn.Module): + """Drop-in SwitchGLU replacement that fuses gate_proj + up_proj into a + single ``mx.gather_qmm`` dispatch. + + SwitchGLU's stock ``__call__`` runs two ``gather_qmm``s for gate and up + (plus one for down). At DSv4's 43 decoder layers that's 43 extra + Metal dispatches per decode token — each with ~100-200 µs of dispatch + + sync overhead on the RDMA cluster. Concatenating gate and up + weights along the output axis lets a single ``gather_qmm`` produce + both halves; we split, apply the original ``self.activation`` (so + custom activations like DSv4's ``_DSV4SwiGLU(swiglu_limit)`` are + preserved), then run down_proj unchanged. + + Uses ``__class__ = _FusedSwitchGLU`` rebind (preserves all attributes + of the pre-quantized / post-sharded SwitchGLU instance — we only + override ``__call__``). + + Concat order in the fused weight is ``[up, gate]`` to match + SwitchGLU's call sequence ``self.activation(x_up, x_gate)``. + """ + + sort_threshold: int = 8 + + def __call__(self, x: mx.array, indices: mx.array) -> mx.array: # type: ignore[override] + self_any: Any = self + + x = mx.expand_dims(x, (-2, -3)) + do_sort = indices.size >= self.sort_threshold + idx: Any = indices + inv_order: Any = None + if do_sort: + flat_indices = indices.flatten() + order = mx.argsort(flat_indices) + inv_order = mx.argsort(order) + x = x.flatten(0, -3)[order // indices.shape[-1]] + idx = flat_indices[order] + + gu: Any = mx.gather_qmm( + x, + self_any._fused_w_gu, + self_any._fused_s_gu, + self_any._fused_b_gu, + rhs_indices=idx, + transpose=True, + group_size=self_any._fused_group_size, + bits=self_any._fused_bits, + mode=self_any._fused_mode, + sorted_indices=do_sort, + ) + n: int = self_any._fused_n_inter + x_up = gu[..., :n] + x_gate = gu[..., n:] + + x = self_any.activation(x_up, x_gate) + + x = self_any.down_proj(x, idx, sorted_indices=do_sort) + if do_sort: + x = x[inv_order] + x = mx.unflatten(x, 0, indices.shape[:-1]) + return x.squeeze(-2) + + +def _install_fused_switch_glu(switch_mlp: nn.Module) -> None: + """Pre-concatenate up_proj + gate_proj weights on `switch_mlp` for a + single ``gather_qmm`` at forward time. Rebinds the instance to + :class:`_FusedSwitchGLU` so its ``__call__`` uses the fused path. + + Must be called after tensor-parallel sharding of gate_proj/up_proj — + output-dim axis is already ``moe_intermediate_size / N`` per rank. + Concat is along that (local) output axis. + """ + sm: Any = switch_mlp + gp: Any = sm.gate_proj + up: Any = sm.up_proj + gp_bits = getattr(gp, "bits", None) + up_bits = getattr(up, "bits", None) + gp_group = getattr(gp, "group_size", None) + up_group = getattr(up, "group_size", None) + assert gp_bits is not None and gp_bits == up_bits, \ + f"gate/up bits mismatch: {gp_bits} vs {up_bits}" + assert gp_group is not None and gp_group == up_group, \ + f"gate/up group_size mismatch: {gp_group} vs {up_group}" + gp_mode = getattr(gp, "mode", "affine") + up_mode = getattr(up, "mode", "affine") + assert gp_mode == up_mode, f"gate/up mode mismatch: {gp_mode} vs {up_mode}" + + gp_w: mx.array = gp["weight"] + gp_s: mx.array = gp["scales"] + up_w: mx.array = up["weight"] + up_s: mx.array = up["scales"] + gp_b = gp.get("biases") if hasattr(gp, "get") else getattr(gp, "biases", None) + up_b = up.get("biases") if hasattr(up, "get") else getattr(up, "biases", None) + + fused_w: mx.array = mx.concatenate([up_w, gp_w], axis=1) + fused_s: mx.array = mx.concatenate([up_s, gp_s], axis=1) + fused_b: mx.array | None = ( + mx.concatenate([up_b, gp_b], axis=1) + if gp_b is not None and up_b is not None + else None + ) + mx.eval(fused_w, fused_s) + if fused_b is not None: + mx.eval(fused_b) + + sm._fused_w_gu = fused_w + sm._fused_s_gu = fused_s + sm._fused_b_gu = fused_b + sm._fused_n_inter = int(up_w.shape[1]) + sm._fused_group_size = int(gp_group) + sm._fused_bits = int(gp_bits) + sm._fused_mode = gp_mode + + # Free the now-redundant originals — gate_proj + up_proj + fused + # together would triple the MoE weight footprint per layer. After + # the __class__ rebind _FusedSwitchGLU only references self.down_proj. + sm.gate_proj = nn.Module() + sm.up_proj = nn.Module() + + switch_mlp.__class__ = _FusedSwitchGLU + + class DeepseekV4ShardingStrategy(TensorParallelShardingStrategy): """Sharding for DeepSeek V4 Flash / Pro — MoE-only, attention replicated. @@ -820,6 +949,13 @@ def shard_model( self.all_to_sharded_linear_in_place(ffn.switch_mlp.gate_proj) self.sharded_to_all_linear_in_place(ffn.switch_mlp.down_proj) self.all_to_sharded_linear_in_place(ffn.switch_mlp.up_proj) + + # Optionally fuse gate+up into a single gather_qmm dispatch. + # Saves 43 dispatches per decode token; off by default until + # opted in via EXO_DSV4_FUSED_MOE=1. + if _DSV4_FUSED_MOE: + _install_fused_switch_glu(ffn.switch_mlp) + wrapped = ShardedMoEV4(ffn) wrapped.sharding_group = self.group layer.ffn = wrapped # type: ignore[assignment]