@@ -37,6 +37,7 @@ class CausalLMBatch(Batch):
3737 # Metadata used for padding
3838 size : int
3939 max_sequence_length : int
40+ padding_right_offset : int
4041
4142 # Past metadata
4243 keys_head_dim_last : bool = True
@@ -61,47 +62,67 @@ def from_pb(
6162 input_lengths = []
6263
6364 # Parse batch
65+ max_sequence_length = 0
66+ padding_right_offset = 0
6467 for r in pb .requests :
6568 inputs .append (r .inputs )
6669 input_lengths .append (r .input_length )
6770 next_token_choosers .append (NextTokenChooser .from_pb (r .parameters , device ))
68- stopping_criterias .append (
69- StoppingCriteria .from_pb (r .stopping_parameters , tokenizer )
71+ stopping_criteria = StoppingCriteria .from_pb (
72+ r .stopping_parameters , tokenizer
73+ )
74+ stopping_criterias .append (stopping_criteria )
75+ max_sequence_length = max (max_sequence_length , r .input_length )
76+ padding_right_offset = max (
77+ padding_right_offset , stopping_criteria .max_new_tokens
7078 )
7179
72- pad_to_multiple_of = 8 if device .type == "cuda" else None
7380 tokenized_inputs = tokenizer (
7481 inputs ,
7582 return_tensors = "pt" ,
7683 padding = True ,
77- pad_to_multiple_of = pad_to_multiple_of ,
7884 return_token_type_ids = False ,
7985 ).to (device )
86+
87+ input_ids = tokenized_inputs ["input_ids" ]
88+ # Allocate maximum attention_mask
89+ attention_mask = input_ids .new_zeros (
90+ (pb .size , max_sequence_length + padding_right_offset )
91+ )
92+ # Copy tokenizer attention_mask into fully allocated attention_mask
93+ attention_mask [:, :max_sequence_length ] = tokenized_inputs ["attention_mask" ]
94+
8095 position_ids = tokenized_inputs ["attention_mask" ].long ().cumsum (- 1 ) - 1
8196 position_ids .masked_fill_ (tokenized_inputs ["attention_mask" ] == 0 , 1 )
8297 all_input_ids = tokenized_inputs ["input_ids" ].unsqueeze (- 1 )
8398
8499 return cls (
85100 batch_id = pb .id ,
86101 requests = pb .requests ,
87- input_ids = tokenized_inputs [ " input_ids" ] ,
88- attention_mask = tokenized_inputs [ " attention_mask" ] ,
102+ input_ids = input_ids ,
103+ attention_mask = attention_mask ,
89104 position_ids = position_ids ,
90105 past_key_values = None ,
91106 all_input_ids = all_input_ids ,
92107 input_lengths = input_lengths ,
93108 next_token_choosers = next_token_choosers ,
94109 stopping_criterias = stopping_criterias ,
95110 size = pb .size ,
96- max_sequence_length = max (input_lengths ),
111+ max_sequence_length = max_sequence_length ,
112+ padding_right_offset = padding_right_offset ,
97113 )
98114
99115 @classmethod
100116 @tracer .start_as_current_span ("concatenate" )
101117 def concatenate (cls , batches : List ["CausalLMBatch" ]) -> "CausalLMBatch" :
102118 # Used for padding
103- total_batch_size = sum (batch .size for batch in batches )
104- max_sequence_length = max (batch .max_sequence_length for batch in batches )
119+ total_batch_size = 0
120+ max_sequence_length = 0
121+ padding_right_offset = 0
122+ for batch in batches :
123+ total_batch_size += batch .size
124+ max_sequence_length = max (max_sequence_length , batch .max_sequence_length )
125+ padding_right_offset = max (padding_right_offset , batch .padding_right_offset )
105126
106127 # Batch attributes
107128 requests = []
@@ -144,13 +165,22 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
144165 # Create padded tensor
145166 if attention_mask is None :
146167 attention_mask = batch .attention_mask .new_zeros (
147- (total_batch_size , max_sequence_length ),
168+ (total_batch_size , max_sequence_length + padding_right_offset ),
148169 )
149170
150171 # We need to slice the attention mask to remove padding from previous steps
172+ # and to remove unused allocated space
173+ left_offset = max_sequence_length - batch .max_sequence_length
174+ batch_left_offset = (
175+ batch .attention_mask .shape [1 ] - batch .max_sequence_length - batch .padding_right_offset
176+ )
151177 attention_mask [
152- start_index :end_index , - batch .max_sequence_length :
153- ] = batch .attention_mask [:, - batch .max_sequence_length :]
178+ start_index :end_index ,
179+ left_offset :- padding_right_offset ,
180+ ] = batch .attention_mask [
181+ :,
182+ batch_left_offset : - batch .padding_right_offset ,
183+ ]
154184
155185 # Create empty tensor
156186 # position_ids is always of shape [batch_size, 1]
@@ -228,6 +258,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
228258 stopping_criterias = stopping_criterias ,
229259 size = total_batch_size ,
230260 max_sequence_length = max_sequence_length ,
261+ padding_right_offset = padding_right_offset ,
231262 keys_head_dim_last = batches [0 ].keys_head_dim_last ,
232263 )
233264
@@ -294,9 +325,12 @@ def forward(
294325 def generate_token (
295326 self , batch : CausalLMBatch
296327 ) -> Tuple [List [Generation ], Optional [CausalLMBatch ]]:
328+ # slice the attention mask to the correct shape
329+ attention_mask = batch .attention_mask [:, : - batch .padding_right_offset ]
330+
297331 logits , past = self .forward (
298332 batch .input_ids ,
299- batch . attention_mask ,
333+ attention_mask ,
300334 batch .position_ids ,
301335 batch .past_key_values ,
302336 )
@@ -448,14 +482,8 @@ def generate_token(
448482 next_batch_next_token_choosers = batch .next_token_choosers
449483 next_batch_stopping_criterias = batch .stopping_criterias
450484
451- # Update attention_mask with padding as we added a new token to input_ids
452- next_batch_attention_mask = torch .cat (
453- [
454- next_batch_attention_mask ,
455- next_batch_attention_mask .new_ones (next_batch_size , 1 ),
456- ],
457- dim = 1 ,
458- )
485+ # Update attention_mask as we added a new token to input_ids
486+ next_batch_attention_mask [:, - batch .padding_right_offset ] = 1
459487
460488 # Update position_ids
461489 next_batch_position_ids = next_batch_position_ids [:, - 1 :] + 1
@@ -473,6 +501,7 @@ def generate_token(
473501 stopping_criterias = next_batch_stopping_criterias ,
474502 size = next_batch_size ,
475503 max_sequence_length = next_batch_max_sequence_length ,
504+ padding_right_offset = batch .padding_right_offset - 1 ,
476505 keys_head_dim_last = batch .keys_head_dim_last ,
477506 )
478507 return generations , next_batch
0 commit comments