Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm_ascend/torchair/torchair_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ def forward(
indices = torch.cat((block_indices, slots_indices), dim=1)
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
self.key_cache = key_cache
self.value_cache = value_cache
Comment on lines +377 to +379
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The self.key_cache and self.value_cache are used by the PrefillCacheHit attention state. This change correctly updates them when the state is PrefillCacheHit. However, the cache update in lines 375-376 happens for any state that has a kv_cache (e.g., DecodeOnly as well). To make the logic more robust and prevent potential issues if other states start using self.key_cache in the future, it would be better to update self.key_cache and self.value_cache unconditionally whenever the cache is modified. This ensures that self.key_cache and self.value_cache always reflect the latest state of the cache tensors passed into this forward pass.

            self.key_cache = key_cache
            self.value_cache = value_cache


if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
Expand Down Expand Up @@ -411,11 +414,13 @@ def forward(
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=attn_metadata.block_tables,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
Expand Down
Loading