Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ actor:
llm_max_rollouts: 64
rollout_workers: 1
discount_factor: 1
pause_training_during_eval: true
problem_queue_size: 64
result_queue_size: 64
throughput_window_size: 50
Expand Down Expand Up @@ -140,4 +141,3 @@ wandb:
wandb_dir: null
# Comma-separated list of keywords to tag the run.
tags: []

368 changes: 332 additions & 36 deletions pipelinerl/actor.py

Large diffs are not rendered by default.

143 changes: 122 additions & 21 deletions pipelinerl/domains/math/load_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import random
import re
from typing import Dict, List, Tuple
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple

import datasets
import hydra
Expand Down Expand Up @@ -190,26 +191,6 @@ def _load_aime_dataset(year: int, upsample_factor: int = 0) -> list[dict]:
return add_ids(samples)


def _load_aime_2025_opencompass(upsample_factor: int = 0) -> list[dict]:
configs = ["AIME2025-I", "AIME2025-II"]
dataset_name = "aime_2025" + ("" if upsample_factor > 0 else "_original")

samples: list[dict] = []
for config_name in configs:
ds = load_dataset("opencompass/AIME2025", config_name, split="test")
samples.extend([s for s in process_math(ds, dataset_name) if s is not None])

original_size = len(samples)
if upsample_factor > 0:
samples *= upsample_factor

logger.info(
f"Loading aime 2025 (OpenCompass) dataset: {len(samples)} samples"
+ (f" (upsampled from {original_size})" if upsample_factor > 0 else "")
)
return add_ids(samples)


def _load_amc_dataset(year: int, upsample_factor: int = 0) -> list[dict]:
amc_dataset = load_dataset("AI-MO/aimo-validation-amc", split="train", trust_remote_code=True)
amc_dataset = amc_dataset.filter(lambda x: str(year) in x["url"])
Expand All @@ -234,32 +215,109 @@ def add_ids(dataset: list[dict]):
return dataset


def _resolve_custom_path(relative_paths: str | Sequence[str]) -> Path:
"""
Resolve a path for locally generated datasets.

Hydra jobs may change the working directory, so we check both the current
directory and the repository root.
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]

resolved = Path(__file__).resolve()
base_candidates = [Path.cwd()]
if len(resolved.parents) >= 5:
base_candidates.append(resolved.parents[4])

candidates: List[Path] = []
for rel in relative_paths:
rel_path = Path(rel)
candidates.append(rel_path)
for base in base_candidates:
if base == Path.cwd():
continue
candidates.append(base / rel_path)

for candidate in candidates:
if candidate.exists():
return candidate
raise FileNotFoundError(
f"Custom dataset not found. Tried: {[str(path) for path in candidates]}"
)


def _load_custom_dataset(dataset_name: str) -> list[dict]:
"""
Load a locally generated dataset by name.

