Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions vllm_ascend/torchair/torchair_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,24 @@ def forward(
block_size = key_cache.shape[1]
query = query.view(num_tokens, 1,
self.num_heads * self.head_size).contiguous()
output = torch_npu.npu_incre_flash_attention(
query,
key_cache,
value_cache,
num_key_value_heads=self.num_kv_heads,
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key_cache,
value=value_cache,
query_rope=None,
key_rope=None,
num_heads=self.num_heads,
actual_seq_lengths=seq_lens,
scale_value=self.scale,
block_table=block_table,
num_key_value_heads=self.num_kv_heads,
input_layout='BSH',
block_size=block_size)
atten_mask=attn_metadata.attn_mask,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency and correctness, atten_mask should be sourced from decode_meta like other parameters in this function call (e.g., block_table and seq_lens). In the DecodeOnly attention state, an attention mask is typically not required, and the previous API (npu_incre_flash_attention) did not accept one. Using decode_meta.attn_mask ensures that None is passed, preventing an unintended mask from being applied, which could happen if attn_metadata.attn_mask is not None.

Suggested change
atten_mask=attn_metadata.attn_mask,
atten_mask=decode_meta.attn_mask,

sparse_mode=0,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=seq_lens,
)
else:
raise NotImplementedError(
"Torchair graph mode with non-MLA attention backend is still experimental."
Expand Down
Loading