Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _generate_completions(
accelerator: Accelerator,
generation_config: GenerationConfig | None,
batch_size: int = 1,
**generation_kwargs,
) -> list[str]:
"""
Generates completions for a list of pre-formatted prompts from the given model.
Expand All @@ -77,19 +78,22 @@ def _generate_completions(
accelerator (Accelerator): The accelerator to be used for model execution.
generation_config (GenerationConfig): Configuration for text generation.
batch_size (int, optional): The number of prompts to process in each batch. Default is 1.
**generation_kwargs: Additional keyword arguments forwarded to `model.generate`. These
override any matching fields in `generation_config` (e.g. `max_new_tokens=50`,
`temperature=0.8`).

Returns:
list[str]: A list of generated text completions corresponding to the input prompts.
"""
completions = []
# TODO: Override model.generation_config with generation_kwargs
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
for idx in range(0, len(prompts), batch_size):
batch = prompts[idx : idx + batch_size]
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device)
generations = unwrapped_model.generate(
**tokenized_batch,
generation_config=generation_config,
**generation_kwargs,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfiltered kwargs risk TypeError in model.generate call

Low Severity

**generation_kwargs and **tokenized_batch are both unpacked into model.generate() without any key-conflict filtering. If generation_kwargs contains a key also present in tokenized_batch (e.g., input_ids, attention_mask), Python raises a TypeError for duplicate keyword arguments at the call site. The rest of the codebase avoids this by folding generation kwargs into a GenerationConfig object (via GenerationConfig(**generation_kwargs)) rather than passing them as raw **kwargs to generate().

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 61468de. Configure here.

)
for prompt, generation in zip(tokenized_batch.input_ids, generations, strict=True):
# Remove prompt from generation
Expand Down Expand Up @@ -258,7 +262,7 @@ class LogCompletionsCallback(TrainerCallback):
Usage:
```python
trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
completions_callback = LogCompletionsCallback(trainer=trainer, max_new_tokens=50)
trainer.add_callback(completions_callback)
```

Expand All @@ -273,6 +277,9 @@ class LogCompletionsCallback(TrainerCallback):
the evaluation dataset.
freq (`int`, *optional*):
The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`.
**generation_kwargs:
Additional keyword arguments forwarded to `model.generate`. These override any matching fields in
`generation_config` (e.g. `max_new_tokens=50`, `temperature=0.8`).
"""

def __init__(
Expand All @@ -281,9 +288,11 @@ def __init__(
generation_config: GenerationConfig | None = None,
num_prompts: int | None = None,
freq: int | None = None,
**generation_kwargs,
):
self.trainer = trainer
self.generation_config = generation_config
self.generation_kwargs = generation_kwargs
self.freq = freq
self.table = []
self._last_logged_step = -1
Expand Down Expand Up @@ -319,6 +328,7 @@ def on_step_end(self, args, state, control, **kwargs):
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
**self.generation_kwargs,
)
completions = gather_object(completions)
prompts = gather_object(prompts)
Expand Down Expand Up @@ -405,6 +415,9 @@ def accuracy_scorer(prompt: str, completion: str) -> float:
Name for the dataset metadata in Weave.
model_name (`str`, *optional*):
Name for the model metadata in Weave. If not provided, attempts to extract from model config.
**generation_kwargs:
Additional keyword arguments forwarded to `model.generate`. These override any matching fields in
`generation_config` (e.g. `max_new_tokens=50`, `temperature=0.8`).
"""

def __init__(
Expand All @@ -416,11 +429,13 @@ def __init__(
num_prompts: int | None = None,
dataset_name: str = "eval_dataset",
model_name: str | None = None,
**generation_kwargs,
):
self.trainer = trainer
self.project_name = project_name
self.scorers = scorers or {}
self.generation_config = generation_config
self.generation_kwargs = generation_kwargs
self.dataset_name = dataset_name
self.model_name = model_name
self._last_logged_step = -1
Expand Down Expand Up @@ -503,6 +518,7 @@ def on_evaluate(self, args, state, control, **kwargs):
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
**self.generation_kwargs,
Comment thread
cursor[bot] marked this conversation as resolved.
)

all_prompts = gather_object(prompts)
Expand All @@ -513,6 +529,7 @@ def on_evaluate(self, args, state, control, **kwargs):
"training_step": state.global_step,
"model_name": self.model_name,
"generation_config": (self.generation_config.to_dict() if self.generation_config else None),
"generation_kwargs": self.generation_kwargs,
}

eval_logger = self._EvaluationLogger(
Expand Down