-
Notifications
You must be signed in to change notification settings - Fork 2.7k
support prefetch/prefetch_depth for async GRPO for ~5% speedups #5602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import math | ||
| import queue | ||
| import textwrap | ||
| import threading | ||
| import time | ||
| from collections import defaultdict | ||
| from collections.abc import Callable, Iterator | ||
|
|
@@ -84,34 +85,128 @@ def __init__(self, rollout_queue, model_version_fn, max_staleness=3, timeout=120 | |
| self.max_staleness = max_staleness | ||
| self.timeout = timeout | ||
|
|
||
| def __iter__(self): | ||
| def _pull_one(self, deadline=None): | ||
| """Pull a single non-stale sample from the rollout queue. | ||
|
|
||
| Blocks until a valid sample is available or the deadline expires. Stale samples (whose model | ||
| version lags the current version by more than ``max_staleness``) are silently dropped. | ||
|
|
||
| Args: | ||
| deadline (`float`, *optional*): | ||
| Absolute ``time.time()`` cutoff. When ``None``, a single attempt with ``self.timeout`` | ||
| is made instead. | ||
|
|
||
| Returns: | ||
| A dict with keys ``input_ids``, ``completion_mask``, ``old_log_probs``, ``advantage``, and | ||
| ``metrics``, or ``None`` if the deadline/timeout expired. | ||
| """ | ||
| while True: | ||
| remaining = max(1.0, deadline - time.time()) if deadline else self.timeout | ||
| t0 = time.time() | ||
| qsize = self.queue.qsize() | ||
| if qsize == 0: | ||
| logger.info("queue empty, waiting for rollout samples...") | ||
| try: | ||
| sample = self.queue.get(timeout=self.timeout) | ||
| sample = self.queue.get(timeout=min(remaining, 2.0)) | ||
| except queue.Empty: | ||
| logger.warning(f"Rollout queue empty for {self.timeout}s, stopping epoch") | ||
| return # StopIteration ends epoch | ||
| queue_wait_time_s = time.time() - t0 | ||
| if queue_wait_time_s > 1.0: | ||
| logger.info(f"waited {queue_wait_time_s:.1f}s for sample (qsize={self.queue.qsize()})") | ||
| if deadline and time.time() >= deadline: | ||
| return None | ||
| if not deadline: | ||
| return None | ||
| continue | ||
| wait = time.time() - t0 | ||
|
|
||
| staleness = self.model_version_fn() - sample.model_version | ||
| if staleness > self.max_staleness: | ||
| logger.info(f"dropping stale sample (staleness={staleness}, max={self.max_staleness})") | ||
| continue # drop stale, pull next | ||
| logger.debug(f"dropping stale sample (staleness={staleness})") | ||
| continue | ||
|
|
||
| yield { | ||
| return { | ||
| "input_ids": sample.input_ids, | ||
| "completion_mask": sample.completion_mask, | ||
| "old_log_probs": sample.old_log_probs, | ||
| "advantage": sample.advantage, | ||
| "metrics": {**sample.metrics, "queue_wait_time_s": queue_wait_time_s}, | ||
| "metrics": {**sample.metrics, "queue_wait_time_s": wait}, | ||
| } | ||
|
|
||
| def __iter__(self): | ||
| while True: | ||
| sample = self._pull_one(deadline=time.time() + self.timeout) | ||
| if sample is None: | ||
| logger.warning(f"Rollout queue empty for {self.timeout}s, stopping epoch") | ||
| return | ||
| yield sample | ||
|
|
||
|
|
||
| class PrefetchRolloutDataset(RolloutQueueDataset): | ||
| """Extends ``RolloutQueueDataset`` with background-thread batch prefetching. | ||
|
|
||
| A background thread continuously calls ``_pull_one`` to batch-collect ``samples_per_step`` | ||
| samples and places them in a bounded prefetch queue. The training loop then drains pre-collected | ||
| batches with near-zero wait, overlapping queue collection with gradient computation. | ||
|
|
||
| Args: | ||
| rollout_queue: The queue of scored rollout samples from the ``AsyncRolloutWorker``. | ||
| model_version_fn: Callable returning the current model version for staleness filtering. | ||
| samples_per_step (`int`): | ||
| Number of samples to collect per training step | ||
| (``per_device_train_batch_size * gradient_accumulation_steps * num_processes``). | ||
| max_staleness (`int`, *optional*, defaults to `3`): | ||
| Maximum model version lag before a sample is dropped. | ||
| timeout (`float`, *optional*, defaults to `120.0`): | ||
| Seconds to wait for each individual sample. | ||
| prefetch_depth (`int`, *optional*, defaults to `1`): | ||
| Number of batches to prefetch ahead. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, rollout_queue, model_version_fn, samples_per_step, max_staleness=3, timeout=120.0, prefetch_depth=1 | ||
| ): | ||
| super().__init__(rollout_queue, model_version_fn, max_staleness, timeout) | ||
| self.samples_per_step = samples_per_step | ||
| self.prefetch_depth = prefetch_depth | ||
| self._prefetch_queue = queue.Queue(maxsize=prefetch_depth) | ||
| self._stop_event = threading.Event() | ||
| self._thread = threading.Thread(target=self._prefetch_loop, daemon=True) | ||
| self._thread.start() | ||
|
|
||
| def _collect_batch(self): | ||
| """Collect ``samples_per_step`` samples from the rollout queue, filtering stale ones.""" | ||
| samples = [] | ||
| deadline = time.time() + self.timeout | ||
| while len(samples) < self.samples_per_step: | ||
| if self._stop_event.is_set(): | ||
| return None | ||
| sample = self._pull_one(deadline=deadline) | ||
| if sample is None: | ||
| break | ||
| samples.append(sample) | ||
| return samples or None | ||
|
|
||
| def _prefetch_loop(self): | ||
| """Background thread: continuously collect batches and enqueue them.""" | ||
| while not self._stop_event.is_set(): | ||
| batch = self._collect_batch() | ||
| if batch is None: | ||
| if self._stop_event.is_set(): | ||
| break | ||
| continue | ||
| try: | ||
| self._prefetch_queue.put(batch, timeout=5.0) | ||
| except queue.Full: | ||
| pass | ||
|
|
||
| def __iter__(self): | ||
| while True: | ||
| try: | ||
| batch = self._prefetch_queue.get(timeout=self.timeout) | ||
| except queue.Empty: | ||
| logger.warning("Prefetch queue empty, stopping epoch") | ||
| return | ||
| yield from batch | ||
|
|
||
| def stop(self): | ||
| """Stop the prefetch background thread.""" | ||
| self._stop_event.set() | ||
| self._thread.join(timeout=5.0) | ||
|
|
||
|
|
||
| class _EmptyIterableDataset(torch.utils.data.IterableDataset): | ||
| """Placeholder for non-rank-0 processes. Never actually iterated.""" | ||
|
|
@@ -395,12 +490,28 @@ def __init__( | |
|
|
||
| def get_train_dataloader(self) -> DataLoader: | ||
| if self.accelerator.is_main_process: | ||
| dataset = RolloutQueueDataset( | ||
| rollout_queue=self.rollout_queue, | ||
| model_version_fn=lambda: self.model_version, | ||
| max_staleness=self.args.max_staleness, | ||
| timeout=self.args.vllm_server_timeout, | ||
| ) | ||
| if self.args.use_prefetch: | ||
| samples_per_step = ( | ||
| self.args.per_device_train_batch_size | ||
| * self.args.gradient_accumulation_steps | ||
| * self.accelerator.num_processes | ||
| ) | ||
| dataset = PrefetchRolloutDataset( | ||
| rollout_queue=self.rollout_queue, | ||
| model_version_fn=lambda: self.model_version, | ||
| samples_per_step=samples_per_step, | ||
| max_staleness=self.args.max_staleness, | ||
| timeout=self.args.vllm_server_timeout, | ||
| prefetch_depth=self.args.prefetch_depth, | ||
| ) | ||
| self._prefetch_dataset = dataset | ||
| else: | ||
| dataset = RolloutQueueDataset( | ||
| rollout_queue=self.rollout_queue, | ||
| model_version_fn=lambda: self.model_version, | ||
| max_staleness=self.args.max_staleness, | ||
| timeout=self.args.vllm_server_timeout, | ||
| ) | ||
| else: | ||
| dataset = _EmptyIterableDataset() | ||
|
|
||
|
|
@@ -619,5 +730,7 @@ def _inner_training_loop(self, *args, **kwargs): | |
| try: | ||
| return super()._inner_training_loop(*args, **kwargs) | ||
| finally: | ||
| if hasattr(self, "_prefetch_dataset"): | ||
| self._prefetch_dataset.stop() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use of
|
||
| if self.accelerator.is_main_process and self.rollout_worker: | ||
| self.rollout_worker.stop() | ||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefetch loop silently drops collected batches without logging
Medium Severity
In
_prefetch_loop, when_prefetch_queue.putraisesqueue.Full, the collected batch is silently discarded viapass. These samples were already consumed from the rollout queue (which required GPU compute for generation and scoring) and are permanently lost. At minimum this warrants alogger.warningso users have visibility into data being dropped, especially since this can mask configuration issues withprefetch_depth.Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.