Skip to content

Commit 9c3882b

Browse files
ChenTaoyu-SJTUoffline0806
authored andcommitted
Refector prepare_inputs in model_runner_v1.py (vllm-project#2750)
### What this PR does / why we need it? Refector prepare_inputs in model_runner_v1.py for more easy read. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? PASS CI - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@e599e2c --------- Signed-off-by: ChenTaoyu-SJTU <[email protected]> Signed-off-by: offline0806 <[email protected]>
1 parent 88a17c7 commit 9c3882b

File tree

1 file changed

+89
-56
lines changed

1 file changed

+89
-56
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 89 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,26 @@ def _gather_mm_embeddings(
893893
mm_embeds.append(mm_embeds_item)
894894
return mm_embeds
895895

896+
def _get_cumsum_and_arange(
897+
self,
898+
num_tokens: np.ndarray,
899+
cumsum_dtype: Optional[np.dtype] = None,
900+
) -> tuple[np.ndarray, np.ndarray]:
901+
"""Get the cumulative sum and batched arange of the given array.
902+
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
903+
# Equivalent to but faster than:
904+
# np.concatenate([np.arange(n) for n in num_tokens])
905+
"""
906+
# Step 1. [2, 5, 3] -> [2, 7, 10]
907+
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
908+
total_num_tokens = cu_num_tokens[-1]
909+
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
910+
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
911+
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
912+
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
913+
914+
return cu_num_tokens, arange
915+
896916
def _prepare_inputs(
897917
self,
898918
scheduler_output: "SchedulerOutput",
@@ -914,17 +934,16 @@ def _prepare_inputs(
914934
self.input_batch.block_table.commit_block_table(num_reqs)
915935

916936
# Get the number of scheduled tokens for each request.
917-
# TODO: The Python loop can be slow. Optimize.
918-
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
919-
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
920-
max_num_scheduled_tokens = 0
921-
for i, req_id in enumerate(self.input_batch.req_ids):
922-
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
923-
num_scheduled_tokens[i] = num_tokens
924-
num_valid_tokens[i] = num_tokens - \
925-
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
926-
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
927-
num_tokens)
937+
req_ids = self.input_batch.req_ids
938+
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
939+
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
940+
max_num_scheduled_tokens = max(tokens)
941+
num_valid_tokens = np.array([
942+
num_tokens -
943+
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
944+
for num_tokens, i in zip(tokens, req_ids)
945+
],
946+
dtype=np.int32)
928947

929948
if (self.use_aclgraph and total_num_scheduled_tokens
930949
<= self.aclgraph_batch_sizes[-1]):
@@ -965,13 +984,15 @@ def _prepare_inputs(
965984
if self.lora_config:
966985
self.set_active_loras(self.input_batch, num_scheduled_tokens)
967986

968-
# Prepare positions
987+
# Get request indices.
988+
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
969989
req_indices = np.repeat(self.arange_np[:num_reqs],
970990
num_scheduled_tokens)
971-
cu_num_tokens = np.cumsum(num_scheduled_tokens)
972-
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
973-
num_scheduled_tokens)
974-
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
991+
992+
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
993+
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
994+
cu_num_tokens, arange = self._get_cumsum_and_arange(
995+
num_scheduled_tokens)
975996

976997
positions_np = self.positions_np[:total_num_scheduled_tokens]
977998
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
@@ -988,50 +1009,73 @@ def _prepare_inputs(
9881009
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
9891010
non_blocking=True)
9901011

991-
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
992-
self.positions[:num_input_tokens].copy_(
993-
self.positions_cpu[:num_input_tokens], non_blocking=True)
994-
positions_cpu = self.positions_cpu[:num_input_tokens]
995-
positions = self.positions[:num_input_tokens]
996-
self.query_lens = torch.from_numpy(num_scheduled_tokens)
1012+
# Get token indices.
1013+
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1014+
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
1015+
# where M is the max_model_len.
1016+
token_indices = (positions_np +
1017+
req_indices * self.input_batch.token_ids_cpu.shape[1])
1018+
1019+
# Prepare input_ids.
1020+
# NOTE(woosuk): We use torch.index_select instead of np.take here
1021+
# because torch.index_select is much faster than np.take for large
1022+
# tensors.
1023+
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
1024+
0,
1025+
torch.from_numpy(token_indices),
1026+
out=self.input_ids_cpu[:total_num_scheduled_tokens])
1027+
1028+
# Prepare some information for building Attention-Metadata
1029+
# Compute and commit slot mapping
1030+
self.input_batch.block_table.compute_slot_mapping(
1031+
req_indices, positions_np)
1032+
self.input_batch.block_table.commit_slot_mapping(
1033+
total_num_scheduled_tokens)
1034+
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
1035+
self.input_batch.block_table[0].
1036+
slot_mapping_cpu[:total_num_scheduled_tokens])
1037+
1038+
self.query_start_loc_np[0] = 0
1039+
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
1040+
self.query_start_loc[:num_reqs + 1].copy_(
1041+
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
9971042

9981043
self.seq_lens_np[:num_reqs] = (
9991044
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
10001045
num_scheduled_tokens)
1001-
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
1046+
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
1047+
non_blocking=True)
10021048

1003-
block_table_indices = (req_indices * self.max_num_blocks_per_req +
1004-
positions_np // self.block_size)
1049+
# Fill unused with -1. Needed for reshape_and_cache
1050+
self.query_start_loc[num_reqs + 1:].fill_(-1)
1051+
self.seq_lens[num_reqs:].fill_(0)
10051052

1006-
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
1007-
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
1008-
block_offsets = positions_np % self.block_size
1009-
np.add(block_numbers * self.block_size,
1010-
block_offsets,
1011-
out=self.slot_mapping_np[:total_num_scheduled_tokens])
1053+
self.query_lens = torch.from_numpy(num_scheduled_tokens)
10121054

1055+
# Copy the tensors to the NPU.
1056+
self.input_ids[:total_num_scheduled_tokens].copy_(
1057+
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
1058+
1059+
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
1060+
self.positions[:num_input_tokens].copy_(
1061+
self.positions_cpu[:num_input_tokens], non_blocking=True)
1062+
1063+
# Make Attention metadata
1064+
positions_cpu = self.positions_cpu[:num_input_tokens]
1065+
positions = self.positions[:num_input_tokens]
1066+
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
10131067
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
10141068
num_valid_tokens)
1015-
10161069
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
10171070
position=positions_cpu,
10181071
attn_state=attn_state)
10191072
self.attn_state = attn_state # type: ignore
10201073

1021-
self.query_start_loc_np[0] = 0
1022-
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
1023-
self.query_start_loc[:num_reqs + 1].copy_(
1024-
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
1025-
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
1026-
non_blocking=True)
1027-
1028-
# Fill unused with -1. Needed for reshape_and_cache
1029-
self.seq_lens[num_reqs:].fill_(0)
1030-
self.query_start_loc[num_reqs + 1:].fill_(-1)
1031-
10321074
self.with_prefill = with_prefill
10331075
self.num_tokens_across_dp = num_tokens_across_dp
10341076
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
1077+
1078+
# Make AscendCommonAttentionMetadata
10351079
common_attn_metadata = AscendCommonAttentionMetadata(
10361080
query_start_loc=self.query_start_loc[:num_reqs + 1],
10371081
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -1057,19 +1101,8 @@ def _prepare_inputs(
10571101
if self.vllm_config.model_config.use_mla:
10581102
attn_metadata.num_input_tokens = num_input_tokens
10591103

1060-
# Prepare input_ids
1061-
token_indices = (positions_np +
1062-
req_indices * self.input_batch.token_ids_cpu.shape[1])
1063-
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
1064-
0,
1065-
torch.from_numpy(token_indices),
1066-
out=self.input_ids_cpu[:total_num_scheduled_tokens])
1067-
# Copy the tensors to the NPU.
1068-
self.input_ids[:total_num_scheduled_tokens].copy_(
1069-
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
1070-
1071-
# _prepare_inputs may reorder the batch, so we must gather multi
1072-
# modal outputs after that to ensure the correct order
1104+
# _prepare_inputs may reorder the batch, so we must gather
1105+
# multi-modal outputs after that to ensure the correct order
10731106
if self.is_multimodal_model:
10741107
# Run the multimodal encoder if any.
10751108
self._execute_mm_encoder(scheduler_output)

0 commit comments

Comments
 (0)