Skip to content

Commit 8daabcc

Browse files
authored
Shard LongCat Flash (ml-explore#771)
1 parent cd7d9a5 commit 8daabcc

1 file changed

Lines changed: 49 additions & 5 deletions

File tree

mlx_lm/models/longcat_flash.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import mlx.core as mx
66
import mlx.nn as nn
7+
from mlx.nn.layers.distributed import shard_inplace, shard_linear, sum_gradients
78

89
from .activations import swiglu
910
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@@ -238,8 +239,11 @@ def __init__(self, args: ModelArgs):
238239
)
239240

240241
self.router = LongcatFlashTopkRouter(args)
242+
self.sharding_group = None
241243

242244
def __call__(self, hidden_states):
245+
if self.sharding_group is not None:
246+
hidden_states = sum_gradients(self.sharding_group)(hidden_states)
243247

244248
topk_indices, topk_weights = self.router(hidden_states)
245249

@@ -251,14 +255,20 @@ def __call__(self, hidden_states):
251255
regular_outputs = self.switch_mlp(hidden_states, topk_indices)
252256

253257
weighted_outputs = regular_outputs * regular_weights[..., None]
258+
final_output = mx.sum(weighted_outputs, axis=-2)
259+
260+
if self.sharding_group is not None:
261+
final_output = mx.distributed.all_sum(
262+
final_output, group=self.sharding_group
263+
)
254264

255-
# Add identity expert contribution if needed
265+
# Add identity expert contribution after all_sum to avoid summing it N times
256266
assert self.zero_expert_type == "identity"
257-
identity_weights = mx.where(mask, topk_weights, 0.0)
258-
identity_outputs = hidden_states[..., None, :] * identity_weights[..., None]
259-
weighted_outputs = weighted_outputs + identity_outputs
267+
identity_weights_sum = mx.sum(
268+
mx.where(mask, topk_weights, 0.0), axis=-1, keepdims=True
269+
)
270+
final_output = final_output + hidden_states * identity_weights_sum
260271

261-
final_output = mx.sum(weighted_outputs, axis=-2)
262272
return final_output
263273

264274

@@ -394,3 +404,37 @@ def sanitize(self, weights):
394404

395405
def make_cache(self):
396406
return [CacheList(KVCache(), KVCache()) for _ in self.model.layers]
407+
408+
def shard(self, group: Optional[mx.distributed.Group] = None):
409+
group = group or mx.distributed.init()
410+
N = group.size()
411+
412+
for layer in self.model.layers:
413+
for attn in layer.self_attn:
414+
if attn.q_lora_rank is None:
415+
attn.q_proj = shard_linear(
416+
attn.q_proj, "all-to-sharded", group=group
417+
)
418+
else:
419+
attn.q_b_proj = shard_linear(
420+
attn.q_b_proj, "all-to-sharded", group=group
421+
)
422+
attn.kv_b_proj = shard_linear(
423+
attn.kv_b_proj, "all-to-sharded", group=group
424+
)
425+
attn.o_proj = shard_linear(attn.o_proj, "sharded-to-all", group=group)
426+
attn.num_attention_heads //= N
427+
428+
for mlp in layer.mlps:
429+
mlp.gate_proj = shard_linear(
430+
mlp.gate_proj, "all-to-sharded", group=group
431+
)
432+
mlp.up_proj = shard_linear(mlp.up_proj, "all-to-sharded", group=group)
433+
mlp.down_proj = shard_linear(
434+
mlp.down_proj, "sharded-to-all", group=group
435+
)
436+
437+
layer.mlp.sharding_group = group
438+
shard_inplace(layer.mlp.switch_mlp.gate_proj, "all-to-sharded", group=group)
439+
shard_inplace(layer.mlp.switch_mlp.up_proj, "all-to-sharded", group=group)
440+
shard_inplace(layer.mlp.switch_mlp.down_proj, "sharded-to-all", group=group)

0 commit comments

Comments
 (0)