Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 146 additions & 118 deletions src/exo/worker/engines/mlx/auto_parallel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Literal, Protocol, cast
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -17,7 +18,7 @@
from mlx_lm.models.cache import ArraysCache, KVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v4 import DeepseekV4MoE, V4Attention
from mlx_lm.models.deepseek_v4 import DeepseekV4MoE
from mlx_lm.models.deepseek_v4 import Model as DeepseekV4Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
Expand Down Expand Up @@ -784,127 +785,152 @@ def __call__(self, x: mx.array, input_ids: mx.array) -> mx.array:
return y


def _shard_quantized_rows(
q: nn.QuantizedLinear,
head_dim: int,
slicer: Callable[[mx.array, int], mx.array],
) -> None:
weight = q["weight"]
scales = q["scales"]
assert isinstance(weight, mx.array)
assert isinstance(scales, mx.array)
q.weight = slicer(weight, head_dim)
q.scales = slicer(scales, head_dim)
biases = q.get("biases")
if isinstance(biases, mx.array):
q.biases = slicer(biases, head_dim)


class _AllSumLinear(nn.Module):
"""Wraps an unsharded wo_b that takes a head-sharded partial wo_a output.

Flow per rank:
1. all_sum the incoming partial wo_a output (summed across the head
input shards → full wo_a_out on every rank)
2. apply the unsharded wo_b → full hidden on every rank

One collective per layer on the smaller of (n_groups * o_lora_rank) vs
hidden. wo_b compute is replicated, but at decode B=1 it's only ~30M FLOPs
per layer and 61 extra all_gathers/token cost more than running wo_b on
every rank.
# Off by default: opt-in fused MoE gate+up dispatch for DSv4. Saves one
# Metal dispatch per decoder layer per decode token (43 per forward) at
# ~100-200 µs each on M4 Max — bench-validated +1.2% c=1 / +1.1% c=2 on
# `mlx-community/DeepSeek-V4-Flash-6bit` (2× M4 Max RDMA, MlxJaccl).
_DSV4_FUSED_MOE: bool = os.environ.get("EXO_DSV4_FUSED_MOE", "0") == "1"


class _FusedSwitchGLU(nn.Module):
"""Drop-in SwitchGLU replacement that fuses gate_proj + up_proj into a
single ``mx.gather_qmm`` dispatch.

SwitchGLU's stock ``__call__`` runs two ``gather_qmm``s for gate and up
(plus one for down). At DSv4's 43 decoder layers that's 43 extra
Metal dispatches per decode token — each with ~100-200 µs of dispatch
+ sync overhead on the RDMA cluster. Concatenating gate and up
weights along the output axis lets a single ``gather_qmm`` produce
both halves; we split, apply the original ``self.activation`` (so
custom activations like DSv4's ``_DSV4SwiGLU(swiglu_limit)`` are
preserved), then run down_proj unchanged.

Uses ``__class__ = _FusedSwitchGLU`` rebind (preserves all attributes
of the pre-quantized / post-sharded SwitchGLU instance — we only
override ``__call__``).

Concat order in the fused weight is ``[up, gate]`` to match
SwitchGLU's call sequence ``self.activation(x_up, x_gate)``.
"""

def __init__(self, inner: nn.Module, group: mx.distributed.Group):
super().__init__()
self.inner = inner
self._group = group
sort_threshold: int = 8

def __call__(self, x: mx.array, indices: mx.array) -> mx.array: # type: ignore[override]
self_any: Any = self

x = mx.expand_dims(x, (-2, -3))
do_sort = indices.size >= self.sort_threshold
idx: Any = indices
inv_order: Any = None
if do_sort:
flat_indices = indices.flatten()
order = mx.argsort(flat_indices)
inv_order = mx.argsort(order)
x = x.flatten(0, -3)[order // indices.shape[-1]]
idx = flat_indices[order]

gu: Any = mx.gather_qmm(
x,
self_any._fused_w_gu,
self_any._fused_s_gu,
self_any._fused_b_gu,
rhs_indices=idx,
transpose=True,
group_size=self_any._fused_group_size,
bits=self_any._fused_bits,
mode=self_any._fused_mode,
sorted_indices=do_sort,
)
n: int = self_any._fused_n_inter
x_up = gu[..., :n]
x_gate = gu[..., n:]

def __call__(self, x: mx.array) -> mx.array:
x = mx.distributed.all_sum(x, group=self._group)
return cast(Callable[[mx.array], mx.array], self.inner)(x)
x = self_any.activation(x_up, x_gate)

x = self_any.down_proj(x, idx, sorted_indices=do_sort)
if do_sort:
x = x[inv_order]
x = mx.unflatten(x, 0, indices.shape[:-1])
return x.squeeze(-2)

def _shard_v4_attention_heads(
attn: V4Attention,
world_size: int,
rank: int,
) -> None:
"""Interleaved-per-group head sharding for V4Attention.

