@@ -975,7 +975,7 @@ def concatenate(
975975 valid_indices = None ,
976976 )
977977
978- def prepare_for_decode (self , dtype , use_contiguous_pa , bucketing_ctx ):
978+ def prepare_for_decode (self , dtype , use_contiguous_pa , bucketing_ctx , pad_token_id ):
979979 block_num = [length // BLOCK_SIZE + 1 for length in self .cache_lengths ]
980980 block_tables = []
981981 for i , bt in enumerate (self .block_tables ):
@@ -998,7 +998,7 @@ def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
998998 bucketing_ctx ,
999999 )
10001000 self .input_ids = F .pad (
1001- self .input_ids , (0 , padded_bs - self .input_ids .shape [0 ]), value = 0
1001+ self .input_ids , (0 , padded_bs - self .input_ids .shape [0 ]), value = pad_token_id
10021002 )
10031003
10041004 if self .position_ids .dim () == 2 :
@@ -1040,7 +1040,7 @@ def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
10401040 )
10411041
10421042 def prepare_for_prefill (
1043- self , max_padded_input_len , max_padded_bs , max_total_tokens
1043+ self , max_padded_input_len , max_padded_bs , max_total_tokens , pad_token_id
10441044 ):
10451045 # Prepare values if we need to continue prefilling
10461046 # Speculation must be ignored while we prefill even with chunking
@@ -1064,18 +1064,23 @@ def prepare_for_prefill(
10641064 for input_id in self .input_ids :
10651065 padded = self .max_input_length - len (input_id ) + extra_pad
10661066 if padded > 0 :
1067- input_id = [0 ] * padded + input_id
1067+ input_id = [pad_token_id ] * padded + input_id
10681068 input_ids .append (input_id )
10691069 input_ids_padded_length .append (padded )
10701070 input_ids = np .concatenate (input_ids , dtype = np .int64 )
10711071 self .input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = device )
10721072 elif isinstance (self .input_ids , list ):
10731073 input_ids = self .input_ids [0 ]
10741074 input_ids_padded_length .append (extra_pad )
1075- input_ids = [0 ] * extra_pad + input_ids
1075+ input_ids = [pad_token_id ] * extra_pad + input_ids
10761076 self .input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = device )
10771077 else :
1078- input_ids = self .input_ids .new_zeros (max_padded_input_len * len (self ))
1078+ input_ids = torch .full (
1079+ (max_padded_input_len * len (self ),),
1080+ pad_token_id ,
1081+ dtype = torch .int64 ,
1082+ device = self .input_ids .device ,
1083+ )
10791084 src_pos = 0
10801085 for i in range (len (self )):
10811086 end_pos = (i + 1 ) * max_padded_input_len
@@ -1090,7 +1095,7 @@ def prepare_for_prefill(
10901095 self .input_ids = input_ids
10911096
10921097 self .input_ids = F .pad (
1093- self .input_ids , (0 , extra_pad_bs * max_padded_input_len ), value = 0
1098+ self .input_ids , (0 , extra_pad_bs * max_padded_input_len ), value = pad_token_id
10941099 )
10951100
10961101 self .input_lengths_tensor = torch .tensor (self .input_lengths , dtype = torch .int32 )
@@ -1312,8 +1317,9 @@ def prepare_for_prefill(
13121317 self .prefill_next_token_indices = (
13131318 self .prefill_next_token_indices + input_ids_padded_length_tensor
13141319 )
1315- all_input_ids_tensor = torch .zeros (
1320+ all_input_ids_tensor = torch .full (
13161321 (max_padded_bs , max (max_total_tokens , self .all_input_ids_tensor .shape [- 1 ])),
1322+ pad_token_id ,
13171323 dtype = torch .int64 ,
13181324 device = "hpu" ,
13191325 )
@@ -1502,6 +1508,19 @@ def __init__(
15021508 )
15031509 self .skip_warmup = os .getenv ("VLLM_SKIP_WARMUP" , "false" ).lower () == "true"
15041510 self .max_seq_len_to_capture = 8192
1511+ if tokenizer .pad_token_id is None :
1512+ if config .pad_token_id is not None :
1513+ tokenizer .pad_token_id = config .pad_token_id
1514+ elif config .eos_token_id is not None :
1515+ tokenizer .pad_token_id = (
1516+ config .eos_token_id [0 ]
1517+ if isinstance (config .eos_token_id , list )
1518+ else config .eos_token_id
1519+ )
1520+ elif tokenizer .eos_token_id is not None :
1521+ tokenizer .pad_token_id = tokenizer .eos_token_id
1522+ else :
1523+ tokenizer .pad_token_id = 0
15051524 super ().__init__ (
15061525 model_id = model_id ,
15071526 model = model ,
@@ -2274,14 +2293,21 @@ def generate_token(
22742293 ),
22752294 self .bucketing_ctx .get_padded_prompt_batch_size (len (batch )),
22762295 self .max_total_tokens ,
2296+ self .tokenizer .pad_token_id ,
22772297 )
22782298 else :
22792299 batch .prepare_for_prefill (
2280- batch .max_input_length , len (batch ), self .max_total_tokens
2300+ batch .max_input_length ,
2301+ len (batch ),
2302+ self .max_total_tokens ,
2303+ self .tokenizer .pad_token_id ,
22812304 )
22822305 else :
22832306 batch .prepare_for_decode (
2284- self .dtype , self .use_contiguous_pa , self .bucketing_ctx
2307+ self .dtype ,
2308+ self .use_contiguous_pa ,
2309+ self .bucketing_ctx ,
2310+ self .tokenizer .pad_token_id ,
22852311 )
22862312 if hasattr (self , "set_inputs_embeds" ) and callable (self .set_inputs_embeds ):
22872313 self .set_inputs_embeds (batch )
0 commit comments