Skip to content

Commit 4c213e6

Browse files
committed
try to support etp
1 parent 0f61aa0 commit 4c213e6

File tree

1 file changed

+1
-3
lines changed
  • python/sgl_jax/srt/layers

1 file changed

+1
-3
lines changed

python/sgl_jax/srt/layers/moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,8 @@ def __call__(self, hidden_states, topk_weights, topk_ids):
279279
)
280280

281281
def _forward(self, hidden_states, topk_weights, topk_ids, w0_weights, w1_weights, wo_weights):
282-
data_index = jax.lax.axis_index("data")
283282
tensor_index = jax.lax.axis_index("tensor")
284-
tensor_size = jax.lax.axis_size("tensor")
285-
expert_shard_id = data_index * tensor_size + tensor_index
283+
expert_shard_id = tensor_index
286284

287285
if hidden_states.ndim == 2:
288286
total_tokens = hidden_states.shape[0]

0 commit comments

Comments
 (0)