@@ -249,6 +249,7 @@ def _ragged_paged_attention_kernel(
249249 page_indices_ref , # [(padded_batch_size * model_context_len + page_size - 1) // page_size]
250250 cu_q_lens_ref , # [padded_batch_size + 1]
251251 cu_kv_lens_ref , # [padded_batch_size + 1]
252+ cu_seq_mask_lens ,
252253 distribution_ref , # [3] (decode_end, prefill_end, mixed_end)
253254 sem_ids_ref , # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
254255 bo_ids_ref , # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
@@ -318,13 +319,6 @@ def _ragged_paged_attention_kernel(
318319 prefill_end = distribution_ref [1 ]
319320 mixed_end = distribution_ref [2 ]
320321
321- kv_lens = cu_kv_lens_ref [1 :] - cu_kv_lens_ref [:- 1 ]
322- q_lens = cu_q_lens_ref [1 :] - cu_q_lens_ref [:- 1 ]
323- seq_mask_lens = kv_lens * q_lens
324- cu_seq_mask_lens = jnp .concatenate (
325- [jnp .array ([0 ], dtype = jnp .int32 ), jnp .cumsum (seq_mask_lens )]
326- )
327-
328322 q_start = cu_q_lens_ref [seq_idx ]
329323 q_end = cu_q_lens_ref [seq_idx + 1 ]
330324 q_len = q_end - q_start
@@ -1337,6 +1331,16 @@ def ragged_paged_attention(
13371331 )
13381332 * 2.4
13391333 )
1334+
1335+ q_lens = cu_q_lens [1 :] - cu_q_lens [:- 1 ]
1336+ seq_mask_lens = kv_lens * q_lens
1337+ cu_seq_mask_lens = jnp .concatenate (
1338+ [jnp .array ([0 ], dtype = jnp .int32 ), jnp .cumsum (seq_mask_lens )]
1339+ )
1340+ if custom_mask is None :
1341+ # fix bug: XLA layout ({0}) does not match Mosaic layout ({0:T(128)}) for an operand of shape s32[0]
1342+ custom_mask = jnp .empty ((1 , 128 ))
1343+
13401344 grid = (distribution [2 ],)
13411345
13421346 in_specs = [
@@ -1397,6 +1401,7 @@ def ragged_paged_attention(
13971401 page_indices ,
13981402 cu_q_lens ,
13991403 cu_kv_lens ,
1404+ cu_seq_mask_lens ,
14001405 distribution ,
14011406 # (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
14021407 jnp .zeros ((3 ,), jnp .int32 ),
@@ -1443,8 +1448,8 @@ def ragged_paged_attention(
14431448 ),
14441449 ],
14451450 input_output_aliases = {
1446- 8 : 0 , # q input -> q output
1447- 10 : 1 , # kv_cache_fused input -> updated kv_cache_fused output
1451+ 9 : 0 , # q input -> q output
1452+ 11 : 1 , # kv_cache_fused input -> updated kv_cache_fused output
14481453 },
14491454 name = scope_name ,
14501455 )
0 commit comments