@@ -91,19 +91,25 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
9191 images = images ,
9292 return_tensors = "pt" ,
9393 )
94+ # dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
9495 # workaround since processor works in batches instead of single examples
9596 for k , val in batch .items ():
9697 if k in ["pixel_values" ]:
9798 batch [k ] = val .tolist ()
9899 else :
99100 batch [k ] = val .squeeze ().tolist ()
101+ batch ["num_tokens_pre_truncation" ] = len (batch ["input_ids" ])
100102 return batch
101103
102- return self .tokenizer .apply_chat_template (
104+ input_ids = self .tokenizer .apply_chat_template (
103105 conversation ,
104106 add_generation_prompt = add_generation_prompt ,
105107 chat_template = self .chat_template ,
106108 )
109+ return {
110+ "input_ids" : input_ids ,
111+ "num_tokens_pre_truncation" : len (input_ids ),
112+ }
107113
108114 def get_offsets_for_train_detail (
109115 self , text : str , train_details : List [Dict ], mask_untrainable : bool = True
@@ -373,21 +379,25 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
373379 ):
374380 turns = self .get_conversation_thread (prompt )
375381 images = self .get_images (prompt )
376- prompt_ids = self .prompter .build_prompt ( # type: ignore
382+ # We get back {"input_ids": [...], "num_tokens_pre_truncation": ...}
383+ _prompt_ids = self .prompter .build_prompt (
377384 turns [:- 1 ],
378385 add_generation_prompt = True ,
379386 images = images ,
380387 )
388+ prompt_ids = _prompt_ids ["input_ids" ]
381389 tokenized_res = self .prompter .build_prompt (
382390 turns , images = images
383391 ) # type: ignore
384392 tokenized_prompt = {}
385- if isinstance ( tokenized_res , list ) :
386- input_ids = prompt_ids + tokenized_res [len (prompt_ids ) :]
393+ if "attention_mask" not in tokenized_res :
394+ input_ids = prompt_ids + tokenized_res ["input_ids" ][ len (prompt_ids ) :]
387395 tokenized_prompt ["input_ids" ] = input_ids
396+ num_tokens_pre_truncation = tokenized_res ["num_tokens_pre_truncation" ]
388397 tokenized_prompt ["attention_mask" ] = [1 ] * len (input_ids )
389398 else :
390399 input_ids = tokenized_res ["input_ids" ]
400+ num_tokens_pre_truncation = tokenized_res ["num_tokens_pre_truncation" ]
391401 tokenized_prompt = tokenized_res
392402
393403 if not self .train_on_inputs :
@@ -397,11 +407,14 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
397407 labels = input_ids
398408
399409 tokenized_prompt ["labels" ] = labels
410+ tokenized_prompt ["num_tokens_pre_truncation" ] = num_tokens_pre_truncation
400411
401412 return tokenized_prompt
402413
403414 turns = self .get_conversation_thread (prompt )
404- input_ids = self .prompter .build_prompt (turns ) # type: ignore
415+ tokenized_res = self .prompter .build_prompt (turns )
416+ input_ids = tokenized_res ["input_ids" ]
417+ num_tokens_pre_truncation = tokenized_res ["num_tokens_pre_truncation" ]
405418 labels = [IGNORE_TOKEN_ID ] * len (input_ids )
406419
407420 last_eos_idx = - 1
@@ -514,6 +527,7 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
514527 "input_ids" : input_ids ,
515528 "labels" : labels ,
516529 "attention_mask" : [1 ] * len (input_ids ),
530+ "num_tokens_pre_truncation" : num_tokens_pre_truncation ,
517531 }
518532
519533 def find_first_eos_token (self , input_ids , start_idx ):
@@ -573,10 +587,10 @@ def find_turn(self, turns: list[dict], turn_idx: int):
573587 turns_with_content = turns [: turn_idx + 1 ]
574588
575589 # Generate the conversation up to the turn, with final turn replaced with dummy content
576- dummy_ids = self .prompter .build_prompt (turns_with_empty ) # type: ignore
590+ dummy_ids = self .prompter .build_prompt (turns_with_empty )[ "input_ids" ] # type: ignore
577591
578592 # Generate the conversation up to the turn, with final turn included
579- full_ids = self .prompter .build_prompt (turns_with_content ) # type: ignore
593+ full_ids = self .prompter .build_prompt (turns_with_content )[ "input_ids" ] # type: ignore
580594
581595 if not full_ids or not dummy_ids :
582596 LOG .warning (f"Empty template generated for turn { turn_idx } " )
0 commit comments