Skip to content

Commit 57d7f3f

Browse files
committed
fix
1 parent c3d902d commit 57d7f3f

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def _ragged_paged_attention_kernel(
258258
q_hbm_ref, # [actual_num_kv_heads, padded_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
259259
kv_hbm_ref, # [padded_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim] - Fused KV with interleaved [K1,V1,K2,V2,...]
260260
kv_cache_fused_hbm_ref, # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
261-
custom_mask_ref, # (flatten_total_kv_len,)
261+
custom_mask_ref, # (flatten_total_kv_len,), int8, dma not support bool type
262262
# Output
263263
o_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
264264
updated_kv_cache_fused_hbm_ref, # [total_num_pages, page_size, num_kv_heads_interleaved // kv_packing, kv_packing, head_dim]
@@ -799,9 +799,10 @@ def flash_attention(q_batch, k_batch, v_batch):
799799
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(
800800
jnp.int32, s.shape, 2
801801
)
802+
# convert custom_mask from int8 to bool
802803
mask = lax.select(
803804
causal == 0,
804-
custom_mask,
805+
custom_mask.astype(jnp.bool),
805806
q_span < k_span,
806807
)
807808
if sliding_window is not None:
@@ -1340,7 +1341,11 @@ def ragged_paged_attention(
13401341
)
13411342
if custom_mask is None:
13421343
# fix bug: XLA layout ({0}) does not match Mosaic layout ({0:T(128)}) for an operand of shape s32[0]
1343-
custom_mask = jnp.empty((1, 128))
1344+
custom_mask = jnp.empty((1, 128), dtype=jnp.int8)
1345+
else:
1346+
assert (
1347+
custom_mask.dtype != jnp.bool
1348+
), f"custom_mask bool dtype is not supported, use int32 instead. 0: False, 1: True"
13441349

13451350
grid = (distribution[2],)
13461351

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ def get_forward_metadata(self, batch: ModelWorkerBatch):
101101
page_indices = (selected_cache_locs // self.page_size).astype(np.int32)
102102

103103
if batch.forward_mode == ForwardMode.TARGET_VERIFY:
104-
metadata.custom_mask = batch.spec_info.custom_mask
104+
# convert custom_mask from bool to int8, because dma not support bool type
105+
if batch.spec_info.custom_mask.dtype == jnp.bool:
106+
metadata.custom_mask = batch.spec_info.custom_mask.astype(jnp.int8)
107+
else:
108+
metadata.custom_mask = batch.spec_info.custom_mask
105109
else:
106110
metadata.custom_mask = None
107111

python/sgl_jax/test/test_flashattention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def align_to_size(l, size, value=0):
247247
attention_backend = FlashAttentionBackend(
248248
num_heads, num_kv_heads, head_dim, page_size=page_size, mesh=mesh
249249
)
250-
print(f"!!!!!!!! {causal=}")
250+
251251
if not causal:
252252
forward_mode = ForwardMode.TARGET_VERIFY
253253
custom_mask = create_custom_mask(lens)

0 commit comments

Comments
 (0)