Skip to content

feat: Add generation_kwargs support to LogCompletionsCallback and Wea…#5625

Open
LhaseParth2610 wants to merge 3 commits intohuggingface:mainfrom
LhaseParth2610:add-generation-kwargs-to-callbacks
Open

feat: Add generation_kwargs support to LogCompletionsCallback and Wea…#5625
LhaseParth2610 wants to merge 3 commits intohuggingface:mainfrom
LhaseParth2610:add-generation-kwargs-to-callbacks

Conversation

@LhaseParth2610
Copy link
Copy Markdown

@LhaseParth2610 LhaseParth2610 commented Apr 22, 2026

What does this PR do?

Resolves a TODO in _generate_completions (callbacks.py) that originally requested generation_kwargs support.

Before this PR, LogCompletionsCallback and WeaveCallback only accepted a full GenerationConfig object for controlling generation. This meant users couldn't pass simple, common keyword arguments like max_new_tokens=50 or temperature=0.8 directly to the callbacks.

This PR introduces **generation_kwargs to the following methods:

  • _generate_completions() (the core utility)
  • LogCompletionsCallback.__init__()
  • WeaveCallback.__init__()

These kwargs are stored and seamlessly forwarded to model.generate(). Before being passed, we filter out any keys that conflict with explicitly defined arguments (like generation_config or tokenized batch keys) to prevent TypeError: got multiple values for argument, adhering perfectly to standard transformers behavior safely.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

Low Risk
Low risk: this only threads optional **generation_kwargs through completion generation and logging, but could change logged outputs if users pass overrides that differ from the provided generation_config.

Overview
Adds support for passing **generation_kwargs through _generate_completions() and into model.generate(), allowing callers to override GenerationConfig fields via simple keyword args.

Extends LogCompletionsCallback and WeaveCallback to accept and store these kwargs, forwards them during completion generation, and records them in Weave eval_attributes for traceability.

Reviewed by Cursor Bugbot for commit 61468de. Bugbot is set up for automated code reviews on this repo. Configure here.

…veCallback

Resolves an outstanding TODO in _generate_completions. This change allows users to pass simple generation keyword arguments (like max_new_tokens=50 or temperature=0.8) directly to the callbacks without needing to construct a full GenerationConfig object.
Comment thread trl/trainer/callbacks.py Outdated
Comment thread trl/trainer/callbacks.py
@LhaseParth2610 LhaseParth2610 force-pushed the add-generation-kwargs-to-callbacks branch from 2b2f461 to 08b1df0 Compare April 22, 2026 13:59
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 61468de. Configure here.

Comment thread trl/trainer/callbacks.py
generations = unwrapped_model.generate(
**tokenized_batch,
generation_config=generation_config,
**generation_kwargs,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfiltered kwargs risk TypeError in model.generate call

Low Severity

**generation_kwargs and **tokenized_batch are both unpacked into model.generate() without any key-conflict filtering. If generation_kwargs contains a key also present in tokenized_batch (e.g., input_ids, attention_mask), Python raises a TypeError for duplicate keyword arguments at the call site. The rest of the codebase avoids this by folding generation kwargs into a GenerationConfig object (via GenerationConfig(**generation_kwargs)) rather than passing them as raw **kwargs to generate().

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 61468de. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant