From e584b51717e9294e4eb84ed41740cd4e0561759d Mon Sep 17 00:00:00 2001 From: rafapi Date: Thu, 16 Oct 2025 18:30:22 +0000 Subject: [PATCH 01/20] Simple curriculum --- pipelinerl/actor.py | 90 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 9ffabe7e..bd6badce 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -14,7 +14,7 @@ import aiohttp import hydra import uvloop -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM from typing import Dict, List @@ -260,6 +260,80 @@ def random_iter(problems: list): yield random.sample(problems, 1)[0] +def curriculum_iter( + problems: list, + trainer_state: TrainerState, + curriculum_cfg: DictConfig, + logger: logging.Logger | None = None, +): + curriculum_obj = OmegaConf.to_container(curriculum_cfg, resolve=True) if isinstance(curriculum_cfg, DictConfig) else curriculum_cfg + base_names = set(curriculum_obj.get("base_datasets", [])) + hard_names = set(curriculum_obj.get("hard_datasets", [])) + if hard_names and not base_names: + base_names = {problem.get("dataset") for problem in problems if problem.get("dataset") not in hard_names} + + base_pool = [ + problem + for problem in problems + if (problem.get("dataset") in base_names) or (not base_names and problem.get("dataset") not in hard_names) + ] + hard_pool = [problem for problem in problems if problem.get("dataset") in hard_names] + + if not hard_pool: + if logger: + logger.warning( + "Curriculum enabled but no problems matched hard_datasets list; falling back to base sampling" + ) + yield from random_iter(problems) + return + + if not base_pool: + if logger: + logger.warning("Curriculum enabled but base pool is empty; sampling exclusively from hard dataset") + base_pool = hard_pool + + schedule_cfg = curriculum_obj.get("schedule", []) + if not schedule_cfg: + schedule = [(0, 0.0)] + else: + if not isinstance(schedule_cfg, list): + schedule_cfg = [schedule_cfg] + schedule = [] + for entry in schedule_cfg: + step = int(entry.get("step", 0)) + hard_weight = float(entry.get("hard_weight", 0.0)) + hard_weight = max(0.0, min(1.0, hard_weight)) + schedule.append((step, hard_weight)) + schedule.sort(key=lambda item: item[0]) + + current_stage = -1 + + while True: + samples_processed = trainer_state.samples_processed or 0 + hard_weight = schedule[0][1] + stage_index = 0 + for idx, (step, weight) in enumerate(schedule): + if samples_processed >= step: + hard_weight = weight + stage_index = idx + else: + break + + if logger and stage_index != current_stage: + logger.info( + "Curriculum stage %d active (samples_processed=%d, hard_weight=%.3f)", + stage_index, + samples_processed, + hard_weight, + ) + current_stage = stage_index + + if hard_pool and random.random() < hard_weight: + yield random.choice(hard_pool) + else: + yield random.choice(base_pool) + + def sequential_iter(problems: list): for problem in problems: yield problem @@ -384,7 +458,19 @@ def run(self, dataset: list[tuple[str, dict]]): # for train sample, sample random batches infinitely # for test samples, loop through the dataset once if self.is_training: - problem_iter = random_iter(dataset) + curriculum_cfg = getattr(self.cfg.actor, "curriculum", None) + use_curriculum = bool( + curriculum_cfg and getattr(curriculum_cfg, "enabled", False) and dataset + ) + if use_curriculum: + problem_iter = curriculum_iter( + dataset, + trainer_state=self.trainer_state, + curriculum_cfg=curriculum_cfg, + logger=logger, + ) + else: + problem_iter = random_iter(dataset) else: problem_iter = sequential_iter(dataset) assert self.trainer_state.propagated_weight_version is not None From efa0d02d1a5e67974dbaa7c283d175eece502f66 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 17 Oct 2025 17:32:51 +0000 Subject: [PATCH 02/20] Adaptive curriculum schedule --- pipelinerl/actor.py | 216 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 195 insertions(+), 21 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index bd6badce..39ce5bb7 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -7,7 +7,7 @@ from queue import Empty import random import time -from collections import defaultdict +from collections import defaultdict, deque from multiprocessing.managers import SharedMemoryManager from pathlib import Path @@ -17,7 +17,7 @@ from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM -from typing import Dict, List +from typing import Dict, List, Optional import wandb from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb @@ -103,10 +103,100 @@ def get_stats(self): +class CurriculumSuccessTracker: + def __init__(self) -> None: + self._buffers: dict[str, deque[int]] = {} + self._max_windows: dict[str, int] = {} + self._total_counts: defaultdict[str, int] = defaultdict(int) + + def ensure_window(self, dataset: str, window: int) -> None: + if window <= 0: + window = 1 + current = self._max_windows.get(dataset, 0) + if window <= current: + return + existing = self._buffers.get(dataset, deque(maxlen=window)) + if existing.maxlen != window: + new_buffer = deque(existing, maxlen=window) + else: + new_buffer = existing + self._buffers[dataset] = new_buffer + self._max_windows[dataset] = window + + def update(self, dataset: str, success_values: list[int | bool]) -> None: + if not success_values: + return + buffer = self._buffers.get(dataset) + if buffer is None: + maxlen = self._max_windows.get(dataset, max(1, len(success_values))) + buffer = deque(maxlen=maxlen) + self._buffers[dataset] = buffer + self._max_windows[dataset] = maxlen + for value in success_values: + buffer.append(1 if bool(value) else 0) + self._total_counts[dataset] += 1 + + def success_mean(self, dataset: str, window: Optional[int] = None) -> Optional[float]: + buffer = self._buffers.get(dataset) + if buffer is None or not buffer: + return None + if window is None or window <= 0 or window >= len(buffer): + values = list(buffer) + else: + if len(buffer) < window: + return None + values = list(buffer)[-window:] + if not values: + return None + return sum(values) / len(values) + + def total_samples(self, dataset: str) -> int: + return self._total_counts.get(dataset, 0) + + + def make_stats_dict() -> dict: return defaultdict(lambda: defaultdict(list)) +def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: + raw_schedule = curriculum_cfg.get("schedule", []) + if not raw_schedule: + return [{"step": 0, "hard_weight": 0.0, "thresholds": []}] + if not isinstance(raw_schedule, list): + raw_schedule = [raw_schedule] + parsed_schedule: list[dict] = [] + for entry in raw_schedule: + step = int(entry.get("step", 0)) + hard_weight = float(entry.get("hard_weight", 0.0)) + hard_weight = max(0.0, min(1.0, hard_weight)) + thresholds_cfg = entry.get("success_thresholds", []) or [] + if not isinstance(thresholds_cfg, list): + thresholds_cfg = [thresholds_cfg] + thresholds: list[dict] = [] + for threshold_entry in thresholds_cfg: + dataset = threshold_entry.get("dataset") + if not dataset: + continue + threshold_value = float(threshold_entry.get("threshold", 1.0)) + window = int(threshold_entry.get("window", threshold_entry.get("window_size", 0) or 1)) + if window <= 0: + window = 1 + min_samples_value = threshold_entry.get("min_samples") + min_samples = int(min_samples_value) if min_samples_value is not None else None + thresholds.append( + { + "dataset": dataset, + "threshold": threshold_value, + "window": window, + "min_samples": min_samples, + } + ) + parsed_schedule.append({"step": step, "hard_weight": hard_weight, "thresholds": thresholds}) + parsed_schedule.sort(key=lambda item: item["step"]) + return parsed_schedule + + async def schedule_rollouts( cfg: DictConfig, attempts: int, @@ -265,8 +355,15 @@ def curriculum_iter( trainer_state: TrainerState, curriculum_cfg: DictConfig, logger: logging.Logger | None = None, + success_tracker: CurriculumSuccessTracker | None = None, + stage_state: Optional[dict] = None, + parsed_schedule: Optional[list[dict]] = None, ): - curriculum_obj = OmegaConf.to_container(curriculum_cfg, resolve=True) if isinstance(curriculum_cfg, DictConfig) else curriculum_cfg + curriculum_obj = ( + OmegaConf.to_container(curriculum_cfg, resolve=True) + if isinstance(curriculum_cfg, DictConfig) + else curriculum_cfg + ) base_names = set(curriculum_obj.get("base_datasets", [])) hard_names = set(curriculum_obj.get("hard_datasets", [])) if hard_names and not base_names: @@ -292,33 +389,82 @@ def curriculum_iter( logger.warning("Curriculum enabled but base pool is empty; sampling exclusively from hard dataset") base_pool = hard_pool - schedule_cfg = curriculum_obj.get("schedule", []) - if not schedule_cfg: - schedule = [(0, 0.0)] - else: - if not isinstance(schedule_cfg, list): - schedule_cfg = [schedule_cfg] - schedule = [] - for entry in schedule_cfg: - step = int(entry.get("step", 0)) - hard_weight = float(entry.get("hard_weight", 0.0)) - hard_weight = max(0.0, min(1.0, hard_weight)) - schedule.append((step, hard_weight)) - schedule.sort(key=lambda item: item[0]) + schedule = parsed_schedule or parse_curriculum_schedule(curriculum_obj) + if success_tracker: + for stage in schedule: + for threshold in stage["thresholds"]: + success_tracker.ensure_window(threshold["dataset"], threshold["window"]) + + def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: + if not stage_cfg["thresholds"] or success_tracker is None: + return True, [] + blockers: list[str] = [] + for threshold in stage_cfg["thresholds"]: + dataset = threshold["dataset"] + threshold_value = threshold["threshold"] + window = threshold["window"] + min_samples = threshold.get("min_samples") + if min_samples is not None: + total_samples = success_tracker.total_samples(dataset) + if total_samples < min_samples: + blockers.append( + f"{dataset}: waiting for {min_samples} samples (have {total_samples})" + ) + continue + success_mean_value = success_tracker.success_mean(dataset, window) + if success_mean_value is None: + blockers.append(f"{dataset}: insufficient window data (need {window})") + continue + if success_mean_value < threshold_value: + blockers.append( + f"{dataset}: success_mean {success_mean_value:.3f} < {threshold_value:.3f} (window={window})" + ) + return (len(blockers) == 0), blockers current_stage = -1 + last_block_log: tuple[int, tuple[str, ...]] | None = None + if stage_state is None: + stage_state = {"index": 0} while True: samples_processed = trainer_state.samples_processed or 0 - hard_weight = schedule[0][1] - stage_index = 0 - for idx, (step, weight) in enumerate(schedule): + desired_stage_index = 0 + hard_weight = schedule[0]["hard_weight"] + + for idx, stage_cfg in enumerate(schedule): + step = stage_cfg["step"] if samples_processed >= step: - hard_weight = weight - stage_index = idx + desired_stage_index = idx + hard_weight = stage_cfg["hard_weight"] else: break + stage_index = desired_stage_index + blocker_messages: list[str] = [] + while stage_index >= 0: + ready, blockers = stage_ready(schedule[stage_index]) + if ready: + blocker_messages = [] + break + blocker_messages = blockers + stage_index -= 1 + + if stage_index < 0: + stage_index = 0 + hard_weight = schedule[0]["hard_weight"] + else: + hard_weight = schedule[stage_index]["hard_weight"] + + if logger and desired_stage_index != stage_index and blocker_messages: + block_signature = (desired_stage_index, tuple(blocker_messages)) + if block_signature != last_block_log: + logger.info( + "Curriculum stage %d gated by: %s", + desired_stage_index, + "; ".join(blocker_messages), + ) + last_block_log = block_signature + if logger and stage_index != current_stage: logger.info( "Curriculum stage %d active (samples_processed=%d, hard_weight=%.3f)", @@ -328,6 +474,8 @@ def curriculum_iter( ) current_stage = stage_index + stage_state["index"] = stage_index + if hard_pool and random.random() < hard_weight: yield random.choice(hard_pool) else: @@ -359,6 +507,8 @@ def __init__( self.is_training = is_training self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) + self.curriculum_tracker: CurriculumSuccessTracker | None = None + self.curriculum_stage_state: dict | None = None # Determine the number of processes to use num_processes = min(self.cfg.actor.rollout_workers, len(self.llms)) @@ -426,8 +576,12 @@ def update_stats(self, rollout_results: List[RolloutResult]): for k, v in all_metrics.items(): if isinstance(v, list): self.stats[k][dataset_name][group_id] += v + if k == "success" and self.curriculum_tracker: + self.curriculum_tracker.update(dataset_name, v) elif isinstance(v, float) | isinstance(v, bool) | isinstance(v, int): self.stats[k][dataset_name][group_id].append(v) + if k == "success" and self.curriculum_tracker: + self.curriculum_tracker.update(dataset_name, [v]) else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") @@ -463,16 +617,34 @@ def run(self, dataset: list[tuple[str, dict]]): curriculum_cfg and getattr(curriculum_cfg, "enabled", False) and dataset ) if use_curriculum: + curriculum_obj = ( + OmegaConf.to_container(curriculum_cfg, resolve=True) + if isinstance(curriculum_cfg, DictConfig) + else curriculum_cfg + ) + parsed_schedule = parse_curriculum_schedule(curriculum_obj) + self.curriculum_tracker = CurriculumSuccessTracker() + for stage in parsed_schedule: + for threshold in stage["thresholds"]: + self.curriculum_tracker.ensure_window(threshold["dataset"], threshold["window"]) + self.curriculum_stage_state = {"index": 0} problem_iter = curriculum_iter( dataset, trainer_state=self.trainer_state, curriculum_cfg=curriculum_cfg, logger=logger, + success_tracker=self.curriculum_tracker, + stage_state=self.curriculum_stage_state, + parsed_schedule=parsed_schedule, ) else: problem_iter = random_iter(dataset) + self.curriculum_tracker = None + self.curriculum_stage_state = None else: problem_iter = sequential_iter(dataset) + self.curriculum_tracker = None + self.curriculum_stage_state = None assert self.trainer_state.propagated_weight_version is not None last_trainer_version = self.trainer_state.propagated_weight_version @@ -633,6 +805,8 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): stats |= loop_stats for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 + if self.curriculum_stage_state is not None: + stats["curriculum_stage_active"] = self.curriculum_stage_state.get("index", 0) if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) From edb72d5c06bae33526c5db5036a517f9b535dff3 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 17 Oct 2025 19:00:22 +0000 Subject: [PATCH 03/20] Fix staging --- pipelinerl/actor.py | 59 ++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 39ce5bb7..7d9100e1 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -429,41 +429,66 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: while True: samples_processed = trainer_state.samples_processed or 0 desired_stage_index = 0 - hard_weight = schedule[0]["hard_weight"] for idx, stage_cfg in enumerate(schedule): step = stage_cfg["step"] if samples_processed >= step: desired_stage_index = idx - hard_weight = stage_cfg["hard_weight"] else: break - stage_index = desired_stage_index - blocker_messages: list[str] = [] - while stage_index >= 0: - ready, blockers = stage_ready(schedule[stage_index]) + current_stage = int(stage_state.get("index", 0)) + if current_stage < 0: + current_stage = 0 + if current_stage >= len(schedule): + current_stage = len(schedule) - 1 + + stage_index = min(current_stage, desired_stage_index) + promotion_blockers: list[str] = [] + + # Walk backwards until the current stage is ready (or we reach stage 0) + while stage_index > 0: + ready, _ = stage_ready(schedule[stage_index]) if ready: - blocker_messages = [] break - blocker_messages = blockers stage_index -= 1 - if stage_index < 0: - stage_index = 0 - hard_weight = schedule[0]["hard_weight"] - else: - hard_weight = schedule[stage_index]["hard_weight"] + ready, current_blockers = stage_ready(schedule[stage_index]) + if not ready and stage_index > 0: + # If even after walking back we are not ready, fall back further until 0 + while stage_index > 0 and not ready: + stage_index -= 1 + ready, current_blockers = stage_ready(schedule[stage_index]) + + # Attempt to promote by at most one stage towards the desired stage + if stage_index < desired_stage_index: + next_index = stage_index + 1 + next_ready, blockers = stage_ready(schedule[next_index]) + if next_ready: + stage_index = next_index + current_blockers = [] + else: + promotion_blockers = blockers + + hard_weight = schedule[stage_index]["hard_weight"] + + blockers_for_log: list[str] = [] + block_stage: int | None = None + if stage_index < desired_stage_index: + blockers_for_log = promotion_blockers or current_blockers + block_stage = stage_index + 1 - if logger and desired_stage_index != stage_index and blocker_messages: - block_signature = (desired_stage_index, tuple(blocker_messages)) + if logger and block_stage is not None and blockers_for_log: + block_signature = (block_stage, tuple(blockers_for_log)) if block_signature != last_block_log: logger.info( "Curriculum stage %d gated by: %s", - desired_stage_index, - "; ".join(blocker_messages), + block_stage, + "; ".join(blockers_for_log), ) last_block_log = block_signature + elif stage_index >= desired_stage_index: + last_block_log = None if logger and stage_index != current_stage: logger.info( From 2438af8264e1822ff3eb0e877ca476df7976ea2c Mon Sep 17 00:00:00 2001 From: rafapi Date: Sat, 18 Oct 2025 18:46:17 +0000 Subject: [PATCH 04/20] Fix overlong_penalty call --- pipelinerl/domains/math/rollouts.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 4da0b753..862375d0 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -192,7 +192,11 @@ async def generate_math_rollout( reward *= discount_factor**llm_call.output_length_tokens overlong_penalty = 0 if reward_table.buffer_tokens > 0: - overlong_penalty = length_penalty(llm.parameters['max_tokens'], llm_call.output_length_tokens, rewards.buffer_tokens) + overlong_penalty = length_penalty( + llm.parameters["max_tokens"], + llm_call.output_length_tokens, + reward_table.buffer_tokens, + ) reward += overlong_penalty trace.reward = reward From 984c28491d11e02c31e45698e73f3f25e6362b60 Mon Sep 17 00:00:00 2001 From: rafapi Date: Sat, 18 Oct 2025 18:49:50 +0000 Subject: [PATCH 05/20] Compute num tokens in result --- pipelinerl/actor.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 93d083de..be5a5b58 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -52,6 +52,22 @@ def save_debug_line(data:dict): with open(fname, "a") as f: f.write(json.dumps(data, ensure_ascii=False) + "\n") +def get_number_of_tokens_in_result(result: RolloutResult) -> int: + """Aggregate prompt + output tokens for all training texts in a rollout result.""" + total_tokens = 0 + for training_text in result.training_texts: + prompt_tokens = getattr(training_text, "prompt_tokens", 0) or 0 + output_tokens = getattr(training_text, "output_tokens", 0) or 0 + if prompt_tokens or output_tokens: + total_tokens += prompt_tokens + output_tokens + continue + input_ids = getattr(training_text, "input_ids", None) + if input_ids: + total_tokens += len(input_ids) + continue + total_tokens += getattr(training_text, "n_predicted", 0) or 0 + return total_tokens + class SlidingWindowData(BaseModel): prompt_tokens_window: list[list[int]] = Field( default_factory=list, @@ -153,12 +169,12 @@ def success_mean(self, dataset: str, window: Optional[int] = None) -> Optional[f buffer = self._buffers.get(dataset) if buffer is None or not buffer: return None - if window is None or window <= 0 or window >= len(buffer): + if window is None or window <= 0: values = list(buffer) else: - if len(buffer) < window: - return None values = list(buffer)[-window:] + if len(values) < window: + return None if not values: return None return sum(values) / len(values) From 72baea6f82a9134b4ec186df1141b217af64e3a0 Mon Sep 17 00:00:00 2001 From: rafapi Date: Sun, 19 Oct 2025 11:28:18 +0000 Subject: [PATCH 06/20] Implement sliding stats for smooth harden, just during training --- pipelinerl/actor.py | 147 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 121 insertions(+), 26 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index be5a5b58..64fbcf2c 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -199,6 +199,19 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: step = int(entry.get("step", 0)) hard_weight = float(entry.get("hard_weight", 0.0)) hard_weight = max(0.0, min(1.0, hard_weight)) + medium_weight_value = entry.get("medium_weight", 0.0) + try: + medium_weight = float(medium_weight_value) + except (TypeError, ValueError): + medium_weight = 0.0 + medium_weight = max(0.0, min(1.0, medium_weight)) + demotion_patience_value = entry.get("demotion_patience", 1) + try: + demotion_patience = int(demotion_patience_value) + except (TypeError, ValueError): + demotion_patience = 1 + if demotion_patience < 1: + demotion_patience = 1 thresholds_cfg = entry.get("success_thresholds", []) or [] if not isinstance(thresholds_cfg, list): thresholds_cfg = [thresholds_cfg] @@ -221,7 +234,15 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: "min_samples": min_samples, } ) - parsed_schedule.append({"step": step, "hard_weight": hard_weight, "thresholds": thresholds}) + parsed_schedule.append( + { + "step": step, + "hard_weight": hard_weight, + "medium_weight": medium_weight, + "thresholds": thresholds, + "demotion_patience": demotion_patience, + } + ) parsed_schedule.sort(key=lambda item: item["step"]) return parsed_schedule @@ -404,15 +425,26 @@ def curriculum_iter( else curriculum_cfg ) base_names = set(curriculum_obj.get("base_datasets", [])) + medium_names = set(curriculum_obj.get("medium_datasets", [])) hard_names = set(curriculum_obj.get("hard_datasets", [])) if hard_names and not base_names: - base_names = {problem.get("dataset") for problem in problems if problem.get("dataset") not in hard_names} + base_names = { + problem.get("dataset") + for problem in problems + if problem.get("dataset") not in hard_names and problem.get("dataset") not in medium_names + } base_pool = [ problem for problem in problems - if (problem.get("dataset") in base_names) or (not base_names and problem.get("dataset") not in hard_names) + if (problem.get("dataset") in base_names) + or ( + not base_names + and problem.get("dataset") not in hard_names + and (not medium_names or problem.get("dataset") not in medium_names) + ) ] + medium_pool = [problem for problem in problems if problem.get("dataset") in medium_names] hard_pool = [problem for problem in problems if problem.get("dataset") in hard_names] if not hard_pool: @@ -423,10 +455,19 @@ def curriculum_iter( yield from random_iter(problems) return + if medium_names and not medium_pool and logger: + logger.warning( + "Curriculum medium_datasets specified but no problems matched; medium weighting will be ignored" + ) + if not base_pool: if logger: - logger.warning("Curriculum enabled but base pool is empty; sampling exclusively from hard dataset") - base_pool = hard_pool + logger.warning("Curriculum enabled but base pool is empty; falling back to medium or hard datasets") + if medium_pool: + base_pool = list(medium_pool) + medium_pool = [] + else: + base_pool = hard_pool schedule = parsed_schedule or parse_curriculum_schedule(curriculum_obj) if success_tracker: @@ -434,10 +475,11 @@ def curriculum_iter( for threshold in stage["thresholds"]: success_tracker.ensure_window(threshold["dataset"], threshold["window"]) - def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: + def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: if not stage_cfg["thresholds"] or success_tracker is None: - return True, [] + return True, [], False blockers: list[str] = [] + threshold_blocked = False for threshold in stage_cfg["thresholds"]: dataset = threshold["dataset"] threshold_value = threshold["threshold"] @@ -455,15 +497,17 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: blockers.append(f"{dataset}: insufficient window data (need {window})") continue if success_mean_value < threshold_value: + threshold_blocked = True blockers.append( f"{dataset}: success_mean {success_mean_value:.3f} < {threshold_value:.3f} (window={window})" ) - return (len(blockers) == 0), blockers + return (len(blockers) == 0), blockers, threshold_blocked current_stage = -1 last_block_log: tuple[int, tuple[str, ...]] | None = None if stage_state is None: stage_state = {"index": 0} + stage_state.setdefault("consecutive_failures", {}) while True: samples_processed = trainer_state.samples_processed or 0 @@ -481,41 +525,76 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: current_stage = 0 if current_stage >= len(schedule): current_stage = len(schedule) - 1 + prev_stage = current_stage stage_index = min(current_stage, desired_stage_index) promotion_blockers: list[str] = [] # Walk backwards until the current stage is ready (or we reach stage 0) while stage_index > 0: - ready, _ = stage_ready(schedule[stage_index]) + ready, _, _ = stage_ready(schedule[stage_index]) if ready: break stage_index -= 1 - ready, current_blockers = stage_ready(schedule[stage_index]) + ready, current_blockers, _ = stage_ready(schedule[stage_index]) if not ready and stage_index > 0: # If even after walking back we are not ready, fall back further until 0 while stage_index > 0 and not ready: stage_index -= 1 - ready, current_blockers = stage_ready(schedule[stage_index]) + ready, current_blockers, _ = stage_ready(schedule[stage_index]) # Attempt to promote by at most one stage towards the desired stage if stage_index < desired_stage_index: next_index = stage_index + 1 - next_ready, blockers = stage_ready(schedule[next_index]) + next_ready, blockers, _ = stage_ready(schedule[next_index]) if next_ready: stage_index = next_index current_blockers = [] else: promotion_blockers = blockers - - hard_weight = schedule[stage_index]["hard_weight"] + promotion_block_stage: int | None = None + promotion_blockers_for_log: list[str] = [] + if stage_index < desired_stage_index: + promotion_block_stage = stage_index + 1 + promotion_blockers_for_log = promotion_blockers or current_blockers blockers_for_log: list[str] = [] block_stage: int | None = None - if stage_index < desired_stage_index: - blockers_for_log = promotion_blockers or current_blockers - block_stage = stage_index + 1 + failure_counts: dict[int, int] = stage_state.setdefault("consecutive_failures", {}) + demotion_cancelled = False + if prev_stage > stage_index: + _, prev_blockers, prev_threshold_blocked = stage_ready(schedule[prev_stage]) + patience = schedule[prev_stage].get("demotion_patience", 1) + if prev_threshold_blocked and patience > 1: + failures = failure_counts.get(prev_stage, 0) + 1 + if failures < patience: + failure_counts[prev_stage] = failures + stage_index = prev_stage + demotion_cancelled = True + block_stage = prev_stage + blockers_for_log = prev_blockers + else: + failure_counts[prev_stage] = 0 + else: + failure_counts[prev_stage] = 0 + else: + failure_counts.setdefault(prev_stage, 0) + failure_counts[prev_stage] = 0 + + if not demotion_cancelled: + failure_counts.setdefault(stage_index, 0) + if stage_index != prev_stage: + failure_counts[stage_index] = 0 + + hard_weight = schedule[stage_index]["hard_weight"] + medium_weight = schedule[stage_index].get("medium_weight", 0.0) + if not medium_pool: + medium_weight = 0.0 + + if block_stage is None and promotion_block_stage is not None: + block_stage = promotion_block_stage + blockers_for_log = promotion_blockers_for_log if logger and block_stage is not None and blockers_for_log: block_signature = (block_stage, tuple(blockers_for_log)) @@ -540,8 +619,11 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: stage_state["index"] = stage_index - if hard_pool and random.random() < hard_weight: + choice = random.random() + if hard_pool and choice < hard_weight: yield random.choice(hard_pool) + elif medium_pool and choice < hard_weight + medium_weight: + yield random.choice(medium_pool) else: yield random.choice(base_pool) @@ -564,7 +646,11 @@ def __init__( self.data_stream = data_stream self.trainer_state = trainer_state self.stats_stream = stats_stream - self.sliding_aggregator = SlidingWindowAggregator(window_size=cfg.actor.throughput_window_size) + self.sliding_aggregator = None + if is_training: + self.sliding_aggregator = SlidingWindowAggregator( + window_size=cfg.actor.throughput_window_size + ) self.llms = llms self.loop_start_time = -1 self.cfg: DictConfig = cfg @@ -655,13 +741,22 @@ def update_stats(self, rollout_results: List[RolloutResult]): else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") - prompt_length_tokens = [training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts] - output_length_tokens = [training_text.output_tokens for result in rollout_results for training_text in result.training_texts] - self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) - sliding_window_stats = self.sliding_aggregator.get_stats() - if sliding_window_stats is not None: - for k, v in sliding_window_stats.items(): - self.sliding_stats[k].append(v) + if self.sliding_aggregator: + prompt_length_tokens = [ + training_text.prompt_tokens + for result in rollout_results + for training_text in result.training_texts + ] + output_length_tokens = [ + training_text.output_tokens + for result in rollout_results + for training_text in result.training_texts + ] + self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) + sliding_window_stats = self.sliding_aggregator.get_stats() + if sliding_window_stats is not None: + for k, v in sliding_window_stats.items(): + self.sliding_stats[k].append(v) From e7c46d543be916f6aac92457448c3a325ed7a018 Mon Sep 17 00:00:00 2001 From: rafapi Date: Mon, 20 Oct 2025 12:17:10 +0000 Subject: [PATCH 07/20] Improve cooling --- pipelinerl/actor.py | 187 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 174 insertions(+), 13 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 64fbcf2c..e309bc20 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -194,6 +194,22 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: return [{"step": 0, "hard_weight": 0.0, "thresholds": []}] if not isinstance(raw_schedule, list): raw_schedule = [raw_schedule] + default_cooldown = curriculum_cfg.get("default_promotion_cooldown_samples", 8000) + try: + default_cooldown = int(default_cooldown) + except (TypeError, ValueError): + default_cooldown = 8000 + if default_cooldown < 0: + default_cooldown = 0 + + default_hysteresis = curriculum_cfg.get("default_threshold_hysteresis", 0.02) + try: + default_hysteresis = float(default_hysteresis) + except (TypeError, ValueError): + default_hysteresis = 0.02 + if default_hysteresis < 0.0: + default_hysteresis = 0.0 + parsed_schedule: list[dict] = [] for entry in raw_schedule: step = int(entry.get("step", 0)) @@ -205,6 +221,12 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: except (TypeError, ValueError): medium_weight = 0.0 medium_weight = max(0.0, min(1.0, medium_weight)) + weight_sum = medium_weight + hard_weight + max_non_base = 0.85 + if weight_sum > max_non_base and weight_sum > 0: + scale = max_non_base / weight_sum + medium_weight *= scale + hard_weight *= scale demotion_patience_value = entry.get("demotion_patience", 1) try: demotion_patience = int(demotion_patience_value) @@ -212,6 +234,21 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: demotion_patience = 1 if demotion_patience < 1: demotion_patience = 1 + cooldown_value = entry.get( + "promotion_cooldown_samples", entry.get("cooldown_samples", default_cooldown) + ) + try: + promotion_cooldown_samples = int(cooldown_value) + except (TypeError, ValueError): + promotion_cooldown_samples = default_cooldown + if promotion_cooldown_samples < 0: + promotion_cooldown_samples = 0 + hysteresis_value = entry.get("threshold_hysteresis", default_hysteresis) + try: + threshold_hysteresis = float(hysteresis_value) + except (TypeError, ValueError): + threshold_hysteresis = default_hysteresis + threshold_hysteresis = max(0.0, threshold_hysteresis) thresholds_cfg = entry.get("success_thresholds", []) or [] if not isinstance(thresholds_cfg, list): thresholds_cfg = [thresholds_cfg] @@ -241,6 +278,8 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: "medium_weight": medium_weight, "thresholds": thresholds, "demotion_patience": demotion_patience, + "promotion_cooldown_samples": promotion_cooldown_samples, + "threshold_hysteresis": threshold_hysteresis, } ) parsed_schedule.sort(key=lambda item: item["step"]) @@ -475,39 +514,91 @@ def curriculum_iter( for threshold in stage["thresholds"]: success_tracker.ensure_window(threshold["dataset"], threshold["window"]) - def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: + def stage_ready(stage_cfg: dict, relaxation: float = 0.0) -> tuple[bool, list[str], bool, list[dict]]: if not stage_cfg["thresholds"] or success_tracker is None: - return True, [], False + return True, [], False, [] blockers: list[str] = [] threshold_blocked = False + stats: list[dict] = [] for threshold in stage_cfg["thresholds"]: dataset = threshold["dataset"] threshold_value = threshold["threshold"] window = threshold["window"] min_samples = threshold.get("min_samples") + total_samples = success_tracker.total_samples(dataset) if min_samples is not None: - total_samples = success_tracker.total_samples(dataset) if total_samples < min_samples: blockers.append( f"{dataset}: waiting for {min_samples} samples (have {total_samples})" ) + stats.append( + { + "dataset": dataset, + "success_mean": None, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "min_samples", + } + ) continue success_mean_value = success_tracker.success_mean(dataset, window) if success_mean_value is None: blockers.append(f"{dataset}: insufficient window data (need {window})") + stats.append( + { + "dataset": dataset, + "success_mean": None, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "insufficient_window", + } + ) continue - if success_mean_value < threshold_value: + adjusted_threshold = threshold_value - relaxation + if success_mean_value < adjusted_threshold: threshold_blocked = True blockers.append( - f"{dataset}: success_mean {success_mean_value:.3f} < {threshold_value:.3f} (window={window})" + f"{dataset}: success_mean {success_mean_value:.3f} < {adjusted_threshold:.3f} (threshold={threshold_value:.3f}, relaxation={relaxation:.3f}, window={window})" ) - return (len(blockers) == 0), blockers, threshold_blocked + stats.append( + { + "dataset": dataset, + "success_mean": success_mean_value, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "threshold", + } + ) + continue + stats.append( + { + "dataset": dataset, + "success_mean": success_mean_value, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "ok", + } + ) + return (len(blockers) == 0), blockers, threshold_blocked, stats current_stage = -1 last_block_log: tuple[int, tuple[str, ...]] | None = None if stage_state is None: stage_state = {"index": 0} stage_state.setdefault("consecutive_failures", {}) + stage_state.setdefault("last_promotion_samples", -math.inf) while True: samples_processed = trainer_state.samples_processed or 0 @@ -529,42 +620,86 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: stage_index = min(current_stage, desired_stage_index) promotion_blockers: list[str] = [] + promotion_stats_for_log: list[dict] = [] # Walk backwards until the current stage is ready (or we reach stage 0) while stage_index > 0: - ready, _, _ = stage_ready(schedule[stage_index]) + ready, _, _, _ = stage_ready(schedule[stage_index], relaxation=0.0) if ready: break stage_index -= 1 - ready, current_blockers, _ = stage_ready(schedule[stage_index]) + ready, current_blockers, _, current_stats = stage_ready( + schedule[stage_index], relaxation=0.0 + ) if not ready and stage_index > 0: # If even after walking back we are not ready, fall back further until 0 while stage_index > 0 and not ready: stage_index -= 1 - ready, current_blockers, _ = stage_ready(schedule[stage_index]) + ready, current_blockers, _, current_stats = stage_ready( + schedule[stage_index], relaxation=0.0 + ) # Attempt to promote by at most one stage towards the desired stage if stage_index < desired_stage_index: next_index = stage_index + 1 - next_ready, blockers, _ = stage_ready(schedule[next_index]) + next_ready, blockers, _, next_stats = stage_ready( + schedule[next_index], relaxation=0.0 + ) if next_ready: stage_index = next_index current_blockers = [] + current_stats = next_stats else: promotion_blockers = blockers + promotion_stats_for_log = next_stats promotion_block_stage: int | None = None promotion_blockers_for_log: list[str] = [] if stage_index < desired_stage_index: promotion_block_stage = stage_index + 1 promotion_blockers_for_log = promotion_blockers or current_blockers + if promotion_blockers_for_log: + if not promotion_stats_for_log: + promotion_stats_for_log = current_stats + else: + promotion_block_stage = None + + candidate_stage_index = stage_index + if candidate_stage_index > prev_stage + 1: + candidate_stage_index = prev_stage + 1 + + cooldown_blockers: list[str] = [] + cooldown_stats: list[dict] = [] + last_promotion_samples = stage_state.get("last_promotion_samples", -math.inf) + if candidate_stage_index > prev_stage: + cooldown_required = schedule[candidate_stage_index].get( + "promotion_cooldown_samples", 0 + ) + samples_since_promotion = samples_processed - last_promotion_samples + if samples_since_promotion < cooldown_required: + cooldown_blockers = [ + ( + f"promotion cooldown active: {samples_since_promotion} / " + f"{cooldown_required} samples since last promotion" + ) + ] + cooldown_stats = current_stats + candidate_stage_index = prev_stage + else: + stage_state["last_promotion_samples"] = samples_processed + + stage_index = candidate_stage_index blockers_for_log: list[str] = [] block_stage: int | None = None failure_counts: dict[int, int] = stage_state.setdefault("consecutive_failures", {}) + demotion_stats_for_log: list[dict] = [] demotion_cancelled = False if prev_stage > stage_index: - _, prev_blockers, prev_threshold_blocked = stage_ready(schedule[prev_stage]) + _, prev_blockers, prev_threshold_blocked, prev_stats = stage_ready( + schedule[prev_stage], + relaxation=schedule[prev_stage].get("threshold_hysteresis", 0.0), + ) patience = schedule[prev_stage].get("demotion_patience", 1) if prev_threshold_blocked and patience > 1: failures = failure_counts.get(prev_stage, 0) + 1 @@ -574,6 +709,8 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: demotion_cancelled = True block_stage = prev_stage blockers_for_log = prev_blockers + promotion_stats_for_log = prev_stats + demotion_stats_for_log = prev_stats else: failure_counts[prev_stage] = 0 else: @@ -592,27 +729,51 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: if not medium_pool: medium_weight = 0.0 + stats_for_log: list[dict] = [] if block_stage is None and promotion_block_stage is not None: block_stage = promotion_block_stage blockers_for_log = promotion_blockers_for_log + stats_for_log = promotion_stats_for_log + if block_stage is None and cooldown_blockers: + block_stage = prev_stage + 1 if prev_stage + 1 < len(schedule) else prev_stage + blockers_for_log = cooldown_blockers + stats_for_log = cooldown_stats + if not stats_for_log and demotion_stats_for_log: + stats_for_log = demotion_stats_for_log if logger and block_stage is not None and blockers_for_log: block_signature = (block_stage, tuple(blockers_for_log)) if block_signature != last_block_log: + stats_desc = "" + if stats_for_log: + formatted = [] + for stat in stats_for_log: + mean_val = stat.get("success_mean") + mean_str = f"{mean_val:.3f}" if mean_val is not None else "n/a" + formatted.append( + f"{stat.get('dataset')}: mean={mean_str}, thr={stat.get('threshold', 0.0):.3f}, " + f"rel={stat.get('relaxation', 0.0):.3f}, window={stat.get('window')}, " + f"samples={stat.get('total_samples')}, status={stat.get('status', 'n/a')}" + ) + stats_desc = " | stats: " + "; ".join(formatted) logger.info( - "Curriculum stage %d gated by: %s", + "Curriculum stage %d gated by: %s%s", block_stage, "; ".join(blockers_for_log), + stats_desc, ) last_block_log = block_signature elif stage_index >= desired_stage_index: last_block_log = None if logger and stage_index != current_stage: + base_weight = max(0.0, 1.0 - hard_weight - medium_weight) logger.info( - "Curriculum stage %d active (samples_processed=%d, hard_weight=%.3f)", + "Curriculum stage %d active (samples_processed=%d, base=%.3f, medium=%.3f, hard=%.3f)", stage_index, samples_processed, + base_weight, + medium_weight, hard_weight, ) current_stage = stage_index From a9d3c70b1c961b6be34defb77a0a8e9162b6a8b8 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:40:17 +0000 Subject: [PATCH 08/20] simplify stage tracking --- pipelinerl/actor.py | 567 ++++++++++++-------------------------------- 1 file changed, 156 insertions(+), 411 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index e309bc20..88bfc193 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -7,11 +7,11 @@ import queue import random import time -from collections import defaultdict, deque +from collections import Counter, defaultdict from multiprocessing.managers import SharedMemoryManager from pathlib import Path from queue import Empty -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional import aiohttp import hydra @@ -21,8 +21,6 @@ from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM -from typing import Dict, List, Optional - import wandb from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb from pipelinerl.rollouts import BaseMetrics, RolloutResult @@ -132,58 +130,6 @@ def get_stats(self): -class CurriculumSuccessTracker: - def __init__(self) -> None: - self._buffers: dict[str, deque[int]] = {} - self._max_windows: dict[str, int] = {} - self._total_counts: defaultdict[str, int] = defaultdict(int) - - def ensure_window(self, dataset: str, window: int) -> None: - if window <= 0: - window = 1 - current = self._max_windows.get(dataset, 0) - if window <= current: - return - existing = self._buffers.get(dataset, deque(maxlen=window)) - if existing.maxlen != window: - new_buffer = deque(existing, maxlen=window) - else: - new_buffer = existing - self._buffers[dataset] = new_buffer - self._max_windows[dataset] = window - - def update(self, dataset: str, success_values: list[int | bool]) -> None: - if not success_values: - return - buffer = self._buffers.get(dataset) - if buffer is None: - maxlen = self._max_windows.get(dataset, max(1, len(success_values))) - buffer = deque(maxlen=maxlen) - self._buffers[dataset] = buffer - self._max_windows[dataset] = maxlen - for value in success_values: - buffer.append(1 if bool(value) else 0) - self._total_counts[dataset] += 1 - - def success_mean(self, dataset: str, window: Optional[int] = None) -> Optional[float]: - buffer = self._buffers.get(dataset) - if buffer is None or not buffer: - return None - if window is None or window <= 0: - values = list(buffer) - else: - values = list(buffer)[-window:] - if len(values) < window: - return None - if not values: - return None - return sum(values) / len(values) - - def total_samples(self, dataset: str) -> int: - return self._total_counts.get(dataset, 0) - - - def make_stats_dict() -> dict: return defaultdict(lambda: defaultdict(list)) @@ -191,28 +137,14 @@ def make_stats_dict() -> dict: def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: raw_schedule = curriculum_cfg.get("schedule", []) if not raw_schedule: - return [{"step": 0, "hard_weight": 0.0, "thresholds": []}] + return [{"step": 0, "medium_weight": 0.0, "hard_weight": 0.0}] if not isinstance(raw_schedule, list): raw_schedule = [raw_schedule] - default_cooldown = curriculum_cfg.get("default_promotion_cooldown_samples", 8000) - try: - default_cooldown = int(default_cooldown) - except (TypeError, ValueError): - default_cooldown = 8000 - if default_cooldown < 0: - default_cooldown = 0 - - default_hysteresis = curriculum_cfg.get("default_threshold_hysteresis", 0.02) - try: - default_hysteresis = float(default_hysteresis) - except (TypeError, ValueError): - default_hysteresis = 0.02 - if default_hysteresis < 0.0: - default_hysteresis = 0.0 - parsed_schedule: list[dict] = [] for entry in raw_schedule: step = int(entry.get("step", 0)) + if step < 0: + step = 0 hard_weight = float(entry.get("hard_weight", 0.0)) hard_weight = max(0.0, min(1.0, hard_weight)) medium_weight_value = entry.get("medium_weight", 0.0) @@ -222,70 +154,106 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: medium_weight = 0.0 medium_weight = max(0.0, min(1.0, medium_weight)) weight_sum = medium_weight + hard_weight - max_non_base = 0.85 - if weight_sum > max_non_base and weight_sum > 0: - scale = max_non_base / weight_sum + if weight_sum > 1.0 and weight_sum > 0.0: + scale = 1.0 / weight_sum medium_weight *= scale hard_weight *= scale - demotion_patience_value = entry.get("demotion_patience", 1) - try: - demotion_patience = int(demotion_patience_value) - except (TypeError, ValueError): - demotion_patience = 1 - if demotion_patience < 1: - demotion_patience = 1 - cooldown_value = entry.get( - "promotion_cooldown_samples", entry.get("cooldown_samples", default_cooldown) - ) - try: - promotion_cooldown_samples = int(cooldown_value) - except (TypeError, ValueError): - promotion_cooldown_samples = default_cooldown - if promotion_cooldown_samples < 0: - promotion_cooldown_samples = 0 - hysteresis_value = entry.get("threshold_hysteresis", default_hysteresis) - try: - threshold_hysteresis = float(hysteresis_value) - except (TypeError, ValueError): - threshold_hysteresis = default_hysteresis - threshold_hysteresis = max(0.0, threshold_hysteresis) - thresholds_cfg = entry.get("success_thresholds", []) or [] - if not isinstance(thresholds_cfg, list): - thresholds_cfg = [thresholds_cfg] - thresholds: list[dict] = [] - for threshold_entry in thresholds_cfg: - dataset = threshold_entry.get("dataset") - if not dataset: + ready_success_cfg = entry.get("ready_success") or [] + if not isinstance(ready_success_cfg, list): + ready_success_cfg = [ready_success_cfg] + ready_success: list[dict] = [] + for cond in ready_success_cfg: + if not isinstance(cond, dict): continue - threshold_value = float(threshold_entry.get("threshold", 1.0)) - window = int(threshold_entry.get("window", threshold_entry.get("window_size", 0) or 1)) - if window <= 0: - window = 1 - min_samples_value = threshold_entry.get("min_samples") - min_samples = int(min_samples_value) if min_samples_value is not None else None - thresholds.append( + dataset = cond.get("dataset") + metric = cond.get("metric", "success_mean") + try: + threshold = float(cond.get("threshold", 1.0)) + except (TypeError, ValueError): + threshold = 1.0 + ready_success.append( { "dataset": dataset, - "threshold": threshold_value, - "window": window, - "min_samples": min_samples, + "metric": metric, + "threshold": threshold, } ) + patience_value = entry.get("ready_patience", 1) + try: + ready_patience = max(1, int(patience_value)) + except (TypeError, ValueError): + ready_patience = 1 parsed_schedule.append( { "step": step, - "hard_weight": hard_weight, "medium_weight": medium_weight, - "thresholds": thresholds, - "demotion_patience": demotion_patience, - "promotion_cooldown_samples": promotion_cooldown_samples, - "threshold_hysteresis": threshold_hysteresis, + "hard_weight": hard_weight, + "ready_success": ready_success, + "ready_patience": ready_patience, } ) parsed_schedule.sort(key=lambda item: item["step"]) return parsed_schedule +def advance_curriculum_stage( + schedule: list[dict], + stage_state: dict, + samples_processed: int, + stats: dict, + logger: logging.Logger | None = None, +) -> None: + if not schedule or stage_state is None: + return + current_idx = int(stage_state.get("index", 0)) + ready_counts = stage_state.setdefault("ready_counts", {}) + advanced = False + + while current_idx + 1 < len(schedule): + next_idx = current_idx + 1 + stage_cfg = schedule[next_idx] + min_step = stage_cfg.get("step", 0) + if samples_processed < min_step: + break + + ready_conditions: list[dict] = stage_cfg.get("ready_success") or [] + patience = max(1, int(stage_cfg.get("ready_patience", 1))) + + if ready_conditions: + all_pass = True + for cond in ready_conditions: + dataset = cond.get("dataset") + metric = cond.get("metric", "success_mean") + threshold = float(cond.get("threshold", 1.0)) + if dataset: + metric_key = f"{dataset}/{metric}" + else: + metric_key = metric + value = stats.get(metric_key) + if value is None or value < threshold: + all_pass = False + break + if all_pass: + ready_counts[next_idx] = ready_counts.get(next_idx, 0) + 1 + else: + ready_counts[next_idx] = 0 + break + if ready_counts[next_idx] < patience: + break + current_idx = next_idx + stage_state["index"] = current_idx + ready_counts.pop(next_idx, None) + advanced = True + if logger: + logger.info( + "Curriculum stage %d activated (samples_processed=%d)", + current_idx, + samples_processed, + ) + if not advanced: + stage_state["index"] = current_idx + + async def schedule_rollouts( cfg: DictConfig, attempts: int, @@ -329,9 +297,9 @@ async def rollout_and_maybe_produce_result( llm = llms[llm_index] model_version = trainer_state.propagated_weight_version assert model_version is not None - logger.info(f"Starting rollout policy for problem {problem['id']}") + logger.debug(f"Starting rollout policy for problem {problem['id']}") rollout_result: RolloutResult = await rollout_policy(cfg, llm, problem, session) - logger.info(f"Finished rollout policy for problem {problem['id']}") + logger.debug(f"Finished rollout policy for problem {problem['id']}") rollout_result.model_version = model_version token_count += get_number_of_tokens_in_result(rollout_result) # Make a group id that will be different from groups made by another rollout maker @@ -378,7 +346,7 @@ async def rollout_and_maybe_produce_result( if finished_rollouts > old_finished_rollouts: old_finished_rollouts = finished_rollouts save_debug_line({"rollouts_finished": finished_rollouts, "tokens_produced": token_count, "dt": time.time() - start_time, "token_speed": token_count / (time.time() - start_time)}) - logger.info( + logger.debug( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " f"groups in progress: {len(group_rollouts)}, " @@ -454,9 +422,8 @@ def curriculum_iter( trainer_state: TrainerState, curriculum_cfg: DictConfig, logger: logging.Logger | None = None, - success_tracker: CurriculumSuccessTracker | None = None, - stage_state: Optional[dict] = None, parsed_schedule: Optional[list[dict]] = None, + stage_state: Optional[dict] = None, ): curriculum_obj = ( OmegaConf.to_container(curriculum_cfg, resolve=True) @@ -486,304 +453,72 @@ def curriculum_iter( medium_pool = [problem for problem in problems if problem.get("dataset") in medium_names] hard_pool = [problem for problem in problems if problem.get("dataset") in hard_names] - if not hard_pool: - if logger: - logger.warning( - "Curriculum enabled but no problems matched hard_datasets list; falling back to base sampling" - ) - yield from random_iter(problems) - return - - if medium_names and not medium_pool and logger: - logger.warning( - "Curriculum medium_datasets specified but no problems matched; medium weighting will be ignored" - ) + if not base_pool and medium_pool: + base_pool = list(medium_pool) + medium_pool = [] if not base_pool: if logger: - logger.warning("Curriculum enabled but base pool is empty; falling back to medium or hard datasets") - if medium_pool: - base_pool = list(medium_pool) - medium_pool = [] - else: - base_pool = hard_pool + logger.warning("Curriculum enabled but no matching datasets were found; falling back to random sampling") + yield from random_iter(problems) + return schedule = parsed_schedule or parse_curriculum_schedule(curriculum_obj) - if success_tracker: - for stage in schedule: - for threshold in stage["thresholds"]: - success_tracker.ensure_window(threshold["dataset"], threshold["window"]) - - def stage_ready(stage_cfg: dict, relaxation: float = 0.0) -> tuple[bool, list[str], bool, list[dict]]: - if not stage_cfg["thresholds"] or success_tracker is None: - return True, [], False, [] - blockers: list[str] = [] - threshold_blocked = False - stats: list[dict] = [] - for threshold in stage_cfg["thresholds"]: - dataset = threshold["dataset"] - threshold_value = threshold["threshold"] - window = threshold["window"] - min_samples = threshold.get("min_samples") - total_samples = success_tracker.total_samples(dataset) - if min_samples is not None: - if total_samples < min_samples: - blockers.append( - f"{dataset}: waiting for {min_samples} samples (have {total_samples})" - ) - stats.append( - { - "dataset": dataset, - "success_mean": None, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "min_samples", - } - ) - continue - success_mean_value = success_tracker.success_mean(dataset, window) - if success_mean_value is None: - blockers.append(f"{dataset}: insufficient window data (need {window})") - stats.append( - { - "dataset": dataset, - "success_mean": None, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "insufficient_window", - } - ) - continue - adjusted_threshold = threshold_value - relaxation - if success_mean_value < adjusted_threshold: - threshold_blocked = True - blockers.append( - f"{dataset}: success_mean {success_mean_value:.3f} < {adjusted_threshold:.3f} (threshold={threshold_value:.3f}, relaxation={relaxation:.3f}, window={window})" - ) - stats.append( - { - "dataset": dataset, - "success_mean": success_mean_value, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "threshold", - } - ) - continue - stats.append( - { - "dataset": dataset, - "success_mean": success_mean_value, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "ok", - } - ) - return (len(blockers) == 0), blockers, threshold_blocked, stats + if not schedule: + schedule = [{"step": 0, "medium_weight": 0.0, "hard_weight": 0.0}] - current_stage = -1 - last_block_log: tuple[int, tuple[str, ...]] | None = None - if stage_state is None: - stage_state = {"index": 0} - stage_state.setdefault("consecutive_failures", {}) - stage_state.setdefault("last_promotion_samples", -math.inf) + current_stage_index = stage_state.get("index", 0) if stage_state is not None else 0 + last_logged_stage: Optional[int] = None while True: samples_processed = trainer_state.samples_processed or 0 - desired_stage_index = 0 - - for idx, stage_cfg in enumerate(schedule): - step = stage_cfg["step"] - if samples_processed >= step: - desired_stage_index = idx - else: - break - - current_stage = int(stage_state.get("index", 0)) - if current_stage < 0: - current_stage = 0 - if current_stage >= len(schedule): - current_stage = len(schedule) - 1 - prev_stage = current_stage - - stage_index = min(current_stage, desired_stage_index) - promotion_blockers: list[str] = [] - promotion_stats_for_log: list[dict] = [] - - # Walk backwards until the current stage is ready (or we reach stage 0) - while stage_index > 0: - ready, _, _, _ = stage_ready(schedule[stage_index], relaxation=0.0) - if ready: - break - stage_index -= 1 - - ready, current_blockers, _, current_stats = stage_ready( - schedule[stage_index], relaxation=0.0 - ) - if not ready and stage_index > 0: - # If even after walking back we are not ready, fall back further until 0 - while stage_index > 0 and not ready: - stage_index -= 1 - ready, current_blockers, _, current_stats = stage_ready( - schedule[stage_index], relaxation=0.0 - ) + max_stage_allowed = current_stage_index + while ( + max_stage_allowed + 1 < len(schedule) + and samples_processed >= schedule[max_stage_allowed + 1]["step"] + ): + max_stage_allowed += 1 - # Attempt to promote by at most one stage towards the desired stage - if stage_index < desired_stage_index: - next_index = stage_index + 1 - next_ready, blockers, _, next_stats = stage_ready( - schedule[next_index], relaxation=0.0 - ) - if next_ready: - stage_index = next_index - current_blockers = [] - current_stats = next_stats - else: - promotion_blockers = blockers - promotion_stats_for_log = next_stats - promotion_block_stage: int | None = None - promotion_blockers_for_log: list[str] = [] - if stage_index < desired_stage_index: - promotion_block_stage = stage_index + 1 - promotion_blockers_for_log = promotion_blockers or current_blockers - if promotion_blockers_for_log: - if not promotion_stats_for_log: - promotion_stats_for_log = current_stats - else: - promotion_block_stage = None - - candidate_stage_index = stage_index - if candidate_stage_index > prev_stage + 1: - candidate_stage_index = prev_stage + 1 - - cooldown_blockers: list[str] = [] - cooldown_stats: list[dict] = [] - last_promotion_samples = stage_state.get("last_promotion_samples", -math.inf) - if candidate_stage_index > prev_stage: - cooldown_required = schedule[candidate_stage_index].get( - "promotion_cooldown_samples", 0 - ) - samples_since_promotion = samples_processed - last_promotion_samples - if samples_since_promotion < cooldown_required: - cooldown_blockers = [ - ( - f"promotion cooldown active: {samples_since_promotion} / " - f"{cooldown_required} samples since last promotion" - ) - ] - cooldown_stats = current_stats - candidate_stage_index = prev_stage - else: - stage_state["last_promotion_samples"] = samples_processed - - stage_index = candidate_stage_index - - blockers_for_log: list[str] = [] - block_stage: int | None = None - failure_counts: dict[int, int] = stage_state.setdefault("consecutive_failures", {}) - demotion_stats_for_log: list[dict] = [] - demotion_cancelled = False - if prev_stage > stage_index: - _, prev_blockers, prev_threshold_blocked, prev_stats = stage_ready( - schedule[prev_stage], - relaxation=schedule[prev_stage].get("threshold_hysteresis", 0.0), - ) - patience = schedule[prev_stage].get("demotion_patience", 1) - if prev_threshold_blocked and patience > 1: - failures = failure_counts.get(prev_stage, 0) + 1 - if failures < patience: - failure_counts[prev_stage] = failures - stage_index = prev_stage - demotion_cancelled = True - block_stage = prev_stage - blockers_for_log = prev_blockers - promotion_stats_for_log = prev_stats - demotion_stats_for_log = prev_stats - else: - failure_counts[prev_stage] = 0 - else: - failure_counts[prev_stage] = 0 + if stage_state is not None: + desired_index = int(stage_state.get("index", 0)) + current_stage_index = max(0, min(desired_index, max_stage_allowed)) else: - failure_counts.setdefault(prev_stage, 0) - failure_counts[prev_stage] = 0 - - if not demotion_cancelled: - failure_counts.setdefault(stage_index, 0) - if stage_index != prev_stage: - failure_counts[stage_index] = 0 + current_stage_index = max_stage_allowed - hard_weight = schedule[stage_index]["hard_weight"] - medium_weight = schedule[stage_index].get("medium_weight", 0.0) + stage_cfg = schedule[current_stage_index] + medium_weight = stage_cfg.get("medium_weight", 0.0) + hard_weight = stage_cfg.get("hard_weight", 0.0) if not medium_pool: medium_weight = 0.0 + if not hard_pool: + hard_weight = 0.0 + base_weight = max(0.0, 1.0 - medium_weight - hard_weight) + weight_sum = base_weight + medium_weight + hard_weight + if weight_sum == 0.0: + base_weight = 1.0 + medium_weight = 0.0 + hard_weight = 0.0 - stats_for_log: list[dict] = [] - if block_stage is None and promotion_block_stage is not None: - block_stage = promotion_block_stage - blockers_for_log = promotion_blockers_for_log - stats_for_log = promotion_stats_for_log - if block_stage is None and cooldown_blockers: - block_stage = prev_stage + 1 if prev_stage + 1 < len(schedule) else prev_stage - blockers_for_log = cooldown_blockers - stats_for_log = cooldown_stats - if not stats_for_log and demotion_stats_for_log: - stats_for_log = demotion_stats_for_log - - if logger and block_stage is not None and blockers_for_log: - block_signature = (block_stage, tuple(blockers_for_log)) - if block_signature != last_block_log: - stats_desc = "" - if stats_for_log: - formatted = [] - for stat in stats_for_log: - mean_val = stat.get("success_mean") - mean_str = f"{mean_val:.3f}" if mean_val is not None else "n/a" - formatted.append( - f"{stat.get('dataset')}: mean={mean_str}, thr={stat.get('threshold', 0.0):.3f}, " - f"rel={stat.get('relaxation', 0.0):.3f}, window={stat.get('window')}, " - f"samples={stat.get('total_samples')}, status={stat.get('status', 'n/a')}" - ) - stats_desc = " | stats: " + "; ".join(formatted) - logger.info( - "Curriculum stage %d gated by: %s%s", - block_stage, - "; ".join(blockers_for_log), - stats_desc, - ) - last_block_log = block_signature - elif stage_index >= desired_stage_index: - last_block_log = None - - if logger and stage_index != current_stage: - base_weight = max(0.0, 1.0 - hard_weight - medium_weight) + if logger and last_logged_stage != current_stage_index: logger.info( "Curriculum stage %d active (samples_processed=%d, base=%.3f, medium=%.3f, hard=%.3f)", - stage_index, + current_stage_index, samples_processed, base_weight, medium_weight, hard_weight, ) - current_stage = stage_index + last_logged_stage = current_stage_index - stage_state["index"] = stage_index + if stage_state is not None: + stage_state["index"] = current_stage_index choice = random.random() - if hard_pool and choice < hard_weight: + hard_cutoff = hard_weight + medium_cutoff = hard_cutoff + medium_weight + if hard_pool and choice < hard_cutoff: yield random.choice(hard_pool) - elif medium_pool and choice < hard_weight + medium_weight: + elif medium_pool and choice < medium_cutoff: yield random.choice(medium_pool) else: yield random.choice(base_pool) @@ -818,7 +553,7 @@ def __init__( self.is_training = is_training self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) - self.curriculum_tracker: CurriculumSuccessTracker | None = None + self.curriculum_schedule: list[dict] | None = None self.curriculum_stage_state: dict | None = None self.smm: SharedMemoryManager | None = None @@ -867,6 +602,7 @@ def init_stats(self): self.latency_list = [] self.model_versions_list = [] self.sliding_stats = defaultdict(list) + self.answer_status_counts = Counter() def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} @@ -893,15 +629,15 @@ def update_stats(self, rollout_results: List[RolloutResult]): for k, v in all_metrics.items(): if isinstance(v, list): self.stats[k][dataset_name][group_id] += v - if k == "success" and self.curriculum_tracker: - self.curriculum_tracker.update(dataset_name, v) elif isinstance(v, float) | isinstance(v, bool) | isinstance(v, int): self.stats[k][dataset_name][group_id].append(v) - if k == "success" and self.curriculum_tracker: - self.curriculum_tracker.update(dataset_name, [v]) else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") + status = getattr(result, "answer_status", None) + if status in {"correct", "wrong", "unparsable", "no_answer"}: + self.answer_status_counts[status] += 1 + if self.sliding_aggregator: prompt_length_tokens = [ training_text.prompt_tokens @@ -949,27 +685,23 @@ def run(self, dataset: list[tuple[str, dict]]): else curriculum_cfg ) parsed_schedule = parse_curriculum_schedule(curriculum_obj) - self.curriculum_tracker = CurriculumSuccessTracker() - for stage in parsed_schedule: - for threshold in stage["thresholds"]: - self.curriculum_tracker.ensure_window(threshold["dataset"], threshold["window"]) + self.curriculum_schedule = parsed_schedule self.curriculum_stage_state = {"index": 0} problem_iter = curriculum_iter( dataset, trainer_state=self.trainer_state, curriculum_cfg=curriculum_cfg, logger=logger, - success_tracker=self.curriculum_tracker, - stage_state=self.curriculum_stage_state, parsed_schedule=parsed_schedule, + stage_state=self.curriculum_stage_state, ) else: problem_iter = random_iter(dataset) - self.curriculum_tracker = None + self.curriculum_schedule = None self.curriculum_stage_state = None else: problem_iter = sequential_iter(dataset) - self.curriculum_tracker = None + self.curriculum_schedule = None self.curriculum_stage_state = None assert self.trainer_state.propagated_weight_version is not None @@ -1136,8 +868,21 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): stats |= loop_stats for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 - if self.curriculum_stage_state is not None: + if self.curriculum_schedule and self.curriculum_stage_state is not None: + advance_curriculum_stage( + self.curriculum_schedule, + self.curriculum_stage_state, + self.trainer_state.samples_processed or 0, + stats, + logger, + ) stats["curriculum_stage_active"] = self.curriculum_stage_state.get("index", 0) + + total_status = sum(self.answer_status_counts.values()) + if total_status: + for status, count in self.answer_status_counts.items(): + stats[f"{split_name}answer_status_{status}_count"] = count + stats[f"{split_name}answer_status_{status}_ratio"] = count / total_status if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) From 3e06b59c10d5c93f55252f18c3f27fbbace1083c Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:42:32 +0000 Subject: [PATCH 09/20] Track answer status --- pipelinerl/domains/math/rollouts.py | 35 ++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 862375d0..1f70c11d 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -1,7 +1,9 @@ +import json +import logging +import random import re import time -import random -import logging +from pathlib import Path import aiohttp from omegaconf import DictConfig @@ -14,6 +16,7 @@ from pipelinerl.async_llm import llm_async_generate, make_training_text from .verifier_api import verify_answer_rpc +logger = logging.getLogger(__name__) class Metrics(BaseMetrics): penalty: float @@ -139,6 +142,30 @@ def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> return ((max_length - buffer_tokens) - sequence_length) / buffer_tokens return 0. + +def log_answer_status(cfg: DictConfig, problem: dict, answer_status: str, reward: float, latency: float) -> None: + """ + Metric logging for answer status - correct, wrong, no_answer, unparsable + """ + try: + log_dir = Path(cfg.output_dir) if cfg.output_dir else None + if not log_dir: + return + log_path = log_dir / "answer_status.jsonl" + record = { + "t": time.time(), + "problem_id": problem.get("id"), + "dataset": problem.get("dataset"), + "answer_status": answer_status, + "reward": reward, + "latency": latency, + } + with log_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record)) + handle.write("\n") + except Exception: + logger.debug("Failed to append answer status log", exc_info=True) + async def generate_math_rollout( cfg: DictConfig, llm: TrainableLLM, @@ -199,6 +226,7 @@ async def generate_math_rollout( ) reward += overlong_penalty trace.reward = reward + log_answer_status(cfg, problem, answer_status, reward, latency) # Prefer backend-provided finish reason if available; normalize for comparisons if isinstance(trace.metadata, dict): @@ -249,6 +277,7 @@ async def generate_math_rollout( return RolloutResult( training_texts=[trace], metrics=metrics, - latency=latency, + latency=latency, dataset_name=problem.get("dataset"), + answer_status=answer_status, ) From 7a63df9d444ce75bf96a5ae2346748009a407a34 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:48:07 +0000 Subject: [PATCH 10/20] Remove delimiter tags --- pipelinerl/domains/math/verifier_api.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pipelinerl/domains/math/verifier_api.py b/pipelinerl/domains/math/verifier_api.py index 9fa9b4ef..59f66dfc 100644 --- a/pipelinerl/domains/math/verifier_api.py +++ b/pipelinerl/domains/math/verifier_api.py @@ -1,6 +1,7 @@ import time import requests import asyncio +import re from concurrent.futures import ProcessPoolExecutor import aiohttp import uvicorn @@ -61,6 +62,18 @@ def timeout_handler(signum, frame): signal.signal(signal.SIGALRM, original_handler) +DELIMITER_STR = re.compile(r"\[END FINAL RESPONSE\]", flags=re.IGNORECASE) + + +def strip_delimiter_strings(text: str) -> str: + if not text: + return text + stripped = DELIMITER_STR.sub("", text) + # Remove lines that became empty after sentinel stripping to avoid parsing noise + cleaned_lines = [line for line in stripped.splitlines() if line.strip()] + return "\n".join(cleaned_lines) + + def verify_answer(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: """ Checks if a predicted answer matches a gold (correct) answer by making a request to the math_verify package. @@ -88,13 +101,13 @@ def verify_answer(prediction: str, gold: str, strict: bool = True, max_predictio def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: - import re - try: # Input Sanitization / Validation if not isinstance(prediction, str) or not isinstance(gold, str): raise ValueError("Prediction and gold must be strings") + prediction = strip_delimiter_strings(prediction) + # Try extracting from \boxed{...} first boxed_start = prediction.rfind("\\boxed{") @@ -109,7 +122,7 @@ def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_ # Fallback: look for ... tags answer_match = re.findall(r"(.*?)", prediction, re.DOTALL) if answer_match: - extracted_prediction = answer_match[-1].strip() # last one if multiple + extracted_prediction = strip_delimiter_strings(answer_match[-1].strip()) # last one else: raise NoAnswerException() @@ -225,5 +238,3 @@ async def health(): return JSONResponse(content={"status": "ok"}) uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=60) - - From fc1ff86dfcc0cb24a96b64734975ba2557850e77 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:48:47 +0000 Subject: [PATCH 11/20] Track ans status in rollout --- pipelinerl/rollouts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py index dcb27f2d..6ee2b6c2 100644 --- a/pipelinerl/rollouts.py +++ b/pipelinerl/rollouts.py @@ -64,3 +64,4 @@ class RolloutResult(BaseModel): model_version: int | None = None dataset_name: str | None = None group_id: str | None = None + answer_status: str | None = None From e6c9463cedb26b110c639e58858f5a3f735970fb Mon Sep 17 00:00:00 2001 From: rafapi Date: Mon, 27 Oct 2025 16:42:22 +0000 Subject: [PATCH 12/20] rollout counter --- pipelinerl/actor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 88bfc193..5e7eeb18 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -603,6 +603,7 @@ def init_stats(self): self.model_versions_list = [] self.sliding_stats = defaultdict(list) self.answer_status_counts = Counter() + self.dataset_sample_counts = Counter() def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} @@ -637,6 +638,8 @@ def update_stats(self, rollout_results: List[RolloutResult]): status = getattr(result, "answer_status", None) if status in {"correct", "wrong", "unparsable", "no_answer"}: self.answer_status_counts[status] += 1 + if dataset_name: + self.dataset_sample_counts[str(dataset_name)] += 1 if self.sliding_aggregator: prompt_length_tokens = [ @@ -883,6 +886,12 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): for status, count in self.answer_status_counts.items(): stats[f"{split_name}answer_status_{status}_count"] = count stats[f"{split_name}answer_status_{status}_ratio"] = count / total_status + total_rollouts = sum(self.dataset_sample_counts.values()) + if total_rollouts: + stats["dataset_rollouts_total"] = total_rollouts + for dataset, count in self.dataset_sample_counts.items(): + stats[f"{dataset}/rollout_count"] = count + stats[f"{dataset}/rollout_ratio"] = count / total_rollouts if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) From 840a65e394c6e4ee5db07a5c2e43c7553f26ffa1 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:29:04 +0000 Subject: [PATCH 13/20] Allow to continue training during eval --- pipelinerl/actor.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 5e7eeb18..7a152a4b 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -1160,6 +1160,9 @@ def run_actor_loop(cfg: DictConfig): ) test_loop_run = None + pause_training_during_eval = bool( + getattr(cfg.actor, "pause_training_during_eval", True) + ) last_regular_eval = -1 current_eval = -1 while True: @@ -1183,7 +1186,8 @@ def run_actor_loop(cfg: DictConfig): test_loop_run = test_loop.run( dataset=test_dataset, ) - train_loop.is_scheduling_paused = True + if pause_training_during_eval: + train_loop.is_scheduling_paused = True current_eval = next_regular_eval # 2. If there is an active test loop, keep it running @@ -1194,7 +1198,8 @@ def run_actor_loop(cfg: DictConfig): # 2.1 If the test loop is finished, resume scheduling the training loop test_loop_run = None last_regular_eval = current_eval - train_loop.is_scheduling_paused = False + if pause_training_during_eval: + train_loop.is_scheduling_paused = False logger.info("Test loop finished") # 3. Keep running the training loop From dc4b346523af37f69e27ffb53acc294292d4ea5f Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:30:39 +0000 Subject: [PATCH 14/20] load custom datasets --- pipelinerl/domains/math/load_datasets.py | 143 +++++++++++++++++++---- 1 file changed, 122 insertions(+), 21 deletions(-) diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py index 2eec4a9b..fbfaea97 100644 --- a/pipelinerl/domains/math/load_datasets.py +++ b/pipelinerl/domains/math/load_datasets.py @@ -2,7 +2,8 @@ import logging import random import re -from typing import Dict, List, Tuple +from pathlib import Path +from typing import Dict, Iterable, List, Sequence, Tuple import datasets import hydra @@ -190,26 +191,6 @@ def _load_aime_dataset(year: int, upsample_factor: int = 0) -> list[dict]: return add_ids(samples) -def _load_aime_2025_opencompass(upsample_factor: int = 0) -> list[dict]: - configs = ["AIME2025-I", "AIME2025-II"] - dataset_name = "aime_2025" + ("" if upsample_factor > 0 else "_original") - - samples: list[dict] = [] - for config_name in configs: - ds = load_dataset("opencompass/AIME2025", config_name, split="test") - samples.extend([s for s in process_math(ds, dataset_name) if s is not None]) - - original_size = len(samples) - if upsample_factor > 0: - samples *= upsample_factor - - logger.info( - f"Loading aime 2025 (OpenCompass) dataset: {len(samples)} samples" - + (f" (upsampled from {original_size})" if upsample_factor > 0 else "") - ) - return add_ids(samples) - - def _load_amc_dataset(year: int, upsample_factor: int = 0) -> list[dict]: amc_dataset = load_dataset("AI-MO/aimo-validation-amc", split="train", trust_remote_code=True) amc_dataset = amc_dataset.filter(lambda x: str(year) in x["url"]) @@ -234,18 +215,93 @@ def add_ids(dataset: list[dict]): return dataset +def _resolve_custom_path(relative_paths: str | Sequence[str]) -> Path: + """ + Resolve a path for locally generated datasets. + + Hydra jobs may change the working directory, so we check both the current + directory and the repository root. + """ + if isinstance(relative_paths, str): + relative_paths = [relative_paths] + + resolved = Path(__file__).resolve() + base_candidates = [Path.cwd()] + if len(resolved.parents) >= 5: + base_candidates.append(resolved.parents[4]) + + candidates: List[Path] = [] + for rel in relative_paths: + rel_path = Path(rel) + candidates.append(rel_path) + for base in base_candidates: + if base == Path.cwd(): + continue + candidates.append(base / rel_path) + + for candidate in candidates: + if candidate.exists(): + return candidate + raise FileNotFoundError( + f"Custom dataset not found. Tried: {[str(path) for path in candidates]}" + ) + + +def _load_custom_dataset(dataset_name: str) -> list[dict]: + """ + Load a locally generated dataset by name. + + The loader searches under `datasets/custom/` and `datasets/custom_runs/` for either + `` or `.jsonl`. + """ + candidate_names: List[str] = [] + if dataset_name.endswith(".jsonl"): + candidate_names.append(dataset_name) + else: + candidate_names.extend([dataset_name, f"{dataset_name}.jsonl"]) + + search_paths: List[str] = [] + for name in candidate_names: + search_paths.extend( + [ + f"datasets/custom/{name}", + f"datasets/custom_runs/{name}", + name, + ] + ) + + dataset_path = _resolve_custom_path(search_paths) + with dataset_path.open("r", encoding="utf-8") as handle: + samples = [json.loads(line) for line in handle if line.strip()] + + dataset_label = dataset_name[:-6] if dataset_name.endswith(".jsonl") else dataset_name + + for idx, sample in enumerate(samples): + sample.setdefault("source_dataset", sample.get("dataset", dataset_label)) + sample.setdefault("source_id", sample.get("id")) + sample["dataset"] = dataset_label + sample["id"] = idx + + logger.info(f"Loading custom dataset {dataset_name}: {len(samples)} samples from {dataset_path}") + return samples + + def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None) -> List[Tuple[str, Dict]]: if dataset_names is None: return [] if isinstance(dataset_names, str): dataset_names = [dataset_names] + # Preserve order while de-duplicating + dataset_names = list(dict.fromkeys(dataset_names)) datasets = [] + remaining = set(dataset_names) if "eurus_train" in dataset_names: dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data", split="train", trust_remote_code=True) samples = [s for s in process_eurus(dataset) if s is not None] logger.info(f"Loading eurus train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("eurus_train") # great for debugging since its much smaller than eurus train if "eurus_validation" in dataset_names: @@ -253,6 +309,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_eurus(dataset) if s is not None] logger.info(f"Loading eurus validation dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("eurus_validation") if "math_train" in dataset_names: # math_dataset = load_math("train") @@ -260,6 +317,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_train") if s is not None] logger.info(f"Loading math train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_train") if "math_simplerl_train" in dataset_names: # SimpleRL MATH dataset @@ -274,6 +332,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_simplerl_train") if s is not None] logger.info(f"Loading math simplerl train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_simplerl_train") if "simplerl_math_subset_1000" in dataset_names: # SimpleRL MATH dataset subset @@ -292,12 +351,14 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = samples[:1000] logger.info(f"Loading math simplerl subset test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("simplerl_math_subset_1000") if "deepscaler_preview" in dataset_names: dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", trust_remote_code=True) samples = [s for s in process_math(dataset, "deepscaler") if s is not None] logger.info(f"Loading deepscaler preview train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("deepscaler_preview") if "math_test" in dataset_names: # math_dataset = load_math("test") @@ -305,36 +366,42 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_test") if s is not None] logger.info(f"Loading math test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_test") if "omni_math_500" in dataset_names: dataset = load_dataset("reliable-agents/Omni-MATH-500", split="test", trust_remote_code=True) samples = [s for s in process_math(dataset, "omni_math_500") if s is not None] logger.info(f"Loading omni math 500 dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("omni_math_500") if "math_500" in dataset_names: dataset = load_dataset("HuggingFaceH4/MATH-500", split="test", trust_remote_code=True) samples = [s for s in process_math(dataset, "math_500") if s is not None] logger.info(f"Loading math 500 dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_500") if "open_r1_math_220k" in dataset_names: dataset = load_dataset("open-r1/OpenR1-Math-220k", split="default", trust_remote_code=True) samples = [s for s in process_math(dataset, "open_r1_math_220k") if s is not None] logger.info(f"Loading open r1 math 220k dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_r1_math_220k") if "gpqa_main" in dataset_names: dataset = load_dataset("hendrydong/gpqa_main", split="test", trust_remote_code=True) samples = [s for s in process_gpqa(dataset, "gpqa_main") if s is not None] logger.info(f"Loading gpqa main dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gpqa_main") if "gpqa_diamond" in dataset_names: dataset = load_dataset("hendrydong/gpqa_diamond", split="test", trust_remote_code=True) samples = [s for s in process_gpqa(dataset, "gpqa_diamond") if s is not None] logger.info(f"Loading gpqa diamond dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gpqa_diamond") if "gpqa_diamond" in dataset_names: pass @@ -344,55 +411,70 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_gsm8k(dataset, "gsm8k_train") if s is not None] logger.info(f"Loading gsm8k train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gsm8k_train") if "gsm8k_test" in dataset_names: dataset = load_dataset("openai/gsm8k", "main", split="test", trust_remote_code=True) samples = [s for s in process_gsm8k(dataset, "gsm8k_test") if s is not None] logger.info(f"Loading gsm8k test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gsm8k_test") if "limo" in dataset_names: dataset = load_dataset("GAIR/LIMO", split="train", trust_remote_code=True) samples = [s for s in process_limo(dataset) if s is not None] logger.info(f"Loading limo dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("limo") if "aime_2022" in dataset_names: datasets += _load_aime_dataset(2022, upsample_factor=16) + remaining.discard("aime_2022") if "aime_2022_original" in dataset_names: datasets += _load_aime_dataset(2022) + remaining.discard("aime_2022_original") if "aime_2023" in dataset_names: datasets += _load_aime_dataset(2023, upsample_factor=16) + remaining.discard("aime_2023") if "aime_2023_original" in dataset_names: datasets += _load_aime_dataset(2023) + remaining.discard("aime_2023_original") if "aime_2024" in dataset_names: datasets += _load_aime_dataset(2024, upsample_factor=16) + remaining.discard("aime_2024") if "aime_2024_original" in dataset_names: datasets += _load_aime_dataset(2024) + remaining.discard("aime_2024_original") if "aime_2025" in dataset_names: datasets += _load_aime_2025_opencompass_dataset(upsample_factor=16) + remaining.discard("aime_2025") if "aime_2025_original" in dataset_names: datasets += _load_aime_2025_opencompass_dataset() + remaining.discard("aime_2025_original") if "amc_2022" in dataset_names: # TODO: AMC 2022 is 43 problems, is that to be expected? datasets += _load_amc_dataset(2022, upsample_factor=16) + remaining.discard("amc_2022") if "amc_2022_original" in dataset_names: datasets += _load_amc_dataset(2022) + remaining.discard("amc_2022_original") if "amc_2023" in dataset_names: datasets += _load_amc_dataset(2023, upsample_factor=16) + remaining.discard("amc_2023") if "amc_2023_original" in dataset_names: datasets += _load_amc_dataset(2023) + remaining.discard("amc_2023_original") if "sometimes_success_data" in dataset_names: PATH = "data/sometimes_success_data/data.jsonl" @@ -400,6 +482,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [json.loads(line) for line in f] logger.info(f"Loading easy data dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("sometimes_success_data") if "open_reasoner_zero_57k" in dataset_names: dataset = load_dataset( @@ -411,6 +494,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_57k") if s is not None] logger.info(f"Loading Open Reasoner Zero dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_57k") if "open_reasoner_zero_extended_72k" in dataset_names: dataset = load_dataset( @@ -422,6 +506,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_extended_72k") if s is not None] logger.info(f"Loading Open Reasoner Zero extended dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_extended_72k") if "open_reasoner_zero_hard_13k" in dataset_names: dataset = load_dataset( @@ -433,6 +518,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_hard_13k") if s is not None] logger.info(f"Loading Open Reasoner Zero hard dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_hard_13k") for dataset_name in dataset_names: test_matched = re.match(r"multiplication_(\d+)_by_(\d+)_(\d+)_test", dataset_name) @@ -453,6 +539,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None ] logger.info(f"Loading multiplication {num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard(dataset_name) elif train_matched: upto_prefix = train_matched.group(1) or "" num_digits_1 = int(train_matched.group(2)) @@ -474,6 +561,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None f"Loading multiplication {upto_prefix}_{num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples" ) datasets += add_ids(samples) + remaining.discard(dataset_name) if "countdown" in dataset_names: dataset = load_dataset( @@ -482,6 +570,19 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_countdown(dataset) if s is not None] logger.info(f"Loading countdown dataset: {len(samples)} samples") datasets += samples + remaining.discard("countdown") + + # resolve any remaining names as local custom datasets. + unresolved: List[str] = [] + for dataset_name in list(remaining): + try: + datasets += _load_custom_dataset(dataset_name) + remaining.discard(dataset_name) + except FileNotFoundError: + unresolved.append(dataset_name) + + if unresolved: + raise ValueError(f"Unknown dataset(s): {unresolved}") if len(datasets) == 0: raise ValueError("No datasets loaded") From 3e17f755387a0dd3e1957432c561589c102ff32b Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:32:04 +0000 Subject: [PATCH 15/20] Align verifier to external benchmarks --- pipelinerl/domains/math/verifier_api.py | 352 +++++++++++++++++------- 1 file changed, 247 insertions(+), 105 deletions(-) diff --git a/pipelinerl/domains/math/verifier_api.py b/pipelinerl/domains/math/verifier_api.py index 59f66dfc..84548f15 100644 --- a/pipelinerl/domains/math/verifier_api.py +++ b/pipelinerl/domains/math/verifier_api.py @@ -1,29 +1,25 @@ -import time -import requests import asyncio -import re -from concurrent.futures import ProcessPoolExecutor -import aiohttp -import uvicorn import logging +import re import signal +import time +from concurrent.futures import ProcessPoolExecutor from contextlib import contextmanager +from functools import partial -from omegaconf import DictConfig -import math_verify # Ensure math_verify is installed - +import aiohttp +import math_verify +import requests # noqa: F401 - retained for parity with upstream +import uvicorn from fastapi import FastAPI from fastapi.responses import JSONResponse -from functools import partial import pipelinerl.countdown_utils logging.basicConfig( - level=logging.DEBUG, # Or INFO, WARNING, etc. + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - logging.StreamHandler(), # Logs to console - ], + handlers=[logging.StreamHandler()], ) @@ -47,105 +43,257 @@ class EmptyBoxedException(Exception): @contextmanager -def timeout(seconds=1): +def timeout(seconds: int = 1): def timeout_handler(signum, frame): raise TimeoutException("Computation timed out") - # Set the timeout handler original_handler = signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(seconds) try: - yield # This is the key addition - context managers must yield + yield finally: - # Restore the original handler and disable the alarm signal.alarm(0) signal.signal(signal.SIGALRM, original_handler) -DELIMITER_STR = re.compile(r"\[END FINAL RESPONSE\]", flags=re.IGNORECASE) +ANSWER_PREFIX_RE = re.compile( + r"^(final answer|answer|ans\.?|thus.*?is|therefore.*?is|so the answer is)[:=\-\s]*", + re.IGNORECASE, +) -def strip_delimiter_strings(text: str) -> str: - if not text: - return text - stripped = DELIMITER_STR.sub("", text) - # Remove lines that became empty after sentinel stripping to avoid parsing noise - cleaned_lines = [line for line in stripped.splitlines() if line.strip()] - return "\n".join(cleaned_lines) +def _strip_answer_prefix(line: str) -> str: + return ANSWER_PREFIX_RE.sub("", line).strip() -def verify_answer(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: - """ - Checks if a predicted answer matches a gold (correct) answer by making a request to the math_verify package. - - Args: - prediction (str): The predicted answer to validate - gold (str): The gold (correct) answer to compare against - strict (bool): Whether to enforce strict comparison mode. - - In strict mode: Variables matter and sets are not comparable with tuples - - In non-strict mode: Variables are matched by position and sets can be compared with tuples - url (str): URL of the validation service endpoint - - Returns: - str: The status of the answer, which can be one of the following: - - "correct": The prediction is correct - - "wrong": The prediction is incorrect - - "no_answer": The prediction is empty - - "unparsable": The prediction cannot be parsed - - """ - if prediction.startswith("countdown"): - return verify_countdown(prediction, gold) - else: - return verify_math(prediction, gold, strict=strict, max_prediction_length=max_prediction_length) +def _extract_fallback_expression(text: str) -> str | None: + if not text: + return None + for raw_line in reversed(text.strip().splitlines()): + cleaned = _strip_answer_prefix(raw_line.strip()).rstrip(".;!") + if cleaned and (any(ch.isdigit() for ch in cleaned) or "\\" in cleaned): + return cleaned + return None + + +def remove_boxed(s: str) -> str: + if "\\boxed " in s: + left = "\\boxed " + if not s.startswith(left): + raise UnparsableException() + return s[len(left) :] + + left = "\\boxed{" + if not s.startswith(left) or not s.endswith("}"): + raise UnparsableException() + return s[len(left) : -1] + + +def last_boxed_only_string(text: str) -> str | None: + idx = text.rfind("\\boxed") + if "\\boxed " in text: + return "\\boxed " + text.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = text.rfind("\\fbox") + if idx < 0: + return None + + right_brace_idx = None + num_left_braces_open = 0 + i = idx + while i < len(text): + if text[i] == "{": + num_left_braces_open += 1 + if text[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + if right_brace_idx is None: + return None + return text[idx : right_brace_idx + 1] + + +def fix_fracs(string: str) -> str: + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + if len(substr) < 2: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + return new_str + + +def fix_a_slash_b(string: str) -> str: + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + return f"\\frac{{{a}}}{{{b}}}" + except (AssertionError, ValueError): + return string + + +def remove_right_units(string: str) -> str: + if "\\text{ " in string: + splits = string.split("\\text{ ") + if len(splits) == 2: + return splits[0] + return string + + +def fix_sqrt(string: str) -> str: + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split and split[0] != "{": + a = split[0] + new_substr = f"\\sqrt{{{a}}}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string: str) -> str: + string = string.replace("\n", "").replace("\\!", "").replace("\\\\", "\\") + string = string.replace("tfrac", "frac").replace("dfrac", "frac") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("^{\\circ}", "").replace("^\\circ", "") + string = string.replace("\\$", "") + string = remove_right_units(string) + string = string.replace("\\%", "") + string = string.replace(" .", " 0.").replace("{.", "{0.") + if not string: + return string + if string[0] == ".": + string = "0" + string + if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + string = fix_sqrt(string) + string = string.replace(" ", "") + string = fix_fracs(string) + if string == "0.5": + string = "\\frac{1}{2}" + string = fix_a_slash_b(string) + return string + + +def is_equiv(str1: str, str2: str) -> bool: + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + return ss1 == ss2 + except Exception: + return str1 == str2 def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: try: - # Input Sanitization / Validation if not isinstance(prediction, str) or not isinstance(gold, str): raise ValueError("Prediction and gold must be strings") - prediction = strip_delimiter_strings(prediction) + prediction = prediction.strip() + if not prediction: + raise NoAnswerException() - # Try extracting from \boxed{...} first - boxed_start = prediction.rfind("\\boxed{") + extracted_prediction: str | None = None - if boxed_start >= 0: - boxed_prediction = prediction[boxed_start:] - if "\\boxed{}" in boxed_prediction: - raise EmptyBoxedException() - if len(boxed_prediction) > max_prediction_length: - raise UnparsableException() - extracted_prediction = boxed_prediction - else: - # Fallback: look for ... tags + boxed_prediction = last_boxed_only_string(prediction) + if boxed_prediction is not None: + try: + extracted_prediction = remove_boxed(boxed_prediction).strip() + except UnparsableException as exc: + logger.debug("Failed to remove boxed expression", exc_info=exc) + extracted_prediction = None + + if not extracted_prediction: answer_match = re.findall(r"(.*?)", prediction, re.DOTALL) if answer_match: - extracted_prediction = strip_delimiter_strings(answer_match[-1].strip()) # last one + extracted_prediction = answer_match[-1].strip() else: - raise NoAnswerException() + fallback_expression = _extract_fallback_expression(prediction) + if fallback_expression: + extracted_prediction = fallback_expression.strip() + else: + raise NoAnswerException() + + if not extracted_prediction: + raise EmptyBoxedException() + + if 0 < max_prediction_length < len(extracted_prediction): + raise UnparsableException() + + if is_equiv(gold, extracted_prediction): + return "correct" + + try: + target_boxed = last_boxed_only_string(f"\\boxed{{{gold}}}") or f"\\boxed{{{gold}}}" + pred_boxed = last_boxed_only_string(f"\\boxed{{{extracted_prediction}}}") or f"\\boxed{{{extracted_prediction}}}" + gold_parsed = math_verify.parse(target_boxed) + pred_parsed = math_verify.parse(pred_boxed) + except Exception as parse_exc: + logger.debug("math_verify.parse failed", exc_info=parse_exc) + raise UnparsableException() from parse_exc - # Parse and verify - gold_parsed = math_verify.parse(gold) - pred_parsed = math_verify.parse(extracted_prediction) if not pred_parsed: - raise ValueError("Failed to parse prediction.") + raise UnparsableException("Prediction parsed to empty result.") - with timeout(1): - equivalent = math_verify.verify(gold_parsed, pred_parsed, strict=strict, timeout_seconds=1) + try: + with timeout(1): + equivalent = math_verify.verify(gold_parsed, pred_parsed, strict=strict, timeout_seconds=1) + except TimeoutException as timeout_exc: + logger.debug("math_verify.verify timed out; treating as wrong", exc_info=timeout_exc) + return "wrong" + except (ValueError, TypeError, NotImplementedError) as verify_exc: + logger.debug("math_verify.verify raised recoverable error; treating as wrong", exc_info=verify_exc) + return "wrong" + except Exception as verify_exc: + logger.debug("math_verify.verify failed unexpectedly", exc_info=verify_exc) + raise - answer_status = "correct" if equivalent else "wrong" + return "correct" if equivalent else "wrong" - except Exception as e: - match e: + except Exception as error: + match error: case NoAnswerException(): answer_status = "no_answer" + case (EmptyBoxedException() | UnparsableException()): + answer_status = "unparsable" case _: + logger.debug("Unexpected verifier error", exc_info=error) answer_status = "unparsable" - - return answer_status - + return answer_status def verify_countdown(prediction: str, gold: str) -> str: @@ -157,28 +305,35 @@ def verify_countdown(prediction: str, gold: str) -> str: if equation is None: return "no_answer" - format_correct = pipelinerl.countdown_utils.validate_format(prediction) - if not format_correct: + if not pipelinerl.countdown_utils.validate_format(prediction): return "unparsable" - # Validate equation uses correct numbers if not pipelinerl.countdown_utils.validate_equation(equation, numbers): return "wrong" - # Evaluate equation try: result = pipelinerl.countdown_utils.evaluate_equation(equation) if result is None: return "wrong" - - if abs(result - target) < 1e-5: # Account for floating point precision - return "correct" - else: - return "wrong" - except Exception as _: + return "correct" if abs(result - target) < 1e-5 else "wrong" + except Exception: return "wrong" +def verify_answer(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: + try: + if prediction.startswith("countdown"): + return verify_countdown(prediction, gold) + return verify_math(prediction, gold, strict=strict, max_prediction_length=max_prediction_length) + except NoAnswerException: + return "no_answer" + except UnparsableException: + return "unparsable" + except Exception as exc: + logger.debug("verify_answer unexpected failure", exc_info=exc) + return "unparsable" + + async def verify_answer_rpc( session: aiohttp.ClientSession, host: str, @@ -188,36 +343,24 @@ async def verify_answer_rpc( strict: bool = True, max_prediction_length: int = 1000, ): - """ - Verify the answer using the verifier API. - """ - json = { + payload = { "prediction": prediction, "gold": gold, "strict": strict, "max_prediction_length": max_prediction_length, } - async with session.post( - f"http://{host}:{port}/verify_answer", - json=json, - ) as response: + async with session.post(f"http://{host}:{port}/verify_answer", json=payload) as response: if response.status == 200: data = await response.json() return data["answer_status"] - else: - logger.error(f"Error verifying answer: {response.status}") - logger.error(f"Response: {await response.text()}") - raise ValueError("Error verifying answer") + logger.error("Error verifying answer: %s", response.status) + logger.error("Response: %s", await response.text()) + raise ValueError("Error verifying answer") class MathEnvironment: - def launch(self, port: int): - """ - Serve the verification API using FastAPI. - """ app = FastAPI() - # Create a process pool with 4 workers with ProcessPoolExecutor(max_workers=4) as process_pool: @app.post("/verify_answer") async def verify(request: dict): @@ -226,7 +369,6 @@ async def verify(request: dict): strict = request["strict"] max_prediction_length = request["max_prediction_length"] - # Run verification in the process pool to avoid blocking the main thread loop = asyncio.get_event_loop() answer_status = await loop.run_in_executor( process_pool, partial(verify_answer, prediction, gold, strict, max_prediction_length) From 3534f3caefcf62cb6c1e9c02d741d4f165690262 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:35:36 +0000 Subject: [PATCH 16/20] Add helpers for rollout tracking --- pipelinerl/preprocess_helpers.py | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 pipelinerl/preprocess_helpers.py diff --git a/pipelinerl/preprocess_helpers.py b/pipelinerl/preprocess_helpers.py new file mode 100644 index 00000000..88ee81a2 --- /dev/null +++ b/pipelinerl/preprocess_helpers.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Iterable + + +def group_rollout_idx(group: Iterable[dict]) -> set[int] | None: + """Extract rollout idx from a rollout group.""" + rollout_indices: set[int] = set() + for text in group: + metadata = text.get("metadata") + if not isinstance(metadata, dict): + return None + rollout_index = metadata.get("rollout_index") + if rollout_index is None: + return None + rollout_indices.add(rollout_index) + return rollout_indices + + +def validate_rollout_group(group: Iterable[dict], group_size: int) -> tuple[bool, list[int], list[int]]: + """Return whether a group is complete and any missing or extra rollout indices.""" + rollout_indices = group_rollout_idx(group) + if rollout_indices is None: + return False, [], [] + if len(rollout_indices) != group_size: + expected_indices = set(range(group_size)) + if rollout_indices.issubset(expected_indices): + missing = sorted(expected_indices - rollout_indices) + extra: list[int] = [] + else: + missing = sorted(expected_indices - rollout_indices) + extra = sorted(rollout_indices - expected_indices) + return False, missing, extra + return True, [], [] From 8d46aab0f04e2ea502d8463cd9d21d5c1759e435 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:36:07 +0000 Subject: [PATCH 17/20] Remove group if rollout is skipped --- pipelinerl/preprocess.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 9a1e4773..8a57c567 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -22,6 +22,7 @@ from litellm import BaseModel, Field from pipelinerl.finetune.logging_ import flatten_dict_config +from pipelinerl.preprocess_helpers import group_rollout_idx, validate_rollout_group from pipelinerl.shared_memory_array import SharedMemoryArray, SharedMemoryQueue from pipelinerl.state import TrainerState from pipelinerl.utils import setup_logging, wait_for_inference_servers, init_wandb @@ -196,10 +197,27 @@ def run_dataset_loader( buffer = [] n_groups = 0 for group in reader.read(): + if not group: + continue + is_complete, missing, extra = validate_rollout_group(group, check_group_size) + if not is_complete: + group_name = group[0].get("group_id") if group else "" + if not missing and not extra: + logger.warning("Skipping group %s without rollout metadata", group_name) + else: + logger.warning( + "Skipping incomplete group %s: missing rollouts %s extra %s", + group_name, + missing, + extra, + ) + continue buffer.extend(group) n_groups += 1 if n_groups == chunk_n_groups: break + if not buffer: + continue if not _check_group_sizes(buffer, check_group_size): raise ValueError("Invalid group sizes in data") try: From 468deea37ca6d8b85770feaa4026d5bf9c1503f6 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:38:12 +0000 Subject: [PATCH 18/20] Skip rollout if timeout --- pipelinerl/domains/math/rollouts.py | 37 ++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 1f70c11d..09c879a6 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import random @@ -182,7 +183,16 @@ async def generate_math_rollout( prompt = Prompt(messages=messages) time_start = time.time() - llm_call = await llm_async_generate(llm, prompt, session) + try: + llm_call = await llm_async_generate(llm, prompt, session) + except (asyncio.TimeoutError, aiohttp.client_exceptions.ServerTimeoutError) as exc: + latency = time.time() - time_start + logger.warning( + "LLM request timed out for problem %s. Skipping sample.", + problem.get("id"), + exc_info=exc, + ) + return create_timeout_rollout_result(cfg, problem, latency) latency = time.time() - time_start assert llm_call.output.content is not None @@ -281,3 +291,28 @@ async def generate_math_rollout( dataset_name=problem.get("dataset"), answer_status=answer_status, ) + + +def create_timeout_rollout_result( + cfg: DictConfig, + problem: dict, + latency: float, +) -> RolloutResult: + answer_status = "timeout" + metrics = Metrics( + reward=0.0, + success=False, + no_error=False, + no_answer=True, + penalty=0.0, + overflow=False, + auto_boxed=False, + ) + log_answer_status(cfg, problem, answer_status, metrics.reward, latency) + return RolloutResult( + training_texts=[], + metrics=metrics, + latency=latency, + dataset_name=problem.get("dataset"), + answer_status=answer_status, + ) From 57e5264ea7f7535dbc179fda0259c748f0090f6d Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:46:06 +0000 Subject: [PATCH 19/20] Allow to continue training during eval --- conf/base.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conf/base.yaml b/conf/base.yaml index 6be2b593..c29092f5 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -15,6 +15,7 @@ actor: llm_max_rollouts: 64 rollout_workers: 1 discount_factor: 1 + pause_training_during_eval: true problem_queue_size: 64 result_queue_size: 64 throughput_window_size: 50 @@ -140,4 +141,3 @@ wandb: wandb_dir: null # Comma-separated list of keywords to tag the run. tags: [] - From 5892fd36ad7c1231aa4e58d03c9563f84e02f826 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 4 Nov 2025 12:46:34 +0000 Subject: [PATCH 20/20] remove test code --- pipelinerl/domains/math/minimal_rollout.py | 72 ---------------------- 1 file changed, 72 deletions(-) delete mode 100644 pipelinerl/domains/math/minimal_rollout.py diff --git a/pipelinerl/domains/math/minimal_rollout.py b/pipelinerl/domains/math/minimal_rollout.py deleted file mode 100644 index bc46d2a1..00000000 --- a/pipelinerl/domains/math/minimal_rollout.py +++ /dev/null @@ -1,72 +0,0 @@ -import time -import random - -import aiohttp -from omegaconf import DictConfig -from pydantic import BaseModel -from pipelinerl.rollouts import RolloutResult, BaseMetrics -from pipelinerl.world import Job -from tapeagents.core import Prompt -from tapeagents.llms.trainable import TrainableLLM - -from pipelinerl.async_llm import llm_async_generate, make_training_text -from .verifier_api import verify_answer_rpc - -class Metrics(BaseMetrics): - pass - -class RewardTable(BaseModel): - wrong_answer_not_finished: float - wrong_answer_finished: float - no_answer_not_finished: float - no_answer_finished: float - unparsable_not_finished: float - unparsable_finished: float - correct_answer_not_finished: float - correct_answer_finished: float - buffer_tokens: int = 0 # 0 means no overlong reward shaping - -def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: - """ - Compute the overlong penalty - """ - if sequence_length > (max_length - buffer_tokens) and sequence_length <= max_length: - return ((max_length - buffer_tokens) - sequence_length) / buffer_tokens - return 0. - -def get_reward(trace, answer_status: str, rewards: RewardTable) -> float: - pass - - -async def generate_math_rollout( - cfg: DictConfig, - llm: TrainableLLM, - problem: dict, - session: aiohttp.ClientSession, -) -> RolloutResult: - messages = [] - if cfg.actor.system_prompt: - messages.append({"role": "system", "content": cfg.actor.system_prompt}) - messages.append({"role": "user", "content": f"{problem['task']} \n{cfg.actor.task_prompt}"}) - prompt = Prompt(messages=messages) - - time_start = time.time() - llm_call = await llm_async_generate(llm, prompt, session) - latency = time.time() - time_start - - assert llm_call.output.content is not None - rewards = RewardTable(**dict(cfg.rewards)) - - env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] - env_job = random.choice(env_jobs) - assert env_job.port is not None - answer_status = await verify_answer_rpc(session=session, host=env_job.hostname, port=env_job.port, prediction=llm_call.output.content, gold=problem["answer"]) - - trace = make_training_text(llm, llm_call) - reward = get_reward(trace, answer_status, rewards) - trace.reward = reward - - metrics = Metrics(reward=reward, success=answer_status == "correct", no_error=answer_status != "unparsable", no_answer=answer_status == "no_answer") - - - return RolloutResult(training_texts=[trace], metrics=metrics, latency=latency, dataset_name=problem.get("dataset"))