Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions trl/experimental/async_grpo/async_grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class AsyncGRPOConfig(_BaseConfig):
Maximum number of rollout samples to buffer in the rollout queue.
weight_sync_steps (`int`, *optional*, defaults to `1`):
Number of training steps between weight synchronizations to the vLLM server.
use_prefetch (`bool`, *optional*, defaults to `False`):
Use a background-thread prefetch for the rollout queue. When enabled, a background thread
batch-collects `samples_per_step` samples while the previous training step runs, eliminating
queue wait time from the critical path.
prefetch_depth (`int`, *optional*, defaults to `1`):
Number of batches to prefetch ahead when `use_prefetch=True`. Higher values keep training
saturated but increase off-policy staleness.

> Parameters that control the logging

Expand Down Expand Up @@ -184,6 +191,20 @@ class AsyncGRPOConfig(_BaseConfig):
default=1,
metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."},
)
use_prefetch: bool = field(
default=False,
metadata={
"help": "Use background-thread prefetch for the rollout queue (no DataProducer dependency). "
"Batch-collects samples_per_step samples in a background thread while training."
},
)
prefetch_depth: int = field(
default=1,
metadata={
"help": "Number of batches to prefetch ahead when use_prefetch=True. "
"Higher values keep training saturated but increase off-policy staleness."
},
)

# Parameters that control the logging
log_completions: bool = field(
Expand Down
153 changes: 133 additions & 20 deletions trl/experimental/async_grpo/async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import queue
import textwrap
import threading
import time
from collections import defaultdict
from collections.abc import Callable, Iterator
Expand Down Expand Up @@ -84,34 +85,128 @@ def __init__(self, rollout_queue, model_version_fn, max_staleness=3, timeout=120
self.max_staleness = max_staleness
self.timeout = timeout

def __iter__(self):
def _pull_one(self, deadline=None):
"""Pull a single non-stale sample from the rollout queue.

Blocks until a valid sample is available or the deadline expires. Stale samples (whose model
version lags the current version by more than ``max_staleness``) are silently dropped.

Args:
deadline (`float`, *optional*):
Absolute ``time.time()`` cutoff. When ``None``, a single attempt with ``self.timeout``
is made instead.

Returns:
A dict with keys ``input_ids``, ``completion_mask``, ``old_log_probs``, ``advantage``, and
``metrics``, or ``None`` if the deadline/timeout expired.
"""
while True:
remaining = max(1.0, deadline - time.time()) if deadline else self.timeout
t0 = time.time()
qsize = self.queue.qsize()
if qsize == 0:
logger.info("queue empty, waiting for rollout samples...")
try:
sample = self.queue.get(timeout=self.timeout)
sample = self.queue.get(timeout=min(remaining, 2.0))
except queue.Empty:
logger.warning(f"Rollout queue empty for {self.timeout}s, stopping epoch")
return # StopIteration ends epoch
queue_wait_time_s = time.time() - t0
if queue_wait_time_s > 1.0:
logger.info(f"waited {queue_wait_time_s:.1f}s for sample (qsize={self.queue.qsize()})")
if deadline and time.time() >= deadline:
return None
if not deadline:
return None
continue
wait = time.time() - t0

staleness = self.model_version_fn() - sample.model_version
if staleness > self.max_staleness:
logger.info(f"dropping stale sample (staleness={staleness}, max={self.max_staleness})")
continue # drop stale, pull next
logger.debug(f"dropping stale sample (staleness={staleness})")
continue

yield {
return {
"input_ids": sample.input_ids,
"completion_mask": sample.completion_mask,
"old_log_probs": sample.old_log_probs,
"advantage": sample.advantage,
"metrics": {**sample.metrics, "queue_wait_time_s": queue_wait_time_s},
"metrics": {**sample.metrics, "queue_wait_time_s": wait},
}

def __iter__(self):
while True:
sample = self._pull_one(deadline=time.time() + self.timeout)
if sample is None:
logger.warning(f"Rollout queue empty for {self.timeout}s, stopping epoch")
return
yield sample


class PrefetchRolloutDataset(RolloutQueueDataset):
"""Extends ``RolloutQueueDataset`` with background-thread batch prefetching.

A background thread continuously calls ``_pull_one`` to batch-collect ``samples_per_step``
samples and places them in a bounded prefetch queue. The training loop then drains pre-collected
batches with near-zero wait, overlapping queue collection with gradient computation.

Args:
rollout_queue: The queue of scored rollout samples from the ``AsyncRolloutWorker``.
model_version_fn: Callable returning the current model version for staleness filtering.
samples_per_step (`int`):
Number of samples to collect per training step
(``per_device_train_batch_size * gradient_accumulation_steps * num_processes``).
max_staleness (`int`, *optional*, defaults to `3`):
Maximum model version lag before a sample is dropped.
timeout (`float`, *optional*, defaults to `120.0`):
Seconds to wait for each individual sample.
prefetch_depth (`int`, *optional*, defaults to `1`):
Number of batches to prefetch ahead.
"""

def __init__(
self, rollout_queue, model_version_fn, samples_per_step, max_staleness=3, timeout=120.0, prefetch_depth=1
):
super().__init__(rollout_queue, model_version_fn, max_staleness, timeout)
self.samples_per_step = samples_per_step
self.prefetch_depth = prefetch_depth
self._prefetch_queue = queue.Queue(maxsize=prefetch_depth)
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._prefetch_loop, daemon=True)
self._thread.start()

def _collect_batch(self):
"""Collect ``samples_per_step`` samples from the rollout queue, filtering stale ones."""
samples = []
deadline = time.time() + self.timeout
while len(samples) < self.samples_per_step:
if self._stop_event.is_set():
return None
sample = self._pull_one(deadline=deadline)
if sample is None:
break
samples.append(sample)
return samples or None

def _prefetch_loop(self):
"""Background thread: continuously collect batches and enqueue them."""
while not self._stop_event.is_set():
batch = self._collect_batch()
if batch is None:
if self._stop_event.is_set():
break
continue
try:
self._prefetch_queue.put(batch, timeout=5.0)
except queue.Full:
pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefetch loop silently drops collected batches without logging

Medium Severity

In _prefetch_loop, when _prefetch_queue.put raises queue.Full, the collected batch is silently discarded via pass. These samples were already consumed from the rollout queue (which required GPU compute for generation and scoring) and are permanently lost. At minimum this warrants a logger.warning so users have visibility into data being dropped, especially since this can mask configuration issues with prefetch_depth.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.


def __iter__(self):
while True:
try:
batch = self._prefetch_queue.get(timeout=self.timeout)
except queue.Empty:
logger.warning("Prefetch queue empty, stopping epoch")
return
yield from batch

def stop(self):
"""Stop the prefetch background thread."""
self._stop_event.set()
self._thread.join(timeout=5.0)


class _EmptyIterableDataset(torch.utils.data.IterableDataset):
"""Placeholder for non-rank-0 processes. Never actually iterated."""
Expand Down Expand Up @@ -395,12 +490,28 @@ def __init__(

def get_train_dataloader(self) -> DataLoader:
if self.accelerator.is_main_process:
dataset = RolloutQueueDataset(
rollout_queue=self.rollout_queue,
model_version_fn=lambda: self.model_version,
max_staleness=self.args.max_staleness,
timeout=self.args.vllm_server_timeout,
)
if self.args.use_prefetch:
samples_per_step = (
self.args.per_device_train_batch_size
* self.args.gradient_accumulation_steps
* self.accelerator.num_processes
)
dataset = PrefetchRolloutDataset(
rollout_queue=self.rollout_queue,
model_version_fn=lambda: self.model_version,
samples_per_step=samples_per_step,
max_staleness=self.args.max_staleness,
timeout=self.args.vllm_server_timeout,
prefetch_depth=self.args.prefetch_depth,
)
self._prefetch_dataset = dataset
else:
dataset = RolloutQueueDataset(
rollout_queue=self.rollout_queue,
model_version_fn=lambda: self.model_version,
max_staleness=self.args.max_staleness,
timeout=self.args.vllm_server_timeout,
)
else:
dataset = _EmptyIterableDataset()

Expand Down Expand Up @@ -619,5 +730,7 @@ def _inner_training_loop(self, *args, **kwargs):
try:
return super()._inner_training_loop(*args, **kwargs)
finally:
if hasattr(self, "_prefetch_dataset"):
self._prefetch_dataset.stop()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of hasattr violates project simplicity rules

Low Severity

hasattr(self, "_prefetch_dataset") violates the AGENTS.md rule that explicitly says to avoid hasattr and getattr, calling them "almost always a symptom of overly defensive programming." The cleaner alternative is to initialize self._prefetch_dataset = None in __init__ (near line 442 where other instance variables are set) and then check if self._prefetch_dataset is not None in the finally block.

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Reviewed by Cursor Bugbot for commit 0b8e30a. Configure here.

if self.accelerator.is_main_process and self.rollout_worker:
self.rollout_worker.stop()