@@ -893,6 +893,26 @@ def _gather_mm_embeddings(
893893 mm_embeds .append (mm_embeds_item )
894894 return mm_embeds
895895
896+ def _get_cumsum_and_arange (
897+ self ,
898+ num_tokens : np .ndarray ,
899+ cumsum_dtype : Optional [np .dtype ] = None ,
900+ ) -> tuple [np .ndarray , np .ndarray ]:
901+ """Get the cumulative sum and batched arange of the given array.
902+ # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
903+ # Equivalent to but faster than:
904+ # np.concatenate([np.arange(n) for n in num_tokens])
905+ """
906+ # Step 1. [2, 5, 3] -> [2, 7, 10]
907+ cu_num_tokens = np .cumsum (num_tokens , dtype = cumsum_dtype )
908+ total_num_tokens = cu_num_tokens [- 1 ]
909+ # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
910+ cumsums_offsets = np .repeat (cu_num_tokens - num_tokens , num_tokens )
911+ # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
912+ arange = self .arange_np [:total_num_tokens ] - cumsums_offsets
913+
914+ return cu_num_tokens , arange
915+
896916 def _prepare_inputs (
897917 self ,
898918 scheduler_output : "SchedulerOutput" ,
@@ -914,17 +934,16 @@ def _prepare_inputs(
914934 self .input_batch .block_table .commit_block_table (num_reqs )
915935
916936 # Get the number of scheduled tokens for each request.
917- # TODO: The Python loop can be slow. Optimize.
918- num_scheduled_tokens = np .empty (num_reqs , dtype = np .int32 )
919- num_valid_tokens = np .empty (num_reqs , dtype = np .int32 )
920- max_num_scheduled_tokens = 0
921- for i , req_id in enumerate (self .input_batch .req_ids ):
922- num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
923- num_scheduled_tokens [i ] = num_tokens
924- num_valid_tokens [i ] = num_tokens - \
925- len (scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
926- max_num_scheduled_tokens = max (max_num_scheduled_tokens ,
927- num_tokens )
937+ req_ids = self .input_batch .req_ids
938+ tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
939+ num_scheduled_tokens = np .array (tokens , dtype = np .int32 )
940+ max_num_scheduled_tokens = max (tokens )
941+ num_valid_tokens = np .array ([
942+ num_tokens -
943+ len (scheduler_output .scheduled_spec_decode_tokens .get (i , []))
944+ for num_tokens , i in zip (tokens , req_ids )
945+ ],
946+ dtype = np .int32 )
928947
929948 if (self .use_aclgraph and total_num_scheduled_tokens
930949 <= self .aclgraph_batch_sizes [- 1 ]):
@@ -965,13 +984,15 @@ def _prepare_inputs(
965984 if self .lora_config :
966985 self .set_active_loras (self .input_batch , num_scheduled_tokens )
967986
968- # Prepare positions
987+ # Get request indices.
988+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
969989 req_indices = np .repeat (self .arange_np [:num_reqs ],
970990 num_scheduled_tokens )
971- cu_num_tokens = np .cumsum (num_scheduled_tokens )
972- cumsums_offsets = np .repeat (cu_num_tokens - num_scheduled_tokens ,
973- num_scheduled_tokens )
974- arange = self .arange_np [:total_num_scheduled_tokens ] - cumsums_offsets
991+
992+ # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
993+ # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
994+ cu_num_tokens , arange = self ._get_cumsum_and_arange (
995+ num_scheduled_tokens )
975996
976997 positions_np = self .positions_np [:total_num_scheduled_tokens ]
977998 np .add (self .input_batch .num_computed_tokens_cpu [req_indices ],
@@ -988,50 +1009,73 @@ def _prepare_inputs(
9881009 self .mrope_positions_cpu [:, :total_num_scheduled_tokens ],
9891010 non_blocking = True )
9901011
991- self .positions_cpu [total_num_scheduled_tokens :num_input_tokens ].zero_ ()
992- self .positions [:num_input_tokens ].copy_ (
993- self .positions_cpu [:num_input_tokens ], non_blocking = True )
994- positions_cpu = self .positions_cpu [:num_input_tokens ]
995- positions = self .positions [:num_input_tokens ]
996- self .query_lens = torch .from_numpy (num_scheduled_tokens )
1012+ # Get token indices.
1013+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1014+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
1015+ # where M is the max_model_len.
1016+ token_indices = (positions_np +
1017+ req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1018+
1019+ # Prepare input_ids.
1020+ # NOTE(woosuk): We use torch.index_select instead of np.take here
1021+ # because torch.index_select is much faster than np.take for large
1022+ # tensors.
1023+ torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1024+ 0 ,
1025+ torch .from_numpy (token_indices ),
1026+ out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1027+
1028+ # Prepare some information for building Attention-Metadata
1029+ # Compute and commit slot mapping
1030+ self .input_batch .block_table .compute_slot_mapping (
1031+ req_indices , positions_np )
1032+ self .input_batch .block_table .commit_slot_mapping (
1033+ total_num_scheduled_tokens )
1034+ self .slot_mapping_cpu [:total_num_scheduled_tokens ].copy_ (
1035+ self .input_batch .block_table [0 ].
1036+ slot_mapping_cpu [:total_num_scheduled_tokens ])
1037+
1038+ self .query_start_loc_np [0 ] = 0
1039+ self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
1040+ self .query_start_loc [:num_reqs + 1 ].copy_ (
1041+ self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
9971042
9981043 self .seq_lens_np [:num_reqs ] = (
9991044 self .input_batch .num_computed_tokens_cpu [:num_reqs ] +
10001045 num_scheduled_tokens )
1001- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ]
1046+ self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
1047+ non_blocking = True )
10021048
1003- block_table_indices = (req_indices * self .max_num_blocks_per_req +
1004- positions_np // self .block_size )
1049+ # Fill unused with -1. Needed for reshape_and_cache
1050+ self .query_start_loc [num_reqs + 1 :].fill_ (- 1 )
1051+ self .seq_lens [num_reqs :].fill_ (0 )
10051052
1006- block_table_cpu = self .input_batch .block_table [0 ].get_cpu_tensor ()
1007- block_numbers = block_table_cpu .flatten ()[block_table_indices ].numpy ()
1008- block_offsets = positions_np % self .block_size
1009- np .add (block_numbers * self .block_size ,
1010- block_offsets ,
1011- out = self .slot_mapping_np [:total_num_scheduled_tokens ])
1053+ self .query_lens = torch .from_numpy (num_scheduled_tokens )
10121054
1055+ # Copy the tensors to the NPU.
1056+ self .input_ids [:total_num_scheduled_tokens ].copy_ (
1057+ self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
1058+
1059+ self .positions_cpu [total_num_scheduled_tokens :num_input_tokens ].zero_ ()
1060+ self .positions [:num_input_tokens ].copy_ (
1061+ self .positions_cpu [:num_input_tokens ], non_blocking = True )
1062+
1063+ # Make Attention metadata
1064+ positions_cpu = self .positions_cpu [:num_input_tokens ]
1065+ positions = self .positions [:num_input_tokens ]
1066+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ]
10131067 attn_state = self ._build_attn_state (num_reqs , num_scheduled_tokens ,
10141068 num_valid_tokens )
1015-
10161069 self .attn_mask = self ._make_attention_mask (seq_lens = seq_lens_cpu ,
10171070 position = positions_cpu ,
10181071 attn_state = attn_state )
10191072 self .attn_state = attn_state # type: ignore
10201073
1021- self .query_start_loc_np [0 ] = 0
1022- self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
1023- self .query_start_loc [:num_reqs + 1 ].copy_ (
1024- self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
1025- self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
1026- non_blocking = True )
1027-
1028- # Fill unused with -1. Needed for reshape_and_cache
1029- self .seq_lens [num_reqs :].fill_ (0 )
1030- self .query_start_loc [num_reqs + 1 :].fill_ (- 1 )
1031-
10321074 self .with_prefill = with_prefill
10331075 self .num_tokens_across_dp = num_tokens_across_dp
10341076 self ._update_graph_pad_size (with_prefill , maybe_padded_num_tokens )
1077+
1078+ # Make AscendCommonAttentionMetadata
10351079 common_attn_metadata = AscendCommonAttentionMetadata (
10361080 query_start_loc = self .query_start_loc [:num_reqs + 1 ],
10371081 query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
@@ -1057,19 +1101,8 @@ def _prepare_inputs(
10571101 if self .vllm_config .model_config .use_mla :
10581102 attn_metadata .num_input_tokens = num_input_tokens
10591103
1060- # Prepare input_ids
1061- token_indices = (positions_np +
1062- req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1063- torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1064- 0 ,
1065- torch .from_numpy (token_indices ),
1066- out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1067- # Copy the tensors to the NPU.
1068- self .input_ids [:total_num_scheduled_tokens ].copy_ (
1069- self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
1070-
1071- # _prepare_inputs may reorder the batch, so we must gather multi
1072- # modal outputs after that to ensure the correct order
1104+ # _prepare_inputs may reorder the batch, so we must gather
1105+ # multi-modal outputs after that to ensure the correct order
10731106 if self .is_multimodal_model :
10741107 # Run the multimodal encoder if any.
10751108 self ._execute_mm_encoder (scheduler_output )
0 commit comments