44
55import mlx .core as mx
66import mlx .nn as nn
7+ from mlx .nn .layers .distributed import shard_inplace , shard_linear , sum_gradients
78
89from .activations import swiglu
910from .base import BaseModelArgs , create_attention_mask , scaled_dot_product_attention
@@ -238,8 +239,11 @@ def __init__(self, args: ModelArgs):
238239 )
239240
240241 self .router = LongcatFlashTopkRouter (args )
242+ self .sharding_group = None
241243
242244 def __call__ (self , hidden_states ):
245+ if self .sharding_group is not None :
246+ hidden_states = sum_gradients (self .sharding_group )(hidden_states )
243247
244248 topk_indices , topk_weights = self .router (hidden_states )
245249
@@ -251,14 +255,20 @@ def __call__(self, hidden_states):
251255 regular_outputs = self .switch_mlp (hidden_states , topk_indices )
252256
253257 weighted_outputs = regular_outputs * regular_weights [..., None ]
258+ final_output = mx .sum (weighted_outputs , axis = - 2 )
259+
260+ if self .sharding_group is not None :
261+ final_output = mx .distributed .all_sum (
262+ final_output , group = self .sharding_group
263+ )
254264
255- # Add identity expert contribution if needed
265+ # Add identity expert contribution after all_sum to avoid summing it N times
256266 assert self .zero_expert_type == "identity"
257- identity_weights = mx .where (mask , topk_weights , 0.0 )
258- identity_outputs = hidden_states [..., None , :] * identity_weights [..., None ]
259- weighted_outputs = weighted_outputs + identity_outputs
267+ identity_weights_sum = mx .sum (
268+ mx .where (mask , topk_weights , 0.0 ), axis = - 1 , keepdims = True
269+ )
270+ final_output = final_output + hidden_states * identity_weights_sum
260271
261- final_output = mx .sum (weighted_outputs , axis = - 2 )
262272 return final_output
263273
264274
@@ -394,3 +404,37 @@ def sanitize(self, weights):
394404
395405 def make_cache (self ):
396406 return [CacheList (KVCache (), KVCache ()) for _ in self .model .layers ]
407+
408+ def shard (self , group : Optional [mx .distributed .Group ] = None ):
409+ group = group or mx .distributed .init ()
410+ N = group .size ()
411+
412+ for layer in self .model .layers :
413+ for attn in layer .self_attn :
414+ if attn .q_lora_rank is None :
415+ attn .q_proj = shard_linear (
416+ attn .q_proj , "all-to-sharded" , group = group
417+ )
418+ else :
419+ attn .q_b_proj = shard_linear (
420+ attn .q_b_proj , "all-to-sharded" , group = group
421+ )
422+ attn .kv_b_proj = shard_linear (
423+ attn .kv_b_proj , "all-to-sharded" , group = group
424+ )
425+ attn .o_proj = shard_linear (attn .o_proj , "sharded-to-all" , group = group )
426+ attn .num_attention_heads //= N
427+
428+ for mlp in layer .mlps :
429+ mlp .gate_proj = shard_linear (
430+ mlp .gate_proj , "all-to-sharded" , group = group
431+ )
432+ mlp .up_proj = shard_linear (mlp .up_proj , "all-to-sharded" , group = group )
433+ mlp .down_proj = shard_linear (
434+ mlp .down_proj , "sharded-to-all" , group = group
435+ )
436+
437+ layer .mlp .sharding_group = group
438+ shard_inplace (layer .mlp .switch_mlp .gate_proj , "all-to-sharded" , group = group )
439+ shard_inplace (layer .mlp .switch_mlp .up_proj , "all-to-sharded" , group = group )
440+ shard_inplace (layer .mlp .switch_mlp .down_proj , "sharded-to-all" , group = group )
0 commit comments