Skip to content

Commit e65e46c

Browse files
committed
try to support etp
1 parent 4c213e6 commit e65e46c

File tree

1 file changed

+3
-0
lines changed
  • python/sgl_jax/srt/layers

1 file changed

+3
-0
lines changed

python/sgl_jax/srt/layers/moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ def _gmm_compute_with_sharded_weights(
380380
tiling=tiling_down,
381381
)
382382

383+
if self.tp_size > 1:
384+
intermediate_output = jax.lax.psum(intermediate_output, "tensor")
385+
383386
return intermediate_output
384387

385388
def _expert_all_to_all_dispatch(self, data, sorted_experts, expert_shard_id):

0 commit comments

Comments
 (0)