Skip to content

Commit cd5bb43

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e7f81c2 commit cd5bb43

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

transformer_engine/jax/triton/permutation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def permute_with_mask_map(
177177
permuted_probs : Optional[jnp.ndarray]
178178
Permuted probabilities if probs was provided, None otherwise.
179179
"""
180+
180181
# one block per token, multiple blocks for hidden dimension
181182
def grid_fn(meta):
182183
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))
@@ -225,8 +226,8 @@ def grid_fn(meta):
225226
inp,
226227
row_id_map,
227228
probs,
228-
dummy_scale, # scale
229-
dummy_permuted_scale, # permuted_scale
229+
dummy_scale, # scale
230+
dummy_permuted_scale, # permuted_scale
230231
0,
231232
row_id_stride_token,
232233
row_id_stride_expert,

0 commit comments

Comments
 (0)