@@ -72,13 +72,13 @@ def make_row_id_map(
7272 # [0, 0, 0, r, r, r, r]]
7373 _row_id_map_pass_1_kernel [grid ](
7474 routing_map ,
75- row_id_map ,
76- workspace_tensor ,
7775 num_tokens ,
7876 routing_map .stride (0 ),
7977 routing_map .stride (1 ),
8078 row_id_map .stride (0 ),
8179 row_id_map .stride (1 ),
80+ row_id_map ,
81+ workspace_tensor ,
8282 block_size ,
8383 )
8484
@@ -110,9 +110,9 @@ def make_row_id_map(
110110 grid = (num_tokens ,)
111111 _row_id_map_pass_3_kernel [grid ](
112112 row_id_map ,
113- num_experts ,
114113 row_id_map .stride (0 ),
115114 row_id_map .stride (1 ),
115+ num_experts ,
116116 triton .next_power_of_2 (num_experts ),
117117 )
118118 return row_id_map
@@ -169,14 +169,10 @@ def permute_with_mask_map(
169169 grid = lambda META : (num_tokens , triton .cdiv (hidden_size , META ["BLOCK_SIZE" ]))
170170 _permute_kernel [grid ](
171171 inp ,
172- output ,
173172 row_id_map ,
174173 probs ,
175174 scale ,
176- permuted_probs ,
177175 permuted_scale ,
178- num_experts ,
179- hidden_size ,
180176 scale_hidden_dim ,
181177 row_id_map .stride (0 ),
182178 row_id_map .stride (1 ),
@@ -191,6 +187,10 @@ def permute_with_mask_map(
191187 permuted_probs .stride (0 ) if permuted_probs is not None else None ,
192188 permuted_scale .stride (0 ) if permuted_scale is not None else None ,
193189 permuted_scale .stride (1 ) if permuted_scale is not None else None ,
190+ output ,
191+ permuted_probs ,
192+ num_experts ,
193+ hidden_size ,
194194 PERMUTE_PROBS = probs is not None ,
195195 PERMUTE_SCALE = scale is not None ,
196196 )
@@ -238,13 +238,9 @@ def unpermute_with_mask_map(
238238 grid = lambda META : (num_tokens , triton .cdiv (hidden_size , META ["BLOCK_SIZE" ]))
239239 _unpermute_kernel [grid ](
240240 inp ,
241- output ,
242241 row_id_map ,
243242 merging_probs ,
244243 permuted_probs ,
245- unpermuted_probs ,
246- num_experts ,
247- hidden_size ,
248244 row_id_map .stride (0 ),
249245 row_id_map .stride (1 ),
250246 inp .stride (0 ),
@@ -256,6 +252,10 @@ def unpermute_with_mask_map(
256252 permuted_probs .stride (0 ) if permuted_probs is not None else None ,
257253 unpermuted_probs .stride (0 ) if unpermuted_probs is not None else None ,
258254 unpermuted_probs .stride (1 ) if unpermuted_probs is not None else None ,
255+ output ,
256+ unpermuted_probs ,
257+ num_experts ,
258+ hidden_size ,
259259 PROBS_LOAD_WIDTH = triton .next_power_of_2 (num_experts ),
260260 WITH_MERGING_PROBS = merging_probs is not None ,
261261 PERMUTE_PROBS = permuted_probs is not None ,
@@ -395,17 +395,17 @@ def sort_chunks_by_map(
395395 grid = lambda META : (num_tokens , triton .cdiv (hidden_size , META ["BLOCK_SIZE" ]))
396396 _sort_chunks_by_map_kernel [grid ](
397397 inp ,
398- output ,
399398 row_id_map ,
400399 probs ,
401- permuted_probs ,
402- hidden_size ,
403400 inp .stride (0 ),
404401 inp .stride (1 ),
405402 output .stride (0 ),
406403 output .stride (1 ),
407404 probs .stride (0 ) if probs is not None else None ,
408405 permuted_probs .stride (0 ) if permuted_probs is not None else None ,
406+ output ,
407+ permuted_probs ,
408+ hidden_size ,
409409 PERMUTE_PROBS = probs is not None ,
410410 FORWARD = is_forward ,
411411 )
0 commit comments