diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..04c7e38533f 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -73,6 +73,13 @@ class AsyncGRPOConfig(_BaseConfig): Maximum number of rollout samples to buffer in the rollout queue. weight_sync_steps (`int`, *optional*, defaults to `1`): Number of training steps between weight synchronizations to the vLLM server. + use_prefetch (`bool`, *optional*, defaults to `False`): + Use a background-thread prefetch for the rollout queue. When enabled, a background thread + batch-collects `samples_per_step` samples while the previous training step runs, eliminating + queue wait time from the critical path. + prefetch_depth (`int`, *optional*, defaults to `1`): + Number of batches to prefetch ahead when `use_prefetch=True`. Higher values keep training + saturated but increase off-policy staleness. > Parameters that control the logging @@ -184,6 +191,20 @@ class AsyncGRPOConfig(_BaseConfig): default=1, metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + use_prefetch: bool = field( + default=False, + metadata={ + "help": "Use background-thread prefetch for the rollout queue (no DataProducer dependency). " + "Batch-collects samples_per_step samples in a background thread while training." + }, + ) + prefetch_depth: int = field( + default=1, + metadata={ + "help": "Number of batches to prefetch ahead when use_prefetch=True. " + "Higher values keep training saturated but increase off-policy staleness." + }, + ) # Parameters that control the logging log_completions: bool = field( diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index a81dad5639f..a004510fa23 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -16,6 +16,7 @@ import math import queue import textwrap +import threading import time from collections import defaultdict from collections.abc import Callable, Iterator @@ -84,34 +85,128 @@ def __init__(self, rollout_queue, model_version_fn, max_staleness=3, timeout=120 self.max_staleness = max_staleness self.timeout = timeout - def __iter__(self): + def _pull_one(self, deadline=None): + """Pull a single non-stale sample from the rollout queue. + + Blocks until a valid sample is available or the deadline expires. Stale samples (whose model + version lags the current version by more than ``max_staleness``) are silently dropped. + + Args: + deadline (`float`, *optional*): + Absolute ``time.time()`` cutoff. When ``None``, a single attempt with ``self.timeout`` + is made instead. + + Returns: + A dict with keys ``input_ids``, ``completion_mask``, ``old_log_probs``, ``advantage``, and + ``metrics``, or ``None`` if the deadline/timeout expired. + """ while True: + remaining = max(1.0, deadline - time.time()) if deadline else self.timeout t0 = time.time() - qsize = self.queue.qsize() - if qsize == 0: - logger.info("queue empty, waiting for rollout samples...") try: - sample = self.queue.get(timeout=self.timeout) + sample = self.queue.get(timeout=min(remaining, 2.0)) except queue.Empty: - logger.warning(f"Rollout queue empty for {self.timeout}s, stopping epoch") - return # StopIteration ends epoch - queue_wait_time_s = time.time() - t0 - if queue_wait_time_s > 1.0: - logger.info(f"waited {queue_wait_time_s:.1f}s for sample (qsize={self.queue.qsize()})") + if deadline and time.time() >= deadline: + return None + if not deadline: + return None + continue + wait = time.time() - t0 staleness = self.model_version_fn() - sample.model_version if staleness > self.max_staleness: - logger.info(f"dropping stale sample (staleness={staleness}, max={self.max_staleness})") - continue # drop stale, pull next + logger.debug(f"dropping stale sample (staleness={staleness})") + continue - yield { + return { "input_ids": sample.input_ids, "completion_mask": sample.completion_mask, "old_log_probs": sample.old_log_probs, "advantage": sample.advantage, - "metrics": {**sample.metrics, "queue_wait_time_s": queue_wait_time_s}, + "metrics": {**sample.metrics, "queue_wait_time_s": wait}, } + def __iter__(self): + while True: + sample = self._pull_one(deadline=time.time() + self.timeout) + if sample is None: + logger.warning(f"Rollout queue empty for {self.timeout}s, stopping epoch") + return + yield sample + + +class PrefetchRolloutDataset(RolloutQueueDataset): + """Extends ``RolloutQueueDataset`` with background-thread batch prefetching. + + A background thread continuously calls ``_pull_one`` to batch-collect ``samples_per_step`` + samples and places them in a bounded prefetch queue. The training loop then drains pre-collected + batches with near-zero wait, overlapping queue collection with gradient computation. + + Args: + rollout_queue: The queue of scored rollout samples from the ``AsyncRolloutWorker``. + model_version_fn: Callable returning the current model version for staleness filtering. + samples_per_step (`int`): + Number of samples to collect per training step + (``per_device_train_batch_size * gradient_accumulation_steps * num_processes``). + max_staleness (`int`, *optional*, defaults to `3`): + Maximum model version lag before a sample is dropped. + timeout (`float`, *optional*, defaults to `120.0`): + Seconds to wait for each individual sample. + prefetch_depth (`int`, *optional*, defaults to `1`): + Number of batches to prefetch ahead. + """ + + def __init__( + self, rollout_queue, model_version_fn, samples_per_step, max_staleness=3, timeout=120.0, prefetch_depth=1 + ): + super().__init__(rollout_queue, model_version_fn, max_staleness, timeout) + self.samples_per_step = samples_per_step + self.prefetch_depth = prefetch_depth + self._prefetch_queue = queue.Queue(maxsize=prefetch_depth) + self._stop_event = threading.Event() + self._thread = threading.Thread(target=self._prefetch_loop, daemon=True) + self._thread.start() + + def _collect_batch(self): + """Collect ``samples_per_step`` samples from the rollout queue, filtering stale ones.""" + samples = [] + deadline = time.time() + self.timeout + while len(samples) < self.samples_per_step: + if self._stop_event.is_set(): + return None + sample = self._pull_one(deadline=deadline) + if sample is None: + break + samples.append(sample) + return samples or None + + def _prefetch_loop(self): + """Background thread: continuously collect batches and enqueue them.""" + while not self._stop_event.is_set(): + batch = self._collect_batch() + if batch is None: + if self._stop_event.is_set(): + break + continue + try: + self._prefetch_queue.put(batch, timeout=5.0) + except queue.Full: + pass + + def __iter__(self): + while True: + try: + batch = self._prefetch_queue.get(timeout=self.timeout) + except queue.Empty: + logger.warning("Prefetch queue empty, stopping epoch") + return + yield from batch + + def stop(self): + """Stop the prefetch background thread.""" + self._stop_event.set() + self._thread.join(timeout=5.0) + class _EmptyIterableDataset(torch.utils.data.IterableDataset): """Placeholder for non-rank-0 processes. Never actually iterated.""" @@ -395,12 +490,28 @@ def __init__( def get_train_dataloader(self) -> DataLoader: if self.accelerator.is_main_process: - dataset = RolloutQueueDataset( - rollout_queue=self.rollout_queue, - model_version_fn=lambda: self.model_version, - max_staleness=self.args.max_staleness, - timeout=self.args.vllm_server_timeout, - ) + if self.args.use_prefetch: + samples_per_step = ( + self.args.per_device_train_batch_size + * self.args.gradient_accumulation_steps + * self.accelerator.num_processes + ) + dataset = PrefetchRolloutDataset( + rollout_queue=self.rollout_queue, + model_version_fn=lambda: self.model_version, + samples_per_step=samples_per_step, + max_staleness=self.args.max_staleness, + timeout=self.args.vllm_server_timeout, + prefetch_depth=self.args.prefetch_depth, + ) + self._prefetch_dataset = dataset + else: + dataset = RolloutQueueDataset( + rollout_queue=self.rollout_queue, + model_version_fn=lambda: self.model_version, + max_staleness=self.args.max_staleness, + timeout=self.args.vllm_server_timeout, + ) else: dataset = _EmptyIterableDataset() @@ -619,5 +730,7 @@ def _inner_training_loop(self, *args, **kwargs): try: return super()._inner_training_loop(*args, **kwargs) finally: + if hasattr(self, "_prefetch_dataset"): + self._prefetch_dataset.stop() if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.stop()