Skip to content

Commit cc34a32

Browse files
committed
Merge commit '0056b981' into teddy/jax-triton-permutation
To eliminate pytorch failures
2 parents 0f1e719 + 0056b98 commit cc34a32

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

transformer_engine/pytorch/triton/permutation.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ def make_row_id_map(
7272
# [0, 0, 0, r, r, r, r]]
7373
_row_id_map_pass_1_kernel[grid](
7474
routing_map,
75-
row_id_map,
76-
workspace_tensor,
7775
num_tokens,
7876
routing_map.stride(0),
7977
routing_map.stride(1),
8078
row_id_map.stride(0),
8179
row_id_map.stride(1),
80+
row_id_map,
81+
workspace_tensor,
8282
block_size,
8383
)
8484

@@ -110,9 +110,9 @@ def make_row_id_map(
110110
grid = (num_tokens,)
111111
_row_id_map_pass_3_kernel[grid](
112112
row_id_map,
113-
num_experts,
114113
row_id_map.stride(0),
115114
row_id_map.stride(1),
115+
num_experts,
116116
triton.next_power_of_2(num_experts),
117117
)
118118
return row_id_map
@@ -169,14 +169,10 @@ def permute_with_mask_map(
169169
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
170170
_permute_kernel[grid](
171171
inp,
172-
output,
173172
row_id_map,
174173
probs,
175174
scale,
176-
permuted_probs,
177175
permuted_scale,
178-
num_experts,
179-
hidden_size,
180176
scale_hidden_dim,
181177
row_id_map.stride(0),
182178
row_id_map.stride(1),
@@ -191,6 +187,10 @@ def permute_with_mask_map(
191187
permuted_probs.stride(0) if permuted_probs is not None else None,
192188
permuted_scale.stride(0) if permuted_scale is not None else None,
193189
permuted_scale.stride(1) if permuted_scale is not None else None,
190+
output,
191+
permuted_probs,
192+
num_experts,
193+
hidden_size,
194194
PERMUTE_PROBS=probs is not None,
195195
PERMUTE_SCALE=scale is not None,
196196
)
@@ -238,13 +238,9 @@ def unpermute_with_mask_map(
238238
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
239239
_unpermute_kernel[grid](
240240
inp,
241-
output,
242241
row_id_map,
243242
merging_probs,
244243
permuted_probs,
245-
unpermuted_probs,
246-
num_experts,
247-
hidden_size,
248244
row_id_map.stride(0),
249245
row_id_map.stride(1),
250246
inp.stride(0),
@@ -256,6 +252,10 @@ def unpermute_with_mask_map(
256252
permuted_probs.stride(0) if permuted_probs is not None else None,
257253
unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
258254
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
255+
output,
256+
unpermuted_probs,
257+
num_experts,
258+
hidden_size,
259259
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
260260
WITH_MERGING_PROBS=merging_probs is not None,
261261
PERMUTE_PROBS=permuted_probs is not None,
@@ -395,17 +395,17 @@ def sort_chunks_by_map(
395395
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
396396
_sort_chunks_by_map_kernel[grid](
397397
inp,
398-
output,
399398
row_id_map,
400399
probs,
401-
permuted_probs,
402-
hidden_size,
403400
inp.stride(0),
404401
inp.stride(1),
405402
output.stride(0),
406403
output.stride(1),
407404
probs.stride(0) if probs is not None else None,
408405
permuted_probs.stride(0) if permuted_probs is not None else None,
406+
output,
407+
permuted_probs,
408+
hidden_size,
409409
PERMUTE_PROBS=probs is not None,
410410
FORWARD=is_forward,
411411
)

0 commit comments

Comments
 (0)