Skip to content
Merged
Changes from 4 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8852bd2
Set tokenizer attribute in OnlineDPO
albertvillanova Apr 16, 2026
ee2fd2d
Replace self.pad_token_id and self.eos_token_id
albertvillanova Apr 16, 2026
5375253
Use self._tokenizer
albertvillanova Apr 16, 2026
995558d
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 17, 2026
ddb0457
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 23, 2026
9c535e0
Set tokenizer attribute in SDFT
albertvillanova Apr 23, 2026
0f85a79
Set tokenizer attribute in SelfDistillation
albertvillanova Apr 23, 2026
a355438
Set tokenizer attribute in SSD
albertvillanova Apr 23, 2026
4e8448b
Set tokenizer attribute in TPO
albertvillanova Apr 23, 2026
4cec956
Replace self.pad_token_id and self.eos_token_id in SDFT
albertvillanova Apr 23, 2026
7916f77
Remove self.pad_token_id and self.eos_token_id in BaseSelfDistillation
albertvillanova Apr 23, 2026
3204efb
Replace self.pad_token_id in SDPO SuccessfulRolloutTeacherContextBuilder
albertvillanova Apr 23, 2026
6d8a23f
Replace self.pad_token_id and self.eos_token_id in SSD
albertvillanova Apr 23, 2026
09b0a4a
Replace self.pad_token_id and self.eos_token_id in OnlineRolloutMixin
albertvillanova Apr 23, 2026
8a81c5d
Use self._tokenizer in PromptTokenizer
albertvillanova Apr 23, 2026
d0911b9
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 23, 2026
ad61f84
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 27, 2026
0161f79
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 28, 2026
d335ec1
Set tokenizer attribute in KTO
albertvillanova Apr 28, 2026
ba22aad
Use self._tokenizer in _prepare_dataset
albertvillanova Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,17 +349,14 @@ def __init__(

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
tokenizer = processing_class.tokenizer
self._tokenizer = processing_class.tokenizer
elif isinstance(processing_class, PreTrainedTokenizerBase):
tokenizer = processing_class
self._tokenizer = processing_class
else:
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
Comment thread
cursor[bot] marked this conversation as resolved.

# Vision tokens for VLM support
self.image_token_id = getattr(processing_class, "image_token_id", None)
Expand All @@ -368,11 +365,11 @@ def __init__(
# Get the image token string for token collapsing
self.image_token = None
if self.image_token_id is not None:
self.image_token = tokenizer.decode([self.image_token_id])
self.image_token = self._tokenizer.decode([self.image_token_id])

# Define the collator if not provided
if data_collator is None:
data_collator = DPODataCollatorWithPadding(pad_token_id=self.pad_token_id)
data_collator = DPODataCollatorWithPadding(pad_token_id=self._tokenizer.pad_token_id)

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
Expand Down Expand Up @@ -505,9 +502,9 @@ def __init__(
generation_kwargs = {
"max_new_tokens": args.max_new_tokens,
"do_sample": True,
"pad_token_id": self.pad_token_id,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": self.eos_token_id,
"pad_token_id": self._tokenizer.pad_token_id,
"bos_token_id": self._tokenizer.bos_token_id,
"eos_token_id": self._tokenizer.eos_token_id,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
Expand Down Expand Up @@ -583,8 +580,8 @@ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPO
return model

def _generate_vllm(self, prompts, images=None):
eos_token_id = self.eos_token_id
pad_token_id = self.pad_token_id
eos_token_id = self._tokenizer.eos_token_id
pad_token_id = self._tokenizer.pad_token_id

# Generate completion_ids and prompt_ids based on mode
if self.vllm_mode == "server":
Expand Down Expand Up @@ -893,8 +890,8 @@ def process_vision_row(
def _generate(self, model, prompts, images=None):
"""Generate completions using the model"""
device = next(model.parameters()).device
eos_token_id = self.eos_token_id
pad_token_id = self.pad_token_id
eos_token_id = self._tokenizer.eos_token_id
pad_token_id = self._tokenizer.pad_token_id

# Apply chat template and tokenize the input
inputs = [{"prompt": prompt} for prompt in prompts]
Expand Down Expand Up @@ -923,9 +920,7 @@ def _generate(self, model, prompts, images=None):
else:
# If the chat template doesn't use the image token, remove all instances
if self.vision_end_token_id is not None:
escaped_eoi_token = re.escape(
self.processing_class.tokenizer.decode([self.vision_end_token_id])
)
escaped_eoi_token = re.escape(self._tokenizer.decode([self.vision_end_token_id]))
prompts_text = [
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
]
Expand Down Expand Up @@ -1118,7 +1113,7 @@ def training_step(
else:
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts, images)

contain_eos_token = torch.any(completion_ids == self.eos_token_id, dim=-1)
contain_eos_token = torch.any(completion_ids == self._tokenizer.eos_token_id, dim=-1)

# Extract vision inputs if available for VLM support
vision_inputs = None
Expand Down
Loading