Skip to content

Commit 115b2da

Browse files
committed
use scalar prefetch to get custom mask
1 parent 57d7f3f commit 115b2da

File tree

2 files changed

+30
-86
lines changed

2 files changed

+30
-86
lines changed

python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py

Lines changed: 29 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,15 @@ def _ragged_paged_attention_kernel(
254254
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
255255
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
256256
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
257+
custom_mask_ref, # (flatten_total_kv_len,),
257258
# Input
258259
q_hbm_ref, # [actual_num_kv_heads, padded_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
259260
kv_hbm_ref, # [padded_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] - Fused KV with interleaved [K1,V1,K2,V2,...]
260261
kv_cache_fused_hbm_ref, # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
261-
custom_mask_ref, # (flatten_total_kv_len,), int8, dma not support bool type
262262
# Output
263263
o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
264264
updated_kv_cache_fused_hbm_ref, # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
265265
# Scratch
266-
bkvmask_ref, # [2, bq_sz, bkv_sz]
267266
bkv_fused_x2_ref, # [2, bkv_sz, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
268267
bq_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
269268
bo_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
@@ -324,54 +323,19 @@ def _ragged_paged_attention_kernel(
324323
q_len = q_end - q_start
325324
kv_len = kv_lens_ref[seq_idx]
326325

326+
cur_seq_mask_start = cu_seq_mask_lens[seq_idx]
327+
cur_seq_mask_len = q_len * kv_len
328+
cur_seq_mask = custom_mask_ref[
329+
cur_seq_mask_start : cur_seq_mask_start + cur_seq_mask_len
330+
].reshape(q_len, kv_len)
331+
327332
def _async_copy(src, dst, sem, wait):
328333
cp = pltpu.make_async_copy(src, dst, sem)
329334
if wait:
330335
cp.wait()
331336
else:
332337
cp.start()
333338

334-
def _fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx, *, wait=False):
335-
sem = sems.at[4, bkvmask_sem_idx]
336-
assert sem.dtype == sems.dtype, f"######## {sem.dtype=} {sems.dtype=}"
337-
kvmask_fused_vmem_ref = bkvmask_ref.at[bkvmask_sem_idx]
338-
339-
kv_len = kv_lens_ref[seq_idx]
340-
mask_len = kv_len
341-
mask_start = bkvmask_idx * bkv_sz
342-
mask_left = mask_len - mask_start
343-
load_kv_sz = jnp.minimum(bkv_sz, mask_left)
344-
345-
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
346-
q_end = cu_q_lens_ref[seq_idx + 1]
347-
load_q_sz = jnp.minimum(bq_sz, q_end - q_len_start)
348-
349-
cur_seq_mask_start = cu_seq_mask_lens[seq_idx]
350-
cur_bq_mask_start = cur_seq_mask_start + bq_idx * bq_sz * kv_len
351-
352-
# Whether using custom mask, depends on causal args
353-
# flatten mask: [TTTTTTFFFFTFTTFFFTTFFTTTTTFFFFTTTTTTFT,FFFTFFTFTTTTTFTFFFFFTTFTTTTFTFTTFTTT]
354-
# ^kv_start ^mask_start
355-
# <--load_sz-->
356-
357-
def loop_body(i, _):
358-
start = cur_bq_mask_start + i * kv_len + mask_start
359-
start = jnp.minimum(custom_mask_ref.shape[0], start)
360-
_async_copy(
361-
custom_mask_ref.at[pl.ds(start, load_kv_sz)],
362-
kvmask_fused_vmem_ref.at[i, pl.ds(0, load_kv_sz)],
363-
sem,
364-
wait,
365-
)
366-
367-
lax.fori_loop(
368-
0,
369-
load_q_sz,
370-
loop_body,
371-
None,
372-
unroll=False,
373-
)
374-
375339
def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
376340
sem = sems.at[0, bkv_sem_idx]
377341
kv_fused_vmem_ref = bkv_fused_x2_ref.at[bkv_sem_idx]
@@ -505,12 +469,6 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
505469
wait,
506470
)
507471

508-
def start_fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx):
509-
return _fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx)
510-
511-
def wait_fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx):
512-
return _fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx, wait=True)
513-
514472
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
515473
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
516474

@@ -691,12 +649,6 @@ def prefetch_next_bkv():
691649
sem_ids_ref[1] = next_bkv_sem_idx
692650
start_fetch_bkv(next_seq_idx, next_bkv_idx, next_bkv_sem_idx)
693651

694-
@pl.when(causal == 0)
695-
def _():
696-
start_fetch_mask(
697-
next_seq_idx, bq_idx, next_bkv_idx, next_bkv_sem_idx
698-
)
699-
700652
# Wait for cur bq if not ready yet
701653
@pl.when(bkv_idx == 0)
702654
def wait_cur_bq():
@@ -705,11 +657,6 @@ def wait_cur_bq():
705657
# Wait for cur bkv
706658
offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
707659

