diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index a530e38b0c4..9885c872d97 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -66,6 +66,7 @@ def _generate_completions( accelerator: Accelerator, generation_config: GenerationConfig | None, batch_size: int = 1, + **generation_kwargs, ) -> list[str]: """ Generates completions for a list of pre-formatted prompts from the given model. @@ -77,12 +78,14 @@ def _generate_completions( accelerator (Accelerator): The accelerator to be used for model execution. generation_config (GenerationConfig): Configuration for text generation. batch_size (int, optional): The number of prompts to process in each batch. Default is 1. + **generation_kwargs: Additional keyword arguments forwarded to `model.generate`. These + override any matching fields in `generation_config` (e.g. `max_new_tokens=50`, + `temperature=0.8`). Returns: list[str]: A list of generated text completions corresponding to the input prompts. """ completions = [] - # TODO: Override model.generation_config with generation_kwargs with unwrap_model_for_generation(model, accelerator) as unwrapped_model: for idx in range(0, len(prompts), batch_size): batch = prompts[idx : idx + batch_size] @@ -90,6 +93,7 @@ def _generate_completions( generations = unwrapped_model.generate( **tokenized_batch, generation_config=generation_config, + **generation_kwargs, ) for prompt, generation in zip(tokenized_batch.input_ids, generations, strict=True): # Remove prompt from generation @@ -258,7 +262,7 @@ class LogCompletionsCallback(TrainerCallback): Usage: ```python trainer = DPOTrainer(...) - completions_callback = LogCompletionsCallback(trainer=trainer) + completions_callback = LogCompletionsCallback(trainer=trainer, max_new_tokens=50) trainer.add_callback(completions_callback) ``` @@ -273,6 +277,9 @@ class LogCompletionsCallback(TrainerCallback): the evaluation dataset. freq (`int`, *optional*): The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. + **generation_kwargs: + Additional keyword arguments forwarded to `model.generate`. These override any matching fields in + `generation_config` (e.g. `max_new_tokens=50`, `temperature=0.8`). """ def __init__( @@ -281,9 +288,11 @@ def __init__( generation_config: GenerationConfig | None = None, num_prompts: int | None = None, freq: int | None = None, + **generation_kwargs, ): self.trainer = trainer self.generation_config = generation_config + self.generation_kwargs = generation_kwargs self.freq = freq self.table = [] self._last_logged_step = -1 @@ -319,6 +328,7 @@ def on_step_end(self, args, state, control, **kwargs): accelerator=accelerator, generation_config=self.generation_config, batch_size=args.per_device_eval_batch_size, + **self.generation_kwargs, ) completions = gather_object(completions) prompts = gather_object(prompts) @@ -405,6 +415,9 @@ def accuracy_scorer(prompt: str, completion: str) -> float: Name for the dataset metadata in Weave. model_name (`str`, *optional*): Name for the model metadata in Weave. If not provided, attempts to extract from model config. + **generation_kwargs: + Additional keyword arguments forwarded to `model.generate`. These override any matching fields in + `generation_config` (e.g. `max_new_tokens=50`, `temperature=0.8`). """ def __init__( @@ -416,11 +429,13 @@ def __init__( num_prompts: int | None = None, dataset_name: str = "eval_dataset", model_name: str | None = None, + **generation_kwargs, ): self.trainer = trainer self.project_name = project_name self.scorers = scorers or {} self.generation_config = generation_config + self.generation_kwargs = generation_kwargs self.dataset_name = dataset_name self.model_name = model_name self._last_logged_step = -1 @@ -503,6 +518,7 @@ def on_evaluate(self, args, state, control, **kwargs): accelerator=accelerator, generation_config=self.generation_config, batch_size=args.per_device_eval_batch_size, + **self.generation_kwargs, ) all_prompts = gather_object(prompts) @@ -513,6 +529,7 @@ def on_evaluate(self, args, state, control, **kwargs): "training_step": state.global_step, "model_name": self.model_name, "generation_config": (self.generation_config.to_dict() if self.generation_config else None), + "generation_kwargs": self.generation_kwargs, } eval_logger = self._EvaluationLogger(