File tree Expand file tree Collapse file tree 1 file changed +17
-9
lines changed Expand file tree Collapse file tree 1 file changed +17
-9
lines changed Original file line number Diff line number Diff line change @@ -431,17 +431,25 @@ def forward(
431
431
block_size = key_cache .shape [1 ]
432
432
query = query .view (num_tokens , 1 ,
433
433
self .num_heads * self .head_size ).contiguous ()
434
- output = torch_npu .npu_incre_flash_attention (
435
- query ,
436
- key_cache ,
437
- value_cache ,
438
- num_key_value_heads = self .num_kv_heads ,
434
+
435
+ output , _ = torch_npu .npu_fused_infer_attention_score (
436
+ query = query ,
437
+ key = key_cache ,
438
+ value = value_cache ,
439
+ query_rope = None ,
440
+ key_rope = None ,
439
441
num_heads = self .num_heads ,
440
- actual_seq_lengths = seq_lens ,
441
- scale_value = self .scale ,
442
- block_table = block_table ,
442
+ num_key_value_heads = self .num_kv_heads ,
443
443
input_layout = 'BSH' ,
444
- block_size = block_size )
444
+ atten_mask = attn_metadata .attn_mask ,
445
+ sparse_mode = 0 ,
446
+ scale = self .scale ,
447
+ antiquant_mode = 0 ,
448
+ antiquant_scale = None ,
449
+ block_table = block_table ,
450
+ block_size = block_size ,
451
+ actual_seq_lengths_kv = seq_lens ,
452
+ )
445
453
else :
446
454
raise NotImplementedError (
447
455
"Torchair graph mode with non-MLA attention backend is still experimental."
You can’t perform that action at this time.
0 commit comments