diff --git a/trl/experimental/kto/kto_trainer.py b/trl/experimental/kto/kto_trainer.py index 02d58bcaa61..05f74a68a42 100644 --- a/trl/experimental/kto/kto_trainer.py +++ b/trl/experimental/kto/kto_trainer.py @@ -270,13 +270,13 @@ def __init__( if processing_class is None: processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config)) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token # PEFT if peft_config is not None: @@ -345,7 +345,7 @@ def __init__( if data_collator is None: data_collator = DataCollatorForUnpairedPreference( - pad_token_id=tokenizer.pad_token_id, + pad_token_id=self._tokenizer.pad_token_id, max_length=max_length, ) @@ -569,8 +569,6 @@ def _prepare_dataset( map_kwargs["desc"] = f"Unpairing {dataset_name} dataset" dataset = unpair_preference_dataset(dataset, **map_kwargs) - tokenizer = getattr(processing_class, "tokenizer", processing_class) - # Add EOS token if needed: non-conversational only first_example = next(iter(dataset)) if not is_conversational(first_example): @@ -582,7 +580,7 @@ def add_eos(example, eos_token): example["completion"] = example["completion"] + eos_token return example - dataset = dataset.map(add_eos, fn_kwargs={"eos_token": tokenizer.eos_token}, **map_kwargs) + dataset = dataset.map(add_eos, fn_kwargs={"eos_token": self._tokenizer.eos_token}, **map_kwargs) # Tokenize dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index 5cb73fbcc83..cb57eba41fa 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -361,17 +361,14 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token # Vision tokens for VLM support self.image_token_id = getattr(processing_class, "image_token_id", None) @@ -380,11 +377,11 @@ def __init__( # Get the image token string for token collapsing self.image_token = None if self.image_token_id is not None: - self.image_token = tokenizer.decode([self.image_token_id]) + self.image_token = self._tokenizer.decode([self.image_token_id]) # Define the collator if not provided if data_collator is None: - data_collator = DPODataCollatorWithPadding(pad_token_id=self.pad_token_id) + data_collator = DPODataCollatorWithPadding(pad_token_id=self._tokenizer.pad_token_id) # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream @@ -517,9 +514,9 @@ def __init__( generation_kwargs = { "max_new_tokens": args.max_new_tokens, "do_sample": True, - "pad_token_id": self.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": self.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, @@ -595,8 +592,8 @@ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPO return model def _generate_vllm(self, prompts, images=None): - eos_token_id = self.eos_token_id - pad_token_id = self.pad_token_id + eos_token_id = self._tokenizer.eos_token_id + pad_token_id = self._tokenizer.pad_token_id # Generate completion_ids and prompt_ids based on mode if self.vllm_mode == "server": @@ -905,8 +902,8 @@ def process_vision_row( def _generate(self, model, prompts, images=None): """Generate completions using the model""" device = next(model.parameters()).device - eos_token_id = self.eos_token_id - pad_token_id = self.pad_token_id + eos_token_id = self._tokenizer.eos_token_id + pad_token_id = self._tokenizer.pad_token_id # Apply chat template and tokenize the input inputs = [{"prompt": prompt} for prompt in prompts] @@ -935,9 +932,7 @@ def _generate(self, model, prompts, images=None): else: # If the chat template doesn't use the image token, remove all instances if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) + escaped_eoi_token = re.escape(self._tokenizer.decode([self.vision_end_token_id])) prompts_text = [ re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text ] @@ -1130,7 +1125,7 @@ def training_step( else: prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts, images) - contain_eos_token = torch.any(completion_ids == self.eos_token_id, dim=-1) + contain_eos_token = torch.any(completion_ids == self._tokenizer.eos_token_id, dim=-1) # Extract vision inputs if available for VLM support vision_inputs = None diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index a4be9b395fe..f87a38ee673 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -221,17 +221,15 @@ def __init__( ) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length self.num_generations = args.num_generations @@ -251,9 +249,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, @@ -410,7 +408,7 @@ def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, to prompt_length = generate_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) @@ -420,7 +418,7 @@ def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, to completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), + pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right"), pad(completion_mask, padding_value=0, padding_side="right"), ) diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index ef84a17a44c..66195da4c58 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -87,7 +87,9 @@ def _tokenize_teacher_messages( teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] return TokenizedPromptBatch( - prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + prompt_ids=pad( + teacher_prompt_ids, padding_value=self.trainer._tokenizer.pad_token_id, padding_side="left" + ), prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), ) @@ -115,7 +117,7 @@ def build( # Use separate variables so the original completion_ids/completion_mask stay unpadded for the # teacher concat (they must match the student's sequence length for logits_to_keep alignment). padded_completion_ids = self.trainer.accelerator.pad_across_processes( - completion_ids, dim=1, pad_index=self.trainer.pad_token_id + completion_ids, dim=1, pad_index=self.trainer._tokenizer.pad_token_id ) all_completion_ids = self.trainer.accelerator.gather(padded_completion_ids) all_prompts = gather_object(prompts) @@ -193,7 +195,7 @@ def build( if demo_idx is None: raise RuntimeError("Expected a successful demonstration index for an active SDPO teacher prompt.") demo_ids = all_completion_ids[demo_idx] - demo_ids = demo_ids[demo_ids != self.trainer.processing_class.pad_token_id] + demo_ids = demo_ids[demo_ids != self.trainer._tokenizer.pad_token_id] demo_text = self.trainer.processing_class.decode(demo_ids, skip_special_tokens=True) if self.trainer.args.remove_thinking_from_demonstration: diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index 5484a2efb66..a332c7caf8a 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -125,17 +125,15 @@ def __init__( ) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id self.temperature = args.temperature self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length @@ -163,9 +161,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, diff --git a/trl/experimental/self_distillation/online_rollout_mixin.py b/trl/experimental/self_distillation/online_rollout_mixin.py index 490724582dc..93caf5e2eeb 100644 --- a/trl/experimental/self_distillation/online_rollout_mixin.py +++ b/trl/experimental/self_distillation/online_rollout_mixin.py @@ -110,7 +110,7 @@ def _generate_transformers(self, prompts): prompt_mask = generate_inputs["attention_mask"] prompt_length = prompt_ids.size(1) completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) @@ -173,15 +173,17 @@ def _generate_and_score_completions(self, inputs): prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) + prompt_ids = pad(prompt_ids, padding_value=self._tokenizer.pad_token_id, padding_side="left").to(device=device) prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) completion_ids = [torch.tensor(ids) for ids in completion_ids_list] completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) + completion_ids = pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right").to( + device=device + ) completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() @@ -240,7 +242,7 @@ def _generate_and_score_completions(self, inputs): self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) diff --git a/trl/experimental/self_distillation/teacher_context.py b/trl/experimental/self_distillation/teacher_context.py index 5e1020c91a7..2448b78a712 100644 --- a/trl/experimental/self_distillation/teacher_context.py +++ b/trl/experimental/self_distillation/teacher_context.py @@ -80,6 +80,6 @@ def tokenize_prompts(self, prompts: list[Any]) -> TokenizedPromptBatch: prompt_ids = [torch.tensor(ids, device=self.trainer.accelerator.device) for ids in prompt_ids] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] return TokenizedPromptBatch( - prompt_ids=pad(prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + prompt_ids=pad(prompt_ids, padding_value=self.trainer._tokenizer.pad_token_id, padding_side="left"), prompt_mask=pad(prompt_mask, padding_value=0, padding_side="left"), ) diff --git a/trl/experimental/ssd/ssd_trainer.py b/trl/experimental/ssd/ssd_trainer.py index 91a0275047a..7246e1a616e 100644 --- a/trl/experimental/ssd/ssd_trainer.py +++ b/trl/experimental/ssd/ssd_trainer.py @@ -161,17 +161,15 @@ def __init__( ) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # SSD always samples a single completion per prompt (N=1 in the paper). @@ -189,9 +187,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, @@ -374,7 +372,7 @@ def _generate_completion_ids_vllm(self, prompts: list[Any]) -> tuple[torch.Tenso completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] completion_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in completion_ids_list] return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), + pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right"), pad(completion_mask, padding_value=0, padding_side="right"), ) @@ -406,7 +404,7 @@ def _generate_completion_ids_transformers(self, prompts: list[Any]) -> tuple[tor prompt_length = generate_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) @@ -416,7 +414,7 @@ def _generate_completion_ids_transformers(self, prompts: list[Any]) -> tuple[tor completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), + pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right"), pad(completion_mask, padding_value=0, padding_side="right"), ) diff --git a/trl/experimental/tpo/tpo_trainer.py b/trl/experimental/tpo/tpo_trainer.py index 7195596e15c..b012a5653d8 100644 --- a/trl/experimental/tpo/tpo_trainer.py +++ b/trl/experimental/tpo/tpo_trainer.py @@ -345,10 +345,10 @@ def __init__( "The `processing_class` must be a `PreTrainedTokenizerBase`. `TPOTrainer` does not currently " "support vision-language models." ) - tokenizer = processing_class + self._tokenizer = processing_class - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token # PEFT if peft_config is not None: @@ -383,7 +383,7 @@ def __init__( # each step. if data_collator is None: data_collator = DataCollatorForTriplePreference( - pad_token_id=tokenizer.pad_token_id, + pad_token_id=self._tokenizer.pad_token_id, max_length=args.max_length, truncation_mode=args.truncation_mode, pad_to_multiple_of=args.pad_to_multiple_of,