We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4c213e6 commit e65e46cCopy full SHA for e65e46c
python/sgl_jax/srt/layers/moe.py
@@ -380,6 +380,9 @@ def _gmm_compute_with_sharded_weights(
380
tiling=tiling_down,
381
)
382
383
+ if self.tp_size > 1:
384
+ intermediate_output = jax.lax.psum(intermediate_output, "tensor")
385
+
386
return intermediate_output
387
388
def _expert_all_to_all_dispatch(self, data, sorted_experts, expert_shard_id):
0 commit comments