diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 14603f4f7f..fe8352afb4 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -1,8 +1,9 @@ +import os from abc import ABC, abstractmethod from collections.abc import Callable, Generator from functools import partial from inspect import signature -from typing import TYPE_CHECKING, Literal, Protocol, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import mlx.core as mx import mlx.nn as nn @@ -17,7 +18,7 @@ from mlx_lm.models.cache import ArraysCache, KVCache from mlx_lm.models.deepseek_v3 import DeepseekV3MLP from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model -from mlx_lm.models.deepseek_v4 import DeepseekV4MoE, V4Attention +from mlx_lm.models.deepseek_v4 import DeepseekV4MoE from mlx_lm.models.deepseek_v4 import Model as DeepseekV4Model from mlx_lm.models.deepseek_v32 import DeepseekV32MLP from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model @@ -784,127 +785,152 @@ def __call__(self, x: mx.array, input_ids: mx.array) -> mx.array: return y -def _shard_quantized_rows( - q: nn.QuantizedLinear, - head_dim: int, - slicer: Callable[[mx.array, int], mx.array], -) -> None: - weight = q["weight"] - scales = q["scales"] - assert isinstance(weight, mx.array) - assert isinstance(scales, mx.array) - q.weight = slicer(weight, head_dim) - q.scales = slicer(scales, head_dim) - biases = q.get("biases") - if isinstance(biases, mx.array): - q.biases = slicer(biases, head_dim) - - -class _AllSumLinear(nn.Module): - """Wraps an unsharded wo_b that takes a head-sharded partial wo_a output. - - Flow per rank: - 1. all_sum the incoming partial wo_a output (summed across the head - input shards → full wo_a_out on every rank) - 2. apply the unsharded wo_b → full hidden on every rank - - One collective per layer on the smaller of (n_groups * o_lora_rank) vs - hidden. wo_b compute is replicated, but at decode B=1 it's only ~30M FLOPs - per layer and 61 extra all_gathers/token cost more than running wo_b on - every rank. +# Off by default: opt-in fused MoE gate+up dispatch for DSv4. Saves one +# Metal dispatch per decoder layer per decode token (43 per forward) at +# ~100-200 µs each on M4 Max — bench-validated +1.2% c=1 / +1.1% c=2 on +# `mlx-community/DeepSeek-V4-Flash-6bit` (2× M4 Max RDMA, MlxJaccl). +_DSV4_FUSED_MOE: bool = os.environ.get("EXO_DSV4_FUSED_MOE", "0") == "1" + + +class _FusedSwitchGLU(nn.Module): + """Drop-in SwitchGLU replacement that fuses gate_proj + up_proj into a + single ``mx.gather_qmm`` dispatch. + + SwitchGLU's stock ``__call__`` runs two ``gather_qmm``s for gate and up + (plus one for down). At DSv4's 43 decoder layers that's 43 extra + Metal dispatches per decode token — each with ~100-200 µs of dispatch + + sync overhead on the RDMA cluster. Concatenating gate and up + weights along the output axis lets a single ``gather_qmm`` produce + both halves; we split, apply the original ``self.activation`` (so + custom activations like DSv4's ``_DSV4SwiGLU(swiglu_limit)`` are + preserved), then run down_proj unchanged. + + Uses ``__class__ = _FusedSwitchGLU`` rebind (preserves all attributes + of the pre-quantized / post-sharded SwitchGLU instance — we only + override ``__call__``). + + Concat order in the fused weight is ``[up, gate]`` to match + SwitchGLU's call sequence ``self.activation(x_up, x_gate)``. """ - def __init__(self, inner: nn.Module, group: mx.distributed.Group): - super().__init__() - self.inner = inner - self._group = group + sort_threshold: int = 8 + + def __call__(self, x: mx.array, indices: mx.array) -> mx.array: # type: ignore[override] + self_any: Any = self + + x = mx.expand_dims(x, (-2, -3)) + do_sort = indices.size >= self.sort_threshold + idx: Any = indices + inv_order: Any = None + if do_sort: + flat_indices = indices.flatten() + order = mx.argsort(flat_indices) + inv_order = mx.argsort(order) + x = x.flatten(0, -3)[order // indices.shape[-1]] + idx = flat_indices[order] + + gu: Any = mx.gather_qmm( + x, + self_any._fused_w_gu, + self_any._fused_s_gu, + self_any._fused_b_gu, + rhs_indices=idx, + transpose=True, + group_size=self_any._fused_group_size, + bits=self_any._fused_bits, + mode=self_any._fused_mode, + sorted_indices=do_sort, + ) + n: int = self_any._fused_n_inter + x_up = gu[..., :n] + x_gate = gu[..., n:] - def __call__(self, x: mx.array) -> mx.array: - x = mx.distributed.all_sum(x, group=self._group) - return cast(Callable[[mx.array], mx.array], self.inner)(x) + x = self_any.activation(x_up, x_gate) + x = self_any.down_proj(x, idx, sorted_indices=do_sort) + if do_sort: + x = x[inv_order] + x = mx.unflatten(x, 0, indices.shape[:-1]) + return x.squeeze(-2) -def _shard_v4_attention_heads( - attn: V4Attention, - world_size: int, - rank: int, -) -> None: - """Interleaved-per-group head sharding for V4Attention. - - V4 uses a grouped low-rank output projection: `_grouped_output_projection` - reshapes the flat `n_heads * head_dim` dim into `(o_groups, heads_per_group, - head_dim)`, so group g owns heads `[g * heads_per_group : (g+1) * heads_per_group]`. - - A naive contiguous `shard_linear("all-to-sharded")` on wq_b puts whole - original groups on each rank — the per-rank "group g" ends up containing - heads that don't belong to original group g. That breaks the wo_a grouped - weight mapping. We instead slice heads interleaved-by-group: each rank - owns `heads_per_group / N` heads *from every original group*, kept in - group-major order so SDPA → reshape → wo_a preserves the group mapping. - - Affects `wq_b.weight` / `wq_b.bias`, `attn_sink`. wo_a is sharded via a - normal input-dim block split (the default axis-(-1) behavior of - shard_inplace), which now correctly aligns with the interleaved head - layout because the last dim of out after reshape is `heads_per_group/N * - head_dim` per group. + +def _install_fused_switch_glu(switch_mlp: nn.Module) -> None: + """Pre-concatenate up_proj + gate_proj weights on `switch_mlp` for a + single ``gather_qmm`` at forward time. Rebinds the instance to + :class:`_FusedSwitchGLU` so its ``__call__`` uses the fused path. + + Must be called after tensor-parallel sharding of gate_proj/up_proj — + output-dim axis is already ``moe_intermediate_size / N`` per rank. + Concat is along that (local) output axis. """ - n_heads: int = attn.n_heads - head_dim: int = attn.head_dim - o_groups: int = attn.n_groups - assert n_heads % o_groups == 0, "n_heads must be divisible by o_groups" - heads_per_group = n_heads // o_groups - assert heads_per_group % world_size == 0, ( - f"heads_per_group ({heads_per_group}) must be divisible by world_size " - f"({world_size}) for interleaved per-group head sharding" + sm: Any = switch_mlp + gp: Any = sm.gate_proj + up: Any = sm.up_proj + gp_bits = getattr(gp, "bits", None) + up_bits = getattr(up, "bits", None) + gp_group = getattr(gp, "group_size", None) + up_group = getattr(up, "group_size", None) + assert gp_bits is not None and gp_bits == up_bits, \ + f"gate/up bits mismatch: {gp_bits} vs {up_bits}" + assert gp_group is not None and gp_group == up_group, \ + f"gate/up group_size mismatch: {gp_group} vs {up_group}" + gp_mode = getattr(gp, "mode", "affine") + up_mode = getattr(up, "mode", "affine") + assert gp_mode == up_mode, f"gate/up mode mismatch: {gp_mode} vs {up_mode}" + + gp_w: mx.array = gp["weight"] + gp_s: mx.array = gp["scales"] + up_w: mx.array = up["weight"] + up_s: mx.array = up["scales"] + gp_b = gp.get("biases") if hasattr(gp, "get") else getattr(gp, "biases", None) + up_b = up.get("biases") if hasattr(up, "get") else getattr(up, "biases", None) + + fused_w: mx.array = mx.concatenate([up_w, gp_w], axis=1) + fused_s: mx.array = mx.concatenate([up_s, gp_s], axis=1) + fused_b: mx.array | None = ( + mx.concatenate([up_b, gp_b], axis=1) + if gp_b is not None and up_b is not None + else None ) - hpg_per_rank = heads_per_group // world_size - start = rank * hpg_per_rank - end = start + hpg_per_rank - - def _slice_head_major_flat(arr: mx.array, stride: int) -> mx.array: - """Slice arr on axis 0 where the flat 0-axis is (o_groups * - heads_per_group * stride), returning a fresh contiguous allocation - so the full unsharded array can be freed. Without the contiguous - copy the slice is a view and the original weight stays resident — - OOM on large V4. Quantized packed weights don't round-trip through - numpy so we use mx.contiguous directly.""" - rest = arr.shape[1:] - reshaped = arr.reshape(o_groups, heads_per_group, stride, *rest) - sliced = reshaped[:, start:end].reshape(o_groups * hpg_per_rank * stride, *rest) - detached = mx.contiguous(sliced) - mx.eval(detached) - return detached - - wq_b: nn.Module = attn.wq_b - if isinstance(wq_b, nn.QuantizedLinear): - # Packed weight: (n_heads*head_dim, q_lora_rank/el_per_int). - # scales/biases: (n_heads*head_dim, q_lora_rank/group_size). - # Slice axis 0 interleaved-by-group with head_dim stride. - _shard_quantized_rows(wq_b, head_dim, _slice_head_major_flat) - else: - dense = wq_b - assert isinstance(dense, nn.Linear) - w = dense.weight - q_lora_rank = w.shape[-1] - w_sharded = _slice_head_major_flat(w, head_dim) - has_bias = "bias" in dense - new_wq_b = nn.Linear(q_lora_rank, w_sharded.shape[0], bias=has_bias) - new_wq_b.weight = w_sharded - if has_bias: - b = dense.bias - assert b is not None - new_wq_b.bias = _slice_head_major_flat(b[:, None], head_dim).reshape(-1) - attn.wq_b = new_wq_b - - sink = attn.attn_sink - reshaped = sink.reshape(o_groups, heads_per_group)[:, start:end].reshape(-1) - detached_sink = mx.contiguous(reshaped) - mx.eval(detached_sink) - attn.attn_sink = detached_sink - attn.n_heads = o_groups * hpg_per_rank + mx.eval(fused_w, fused_s) + if fused_b is not None: + mx.eval(fused_b) + + sm._fused_w_gu = fused_w + sm._fused_s_gu = fused_s + sm._fused_b_gu = fused_b + sm._fused_n_inter = int(up_w.shape[1]) + sm._fused_group_size = int(gp_group) + sm._fused_bits = int(gp_bits) + sm._fused_mode = gp_mode + + # Free the now-redundant originals — gate_proj + up_proj + fused + # together would triple the MoE weight footprint per layer. After + # the __class__ rebind _FusedSwitchGLU only references self.down_proj. + sm.gate_proj = nn.Module() + sm.up_proj = nn.Module() + + switch_mlp.__class__ = _FusedSwitchGLU class DeepseekV4ShardingStrategy(TensorParallelShardingStrategy): + """Sharding for DeepSeek V4 Flash / Pro — MoE-only, attention replicated. + + DSv4's V4Attention uses a LoRA-decomposed Q/output projection plus a + ``_grouped_output_projection`` that manually reshapes + ``wo_a.weight/.scales/.biases`` — head-parallel weight slicing of + ``wq_b`` makes that manual reshape see half the per-group input dim, + producing arithmetically incorrect activations. To keep model math + intact we replicate attention on every rank and shard only the MoE + block. Cross-rank reduction for the MoE happens via ``ShardedMoEV4``, + which all-sums the MoE output. + + Memory footprint at 4-bit on a 158B-total / 13B-active model: + - Attention is ~30 GB across 43 layers (replicated on every rank). + - MoE bulk is ~130 GB; sharded across N ranks ⇒ ~130 / N GB / rank. + - Total at N=2 ⇒ ~95 GB / rank — comfortable on 128 GB nodes. + """ + def shard_model( self, model: nn.Module, @@ -915,11 +941,6 @@ def shard_model( for i, layer in enumerate(model.layers): mx.eval(layer.parameters()) - # Head-parallel attention with interleaved-per-group sharding. - _shard_v4_attention_heads(layer.attn, self.N, self.group.rank()) - self.sharded_to_all_linear_in_place(layer.attn.wo_a) - layer.attn.wo_b = _AllSumLinear(layer.attn.wo_b, self.group) # type: ignore[assignment] - ffn = layer.ffn if getattr(ffn, "shared_experts", None) is not None: self.all_to_sharded_linear_in_place(ffn.shared_experts.gate_proj) @@ -928,6 +949,13 @@ def shard_model( self.all_to_sharded_linear_in_place(ffn.switch_mlp.gate_proj) self.sharded_to_all_linear_in_place(ffn.switch_mlp.down_proj) self.all_to_sharded_linear_in_place(ffn.switch_mlp.up_proj) + + # Optionally fuse gate+up into a single gather_qmm dispatch. + # Saves 43 dispatches per decode token; off by default until + # opted in via EXO_DSV4_FUSED_MOE=1. + if _DSV4_FUSED_MOE: + _install_fused_switch_glu(ffn.switch_mlp) + wrapped = ShardedMoEV4(ffn) wrapped.sharding_group = self.group layer.ffn = wrapped # type: ignore[assignment]