Skip to content

Commit feb3247

Browse files
committed
feat: support custom mask for flash attention
1 parent b602c04 commit feb3247

File tree

3 files changed

+215
-13
lines changed

3 files changed

+215
-13
lines changed

python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py

Lines changed: 132 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from jax import lax
1414
from jax.experimental import pallas as pl
1515
from jax.experimental.pallas import tpu as pltpu
16+
from numpy import int32
1617

1718
from sgl_jax.srt.layers.attention.flash_attn_kernel.tuned_block_sizes import (
1819
get_tuned_block_sizes,
@@ -101,20 +102,31 @@ def ref_ragged_paged_attention(
101102
cu_q_lens: jax.Array, # i32[padded_batch_size + 1]
102103
num_seqs: jax.Array, # i32[1],
103104
*,
105+
custom_mask: jax.Array = None, # [pattern_total_kv_len]
106+
causal: bool = True,
104107
sm_scale: float = 1.0,
105108
sliding_window: int | None = None,
106109
soft_cap: float | None = None,
107110
mask_value: float | None = DEFAULT_MASK_VALUE,
108111
k_scale: float | None = None,
109112
v_scale: float | None = None,
110113
):
114+
if causal:
115+
if custom_mask != None:
116+
raise ValueError(f"use causal mask, custom_mask is not None")
117+
else:
118+
if custom_mask == None or custom_mask.size() < jnp.cumsum(kv_lens)[-1]:
119+
raise ValueError(
120+
f"use custom_mask, custom_mask length must larger than total kv length"
121+
)
111122
if mask_value is None:
112123
mask_value = DEFAULT_MASK_VALUE
113124
_, _, num_kv_heads, head_dim = k_pages.shape
114125
num_q_heads = queries.shape[1]
115126
assert num_q_heads % num_kv_heads == 0
116127
num_query_per_kv = num_q_heads // num_kv_heads
117128
outputs = []
129+
cu_kv_lens = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(kv_lens)])
118130
for i in range(num_seqs[0]):
119131
q_start = cu_q_lens[i]
120132
q_end = cu_q_lens[i + 1]
@@ -134,9 +146,15 @@ def ref_ragged_paged_attention(
134146
v = jnp.repeat(v, num_query_per_kv, axis=1)
135147
attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32)
136148
attn *= sm_scale
137-
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(jnp.int32, attn.shape, 1)
138-
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
139-
mask = q_span < kv_span
149+
if causal:
150+
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
151+
jnp.int32, attn.shape, 1
152+
)
153+
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
154+
mask = q_span < kv_span
155+
else:
156+
mask_start = cu_kv_lens[i]
157+
mask = custom_mask[mask_start : mask_start + kv_len]
140158
if sliding_window is not None:
141159
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
142160
if soft_cap is not None:
@@ -239,18 +257,21 @@ def _ragged_paged_attention_kernel(
239257
q_hbm_ref, # [actual_num_kv_heads, padded_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
240258
kv_hbm_ref, # [padded_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] - Fused KV with interleaved [K1,V1,K2,V2,...]
241259
kv_cache_fused_hbm_ref, # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
260+
custom_mask_ref, # (flatten_total_kv_len,)
242261
# Output
243262
o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
244263
updated_kv_cache_fused_hbm_ref, # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
245264
# Scratch
265+
bkvmask_ref, # [2, bq_sz, bkv_sz]
246266
bkv_fused_x2_ref, # [2, bkv_sz, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
247267
bq_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
248268
bo_x2_ref, # [2, actual_num_kv_heads, bq_sz, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
249-
sems, # [4, 2]
269+
sems, # [5, 2]
250270
l_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
251271
m_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, 128],
252272
acc_ref, # [actual_num_kv_heads, bq_sz * num_q_heads_per_kv_head, head_dim],
253273
*,
274+
causal: int, # shape: (1,) 0: False, 1: True,
254275
sm_scale: float,
255276
sliding_window: int | None = None,
256277
soft_cap: float | None = None,
@@ -297,6 +318,13 @@ def _ragged_paged_attention_kernel(
297318
prefill_end = distribution_ref[1]
298319
mixed_end = distribution_ref[2]
299320

321+
kv_lens = cu_kv_lens_ref[1:] - cu_kv_lens_ref[:-1]
322+
q_lens = cu_q_lens_ref[1:] - cu_q_lens_ref[:-1]
323+
seq_mask_lens = kv_lens * q_lens
324+
cu_seq_mask_lens = jnp.concatenate(
325+
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(seq_mask_lens)]
326+
)
327+
300328
q_start = cu_q_lens_ref[seq_idx]
301329
q_end = cu_q_lens_ref[seq_idx + 1]
302330
q_len = q_end - q_start
@@ -309,6 +337,46 @@ def _async_copy(src, dst, sem, wait):
309337
else:
310338
cp.start()
311339

