From fe2087e2067301b24371333f396f7e5d410c6f94 Mon Sep 17 00:00:00 2001 From: ldh2020 <62470572+ldh2020@users.noreply.github.com> Date: Fri, 6 Feb 2026 13:24:44 +0800 Subject: [PATCH] [Attention] optimize the build of attn_metadata Optimize the build of attn_metadata --- .../v1/attention/backends/kunlun_attn.py | 83 +++++++++++++++---- 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/vllm_kunlun/v1/attention/backends/kunlun_attn.py b/vllm_kunlun/v1/attention/backends/kunlun_attn.py index 2c97aa1..3f5e496 100644 --- a/vllm_kunlun/v1/attention/backends/kunlun_attn.py +++ b/vllm_kunlun/v1/attention/backends/kunlun_attn.py @@ -455,6 +455,73 @@ def build_for_cudagraph_capture( return attn_metadata def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + """build""" + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + common_prefix_len = common_prefix_len + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc_host = common_attn_metadata.query_start_loc_cpu + query_start_loc = common_attn_metadata.query_start_loc + + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + + # kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu") + # kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0) + # kv_lod_xpu = kv_lod_cpu.to(self.device) + kv_lod_cpu = None + kv_lod_xpu = None + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ + split_decodes_and_prefills(common_attn_metadata) + + num_scheduled_tokens = common_attn_metadata.query_start_loc[1:] - common_attn_metadata.query_start_loc[:-1] + + if num_decode_tokens == 0: + max_decode_seq_len = 0 + else: + tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes] + max_decode_seq_len = torch.max(tmp_decode_scheduled_tokens).item() + + if num_prefill_tokens == 0: + max_prefill_seq_len = 0 + else: + tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs] + max_prefill_seq_len = torch.max(tmp_prefill_scheduled_tokens).item() + + use_cascade = False + + attn_metadata = KunlunMetadata( + num_actual_tokens=num_actual_tokens, + num_prefills=num_prefills, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens_tensor=seq_lens, + seq_lens_tensor_cpu=seq_lens_cpu, + kv_lod_xpu=kv_lod_xpu, + kv_lod_cpu=kv_lod_cpu, + max_query_len=max_prefill_seq_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + query_start_loc_host=query_start_loc_host, + context_lens_tensor=None, + block_tables=block_table_tensor, + use_cuda_graph=False, + use_cascade=use_cascade, + ) + return attn_metadata + + def build_bak(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): """build""" num_reqs = common_attn_metadata.num_reqs @@ -645,7 +712,6 @@ def forward( # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory value = value.contiguous() - # if key_cache.is_contiguous(): kunlun_ops.reshape_and_cache_flash( key, value, @@ -653,19 +719,6 @@ def forward( value_cache, updated_slot_mapping, BLHD_LAYOUT=False) - # else: - # cast_key_cache = key_cache.squeeze(1).unsqueeze(-2) - # cast_value_cache = value_cache.squeeze(1).unsqueeze(-2) - # # print("key", key.shape) - # # print("value", value.shape) - # # print("cast_key_cache", cast_key_cache.shape) - # # print("cast_value_cache", key_cache.shape) - # kunlun_ops.reshape_and_cache_flash( - # key, - # value, - # cast_key_cache, - # cast_value_cache, - # updated_slot_mapping) assert attn_type == AttentionType.DECODER # Decoder self-attention supports chunked prefill. @@ -801,4 +854,4 @@ def use_cascade_attention( flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time \ No newline at end of file + return cascade_time < flash_decoding_time