@@ -258,7 +258,7 @@ def _ragged_paged_attention_kernel(
258258 q_hbm_ref , # [actual_num_kv_heads, padded_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
259259 kv_hbm_ref , # [padded_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] - Fused KV with interleaved [K1,V1,K2,V2,...]
260260 kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
261- custom_mask_ref , # (flatten_total_kv_len,)
261+ custom_mask_ref , # (flatten_total_kv_len,), int8, dma not support bool type
262262 # Output
263263 o_hbm_ref , # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
264264 updated_kv_cache_fused_hbm_ref , # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
@@ -799,9 +799,10 @@ def flash_attention(q_batch, k_batch, v_batch):
799799 k_span = bkv_idx * bkv_sz + lax .broadcasted_iota (
800800 jnp .int32 , s .shape , 2
801801 )
802+ # convert custom_mask from int8 to bool
802803 mask = lax .select (
803804 causal == 0 ,
804- custom_mask ,
805+ custom_mask . astype ( jnp . bool ) ,
805806 q_span < k_span ,
806807 )
807808 if sliding_window is not None :
@@ -1340,7 +1341,11 @@ def ragged_paged_attention(
13401341 )
13411342 if custom_mask is None :
13421343 # fix bug: XLA layout ({0}) does not match Mosaic layout ({0:T(128)}) for an operand of shape s32[0]
1343- custom_mask = jnp .empty ((1 , 128 ))
1344+ custom_mask = jnp .empty ((1 , 128 ), dtype = jnp .int8 )
1345+ else :
1346+ assert (
1347+ custom_mask .dtype != jnp .bool
1348+ ), f"custom_mask bool dtype is not supported, use int32 instead. 0: False, 1: True"
13441349
13451350 grid = (distribution [2 ],)
13461351
0 commit comments