1313from jax import lax
1414from jax .experimental import pallas as pl
1515from jax .experimental .pallas import tpu as pltpu
16+ from numpy import int32
1617
1718from sgl_jax .srt .layers .attention .flash_attn_kernel .tuned_block_sizes import (
1819 get_tuned_block_sizes ,
@@ -101,20 +102,31 @@ def ref_ragged_paged_attention(
101102 cu_q_lens : jax .Array , # i32[padded_batch_size + 1]
102103 num_seqs : jax .Array , # i32[1],
103104 * ,
105+ custom_mask : jax .Array = None , # [pattern_total_kv_len]
106+ causal : bool = True ,
104107 sm_scale : float = 1.0 ,
105108 sliding_window : int | None = None ,
106109 soft_cap : float | None = None ,
107110 mask_value : float | None = DEFAULT_MASK_VALUE ,
108111 k_scale : float | None = None ,
109112 v_scale : float | None = None ,
110113):
114+ if causal :
115+ if custom_mask != None :
116+ raise ValueError (f"use causal mask, custom_mask is not None" )
117+ else :
118+ if custom_mask == None or custom_mask .size () < jnp .cumsum (kv_lens )[- 1 ]:
119+ raise ValueError (
120+ f"use custom_mask, custom_mask length must larger than total kv length"
121+ )
111122 if mask_value is None :
112123 mask_value = DEFAULT_MASK_VALUE
113124 _ , _ , num_kv_heads , head_dim = k_pages .shape
114125 num_q_heads = queries .shape [1 ]
115126 assert num_q_heads % num_kv_heads == 0
116127 num_query_per_kv = num_q_heads // num_kv_heads
117128 outputs = []
129+ cu_kv_lens = jnp .concatenate ([jnp .array ([0 ], dtype = jnp .int32 ), jnp .cumsum (kv_lens )])
118130 for i in range (num_seqs [0 ]):
119131 q_start = cu_q_lens [i ]
120132 q_end = cu_q_lens [i + 1 ]
@@ -134,9 +146,15 @@ def ref_ragged_paged_attention(
134146 v = jnp .repeat (v , num_query_per_kv , axis = 1 )
135147 attn = jnp .einsum ("qhd,khd->hqk" , q , k , preferred_element_type = jnp .float32 )
136148 attn *= sm_scale
137- q_span = (kv_len - q_len ) + jax .lax .broadcasted_iota (jnp .int32 , attn .shape , 1 )
138- kv_span = jax .lax .broadcasted_iota (jnp .int32 , attn .shape , 2 )
139- mask = q_span < kv_span
149+ if causal :
150+ q_span = (kv_len - q_len ) + jax .lax .broadcasted_iota (
151+ jnp .int32 , attn .shape , 1
152+ )
153+ kv_span = jax .lax .broadcasted_iota (jnp .int32 , attn .shape , 2 )
154+ mask = q_span < kv_span
155+ else :
156+ mask_start = cu_kv_lens [i ]
157+ mask = custom_mask [mask_start : mask_start + kv_len ]
140158 if sliding_window is not None :
141159 mask = jnp .logical_or (mask , q_span - sliding_window >= kv_span )
142160 if soft_cap is not None :
@@ -239,18 +257,21 @@ def _ragged_paged_attention_kernel(
239257 q_hbm_ref , # [actual_num_kv_heads, padded_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
240258 kv_hbm_ref , # [padded_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] - Fused KV with interleaved [K1,V1,K2,V2,...]
241259 kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
260+ custom_mask_ref , # (flatten_total_kv_len,)
242261 # Output
243262 o_hbm_ref , # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
244263 updated_kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
245264 # Scratch
265+ bkvmask_ref , # [2, bq_sz, bkv_sz]
246266 bkv_fused_x2_ref , # [2, bkv_sz, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
247267 bq_x2_ref , # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
248268 bo_x2_ref , # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
249- sems , # [4 , 2]
269+ sems , # [5 , 2]
250270 l_ref , # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
251271 m_ref , # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
252272 acc_ref , # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim],
253273 * ,
274+ causal : int , # shape: (1,) 0: False, 1: True,
254275 sm_scale : float ,
255276 sliding_window : int | None = None ,
256277 soft_cap : float | None = None ,
@@ -297,6 +318,13 @@ def _ragged_paged_attention_kernel(
297318 prefill_end = distribution_ref [1 ]
298319 mixed_end = distribution_ref [2 ]
299320
321+ kv_lens = cu_kv_lens_ref [1 :] - cu_kv_lens_ref [:- 1 ]
322+ q_lens = cu_q_lens_ref [1 :] - cu_q_lens_ref [:- 1 ]
323+ seq_mask_lens = kv_lens * q_lens
324+ cu_seq_mask_lens = jnp .concatenate (
325+ [jnp .array ([0 ], dtype = jnp .int32 ), jnp .cumsum (seq_mask_lens )]
326+ )
327+
300328 q_start = cu_q_lens_ref [seq_idx ]
301329 q_end = cu_q_lens_ref [seq_idx + 1 ]
302330 q_len = q_end - q_start
@@ -309,6 +337,46 @@ def _async_copy(src, dst, sem, wait):
309337 else :
310338 cp .start ()
311339
340+ def _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx , * , wait = False ):
341+ sem = sems .at [4 , bkvmask_sem_idx ]
342+ kvmask_fused_vmem_ref = bkvmask_ref .at [bkvmask_sem_idx ]
343+
344+ kv_len = kv_lens_ref [seq_idx ]
345+ mask_len = kv_len
346+ mask_start = bkvmask_idx * bkv_sz
347+ mask_left = mask_len - mask_start
348+ load_kv_sz = jnp .minimum (bkv_sz , mask_left )
349+
350+ q_len_start = cu_q_lens_ref [seq_idx ] + bq_idx * bq_sz
351+ q_end = cu_q_lens_ref [seq_idx + 1 ]
352+ load_q_sz = jnp .minimum (bq_sz , q_end - q_len_start )
353+
354+ cur_seq_mask_start = cu_seq_mask_lens [seq_idx ]
355+ cur_bq_mask_start = cur_seq_mask_start + bq_idx * bq_sz * kv_len
356+
357+ # Whether using custom mask, depends on causal args
358+ # flatten mask: [TTTTTTFFFFTFTTFFFTTFFTTTTTFFFFTTTTTTFT,FFFTFFTFTTTTTFTFFFFFTTFTTTTFTFTTFTTT]
359+ # ^kv_start ^mask_start
360+ # <--load_sz-->
361+
362+ def loop_body (i , _ ):
363+ start = cur_bq_mask_start + i * kv_len + mask_start
364+ start = jnp .minimum (custom_mask_ref .shape [0 ], start )
365+ _async_copy (
366+ custom_mask_ref .at [pl .ds (start , load_kv_sz )],
367+ kvmask_fused_vmem_ref .at [i , pl .ds (0 , load_kv_sz )],
368+ sem ,
369+ wait ,
370+ )
371+
372+ lax .fori_loop (
373+ 0 ,
374+ load_q_sz ,
375+ loop_body ,
376+ None ,
377+ unroll = False ,
378+ )
379+
312380 def _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , wait = False ):
313381 sem = sems .at [0 , bkv_sem_idx ]
314382 kv_fused_vmem_ref = bkv_fused_x2_ref .at [bkv_sem_idx ]
@@ -442,6 +510,12 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
442510 wait ,
443511 )
444512
513+ def start_fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx ):
514+ return _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx )
515+
516+ def wait_fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx ):
517+ return _fetch_mask (seq_idx , bq_idx , bkvmask_idx , bkvmask_sem_idx , wait = True )
518+
445519 def start_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx ):
446520 return _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
447521
@@ -489,9 +563,10 @@ def load_bq(bq_sem_idx, kv_head_idx, *, actual_bq_sz=bq_sz):
489563 .at [bq_sem_idx , kv_head_idx ]
490564 .reshape (bq_sz * num_q_heads_per_kv_head_per_packing , head_dim )
491565 )
492- return pltpu .bitcast (
566+ res = pltpu .bitcast (
493567 q_ref [: actual_bq_sz * num_q_heads_per_kv_head_per_packing ], q_dtype
494568 )
569+ return res
495570
496571 def strided_load (ref , start , step , * , dtype = None ):
497572 assert get_dtype_packing (ref .dtype ) == 1
@@ -621,6 +696,12 @@ def prefetch_next_bkv():
621696 sem_ids_ref [1 ] = next_bkv_sem_idx
622697 start_fetch_bkv (next_seq_idx , next_bkv_idx , next_bkv_sem_idx )
623698
699+ @pl .when (causal == 0 )
700+ def _ ():
701+ start_fetch_mask (
702+ next_seq_idx , bq_idx , next_bkv_idx , next_bkv_sem_idx
703+ )
704+
624705 # Wait for cur bq if not ready yet
625706 @pl .when (bkv_idx == 0 )
626707 def wait_cur_bq ():
@@ -629,6 +710,11 @@ def wait_cur_bq():
629710 # Wait for cur bkv
630711 offset , update_sz = wait_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
631712
713+ # Wait for kv mask if not use causal mask
714+ @pl .when (causal == 0 )
715+ def _ ():
716+ wait_fetch_mask (seq_idx , bq_idx , bkv_idx , bkv_sem_idx )
717+
632718 # Start updating bkv to kv cache if applicable.
633719 # Only needed in first bq loop.
634720 @pl .when (jnp .logical_and (update_sz > 0 , bq_idx == 0 ))
@@ -664,9 +750,26 @@ def batch_prepare_queries():
664750
665751 return jnp .stack (q_heads , axis = 0 )
666752
753+ def load_mask ():
754+ mask = bkvmask_ref [bkv_sem_idx , :actual_bq_sz ]
755+ # assert False, f'{mask.shape=} {jnp.zeros((actual_num_kv_heads, actual_bq_sz*num_q_heads_per_kv_head, mask.shape[-1])).shape=}'
756+ num_q_heads_per_kv_head_mask = jnp .concat (
757+ [mask ] * num_q_heads_per_kv_head
758+ )
759+ num_kv_heads_mask = jnp .concat (
760+ [
761+ num_q_heads_per_kv_head_mask .reshape (
762+ 1 , * num_q_heads_per_kv_head_mask .shape
763+ )
764+ ]
765+ * actual_num_kv_heads
766+ )
767+ return num_kv_heads_mask
768+
667769 # Load batched data
668770 k_batch , v_batch = batch_load_all_heads_kv ()
669771 q_batch = batch_prepare_queries ()
772+ custom_mask = load_mask ()
670773
671774 def flash_attention (q_batch , k_batch , v_batch ):
672775 q_batch_f32 = q_batch .astype (jnp .float32 )
@@ -701,8 +804,11 @@ def flash_attention(q_batch, k_batch, v_batch):
701804 k_span = bkv_idx * bkv_sz + lax .broadcasted_iota (
702805 jnp .int32 , s .shape , 2
703806 )
704- mask = q_span < k_span
705-
807+ mask = lax .select (
808+ causal == 0 ,
809+ custom_mask ,
810+ q_span < k_span ,
811+ )
706812 if sliding_window is not None :
707813 mask = jnp .logical_or (mask , q_span - sliding_window >= k_span )
708814
@@ -1079,6 +1185,7 @@ def static_validate_inputs_fused(
10791185@functools .partial (
10801186 jax .jit ,
10811187 static_argnames = (
1188+ "causal" ,
10821189 "sm_scale" ,
10831190 "sliding_window" ,
10841191 "soft_cap" ,
@@ -1103,7 +1210,9 @@ def ragged_paged_attention(
11031210 cu_q_lens : jax .Array , # i32[padded_batch_size + 1]
11041211 cu_kv_lens : jax .Array , # i32[padded_batch_size + 1]
11051212 distribution : jax .Array , # i32[3]
1213+ custom_mask : jax .Array , # if causal is True, custom_mask shape is [patten_total_kv_len], else [0]
11061214 * ,
1215+ causal : int = 1 , # 1: True, 0: False
11071216 sm_scale : float = 1.0 ,
11081217 sliding_window : int | None = None ,
11091218 soft_cap : float | None = None ,
@@ -1132,8 +1241,10 @@ def ragged_paged_attention(
11321241 distribution: (i, j, k) represents that sequences[0:i] are decode-only,
11331242 sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
11341243 k is also the total number of sequences.
1244+ custom_mask: use custom mask to calculate attention.
11351245 actual_head_dim: the actual head size of the attention. Here we assume k and
11361246 v have the same actual head size.
1247+ causal: If causal is set to True, use causal mask. Otherwise, use custom_mask.
11371248 sm_scale: the softmax scale which will be applied to the Q@K^T.
11381249 sliding_window: the sliding window size for the attention.
11391250 soft_cap: the logit soft cap for the attention.
@@ -1232,6 +1343,7 @@ def ragged_paged_attention(
12321343 pl .BlockSpec (memory_space = pltpu .ANY ), # q
12331344 pl .BlockSpec (memory_space = pltpu .ANY ), # kv_fused
12341345 pl .BlockSpec (memory_space = pltpu .ANY ), # kv_cache_fused
1346+ pl .BlockSpec (memory_space = pltpu .ANY ), # custom_mask
12351347 ]
12361348
12371349 out_specs = [
@@ -1244,6 +1356,11 @@ def ragged_paged_attention(
12441356 kv_cache_fused_processed .dtype ,
12451357 )
12461358
1359+ bkvmask_double_buf = pltpu .VMEM (
1360+ (2 , bq_sz , bkv_sz ),
1361+ jnp .bool ,
1362+ )
1363+
12471364 bq_double_buf = pltpu .VMEM (
12481365 (2 , actual_num_kv_heads , bq_sz , * q .shape [2 :]),
12491366 q .dtype ,
@@ -1263,11 +1380,12 @@ def ragged_paged_attention(
12631380 )
12641381
12651382 scratch_shapes = [
1383+ bkvmask_double_buf , # Double buffering for fused kv mask block with head interleaving.
12661384 bkv_fused_double_buf , # Double buffering for fused kv block with head interleaving.
12671385 bq_double_buf , # Double buffering for q block.
12681386 bo_double_buf , # Double buffering for output block.
12691387 # Semaphores for double buffering of bkv, bq, bo and bkv_update.
1270- pltpu .SemaphoreType .DMA ((4 , 2 )),
1388+ pltpu .SemaphoreType .DMA ((5 , 2 )),
12711389 # Intermediate buffers per kv head for flash attention.
12721390 l_scratch ,
12731391 m_scratch ,
@@ -1287,12 +1405,12 @@ def ragged_paged_attention(
12871405 # (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
12881406 jnp .full ((6 ,), - 1 , jnp .int32 ),
12891407 )
1290-
12911408 scope_name = f"RPA-bq_{ bq_sz } -bkvp_{ bkv_p } -p_{ page_size } "
12921409 kernel = jax .named_scope (scope_name )(
12931410 pl .pallas_call (
12941411 functools .partial (
12951412 _ragged_paged_attention_kernel ,
1413+ causal = causal ,
12961414 sm_scale = sm_scale ,
12971415 sliding_window = sliding_window ,
12981416 soft_cap = soft_cap ,
@@ -1333,7 +1451,11 @@ def ragged_paged_attention(
13331451 )
13341452
13351453 output , updated_kv_cache_fused = kernel (
1336- * scalar_prefetches , q , kv , kv_cache_fused_processed
1454+ * scalar_prefetches ,
1455+ q ,
1456+ kv ,
1457+ kv_cache_fused_processed ,
1458+ custom_mask ,
13371459 )
13381460 return (
13391461 prepare_outputs (output , actual_num_q_heads_per_kv_head , actual_head_dim ),
0 commit comments