support prefetch/prefetch_depth for async GRPO for ~5% speedups#5602
support prefetch/prefetch_depth for async GRPO for ~5% speedups#5602winglian wants to merge 1 commit intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.
| return super()._inner_training_loop(*args, **kwargs) | ||
| finally: | ||
| if hasattr(self, "_prefetch_dataset"): | ||
| self._prefetch_dataset.stop() |
There was a problem hiding this comment.
Use of hasattr violates project simplicity rules
Low Severity
hasattr(self, "_prefetch_dataset") violates the AGENTS.md rule that explicitly says to avoid hasattr and getattr, calling them "almost always a symptom of overly defensive programming." The cleaner alternative is to initialize self._prefetch_dataset = None in __init__ (near line 442 where other instance variables are set) and then check if self._prefetch_dataset is not None in the finally block.
Triggered by project rule: ../.ai/AGENTS.md
Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.
| try: | ||
| self._prefetch_queue.put(batch, timeout=5.0) | ||
| except queue.Full: | ||
| pass |
There was a problem hiding this comment.
Prefetch loop silently drops collected batches without logging
Medium Severity
In _prefetch_loop, when _prefetch_queue.put raises queue.Full, the collected batch is silently discarded via pass. These samples were already consumed from the rollout queue (which required GPU compute for generation and scoring) and are permanently lost. At minimum this warrants a logger.warning so users have visibility into data being dropped, especially since this can mask configuration issues with prefetch_depth.
Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.


What does this PR do?
Rather than having the async loader only grab one step ahead of the trainer, we can prefetch and be slightly off policy. for Qwen3 1.7B, this reduces the queue latency by ~50% and improves per step time for that model size @ completion_len=512 by ~4.5%
Fixes # (issue)
Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Medium Risk
Adds a background thread and new buffering behavior to the async GRPO training dataloader, which can impact training stability and sample staleness if misconfigured. Thread lifecycle/cleanup is handled but concurrency and timeout edge cases increase risk vs. a purely synchronous queue read.
Overview
Adds optional rollout batch prefetching to
AsyncGRPOTrainerto overlap rollout-sample collection with training and reduce queue wait time.Introduces new
AsyncGRPOConfigflagsuse_prefetchandprefetch_depth, plus aPrefetchRolloutDatasetthat uses a background thread to pre-collectsamples_per_stepsamples into a bounded prefetch queue; the trainer selects this dataset when enabled and stops the thread on training teardown. Also refactorsRolloutQueueDatasetto use a reusable_pull_one()with deadline-based timeouts and quieter stale-sample dropping.Reviewed by Cursor Bugbot for commit 0b8e30a. Bugbot is set up for automated code reviews on this repo. Configure here.