We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e7f81c2 commit cd5bb43Copy full SHA for cd5bb43
transformer_engine/jax/triton/permutation.py
@@ -177,6 +177,7 @@ def permute_with_mask_map(
177
permuted_probs : Optional[jnp.ndarray]
178
Permuted probabilities if probs was provided, None otherwise.
179
"""
180
+
181
# one block per token, multiple blocks for hidden dimension
182
def grid_fn(meta):
183
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))
@@ -225,8 +226,8 @@ def grid_fn(meta):
225
226
inp,
227
row_id_map,
228
probs,
- dummy_scale, # scale
229
- dummy_permuted_scale, # permuted_scale
+ dummy_scale, # scale
230
+ dummy_permuted_scale, # permuted_scale
231
0,
232
row_id_stride_token,
233
row_id_stride_expert,
0 commit comments