The loader searches under `datasets/custom/` and `datasets/custom_runs/` for either
`<dataset_name>` or `<dataset_name>.jsonl`.
"""
candidate_names: List[str] = []
if dataset_name.endswith(".jsonl"):
candidate_names.append(dataset_name)
else:
candidate_names.extend([dataset_name, f"{dataset_name}.jsonl"])

search_paths: List[str] = []
for name in candidate_names:
search_paths.extend(
[
f"datasets/custom/{name}",
f"datasets/custom_runs/{name}",
name,
]
)

dataset_path = _resolve_custom_path(search_paths)
with dataset_path.open("r", encoding="utf-8") as handle:
samples = [json.loads(line) for line in handle if line.strip()]

dataset_label = dataset_name[:-6] if dataset_name.endswith(".jsonl") else dataset_name

for idx, sample in enumerate(samples):
sample.setdefault("source_dataset", sample.get("dataset", dataset_label))
sample.setdefault("source_id", sample.get("id"))
sample["dataset"] = dataset_label
sample["id"] = idx

logger.info(f"Loading custom dataset {dataset_name}: {len(samples)} samples from {dataset_path}")
return samples


def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None) -> List[Tuple[str, Dict]]:
if dataset_names is None:
return []

if isinstance(dataset_names, str):
dataset_names = [dataset_names]
# Preserve order while de-duplicating
dataset_names = list(dict.fromkeys(dataset_names))
datasets = []
remaining = set(dataset_names)
if "eurus_train" in dataset_names:
dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data", split="train", trust_remote_code=True)
samples = [s for s in process_eurus(dataset) if s is not None]
logger.info(f"Loading eurus train dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("eurus_train")

# great for debugging since its much smaller than eurus train
if "eurus_validation" in dataset_names:
dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data", split="validation", trust_remote_code=True)
samples = [s for s in process_eurus(dataset) if s is not None]
logger.info(f"Loading eurus validation dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("eurus_validation")

if "math_train" in dataset_names:
# math_dataset = load_math("train")
dataset = load_dataset("hendrycks/competition_math", split="train", trust_remote_code=True)
samples = [s for s in process_math(dataset, "math_train") if s is not None]
logger.info(f"Loading math train dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("math_train")

if "math_simplerl_train" in dataset_names:
# SimpleRL MATH dataset
Expand All @@ -274,6 +332,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_math(dataset, "math_simplerl_train") if s is not None]
logger.info(f"Loading math simplerl train dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("math_simplerl_train")

if "simplerl_math_subset_1000" in dataset_names:
# SimpleRL MATH dataset subset
Expand All @@ -292,49 +351,57 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = samples[:1000]
logger.info(f"Loading math simplerl subset test dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("simplerl_math_subset_1000")

if "deepscaler_preview" in dataset_names:
dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", trust_remote_code=True)
samples = [s for s in process_math(dataset, "deepscaler") if s is not None]
logger.info(f"Loading deepscaler preview train dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("deepscaler_preview")

if "math_test" in dataset_names:
# math_dataset = load_math("test")
dataset = load_dataset("hendrycks/competition_math", split="test", trust_remote_code=True)
samples = [s for s in process_math(dataset, "math_test") if s is not None]
logger.info(f"Loading math test dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("math_test")

if "omni_math_500" in dataset_names:
dataset = load_dataset("reliable-agents/Omni-MATH-500", split="test", trust_remote_code=True)
samples = [s for s in process_math(dataset, "omni_math_500") if s is not None]
logger.info(f"Loading omni math 500 dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("omni_math_500")

if "math_500" in dataset_names:
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test", trust_remote_code=True)
samples = [s for s in process_math(dataset, "math_500") if s is not None]
logger.info(f"Loading math 500 dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("math_500")

if "open_r1_math_220k" in dataset_names:
dataset = load_dataset("open-r1/OpenR1-Math-220k", split="default", trust_remote_code=True)
samples = [s for s in process_math(dataset, "open_r1_math_220k") if s is not None]
logger.info(f"Loading open r1 math 220k dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("open_r1_math_220k")

if "gpqa_main" in dataset_names:
dataset = load_dataset("hendrydong/gpqa_main", split="test", trust_remote_code=True)
samples = [s for s in process_gpqa(dataset, "gpqa_main") if s is not None]
logger.info(f"Loading gpqa main dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("gpqa_main")

if "gpqa_diamond" in dataset_names:
dataset = load_dataset("hendrydong/gpqa_diamond", split="test", trust_remote_code=True)
samples = [s for s in process_gpqa(dataset, "gpqa_diamond") if s is not None]
logger.info(f"Loading gpqa diamond dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("gpqa_diamond")

if "gpqa_diamond" in dataset_names:
pass
Expand All @@ -344,62 +411,78 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_gsm8k(dataset, "gsm8k_train") if s is not None]
logger.info(f"Loading gsm8k train dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("gsm8k_train")

if "gsm8k_test" in dataset_names:
dataset = load_dataset("openai/gsm8k", "main", split="test", trust_remote_code=True)
samples = [s for s in process_gsm8k(dataset, "gsm8k_test") if s is not None]
logger.info(f"Loading gsm8k test dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("gsm8k_test")

if "limo" in dataset_names:
dataset = load_dataset("GAIR/LIMO", split="train", trust_remote_code=True)
samples = [s for s in process_limo(dataset) if s is not None]
logger.info(f"Loading limo dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("limo")

if "aime_2022" in dataset_names:
datasets += _load_aime_dataset(2022, upsample_factor=16)
remaining.discard("aime_2022")

if "aime_2022_original" in dataset_names:
datasets += _load_aime_dataset(2022)
remaining.discard("aime_2022_original")

if "aime_2023" in dataset_names:
datasets += _load_aime_dataset(2023, upsample_factor=16)
remaining.discard("aime_2023")

if "aime_2023_original" in dataset_names:
datasets += _load_aime_dataset(2023)
remaining.discard("aime_2023_original")

if "aime_2024" in dataset_names:
datasets += _load_aime_dataset(2024, upsample_factor=16)
remaining.discard("aime_2024")

if "aime_2024_original" in dataset_names:
datasets += _load_aime_dataset(2024)
remaining.discard("aime_2024_original")

if "aime_2025" in dataset_names:
datasets += _load_aime_2025_opencompass_dataset(upsample_factor=16)
remaining.discard("aime_2025")

if "aime_2025_original" in dataset_names:
datasets += _load_aime_2025_opencompass_dataset()
remaining.discard("aime_2025_original")

if "amc_2022" in dataset_names:
# TODO: AMC 2022 is 43 problems, is that to be expected?
datasets += _load_amc_dataset(2022, upsample_factor=16)
remaining.discard("amc_2022")

if "amc_2022_original" in dataset_names:
datasets += _load_amc_dataset(2022)
remaining.discard("amc_2022_original")

if "amc_2023" in dataset_names:
datasets += _load_amc_dataset(2023, upsample_factor=16)
remaining.discard("amc_2023")

if "amc_2023_original" in dataset_names:
datasets += _load_amc_dataset(2023)
remaining.discard("amc_2023_original")

if "sometimes_success_data" in dataset_names:
PATH = "data/sometimes_success_data/data.jsonl"
with open(PATH, "r") as f:
samples = [json.loads(line) for line in f]
logger.info(f"Loading easy data dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("sometimes_success_data")

if "open_reasoner_zero_57k" in dataset_names:
dataset = load_dataset(
Expand All @@ -411,6 +494,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_57k") if s is not None]
logger.info(f"Loading Open Reasoner Zero dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("open_reasoner_zero_57k")

if "open_reasoner_zero_extended_72k" in dataset_names:
dataset = load_dataset(
Expand All @@ -422,6 +506,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_extended_72k") if s is not None]
logger.info(f"Loading Open Reasoner Zero extended dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("open_reasoner_zero_extended_72k")

if "open_reasoner_zero_hard_13k" in dataset_names:
dataset = load_dataset(
Expand All @@ -433,6 +518,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_hard_13k") if s is not None]
logger.info(f"Loading Open Reasoner Zero hard dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard("open_reasoner_zero_hard_13k")

for dataset_name in dataset_names:
test_matched = re.match(r"multiplication_(\d+)_by_(\d+)_(\d+)_test", dataset_name)
Expand All @@ -453,6 +539,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
]
logger.info(f"Loading multiplication {num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples")
datasets += add_ids(samples)
remaining.discard(dataset_name)
elif train_matched:
upto_prefix = train_matched.group(1) or ""
num_digits_1 = int(train_matched.group(2))
Expand All @@ -474,6 +561,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
f"Loading multiplication {upto_prefix}_{num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples"
)
datasets += add_ids(samples)
remaining.discard(dataset_name)

if "countdown" in dataset_names:
dataset = load_dataset(
Expand All @@ -482,6 +570,19 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_countdown(dataset) if s is not None]
logger.info(f"Loading countdown dataset: {len(samples)} samples")
datasets += samples
remaining.discard("countdown")

# resolve any remaining names as local custom datasets.
unresolved: List[str] = []
for dataset_name in list(remaining):
try:
datasets += _load_custom_dataset(dataset_name)
remaining.discard(dataset_name)
except FileNotFoundError:
unresolved.append(dataset_name)

if unresolved:
raise ValueError(f"Unknown dataset(s): {unresolved}")

if len(datasets) == 0:
raise ValueError("No datasets loaded")
Expand Down
Loading