@@ -254,16 +254,15 @@ def _ragged_paged_attention_kernel(
254254 sem_ids_ref , # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
255255 bo_ids_ref , # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
256256 bkv_update_ids_ref , # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
257+ custom_mask_ref , # (flatten_total_kv_len,),
257258 # Input
258259 q_hbm_ref , # [actual_num_kv_heads, padded_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
259260 kv_hbm_ref , # [padded_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] - Fused KV with interleaved [K1,V1,K2,V2,...]
260261 kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
261- custom_mask_ref , # (flatten_total_kv_len,), int8, dma not support bool type
262262 # Output
263263 o_hbm_ref , # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
264264 updated_kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
265265 # Scratch
266- bkvmask_ref , # [2, bq_sz, bkv_sz]
267266 bkv_fused_x2_ref , # [2, bkv_sz, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
268267 bq_x2_ref , # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
269268 bo_x2_ref , # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
@@ -324,54 +323,19 @@ def _ragged_paged_attention_kernel(
324323 q_len = q_end - q_start
325324 kv_len = kv_lens_ref [seq_idx ]
326325
326+ cur_seq_mask_start = cu_seq_mask_lens [seq_idx ]
327+ cur_seq_mask_len = q_len * kv_len
328+ cur_seq_mask = custom_mask_ref [
329+ cur_seq_mask_start : cur_seq_mask_start + cur_seq_mask_len
330+ ].reshape (q_len , kv_len )
331+
327332 def _async_copy (src , dst , sem , wait ):
328333 cp = pltpu .make_async_copy (src , dst , sem )
329334 if wait :
330335 cp .wait ()
331336 else :
332337 cp .start ()
333338
334- def _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx , * , wait = False ):
335- sem = sems .at [4 , bkvmask_sem_idx ]
336- assert sem .dtype == sems .dtype , f"######## { sem .dtype = } { sems .dtype = } "
337- kvmask_fused_vmem_ref = bkvmask_ref .at [bkvmask_sem_idx ]
338-
339- kv_len = kv_lens_ref [seq_idx ]
340- mask_len = kv_len
341- mask_start = bkvmask_idx * bkv_sz
342- mask_left = mask_len - mask_start
343- load_kv_sz = jnp .minimum (bkv_sz , mask_left )
344-
345- q_len_start = cu_q_lens_ref [seq_idx ] + bq_idx * bq_sz
346- q_end = cu_q_lens_ref [seq_idx + 1 ]
347- load_q_sz = jnp .minimum (bq_sz , q_end - q_len_start )
348-
349- cur_seq_mask_start = cu_seq_mask_lens [seq_idx ]
350- cur_bq_mask_start = cur_seq_mask_start + bq_idx * bq_sz * kv_len
351-
352- # Whether using custom mask, depends on causal args
353- # flatten mask: [TTTTTTFFFFTFTTFFFTTFFTTTTTFFFFTTTTTTFT,FFFTFFTFTTTTTFTFFFFFTTFTTTTFTFTTFTTT]
354- # ^kv_start ^mask_start
355- # <--load_sz-->
356-
357- def loop_body (i , _ ):
358- start = cur_bq_mask_start + i * kv_len + mask_start
359- start = jnp .minimum (custom_mask_ref .shape [0 ], start )
360- _async_copy (
361- custom_mask_ref .at [pl .ds (start , load_kv_sz )],
362- kvmask_fused_vmem_ref .at [i , pl .ds (0 , load_kv_sz )],
363- sem ,
364- wait ,
365- )
366-
367- lax .fori_loop (
368- 0 ,
369- load_q_sz ,
370- loop_body ,
371- None ,
372- unroll = False ,
373- )
374-
375339 def _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , wait = False ):
376340 sem = sems .at [0 , bkv_sem_idx ]
377341 kv_fused_vmem_ref = bkv_fused_x2_ref .at [bkv_sem_idx ]
@@ -505,12 +469,6 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
505469 wait ,
506470 )
507471
508- def start_fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx ):
509- return _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx )
510-
511- def wait_fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx ):
512- return _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx , wait = True )
513-
514472 def start_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx ):
515473 return _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
516474
@@ -691,12 +649,6 @@ def prefetch_next_bkv():
691649 sem_ids_ref [1 ] = next_bkv_sem_idx
692650 start_fetch_bkv (next_seq_idx , next_bkv_idx , next_bkv_sem_idx )
693651
694- @pl .when (causal == 0 )
695- def _ ():
696- start_fetch_mask (
697- next_seq_idx , bq_idx , next_bkv_idx , next_bkv_sem_idx
698- )
699-
700652 # Wait for cur bq if not ready yet
701653 @pl .when (bkv_idx == 0 )
702654 def wait_cur_bq ():
@@ -705,11 +657,6 @@ def wait_cur_bq():
705657 # Wait for cur bkv
706658 offset , update_sz = wait_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
707659
708- # Wait for kv mask if not use causal mask
709- @pl .when (causal == 0 )
710- def _ ():
711- wait_fetch_mask (seq_idx , bq_idx , bkv_idx , bkv_sem_idx )
712-
713660 # Start updating bkv to kv cache if applicable.
714661 # Only needed in first bq loop.
715662 @pl .when (jnp .logical_and (update_sz > 0 , bq_idx == 0 ))
@@ -746,10 +693,23 @@ def batch_prepare_queries():
746693 return jnp .stack (q_heads , axis = 0 )
747694
748695 def load_mask ():
749- mask = bkvmask_ref [bkv_sem_idx , :actual_bq_sz ]
750- # assert False, f'{mask.shape=} {jnp.zeros((actual_num_kv_heads, actual_bq_sz*num_q_heads_per_kv_head, mask.shape[-1])).shape=}'
696+ bq_mask_start = bq_idx * bq_sz
697+ bq_mask_end = (bq_idx + 1 ) * bq_sz
698+ bq_mask_offset = lax .select (
699+ bq_mask_end - q_len < 0 , bq_sz , bq_mask_end - q_len
700+ )
701+
702+ bkv_mask_start = bkv_idx * bkv_sz
703+ bkv_mask_end = (bkv_idx + 1 ) * bkv_sz
704+ bkv_mask_offset = lax .select (
705+ bkv_mask_end - kv_len < 0 , bkv_sz , bkv_mask_end - kv_len
706+ )
707+
708+ cur_bq_bkv_mask = cur_seq_mask [
709+ bq_mask_start :bq_mask_offset , bkv_mask_start :bkv_mask_offset
710+ ]
751711 num_q_heads_per_kv_head_mask = jnp .concat (
752- [mask ] * num_q_heads_per_kv_head
712+ [cur_bq_bkv_mask ] * num_q_heads_per_kv_head
753713 )
754714 num_kv_heads_mask = jnp .concat (
755715 [
@@ -764,7 +724,6 @@ def load_mask():
764724 # Load batched data
765725 k_batch , v_batch = batch_load_all_heads_kv ()
766726 q_batch = batch_prepare_queries ()
767- custom_mask = load_mask ()
768727
769728 def flash_attention (q_batch , k_batch , v_batch ):
770729 q_batch_f32 = q_batch .astype (jnp .float32 )
@@ -799,12 +758,8 @@ def flash_attention(q_batch, k_batch, v_batch):
799758 k_span = bkv_idx * bkv_sz + lax .broadcasted_iota (
800759 jnp .int32 , s .shape , 2
801760 )
802- # convert custom_mask from int8 to bool
803- mask = lax .select (
804- causal == 0 ,
805- custom_mask .astype (jnp .bool ),
806- q_span < k_span ,
807- )
761+ mask = lax .cond (causal == 1 , lambda : q_span < k_span , load_mask )
762+
808763 if sliding_window is not None :
809764 mask = jnp .logical_or (mask , q_span - sliding_window >= k_span )
810765
@@ -1341,7 +1296,7 @@ def ragged_paged_attention(
13411296 )
13421297 if custom_mask is None :
13431298 # fix bug: XLA layout ({0}) does not match Mosaic layout ({0:T(128)}) for an operand of shape s32[0]
1344- custom_mask = jnp .empty ((1 , 128 ), dtype = jnp .int8 )
1299+ custom_mask = jnp .empty ((1 , 128 ), dtype = jnp .bool )
13451300 else :
13461301 assert (
13471302 custom_mask .dtype != jnp .bool
@@ -1353,7 +1308,6 @@ def ragged_paged_attention(
13531308 pl .BlockSpec (memory_space = pltpu .ANY ), # q
13541309 pl .BlockSpec (memory_space = pltpu .ANY ), # kv_fused
13551310 pl .BlockSpec (memory_space = pltpu .ANY ), # kv_cache_fused
1356- pl .BlockSpec (memory_space = pltpu .ANY ), # custom_mask
13571311 ]
13581312
13591313 out_specs = [
@@ -1366,11 +1320,6 @@ def ragged_paged_attention(
13661320 kv_cache_fused_processed .dtype ,
13671321 )
13681322
1369- bkvmask_double_buf = pltpu .VMEM (
1370- (2 , bq_sz , bkv_sz ),
1371- jnp .bool ,
1372- )
1373-
13741323 bq_double_buf = pltpu .VMEM (
13751324 (2 , actual_num_kv_heads , bq_sz , * q .shape [2 :]),
13761325 q .dtype ,
@@ -1390,7 +1339,6 @@ def ragged_paged_attention(
13901339 )
13911340
13921341 scratch_shapes = [
1393- bkvmask_double_buf , # Double buffering for fused kv mask block with head interleaving.
13941342 bkv_fused_double_buf , # Double buffering for fused kv block with head interleaving.
13951343 bq_double_buf , # Double buffering for q block.
13961344 bo_double_buf , # Double buffering for output block.
@@ -1415,6 +1363,7 @@ def ragged_paged_attention(
14151363 jnp .full ((4 ,), - 1 , jnp .int32 ),
14161364 # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
14171365 jnp .full ((6 ,), - 1 , jnp .int32 ),
1366+ custom_mask ,
14181367 )
14191368 scope_name = f"RPA-bq_{ bq_sz } -bkvp_{ bkv_p } -p_{ page_size } "
14201369 kernel = jax .named_scope (scope_name )(
@@ -1454,8 +1403,8 @@ def ragged_paged_attention(
14541403 ),
14551404 ],
14561405 input_output_aliases = {
1457- 9 : 0 , # q input -> q output
1458- 11 : 1 , # kv_cache_fused input -> updated kv_cache_fused output
1406+ 10 : 0 , # q input -> q output
1407+ 12 : 1 , # kv_cache_fused input -> updated kv_cache_fused output
14591408 },
14601409 name = scope_name ,
14611410 )
@@ -1466,7 +1415,6 @@ def ragged_paged_attention(
14661415 q ,
14671416 kv ,
14681417 kv_cache_fused_processed ,
1469- custom_mask ,
14701418 )
14711419 return (
14721420 prepare_outputs (output , actual_num_q_heads_per_kv_head , actual_head_dim ),
0 commit comments