Skip to content

Commit 90474db

Browse files
author
p00465316
committed
replace npu_incre_flash_attention with npu_fused_infer_attention_score
Signed-off-by: p00465316 <[email protected]>
1 parent a58b43b commit 90474db

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

vllm_ascend/torchair/torchair_attention.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,17 +431,24 @@ def forward(
431431
block_size = key_cache.shape[1]
432432
query = query.view(num_tokens, 1,
433433
self.num_heads * self.head_size).contiguous()
434-
output = torch_npu.npu_incre_flash_attention(
435-
query,
436-
key_cache,
437-
value_cache,
438-
num_key_value_heads=self.num_kv_heads,
434+
output, _ = torch_npu.npu_fused_infer_attention_score(
435+
query=query,
436+
key=key_cache,
437+
value=value_cache,
438+
query_rope=None,
439+
key_rope=None,
439440
num_heads=self.num_heads,
440-
actual_seq_lengths=seq_lens,
441-
scale_value=self.scale,
442-
block_table=block_table,
441+
num_key_value_heads=self.num_kv_heads,
443442
input_layout='BSH',
444-
block_size=block_size)
443+
atten_mask=attn_metadata.attn_mask,
444+
sparse_mode=0,
445+
scale=self.scale,
446+
antiquant_mode=0,
447+
antiquant_scale=None,
448+
block_table=block_table,
449+
block_size=block_size,
450+
actual_seq_lengths_kv=seq_lens,
451+
)
445452
else:
446453
raise NotImplementedError(
447454
"Torchair graph mode with non-MLA attention backend is still experimental."

0 commit comments

Comments
 (0)