@@ -94,20 +94,26 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
9494 images = images ,
9595 return_tensors = "pt" ,
9696 )
97+ # dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
9798 # workaround since processor works in batches instead of single examples
9899 for k , val in batch .items ():
99100 if k in ["pixel_values" ]:
100101 batch [k ] = val .tolist ()
101102 else :
102103 batch [k ] = val .squeeze ().tolist ()
104+ batch ["num_tokens_pre_truncation" ] = len (batch ["input_ids" ])
103105 return batch
104106
105- return self .tokenizer .apply_chat_template (
107+ input_ids = self .tokenizer .apply_chat_template (
106108 conversation ,
107109 add_generation_prompt = add_generation_prompt ,
108110 chat_template = self .chat_template ,
109111 ** self .chat_template_kwargs ,
110112 )
113+ return {
114+ "input_ids" : input_ids ,
115+ "num_tokens_pre_truncation" : len (input_ids ),
116+ }
111117
112118 def get_offsets_for_train_detail (
113119 self , text : str , train_details : List [Dict ], mask_untrainable : bool = True
@@ -377,21 +383,25 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
377383 ):
378384 turns = self .get_conversation_thread (prompt )
379385 images = self .get_images (prompt )
380- prompt_ids = self .prompter .build_prompt ( # type: ignore
386+ # We get back {"input_ids": [...], "num_tokens_pre_truncation": ...}
387+ _prompt_ids = self .prompter .build_prompt (
381388 turns [:- 1 ],
382389 add_generation_prompt = True ,
383390 images = images ,
384391 )
392+ prompt_ids = _prompt_ids ["input_ids" ]
385393 tokenized_res = self .prompter .build_prompt (
386394 turns , images = images
387395 ) # type: ignore
388396 tokenized_prompt = {}
389- if isinstance ( tokenized_res , list ) :
390- input_ids = prompt_ids + tokenized_res [len (prompt_ids ) :]
397+ if "attention_mask" not in tokenized_res :
398+ input_ids = prompt_ids + tokenized_res ["input_ids" ][ len (prompt_ids ) :]
391399 tokenized_prompt ["input_ids" ] = input_ids
400+ num_tokens_pre_truncation = tokenized_res ["num_tokens_pre_truncation" ]
392401 tokenized_prompt ["attention_mask" ] = [1 ] * len (input_ids )
393402 else :
394403 input_ids = tokenized_res ["input_ids" ]
404+ num_tokens_pre_truncation = tokenized_res ["num_tokens_pre_truncation" ]
395405 tokenized_prompt = tokenized_res
396406
397407 if not self .train_on_inputs :
@@ -401,11 +411,14 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
401411 labels = input_ids
402412
403413 tokenized_prompt ["labels" ] = labels
414+ tokenized_prompt ["num_tokens_pre_truncation" ] = num_tokens_pre_truncation
404415
405416 return tokenized_prompt
406417
407418 turns = self .get_conversation_thread (prompt )
408- input_ids = self .prompter .build_prompt (turns ) # type: ignore
419+ tokenized_res = self .prompter .build_prompt (turns )
420+ input_ids = tokenized_res ["input_ids" ]
421+ num_tokens_pre_truncation = tokenized_res ["num_tokens_pre_truncation" ]
409422 labels = [IGNORE_TOKEN_ID ] * len (input_ids )
410423
411424 last_eos_idx = - 1
@@ -518,6 +531,7 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
518531 "input_ids" : input_ids ,
519532 "labels" : labels ,
520533 "attention_mask" : [1 ] * len (input_ids ),
534+ "num_tokens_pre_truncation" : num_tokens_pre_truncation ,
521535 }
522536
523537 def find_first_eos_token (self , input_ids , start_idx ):
@@ -577,10 +591,10 @@ def find_turn(self, turns: list[dict], turn_idx: int):
577591 turns_with_content = turns [: turn_idx + 1 ]
578592
579593 # Generate the conversation up to the turn, with final turn replaced with dummy content
580- dummy_ids = self .prompter .build_prompt (turns_with_empty ) # type: ignore
594+ dummy_ids = self .prompter .build_prompt (turns_with_empty )[ "input_ids" ] # type: ignore
581595
582596 # Generate the conversation up to the turn, with final turn included
583- full_ids = self .prompter .build_prompt (turns_with_content ) # type: ignore
597+ full_ids = self .prompter .build_prompt (turns_with_content )[ "input_ids" ] # type: ignore
584598
585599 if not full_ids or not dummy_ids :
586600 LOG .warning (f"Empty template generated for turn { turn_idx } " )
0 commit comments