Skip to content

Commit b9a0a75

Browse files
zhaozx-cnzhaozixin
andauthored
fix qwen torchair attention PrefillCacheHit (#2787)
### What this PR does / why we need it? Fix qwen torchair attention PrefillCacheHit ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? vLLM version: v0.10.1.1 vLLM main: vllm-project/vllm@e599e2c - vLLM version: main - vLLM main: vllm-project/vllm@0b9a612 Signed-off-by: zhaozixin <[email protected]> Co-authored-by: zhaozixin <[email protected]>
1 parent 7b2ecc1 commit b9a0a75

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

vllm_ascend/torchair/torchair_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ def forward(
374374
indices = torch.cat((block_indices, slots_indices), dim=1)
375375
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
376376
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
377+
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
378+
self.key_cache = key_cache
379+
self.value_cache = value_cache
377380

378381
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
379382
assert attn_metadata is not None
@@ -411,11 +414,13 @@ def forward(
411414
assert attn_metadata is not None
412415
assert attn_metadata.attn_mask is not None
413416
compress_mask = attn_metadata.attn_mask
417+
batch_size = attn_metadata.query_lens.shape[0]
418+
block_table = attn_metadata.block_tables[:batch_size, :]
414419
torch_npu._npu_flash_attention_qlens(
415420
query=query,
416421
key_cache=self.key_cache,
417422
value_cache=self.value_cache,
418-
block_table=attn_metadata.block_tables,
423+
block_table=block_table,
419424
mask=compress_mask,
420425
seq_len=attn_metadata.query_lens,
421426
context_lens=attn_metadata.seq_lens,

0 commit comments

Comments
 (0)