@@ -88,19 +88,25 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
8888 images = images ,
8989 return_tensors = "pt" ,
9090 )
91+ # dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
9192 # workaround since processor works in batches instead of single examples
9293 for k , val in batch .items ():
9394 if k in ["pixel_values" ]:
9495 batch [k ] = val .tolist ()
9596 else :
9697 batch [k ] = val .squeeze ().tolist ()
98+ batch ["num_tokens_pre_truncation" ] = len (batch ["input_ids" ])
9799 return batch
98100
99- return self .tokenizer .apply_chat_template (
101+ input_ids = self .tokenizer .apply_chat_template (
100102 conversation ,
101103 add_generation_prompt = add_generation_prompt ,
102104 chat_template = self .chat_template ,
103105 )
106+ return {
107+ "input_ids" : input_ids ,
108+ "num_tokens_pre_truncation" : len (input_ids ),
109+ }
104110
105111 def get_offsets_for_train_detail (
106112 self , text : str , train_details : List [Dict ], mask_untrainable : bool = True
@@ -290,20 +296,29 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
290296 ):
291297 turns = self .get_conversation_thread (prompt )
292298 images = self .get_images (prompt )
293- prompt_ids = self .prompter .build_prompt ( # type: ignore
299+ prompt_tokenized = self .prompter .build_prompt (
294300 turns [:- 1 ],
295301 add_generation_prompt = True ,
296302 images = images ,
297303 )
298- tokenized_res = self .prompter .build_prompt (turns , images = images ) # type: ignore
304+ all_turns_tokenized = self .prompter .build_prompt (turns , images = images )
299305 tokenized_prompt = {}
300- if isinstance (tokenized_res , list ):
301- input_ids = prompt_ids + tokenized_res [len (prompt_ids ) :]
306+ if "attention_mask" not in all_turns_tokenized :
307+ prompt_ids = prompt_tokenized ["input_ids" ]
308+ input_ids = (
309+ prompt_ids + all_turns_tokenized ["input_ids" ][len (prompt_ids ) :]
310+ )
302311 tokenized_prompt ["input_ids" ] = input_ids
312+ num_tokens_pre_truncation = all_turns_tokenized [
313+ "num_tokens_pre_truncation"
314+ ]
303315 tokenized_prompt ["attention_mask" ] = [1 ] * len (input_ids )
304316 else :
305- input_ids = tokenized_res ["input_ids" ]
306- tokenized_prompt = tokenized_res
317+ input_ids = all_turns_tokenized ["input_ids" ]
318+ num_tokens_pre_truncation = all_turns_tokenized [
319+ "num_tokens_pre_truncation"
320+ ]
321+ tokenized_prompt = all_turns_tokenized
307322
308323 if not self .train_on_inputs :
309324 user_prompt_len = len (prompt_ids )
@@ -312,11 +327,14 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
312327 labels = input_ids
313328
314329 tokenized_prompt ["labels" ] = labels
330+ tokenized_prompt ["num_tokens_pre_truncation" ] = num_tokens_pre_truncation
315331
316332 return tokenized_prompt
317333
318334 turns = self .get_conversation_thread (prompt )
319- input_ids = self .prompter .build_prompt (turns ) # type: ignore
335+ tokenized_res = self .prompter .build_prompt (turns )
336+ input_ids = tokenized_res ["input_ids" ]
337+ num_tokens_pre_truncation = tokenized_res ["num_tokens_pre_truncation" ]
320338 labels = [IGNORE_TOKEN_ID ] * len (input_ids )
321339
322340 last_eos_idx = - 1
@@ -393,6 +411,7 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
393411 "input_ids" : input_ids ,
394412 "labels" : labels ,
395413 "attention_mask" : [1 ] * len (input_ids ),
414+ "num_tokens_pre_truncation" : num_tokens_pre_truncation ,
396415 }
397416
398417 def find_first_eos_token (self , input_ids , start_idx ):
@@ -433,10 +452,10 @@ def find_turn(self, turns: list[dict], turn_idx: int):
433452 turns_with_content = turns [: turn_idx + 1 ]
434453
435454 # Generate the conversation up to the turn, with final turn replaced with dummy content
436- dummy_ids = self .prompter .build_prompt (turns_with_empty ) # type: ignore
455+ dummy_ids = self .prompter .build_prompt (turns_with_empty )[ "input_ids" ] # type: ignore
437456
438457 # Generate the conversation up to the turn, with final turn included
439- full_ids = self .prompter .build_prompt (turns_with_content ) # type: ignore
458+ full_ids = self .prompter .build_prompt (turns_with_content )[ "input_ids" ] # type: ignore
440459
441460 if not full_ids or not dummy_ids :
442461 LOG .warning (f"Empty template generated for turn { turn_idx } " )
0 commit comments