Skip to content

Commit 8b8bd8d

Browse files
author
p00465316
committed
replace npu_incre_flash_attention with npu_fused_infer_attention_score
Signed-off-by: p00465316 <[email protected]>
1 parent a58b43b commit 8b8bd8d

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

vllm_ascend/torchair/torchair_attention.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,17 +431,25 @@ def forward(
431431
block_size = key_cache.shape[1]
432432
query = query.view(num_tokens, 1,
433433
self.num_heads * self.head_size).contiguous()
434-
output = torch_npu.npu_incre_flash_attention(
435-
query,
436-
key_cache,
437-
value_cache,
438-
num_key_value_heads=self.num_kv_heads,
434+
435+
output, _ = torch_npu.npu_fused_infer_attention_score(
436+
query=query,
437+
key=key_cache,
438+
value=value_cache,
439+
query_rope=None,
440+
key_rope=None,
439441
num_heads=self.num_heads,
440-
actual_seq_lengths=seq_lens,
441-
scale_value=self.scale,
442-
block_table=block_table,
442+
num_key_value_heads=self.num_kv_heads,
443443
input_layout='BSH',
444-
block_size=block_size)
444+
atten_mask=attn_metadata.attn_mask,
445+
sparse_mode=0,
446+
scale=self.scale,
447+
antiquant_mode=0,
448+
antiquant_scale=None,
449+
block_table=block_table,
450+
block_size=block_size,
451+
actual_seq_lengths_kv=seq_lens,
452+
)
445453
else:
446454
raise NotImplementedError(
447455
"Torchair graph mode with non-MLA attention backend is still experimental."

0 commit comments

Comments
 (0)