V4 uses a grouped low-rank output projection: `_grouped_output_projection`
reshapes the flat `n_heads * head_dim` dim into `(o_groups, heads_per_group,
head_dim)`, so group g owns heads `[g * heads_per_group : (g+1) * heads_per_group]`.

A naive contiguous `shard_linear("all-to-sharded")` on wq_b puts whole
original groups on each rank — the per-rank "group g" ends up containing
heads that don't belong to original group g. That breaks the wo_a grouped
weight mapping. We instead slice heads interleaved-by-group: each rank
owns `heads_per_group / N` heads *from every original group*, kept in
group-major order so SDPA → reshape → wo_a preserves the group mapping.

Affects `wq_b.weight` / `wq_b.bias`, `attn_sink`. wo_a is sharded via a
normal input-dim block split (the default axis-(-1) behavior of
shard_inplace), which now correctly aligns with the interleaved head
layout because the last dim of out after reshape is `heads_per_group/N *
head_dim` per group.

def _install_fused_switch_glu(switch_mlp: nn.Module) -> None:
"""Pre-concatenate up_proj + gate_proj weights on `switch_mlp` for a
single ``gather_qmm`` at forward time. Rebinds the instance to
:class:`_FusedSwitchGLU` so its ``__call__`` uses the fused path.

Must be called after tensor-parallel sharding of gate_proj/up_proj —
output-dim axis is already ``moe_intermediate_size / N`` per rank.
Concat is along that (local) output axis.
"""
n_heads: int = attn.n_heads
head_dim: int = attn.head_dim
o_groups: int = attn.n_groups
assert n_heads % o_groups == 0, "n_heads must be divisible by o_groups"
heads_per_group = n_heads // o_groups
assert heads_per_group % world_size == 0, (
f"heads_per_group ({heads_per_group}) must be divisible by world_size "
f"({world_size}) for interleaved per-group head sharding"
sm: Any = switch_mlp
gp: Any = sm.gate_proj
up: Any = sm.up_proj
gp_bits = getattr(gp, "bits", None)
up_bits = getattr(up, "bits", None)
gp_group = getattr(gp, "group_size", None)
up_group = getattr(up, "group_size", None)
assert gp_bits is not None and gp_bits == up_bits, \
f"gate/up bits mismatch: {gp_bits} vs {up_bits}"
assert gp_group is not None and gp_group == up_group, \
f"gate/up group_size mismatch: {gp_group} vs {up_group}"
gp_mode = getattr(gp, "mode", "affine")
up_mode = getattr(up, "mode", "affine")
assert gp_mode == up_mode, f"gate/up mode mismatch: {gp_mode} vs {up_mode}"

gp_w: mx.array = gp["weight"]
gp_s: mx.array = gp["scales"]
up_w: mx.array = up["weight"]
up_s: mx.array = up["scales"]
gp_b = gp.get("biases") if hasattr(gp, "get") else getattr(gp, "biases", None)
up_b = up.get("biases") if hasattr(up, "get") else getattr(up, "biases", None)

fused_w: mx.array = mx.concatenate([up_w, gp_w], axis=1)
fused_s: mx.array = mx.concatenate([up_s, gp_s], axis=1)
fused_b: mx.array | None = (
mx.concatenate([up_b, gp_b], axis=1)
if gp_b is not None and up_b is not None
else None
)
hpg_per_rank = heads_per_group // world_size
start = rank * hpg_per_rank
end = start + hpg_per_rank

def _slice_head_major_flat(arr: mx.array, stride: int) -> mx.array:
"""Slice arr on axis 0 where the flat 0-axis is (o_groups *
heads_per_group * stride), returning a fresh contiguous allocation
so the full unsharded array can be freed. Without the contiguous
copy the slice is a view and the original weight stays resident —
OOM on large V4. Quantized packed weights don't round-trip through
numpy so we use mx.contiguous directly."""
rest = arr.shape[1:]
reshaped = arr.reshape(o_groups, heads_per_group, stride, *rest)
sliced = reshaped[:, start:end].reshape(o_groups * hpg_per_rank * stride, *rest)
detached = mx.contiguous(sliced)
mx.eval(detached)
return detached

wq_b: nn.Module = attn.wq_b
if isinstance(wq_b, nn.QuantizedLinear):
# Packed weight: (n_heads*head_dim, q_lora_rank/el_per_int).
# scales/biases: (n_heads*head_dim, q_lora_rank/group_size).
# Slice axis 0 interleaved-by-group with head_dim stride.
_shard_quantized_rows(wq_b, head_dim, _slice_head_major_flat)
else:
dense = wq_b
assert isinstance(dense, nn.Linear)
w = dense.weight
q_lora_rank = w.shape[-1]
w_sharded = _slice_head_major_flat(w, head_dim)
has_bias = "bias" in dense
new_wq_b = nn.Linear(q_lora_rank, w_sharded.shape[0], bias=has_bias)
new_wq_b.weight = w_sharded
if has_bias:
b = dense.bias
assert b is not None
new_wq_b.bias = _slice_head_major_flat(b[:, None], head_dim).reshape(-1)
attn.wq_b = new_wq_b

sink = attn.attn_sink
reshaped = sink.reshape(o_groups, heads_per_group)[:, start:end].reshape(-1)
detached_sink = mx.contiguous(reshaped)
mx.eval(detached_sink)
attn.attn_sink = detached_sink
attn.n_heads = o_groups * hpg_per_rank
mx.eval(fused_w, fused_s)
if fused_b is not None:
mx.eval(fused_b)

sm._fused_w_gu = fused_w
sm._fused_s_gu = fused_s
sm._fused_b_gu = fused_b
sm._fused_n_inter = int(up_w.shape[1])
sm._fused_group_size = int(gp_group)
sm._fused_bits = int(gp_bits)
sm._fused_mode = gp_mode

# Free the now-redundant originals — gate_proj + up_proj + fused
# together would triple the MoE weight footprint per layer. After
# the __class__ rebind _FusedSwitchGLU only references self.down_proj.
sm.gate_proj = nn.Module()
sm.up_proj = nn.Module()

switch_mlp.__class__ = _FusedSwitchGLU


class DeepseekV4ShardingStrategy(TensorParallelShardingStrategy):
"""Sharding for DeepSeek V4 Flash / Pro — MoE-only, attention replicated.

DSv4's V4Attention uses a LoRA-decomposed Q/output projection plus a
``_grouped_output_projection`` that manually reshapes
``wo_a.weight/.scales/.biases`` — head-parallel weight slicing of
``wq_b`` makes that manual reshape see half the per-group input dim,
producing arithmetically incorrect activations. To keep model math
intact we replicate attention on every rank and shard only the MoE
block. Cross-rank reduction for the MoE happens via ``ShardedMoEV4``,
which all-sums the MoE output.

Memory footprint at 4-bit on a 158B-total / 13B-active model:
- Attention is ~30 GB across 43 layers (replicated on every rank).
- MoE bulk is ~130 GB; sharded across N ranks ⇒ ~130 / N GB / rank.
- Total at N=2 ⇒ ~95 GB / rank — comfortable on 128 GB nodes.
"""

def shard_model(
self,
model: nn.Module,
Expand All @@ -915,11 +941,6 @@ def shard_model(
for i, layer in enumerate(model.layers):
mx.eval(layer.parameters())

# Head-parallel attention with interleaved-per-group sharding.
_shard_v4_attention_heads(layer.attn, self.N, self.group.rank())
self.sharded_to_all_linear_in_place(layer.attn.wo_a)
layer.attn.wo_b = _AllSumLinear(layer.attn.wo_b, self.group) # type: ignore[assignment]

ffn = layer.ffn
if getattr(ffn, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(ffn.shared_experts.gate_proj)
Expand All @@ -928,6 +949,13 @@ def shard_model(
self.all_to_sharded_linear_in_place(ffn.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(ffn.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(ffn.switch_mlp.up_proj)

# Optionally fuse gate+up into a single gather_qmm dispatch.
# Saves 43 dispatches per decode token; off by default until
# opted in via EXO_DSV4_FUSED_MOE=1.
if _DSV4_FUSED_MOE:
_install_fused_switch_glu(ffn.switch_mlp)

wrapped = ShardedMoEV4(ffn)
wrapped.sharding_group = self.group
layer.ffn = wrapped # type: ignore[assignment]
Expand Down