340+
def _fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx, *, wait=False):
341+
sem = sems.at[4, bkvmask_sem_idx]
342+
kvmask_fused_vmem_ref = bkvmask_ref.at[bkvmask_sem_idx]
343+
344+
kv_len = kv_lens_ref[seq_idx]
345+
mask_len = kv_len
346+
mask_start = bkvmask_idx * bkv_sz
347+
mask_left = mask_len - mask_start
348+
load_kv_sz = jnp.minimum(bkv_sz, mask_left)
349+
350+
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
351+
q_end = cu_q_lens_ref[seq_idx + 1]
352+
load_q_sz = jnp.minimum(bq_sz, q_end - q_len_start)
353+
354+
cur_seq_mask_start = cu_seq_mask_lens[seq_idx]
355+
cur_bq_mask_start = cur_seq_mask_start + bq_idx * bq_sz * kv_len
356+
357+
# Whether using custom mask, depends on causal args
358+
# flatten mask: [TTTTTTFFFFTFTTFFFTTFFTTTTTFFFFTTTTTTFT,FFFTFFTFTTTTTFTFFFFFTTFTTTTFTFTTFTTT]
359+
# ^kv_start ^mask_start
360+
# <--load_sz-->
361+
362+
def loop_body(i, _):
363+
start = cur_bq_mask_start + i * kv_len + mask_start
364+
start = jnp.minimum(custom_mask_ref.shape[0], start)
365+
_async_copy(
366+
custom_mask_ref.at[pl.ds(start, load_kv_sz)],
367+
kvmask_fused_vmem_ref.at[i, pl.ds(0, load_kv_sz)],
368+
sem,
369+
wait,
370+
)
371+
372+
lax.fori_loop(
373+
0,
374+
load_q_sz,
375+
loop_body,
376+
None,
377+
unroll=False,
378+
)
379+
312380
def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
313381
sem = sems.at[0, bkv_sem_idx]
314382
kv_fused_vmem_ref = bkv_fused_x2_ref.at[bkv_sem_idx]
@@ -442,6 +510,12 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
442510
wait,
443511
)
444512

513+
def start_fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx):
514+
return _fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx)
515+
516+
def wait_fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx):
517+
return _fetch_mask(seq_idx, bq_idx, bkvmask_idx, bkvmask_sem_idx, wait=True)
518+
445519
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
446520
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
447521

@@ -489,9 +563,10 @@ def load_bq(bq_sem_idx, kv_head_idx, *, actual_bq_sz=bq_sz):
489563
.at[bq_sem_idx, kv_head_idx]
490564
.reshape(bq_sz * num_q_heads_per_kv_head_per_packing, head_dim)
491565
)
492-
return pltpu.bitcast(
566+
res = pltpu.bitcast(
493567
q_ref[: actual_bq_sz * num_q_heads_per_kv_head_per_packing], q_dtype
494568
)
569+
return res
495570

496571
def strided_load(ref, start, step, *, dtype=None):
497572
assert get_dtype_packing(ref.dtype) == 1
@@ -621,6 +696,12 @@ def prefetch_next_bkv():
621696
sem_ids_ref[1] = next_bkv_sem_idx
622697
start_fetch_bkv(next_seq_idx, next_bkv_idx, next_bkv_sem_idx)
623698

699+
@pl.when(causal == 0)
700+
def _():
701+
start_fetch_mask(
702+
next_seq_idx, bq_idx, next_bkv_idx, next_bkv_sem_idx
703+
)
704+
624705
# Wait for cur bq if not ready yet
625706
@pl.when(bkv_idx == 0)
626707
def wait_cur_bq():
@@ -629,6 +710,11 @@ def wait_cur_bq():
629710
# Wait for cur bkv
630711
offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
631712

713+
# Wait for kv mask if not use causal mask
714+
@pl.when(causal == 0)
715+
def _():
716+
wait_fetch_mask(seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
717+
632718
# Start updating bkv to kv cache if applicable.
633719
# Only needed in first bq loop.
634720
@pl.when(jnp.logical_and(update_sz > 0, bq_idx == 0))
@@ -664,9 +750,26 @@ def batch_prepare_queries():
664750

665751
return jnp.stack(q_heads, axis=0)
666752

753+
def load_mask():
754+
mask = bkvmask_ref[bkv_sem_idx, :actual_bq_sz]
755+
# assert False, f'{mask.shape=} {jnp.zeros((actual_num_kv_heads, actual_bq_sz*num_q_heads_per_kv_head, mask.shape[-1])).shape=}'
756+
num_q_heads_per_kv_head_mask = jnp.concat(
757+
[mask] * num_q_heads_per_kv_head
758+
)
759+
num_kv_heads_mask = jnp.concat(
760+
[
761+
num_q_heads_per_kv_head_mask.reshape(
762+
1, *num_q_heads_per_kv_head_mask.shape
763+
)
764+
]
765+
* actual_num_kv_heads
766+
)
767+
return num_kv_heads_mask
768+
667769
# Load batched data
668770
k_batch, v_batch = batch_load_all_heads_kv()
669771
q_batch = batch_prepare_queries()
772+
custom_mask = load_mask()
670773

671774
def flash_attention(q_batch, k_batch, v_batch):
672775
q_batch_f32 = q_batch.astype(jnp.float32)
@@ -701,8 +804,11 @@ def flash_attention(q_batch, k_batch, v_batch):
701804
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(
702805
jnp.int32, s.shape, 2
703806
)
704-
mask = q_span < k_span
705-
807+
mask = lax.select(
808+
causal == 0,
809+
custom_mask,
810+
q_span < k_span,
811+
)
706812
if sliding_window is not None:
707813
mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
708814

@@ -1079,6 +1185,7 @@ def static_validate_inputs_fused(
10791185
@functools.partial(
10801186
jax.jit,
10811187
static_argnames=(
1188+
"causal",
10821189
"sm_scale",
10831190
"sliding_window",
10841191
"soft_cap",
@@ -1103,7 +1210,9 @@ def ragged_paged_attention(
11031210
cu_q_lens: jax.Array, # i32[padded_batch_size + 1]
11041211
cu_kv_lens: jax.Array, # i32[padded_batch_size + 1]
11051212
distribution: jax.Array, # i32[3]
1213+
custom_mask: jax.Array, # if causal is True, custom_mask shape is [patten_total_kv_len], else [0]
11061214
*,
1215+
causal: int = 1, # 1: True, 0: False
11071216
sm_scale: float = 1.0,
11081217
sliding_window: int | None = None,
11091218
soft_cap: float | None = None,
@@ -1132,8 +1241,10 @@ def ragged_paged_attention(
11321241
distribution: (i, j, k) represents that sequences[0:i] are decode-only,
11331242
sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
11341243
k is also the total number of sequences.
1244+
custom_mask: use custom mask to calculate attention.
11351245
actual_head_dim: the actual head size of the attention. Here we assume k and
11361246
v have the same actual head size.
1247+
causal: If causal is set to True, use causal mask. Otherwise, use custom_mask.
11371248
sm_scale: the softmax scale which will be applied to the Q@K^T.
11381249
sliding_window: the sliding window size for the attention.
11391250
soft_cap: the logit soft cap for the attention.
@@ -1232,6 +1343,7 @@ def ragged_paged_attention(
12321343
pl.BlockSpec(memory_space=pltpu.ANY), # q
12331344
pl.BlockSpec(memory_space=pltpu.ANY), # kv_fused
12341345
pl.BlockSpec(memory_space=pltpu.ANY), # kv_cache_fused
1346+
pl.BlockSpec(memory_space=pltpu.ANY), # custom_mask
12351347
]
12361348

12371349
out_specs = [
@@ -1244,6 +1356,11 @@ def ragged_paged_attention(
12441356
kv_cache_fused_processed.dtype,
12451357
)
12461358

1359+
bkvmask_double_buf = pltpu.VMEM(
1360+
(2, bq_sz, bkv_sz),
1361+
jnp.bool,
1362+
)
1363+
12471364
bq_double_buf = pltpu.VMEM(
12481365
(2, actual_num_kv_heads, bq_sz, *q.shape[2:]),
12491366
q.dtype,
@@ -1263,11 +1380,12 @@ def ragged_paged_attention(
12631380
)
12641381

12651382
scratch_shapes = [
1383+
bkvmask_double_buf, # Double buffering for fused kv mask block with head interleaving.
12661384
bkv_fused_double_buf, # Double buffering for fused kv block with head interleaving.
12671385
bq_double_buf, # Double buffering for q block.
12681386
bo_double_buf, # Double buffering for output block.
12691387
# Semaphores for double buffering of bkv, bq, bo and bkv_update.
1270-
pltpu.SemaphoreType.DMA((4, 2)),
1388+
pltpu.SemaphoreType.DMA((5, 2)),
12711389
# Intermediate buffers per kv head for flash attention.
12721390
l_scratch,
12731391
m_scratch,
@@ -1287,12 +1405,12 @@ def ragged_paged_attention(
12871405
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
12881406
jnp.full((6,), -1, jnp.int32),
12891407
)
1290-
12911408
scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
12921409
kernel = jax.named_scope(scope_name)(
12931410
pl.pallas_call(
12941411
functools.partial(
12951412
_ragged_paged_attention_kernel,
1413+
causal=causal,
12961414
sm_scale=sm_scale,
12971415
sliding_window=sliding_window,
12981416
soft_cap=soft_cap,
@@ -1333,7 +1451,11 @@ def ragged_paged_attention(
13331451
)
13341452

13351453
output, updated_kv_cache_fused = kernel(
1336-
*scalar_prefetches, q, kv, kv_cache_fused_processed
1454+
*scalar_prefetches,
1455+
q,
1456+
kv,
1457+
kv_cache_fused_processed,
1458+
custom_mask,
13371459
)
13381460
return (
13391461
prepare_outputs(output, actual_num_q_heads_per_kv_head, actual_head_dim),

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class FlashAttentionMetadata:
3434
page_indices: jax.Array = None
3535
seq_lens: jax.Array = None
3636
distribution: jax.Array = None
37+
custom_mask: jax.Array = None
3738

3839
def tree_flatten(self):
3940
children = (
@@ -43,6 +44,7 @@ def tree_flatten(self):
4344
self.page_indices,
4445
self.seq_lens,
4546
self.distribution,
47+
self.custom_mask,
4648
)
4749

4850
aux_data = {}
@@ -58,6 +60,7 @@ def tree_unflatten(cls, aux_data, children):
5860
obj.page_indices = children[3]
5961
obj.seq_lens = children[4]
6062
obj.distribution = children[5]
63+
obj.custom_mask = children[6]
6164

6265
return obj
6366

@@ -97,6 +100,11 @@ def get_forward_metadata(self, batch: ModelWorkerBatch):
97100
selected_cache_locs = batch.cache_loc[indices]
98101
page_indices = (selected_cache_locs // self.page_size).astype(np.int32)
99102

103+
if batch.forward_mode == ForwardMode.TARGET_VERIFY:
104+
metadata.custom_mask = batch.spec_info.custom_mask
105+
else:
106+
metadata.custom_mask = jnp.array([], dtype=jnp.bool)
107+
100108
if batch.forward_mode.is_extend():
101109
cu_q_lens = np.concatenate(
102110
[
@@ -215,6 +223,11 @@ def __call__(
215223
num_pages, self.page_size, -1, self.head_dim
216224
)
217225

226+
causal = 1
227+
custom_mask = self.forward_metadata.custom_mask
228+
if forward_batch.forward_mode == ForwardMode.TARGET_VERIFY:
229+
causal = 0
230+
218231
in_specs = (
219232
P(None, self.kv_partition_axis), # queries
220233
P(None, self.kv_partition_axis), # keys (new tokens)
@@ -227,6 +240,7 @@ def __call__(
227240
P(), # cu_q_lens
228241
P(), # cu_kv_lens
229242
P(), # distribution
243+
P(), # custom_mask
230244
)
231245
out_specs = (
232246
P(None, self.kv_partition_axis), # attention output
@@ -246,6 +260,7 @@ def _ragged_paged_attention_with_fused_kv(*args):
246260
values,
247261
kv_cache_fused,
248262
*other_args,
263+
causal=causal,
249264
sm_scale=scale,
250265
sliding_window=None,
251266
soft_cap=None,
@@ -272,6 +287,7 @@ def _ragged_paged_attention_with_fused_kv(*args):
272287
self.forward_metadata.cu_q_lens,
273288
self.forward_metadata.cu_kv_lens,
274289
self.forward_metadata.distribution,
290+
self.forward_metadata.custom_mask,
275291
)
276292

277293
return (

0 commit comments

Comments
 (0)