Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions vllm_kunlun/v1/attention/backends/kunlun_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,73 @@ def build_for_cudagraph_capture(
return attn_metadata

def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
"""build"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
common_prefix_len = common_prefix_len
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping


max_seq_len = common_attn_metadata.max_seq_len
query_start_loc_host = common_attn_metadata.query_start_loc_cpu
query_start_loc = common_attn_metadata.query_start_loc

seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu

# kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
# kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
# kv_lod_xpu = kv_lod_cpu.to(self.device)
kv_lod_cpu = None
kv_lod_xpu = None

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)

num_scheduled_tokens = common_attn_metadata.query_start_loc[1:] - common_attn_metadata.query_start_loc[:-1]

if num_decode_tokens == 0:
max_decode_seq_len = 0
else:
tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes]
max_decode_seq_len = torch.max(tmp_decode_scheduled_tokens).item()

if num_prefill_tokens == 0:
max_prefill_seq_len = 0
else:
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
max_prefill_seq_len = torch.max(tmp_prefill_scheduled_tokens).item()

use_cascade = False

attn_metadata = KunlunMetadata(
num_actual_tokens=num_actual_tokens,
num_prefills=num_prefills,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens_tensor=seq_lens,
seq_lens_tensor_cpu=seq_lens_cpu,
kv_lod_xpu=kv_lod_xpu,
kv_lod_cpu=kv_lod_cpu,
max_query_len=max_prefill_seq_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
query_start_loc_host=query_start_loc_host,
context_lens_tensor=None,
block_tables=block_table_tensor,
use_cuda_graph=False,
use_cascade=use_cascade,
)
return attn_metadata

def build_bak(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
"""build"""
num_reqs = common_attn_metadata.num_reqs
Expand Down Expand Up @@ -645,27 +712,13 @@ def forward(
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
value = value.contiguous()
# if key_cache.is_contiguous():
kunlun_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
updated_slot_mapping,
BLHD_LAYOUT=False)
# else:
# cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
# cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
# # print("key", key.shape)
# # print("value", value.shape)
# # print("cast_key_cache", cast_key_cache.shape)
# # print("cast_value_cache", key_cache.shape)
# kunlun_ops.reshape_and_cache_flash(
# key,
# value,
# cast_key_cache,
# cast_value_cache,
# updated_slot_mapping)

assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.
Expand Down Expand Up @@ -801,4 +854,4 @@ def use_cascade_attention(
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

# Use cascade attention if it is faster than FlashDecoding.
return cascade_time < flash_decoding_time
return cascade_time < flash_decoding_time