Skip to content

Commit 6054f6b

Browse files
committed
fix
1 parent feb3247 commit 6054f6b

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def _ragged_paged_attention_kernel(
249249
page_indices_ref, # [(padded_batch_size * model_context_len + page_size - 1) // page_size]
250250
cu_q_lens_ref, # [padded_batch_size + 1]
251251
cu_kv_lens_ref, # [padded_batch_size + 1]
252+
cu_seq_mask_lens,
252253
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
253254
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
254255
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
@@ -318,13 +319,6 @@ def _ragged_paged_attention_kernel(
318319
prefill_end = distribution_ref[1]
319320
mixed_end = distribution_ref[2]
320321

321-
kv_lens = cu_kv_lens_ref[1:] - cu_kv_lens_ref[:-1]
322-
q_lens = cu_q_lens_ref[1:] - cu_q_lens_ref[:-1]
323-
seq_mask_lens = kv_lens * q_lens
324-
cu_seq_mask_lens = jnp.concatenate(
325-
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(seq_mask_lens)]
326-
)
327-
328322
q_start = cu_q_lens_ref[seq_idx]
329323
q_end = cu_q_lens_ref[seq_idx + 1]
330324
q_len = q_end - q_start
@@ -1337,6 +1331,16 @@ def ragged_paged_attention(
13371331
)
13381332
* 2.4
13391333
)
1334+
1335+
q_lens = cu_q_lens[1:] - cu_q_lens[:-1]
1336+
seq_mask_lens = kv_lens * q_lens
1337+
cu_seq_mask_lens = jnp.concatenate(
1338+
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(seq_mask_lens)]
1339+
)
1340+
if custom_mask is None:
1341+
# fix bug: XLA layout ({0}) does not match Mosaic layout ({0:T(128)}) for an operand of shape s32[0]
1342+
custom_mask = jnp.empty((1, 128))
1343+
13401344
grid = (distribution[2],)
13411345

13421346
in_specs = [
@@ -1397,6 +1401,7 @@ def ragged_paged_attention(
13971401
page_indices,
13981402
cu_q_lens,
13991403
cu_kv_lens,
1404+
cu_seq_mask_lens,
14001405
distribution,
14011406
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
14021407
jnp.zeros((3,), jnp.int32),
@@ -1443,8 +1448,8 @@ def ragged_paged_attention(
14431448
),
14441449
],
14451450
input_output_aliases={
1446-
8: 0, # q input -> q output
1447-
10: 1, # kv_cache_fused input -> updated kv_cache_fused output
1451+
9: 0, # q input -> q output
1452+
11: 1, # kv_cache_fused input -> updated kv_cache_fused output
14481453
},
14491454
name=scope_name,
14501455
)

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_forward_metadata(self, batch: ModelWorkerBatch):
103103
if batch.forward_mode == ForwardMode.TARGET_VERIFY:
104104
metadata.custom_mask = batch.spec_info.custom_mask
105105
else:
106-
metadata.custom_mask = jnp.array([], dtype=jnp.bool)
106+
metadata.custom_mask = None
107107

108108
if batch.forward_mode.is_extend():
109109
cu_q_lens = np.concatenate(

python/sgl_jax/test/test_flashattention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,12 +648,14 @@ def test_mha_prefill_with_custom_mask(self):
648648
(2, 22),
649649
(42, 42),
650650
]
651-
self.run_test(
652-
"prefill",
653-
lens,
654-
(num_heads, head_dim, num_kv_heads, 1, jnp.bfloat16),
655-
False,
656-
)
651+
page_size = [1, 64]
652+
causal_mask = False
653+
for size in page_size:
654+
self.run_test(
655+
"prefill",
656+
lens,
657+
(num_heads, head_dim, num_kv_heads, size, jnp.bfloat16, causal_mask),
658+
)
657659

658660
def test_mha_decode_with_custom_mask(self):
659661
pass

0 commit comments

Comments
 (0)