Skip to content

support prefetch/prefetch_depth for async GRPO for ~5% speedups#5602

Open
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:feat/prefetch-rollout-dataset
Open

support prefetch/prefetch_depth for async GRPO for ~5% speedups#5602
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:feat/prefetch-rollout-dataset

Conversation

@winglian
Copy link
Copy Markdown
Contributor

@winglian winglian commented Apr 20, 2026

What does this PR do?

Rather than having the async loader only grab one step ahead of the trainer, we can prefetch and be slightly off policy. for Qwen3 1.7B, this reduces the queue latency by ~50% and improves per step time for that model size @ completion_len=512 by ~4.5%

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

Medium Risk
Adds a background thread and new buffering behavior to the async GRPO training dataloader, which can impact training stability and sample staleness if misconfigured. Thread lifecycle/cleanup is handled but concurrency and timeout edge cases increase risk vs. a purely synchronous queue read.

Overview
Adds optional rollout batch prefetching to AsyncGRPOTrainer to overlap rollout-sample collection with training and reduce queue wait time.

Introduces new AsyncGRPOConfig flags use_prefetch and prefetch_depth, plus a PrefetchRolloutDataset that uses a background thread to pre-collect samples_per_step samples into a bounded prefetch queue; the trainer selects this dataset when enabled and stops the thread on training teardown. Also refactors RolloutQueueDataset to use a reusable _pull_one() with deadline-based timeouts and quieter stale-sample dropping.

Reviewed by Cursor Bugbot for commit 0b8e30a. Bugbot is set up for automated code reviews on this repo. Configure here.

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.

return super()._inner_training_loop(*args, **kwargs)
finally:
if hasattr(self, "_prefetch_dataset"):
self._prefetch_dataset.stop()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of hasattr violates project simplicity rules

Low Severity

hasattr(self, "_prefetch_dataset") violates the AGENTS.md rule that explicitly says to avoid hasattr and getattr, calling them "almost always a symptom of overly defensive programming." The cleaner alternative is to initialize self._prefetch_dataset = None in __init__ (near line 442 where other instance variables are set) and then check if self._prefetch_dataset is not None in the finally block.

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.

try:
self._prefetch_queue.put(batch, timeout=5.0)
except queue.Full:
pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefetch loop silently drops collected batches without logging

Medium Severity

In _prefetch_loop, when _prefetch_queue.put raises queue.Full, the collected batch is silently discarded via pass. These samples were already consumed from the rollout queue (which required GPU compute for generation and scoring) and are permanently lost. At minimum this warrants a logger.warning so users have visibility into data being dropped, especially since this can mask configuration issues with prefetch_depth.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants