File tree Expand file tree Collapse file tree 1 file changed +16
-9
lines changed Expand file tree Collapse file tree 1 file changed +16
-9
lines changed Original file line number Diff line number Diff line change @@ -431,17 +431,24 @@ def forward(
431
431
block_size = key_cache .shape [1 ]
432
432
query = query .view (num_tokens , 1 ,
433
433
self .num_heads * self .head_size ).contiguous ()
434
- output = torch_npu .npu_incre_flash_attention (
435
- query ,
436
- key_cache ,
437
- value_cache ,
438
- num_key_value_heads = self .num_kv_heads ,
434
+ output , _ = torch_npu .npu_fused_infer_attention_score (
435
+ query = query ,
436
+ key = key_cache ,
437
+ value = value_cache ,
438
+ query_rope = None ,
439
+ key_rope = None ,
439
440
num_heads = self .num_heads ,
440
- actual_seq_lengths = seq_lens ,
441
- scale_value = self .scale ,
442
- block_table = block_table ,
441
+ num_key_value_heads = self .num_kv_heads ,
443
442
input_layout = 'BSH' ,
444
- block_size = block_size )
443
+ atten_mask = attn_metadata .attn_mask ,
444
+ sparse_mode = 0 ,
445
+ scale = self .scale ,
446
+ antiquant_mode = 0 ,
447
+ antiquant_scale = None ,
448
+ block_table = block_table ,
449
+ block_size = block_size ,
450
+ actual_seq_lengths_kv = seq_lens ,
451
+ )
445
452
else :
446
453
raise NotImplementedError (
447
454
"Torchair graph mode with non-MLA attention backend is still experimental."
You can’t perform that action at this time.
0 commit comments