Skip to content

Commit 6479514

Browse files
committed
Correct setting of permuted_probs_stride_token, unpermuted_probs_stride_token and unpermuted_probs_stride_expert in unpermutation
1 parent 60b74c3 commit 6479514

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

transformer_engine/jax/triton/permutation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,9 @@ def grid_fn(meta):
354354
output_stride_hidden,
355355
merging_probs_stride_token,
356356
merging_probs_stride_expert,
357-
1,
358-
0,
359-
0,
357+
permuted_probs_stride_token,
358+
unpermuted_probs_stride_token,
359+
unpermuted_probs_stride_expert,
360360
kernel=_unpermute_kernel,
361361
out_shape=out_shape,
362362
grid=grid_fn,

0 commit comments

Comments
 (0)