From 1b47975f5f34a731e48c922b08b13b8f83e094c4 Mon Sep 17 00:00:00 2001 From: p00465316 Date: Sat, 6 Sep 2025 15:12:42 +0800 Subject: [PATCH] replace npu_incre_flash_attention with npu_fused_infer_attention_score Signed-off-by: p00465316 --- vllm_ascend/torchair/torchair_attention.py | 25 ++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 81f2968a8e..18276d97c4 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -431,17 +431,24 @@ def forward( block_size = key_cache.shape[1] query = query.view(num_tokens, 1, self.num_heads * self.head_size).contiguous() - output = torch_npu.npu_incre_flash_attention( - query, - key_cache, - value_cache, - num_key_value_heads=self.num_kv_heads, + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key_cache, + value=value_cache, + query_rope=None, + key_rope=None, num_heads=self.num_heads, - actual_seq_lengths=seq_lens, - scale_value=self.scale, - block_table=block_table, + num_key_value_heads=self.num_kv_heads, input_layout='BSH', - block_size=block_size) + atten_mask=attn_metadata.attn_mask, + sparse_mode=0, + scale=self.scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens, + ) else: raise NotImplementedError( "Torchair graph mode with non-MLA attention backend is still experimental."