@@ -880,6 +880,26 @@ def _gather_mm_embeddings(
880880 mm_embeds .append (mm_embeds_item )
881881 return mm_embeds
882882
883+ def _get_cumsum_and_arange (
884+ self ,
885+ num_tokens : np .ndarray ,
886+ cumsum_dtype : Optional [np .dtype ] = None ,
887+ ) -> tuple [np .ndarray , np .ndarray ]:
888+ """Get the cumulative sum and batched arange of the given array.
889+ # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
890+ # Equivalent to but faster than:
891+ # np.concatenate([np.arange(n) for n in num_tokens])
892+ """
893+ # Step 1. [2, 5, 3] -> [2, 7, 10]
894+ cu_num_tokens = np .cumsum (num_tokens , dtype = cumsum_dtype )
895+ total_num_tokens = cu_num_tokens [- 1 ]
896+ # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
897+ cumsums_offsets = np .repeat (cu_num_tokens - num_tokens , num_tokens )
898+ # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
899+ arange = self .arange_np [:total_num_tokens ] - cumsums_offsets
900+
901+ return cu_num_tokens , arange
902+
883903 def _prepare_inputs (
884904 self ,
885905 scheduler_output : "SchedulerOutput" ,
@@ -901,17 +921,16 @@ def _prepare_inputs(
901921 self .input_batch .block_table .commit_block_table (num_reqs )
902922
903923 # Get the number of scheduled tokens for each request.
904- # TODO: The Python loop can be slow. Optimize.
905- num_scheduled_tokens = np .empty (num_reqs , dtype = np .int32 )
906- num_valid_tokens = np .empty (num_reqs , dtype = np .int32 )
907- max_num_scheduled_tokens = 0
908- for i , req_id in enumerate (self .input_batch .req_ids ):
909- num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
910- num_scheduled_tokens [i ] = num_tokens
911- num_valid_tokens [i ] = num_tokens - \
912- len (scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
913- max_num_scheduled_tokens = max (max_num_scheduled_tokens ,
914- num_tokens )
924+ req_ids = self .input_batch .req_ids
925+ tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
926+ num_scheduled_tokens = np .array (tokens , dtype = np .int32 )
927+ max_num_scheduled_tokens = max (tokens )
928+ num_valid_tokens = np .array ([
929+ num_tokens -
930+ len (scheduler_output .scheduled_spec_decode_tokens .get (i , []))
931+ for num_tokens , i in zip (tokens , req_ids )
932+ ],
933+ dtype = np .int32 )
915934
916935 if (self .use_aclgraph and total_num_scheduled_tokens
917936 <= self .aclgraph_batch_sizes [- 1 ]):
@@ -952,13 +971,15 @@ def _prepare_inputs(
952971 if self .lora_config :
953972 self .set_active_loras (self .input_batch , num_scheduled_tokens )
954973
955- # Prepare positions
974+ # Get request indices.
975+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
956976 req_indices = np .repeat (self .arange_np [:num_reqs ],
957977 num_scheduled_tokens )
958- cu_num_tokens = np .cumsum (num_scheduled_tokens )
959- cumsums_offsets = np .repeat (cu_num_tokens - num_scheduled_tokens ,
960- num_scheduled_tokens )
961- arange = self .arange_np [:total_num_scheduled_tokens ] - cumsums_offsets
978+
979+ # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
980+ # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
981+ cu_num_tokens , arange = self ._get_cumsum_and_arange (
982+ num_scheduled_tokens )
962983
963984 positions_np = self .positions_np [:total_num_scheduled_tokens ]
964985 np .add (self .input_batch .num_computed_tokens_cpu [req_indices ],
@@ -975,50 +996,73 @@ def _prepare_inputs(
975996 self .mrope_positions_cpu [:, :total_num_scheduled_tokens ],
976997 non_blocking = True )
977998
978- self .positions_cpu [total_num_scheduled_tokens :num_input_tokens ].zero_ ()
979- self .positions [:num_input_tokens ].copy_ (
980- self .positions_cpu [:num_input_tokens ], non_blocking = True )
981- positions_cpu = self .positions_cpu [:num_input_tokens ]
982- positions = self .positions [:num_input_tokens ]
983- self .query_lens = torch .from_numpy (num_scheduled_tokens )
999+ # Get token indices.
1000+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1001+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
1002+ # where M is the max_model_len.
1003+ token_indices = (positions_np +
1004+ req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1005+
1006+ # Prepare input_ids.
1007+ # NOTE(woosuk): We use torch.index_select instead of np.take here
1008+ # because torch.index_select is much faster than np.take for large
1009+ # tensors.
1010+ torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1011+ 0 ,
1012+ torch .from_numpy (token_indices ),
1013+ out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1014+
1015+ # Prepare some information for building Attention-Metadata
1016+ # Compute and commit slot mapping
1017+ self .input_batch .block_table .compute_slot_mapping (
1018+ req_indices , positions_np )
1019+ self .input_batch .block_table .commit_slot_mapping (
1020+ total_num_scheduled_tokens )
1021+ self .slot_mapping_cpu [:total_num_scheduled_tokens ].copy_ (
1022+ self .input_batch .block_table [0 ].
1023+ slot_mapping_cpu [:total_num_scheduled_tokens ])
1024+
1025+ self .query_start_loc_np [0 ] = 0
1026+ self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
1027+ self .query_start_loc [:num_reqs + 1 ].copy_ (
1028+ self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
9841029
9851030 self .seq_lens_np [:num_reqs ] = (
9861031 self .input_batch .num_computed_tokens_cpu [:num_reqs ] +
9871032 num_scheduled_tokens )
988- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ]
1033+ self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
1034+ non_blocking = True )
9891035
990- block_table_indices = (req_indices * self .max_num_blocks_per_req +
991- positions_np // self .block_size )
1036+ # Fill unused with -1. Needed for reshape_and_cache
1037+ self .query_start_loc [num_reqs + 1 :].fill_ (- 1 )
1038+ self .seq_lens [num_reqs :].fill_ (0 )
9921039
993- block_table_cpu = self .input_batch .block_table [0 ].get_cpu_tensor ()
994- block_numbers = block_table_cpu .flatten ()[block_table_indices ].numpy ()
995- block_offsets = positions_np % self .block_size
996- np .add (block_numbers * self .block_size ,
997- block_offsets ,
998- out = self .slot_mapping_np [:total_num_scheduled_tokens ])
1040+ self .query_lens = torch .from_numpy (num_scheduled_tokens )
9991041
1042+ # Copy the tensors to the NPU.
1043+ self .input_ids [:total_num_scheduled_tokens ].copy_ (
1044+ self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
1045+
1046+ self .positions_cpu [total_num_scheduled_tokens :num_input_tokens ].zero_ ()
1047+ self .positions [:num_input_tokens ].copy_ (
1048+ self .positions_cpu [:num_input_tokens ], non_blocking = True )
1049+
1050+ # Make Attention metadata
1051+ positions_cpu = self .positions_cpu [:num_input_tokens ]
1052+ positions = self .positions [:num_input_tokens ]
1053+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ]
10001054 attn_state = self ._build_attn_state (num_reqs , num_scheduled_tokens ,
10011055 num_valid_tokens )
1002-
10031056 self .attn_mask = self ._make_attention_mask (seq_lens = seq_lens_cpu ,
10041057 position = positions_cpu ,
10051058 attn_state = attn_state )
10061059 self .attn_state = attn_state # type: ignore
10071060
1008- self .query_start_loc_np [0 ] = 0
1009- self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
1010- self .query_start_loc [:num_reqs + 1 ].copy_ (
1011- self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
1012- self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
1013- non_blocking = True )
1014-
1015- # Fill unused with -1. Needed for reshape_and_cache
1016- self .seq_lens [num_reqs :].fill_ (0 )
1017- self .query_start_loc [num_reqs + 1 :].fill_ (- 1 )
1018-
10191061 self .with_prefill = with_prefill
10201062 self .num_tokens_across_dp = num_tokens_across_dp
10211063 self ._update_graph_pad_size (with_prefill , maybe_padded_num_tokens )
1064+
1065+ # Make AscendCommonAttentionMetadata
10221066 common_attn_metadata = AscendCommonAttentionMetadata (
10231067 query_start_loc = self .query_start_loc [:num_reqs + 1 ],
10241068 query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
@@ -1044,19 +1088,8 @@ def _prepare_inputs(
10441088 if self .vllm_config .model_config .use_mla :
10451089 attn_metadata .num_input_tokens = num_input_tokens
10461090
1047- # Prepare input_ids
1048- token_indices = (positions_np +
1049- req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1050- torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1051- 0 ,
1052- torch .from_numpy (token_indices ),
1053- out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1054- # Copy the tensors to the NPU.
1055- self .input_ids [:total_num_scheduled_tokens ].copy_ (
1056- self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
1057-
1058- # _prepare_inputs may reorder the batch, so we must gather multi
1059- # modal outputs after that to ensure the correct order
1091+ # _prepare_inputs may reorder the batch, so we must gather
1092+ # multi-modal outputs after that to ensure the correct order
10601093 if self .is_multimodal_model :
10611094 # Run the multimodal encoder if any.
10621095 self ._execute_mm_encoder (scheduler_output )
0 commit comments