Skip to content

Commit 4d18b12

Browse files
committed
fix
1 parent 4c7aab3 commit 4d18b12

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def ref_ragged_paged_attention(
115115
if custom_mask != None:
116116
raise ValueError(f"use causal mask, custom_mask is not None")
117117
else:
118+
print(f"######### {custom_mask=} {kv_lens=}")
118119
if custom_mask == None or custom_mask.size() < jnp.cumsum(kv_lens)[-1]:
119120
raise ValueError(
120121
f"use custom_mask, custom_mask length must larger than total kv length"

python/sgl_jax/test/test_flashattention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def run_test(self, mode, lens, mode_args):
405405
cache_loc_list.append(padded_page_indices)
406406
page_table = jnp.stack(cache_loc_list)
407407

408+
print(f"@@@@@ {forward_batch.spec_info=}")
408409
expected = ref_ragged_paged_attention(
409410
q.reshape(q.shape[0], num_heads, head_dim),
410411
k.reshape(k.shape[0] // page_size, page_size, num_kv_heads, head_dim),

0 commit comments

Comments
 (0)