708-
# Wait for kv mask if not use causal mask
709-
@pl.when(causal == 0)
710-
def _():
711-
wait_fetch_mask(seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
712-
713660
# Start updating bkv to kv cache if applicable.
714661
# Only needed in first bq loop.
715662
@pl.when(jnp.logical_and(update_sz > 0, bq_idx == 0))
@@ -746,10 +693,23 @@ def batch_prepare_queries():
746693
return jnp.stack(q_heads, axis=0)
747694

748695
def load_mask():
749-
mask = bkvmask_ref[bkv_sem_idx, :actual_bq_sz]
750-
# assert False, f'{mask.shape=} {jnp.zeros((actual_num_kv_heads, actual_bq_sz*num_q_heads_per_kv_head, mask.shape[-1])).shape=}'
696+
bq_mask_start = bq_idx * bq_sz
697+
bq_mask_end = (bq_idx + 1) * bq_sz
698+
bq_mask_offset = lax.select(
699+
bq_mask_end - q_len < 0, bq_sz, bq_mask_end - q_len
700+
)
701+
702+
bkv_mask_start = bkv_idx * bkv_sz
703+
bkv_mask_end = (bkv_idx + 1) * bkv_sz
704+
bkv_mask_offset = lax.select(
705+
bkv_mask_end - kv_len < 0, bkv_sz, bkv_mask_end - kv_len
706+
)
707+
708+
cur_bq_bkv_mask = cur_seq_mask[
709+
bq_mask_start:bq_mask_offset, bkv_mask_start:bkv_mask_offset
710+
]
751711
num_q_heads_per_kv_head_mask = jnp.concat(
752-
[mask] * num_q_heads_per_kv_head
712+
[cur_bq_bkv_mask] * num_q_heads_per_kv_head
753713
)
754714
num_kv_heads_mask = jnp.concat(
755715
[
@@ -764,7 +724,6 @@ def load_mask():
764724
# Load batched data
765725
k_batch, v_batch = batch_load_all_heads_kv()
766726
q_batch = batch_prepare_queries()
767-
custom_mask = load_mask()
768727

769728
def flash_attention(q_batch, k_batch, v_batch):
770729
q_batch_f32 = q_batch.astype(jnp.float32)
@@ -799,12 +758,8 @@ def flash_attention(q_batch, k_batch, v_batch):
799758
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(
800759
jnp.int32, s.shape, 2
801760
)
802-
# convert custom_mask from int8 to bool
803-
mask = lax.select(
804-
causal == 0,
805-
custom_mask.astype(jnp.bool),
806-
q_span < k_span,
807-
)
761+
mask = lax.cond(causal == 1, lambda: q_span < k_span, load_mask)
762+
808763
if sliding_window is not None:
809764
mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
810765

@@ -1341,7 +1296,7 @@ def ragged_paged_attention(
13411296
)
13421297
if custom_mask is None:
13431298
# fix bug: XLA layout ({0}) does not match Mosaic layout ({0:T(128)}) for an operand of shape s32[0]
1344-
custom_mask = jnp.empty((1, 128), dtype=jnp.int8)
1299+
custom_mask = jnp.empty((1, 128), dtype=jnp.bool)
13451300
else:
13461301
assert (
13471302
custom_mask.dtype != jnp.bool
@@ -1353,7 +1308,6 @@ def ragged_paged_attention(
13531308
pl.BlockSpec(memory_space=pltpu.ANY), # q
13541309
pl.BlockSpec(memory_space=pltpu.ANY), # kv_fused
13551310
pl.BlockSpec(memory_space=pltpu.ANY), # kv_cache_fused
1356-
pl.BlockSpec(memory_space=pltpu.ANY), # custom_mask
13571311
]
13581312

13591313
out_specs = [
@@ -1366,11 +1320,6 @@ def ragged_paged_attention(
13661320
kv_cache_fused_processed.dtype,
13671321
)
13681322

1369-
bkvmask_double_buf = pltpu.VMEM(
1370-
(2, bq_sz, bkv_sz),
1371-
jnp.bool,
1372-
)
1373-
13741323
bq_double_buf = pltpu.VMEM(
13751324
(2, actual_num_kv_heads, bq_sz, *q.shape[2:]),
13761325
q.dtype,
@@ -1390,7 +1339,6 @@ def ragged_paged_attention(
13901339
)
13911340

13921341
scratch_shapes = [
1393-
bkvmask_double_buf, # Double buffering for fused kv mask block with head interleaving.
13941342
bkv_fused_double_buf, # Double buffering for fused kv block with head interleaving.
13951343
bq_double_buf, # Double buffering for q block.
13961344
bo_double_buf, # Double buffering for output block.
@@ -1415,6 +1363,7 @@ def ragged_paged_attention(
14151363
jnp.full((4,), -1, jnp.int32),
14161364
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
14171365
jnp.full((6,), -1, jnp.int32),
1366+
custom_mask,
14181367
)
14191368
scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
14201369
kernel = jax.named_scope(scope_name)(
@@ -1454,8 +1403,8 @@ def ragged_paged_attention(
14541403
),
14551404
],
14561405
input_output_aliases={
1457-
9: 0, # q input -> q output
1458-
11: 1, # kv_cache_fused input -> updated kv_cache_fused output
1406+
10: 0, # q input -> q output
1407+
12: 1, # kv_cache_fused input -> updated kv_cache_fused output
14591408
},
14601409
name=scope_name,
14611410
)
@@ -1466,7 +1415,6 @@ def ragged_paged_attention(
14661415
q,
14671416
kv,
14681417
kv_cache_fused_processed,
1469-
custom_mask,
14701418
)
14711419
return (
14721420
prepare_outputs(output, actual_num_q_heads_per_kv_head, actual_head_dim),

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,7 @@ def get_forward_metadata(self, batch: ModelWorkerBatch):
101101
page_indices = (selected_cache_locs // self.page_size).astype(np.int32)
102102

103103
if batch.forward_mode == ForwardMode.TARGET_VERIFY:
104-
# convert custom_mask from bool to int8, because dma not support bool type
105-
if batch.spec_info.custom_mask.dtype == jnp.bool:
106-
metadata.custom_mask = batch.spec_info.custom_mask.astype(jnp.int8)
107-
else:
108-
metadata.custom_mask = batch.spec_info.custom_mask
104+
metadata.custom_mask = batch.spec_info.custom_mask
109105
else:
110106
metadata.custom_mask = None
111107

0 commit comments

Comments
 (0)