Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions recipe/dapo/config/dapo_fsdp_config_with_resampling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
hydra:
searchpath:
- file://verl/trainer/config

defaults:
- ppo_trainer
- _self_

# parameters added to enable PassRateWeightedSampler with DAPO; override parameters in verl/trainer/config/data/legacy_data.yaml
data:
gen_batch_size: ${data.train_batch_size}
dataloader_num_workers: 0 # Recommended to set to 0 when using curriculum learning samplers (e.g., PassRateWeightedSampler) to prevent data caching before batches are reordered.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious to know if it workers if it is non zero

sampler:
pass_rate_temperature: 1.0 # temperature parameter for PassRateWeightedSampler, controls sharpness of weighting distribution
use_ema: False # whether to use EMA smoothed pass rates for weighting
ema_alpha: 0.1 # alpha parameter for EMA smoothing of pass rates

reward_model:
reward_manager: dapo
overlong_buffer:
enable: False # We try to avoid forgetting to set enable
len: 0
penalty_factor: 0.0
log: False

algorithm:
filter_groups:
_target_: verl.trainer.config.FilterGroupsConfig
enable: False # We try to avoid forgetting to set enable
metric: null # acc / score / seq_reward / seq_final_reward / ...
max_num_gen_batches: 0 # Non-positive values mean no upper limit


83 changes: 77 additions & 6 deletions recipe/dapo/dapo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from verl.utils.profiler import marked_timer
from verl.utils.rollout_skip import RolloutSkip

from verl.utils.pass_rate_weighted_sampler import PassRateWeightedSampler

class RayDAPOTrainer(RayPPOTrainer):
"""
Expand All @@ -68,12 +68,19 @@ def fit(self):
config=OmegaConf.to_container(self.config, resolve=True),
)

self.global_steps = 0
self.global_steps = 0
self.gen_steps = 0

# load checkpoint before doing anything
self._load_checkpoint()

# Extract pass rate tracker from sampler if using curriculum learning
# The PassRateWeightedSampler owns the tracker internally but we need to manually update it during training
# Currently, we only support PassRateWeightedSampler for curriculum learning
self.pass_rate_tracker = None
self.data_sampler = self.train_dataloader.sampler # train_dataloader is created in `RayPPOTrainer._create_dataloader()` and always has a sampler
if isinstance(self.data_sampler, PassRateWeightedSampler):
self.pass_rate_tracker = self.data_sampler.pass_rate_tracker

# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
Expand Down Expand Up @@ -135,7 +142,6 @@ def fit(self):
non_tensor_batch_keys=["raw_prompt_ids"],
)
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

is_last_step = self.global_steps >= self.total_training_steps

with marked_timer("step", timing_raw):
Expand Down Expand Up @@ -189,7 +195,6 @@ def fit(self):
reward_extra_infos_dict = {}

new_batch.batch["token_level_scores"] = reward_tensor

if reward_extra_infos_dict:
new_batch.non_tensor_batch.update(
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
Expand All @@ -206,6 +211,47 @@ def fit(self):
else:
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]

# === Curriculum Learning: Update pass rate tracker for weighted resampling ===
# When using PassRateWeightedSampler, track per-sample success rates to enable dynamic curriculum learning.
# The sampler uses these pass rates to adjust sampling probabilities in the next epoch.

# Note: make updating the pass rate tracker as a utility function later
# 1. if sampler is an instance of PassRateWeightedSampler, self.pass_rate_tracker is not None
# 2. `dataset_index` field is added to the RL datatset to identify samples
if "dataset_index" in new_batch.non_tensor_batch and self.pass_rate_tracker is not None:
dataset_indices = new_batch.non_tensor_batch["dataset_index"]
# Sum token-level rewards to get sequence-level reward
seq_rewards = new_batch.batch["token_level_rewards"].sum(dim=-1).cpu().numpy()
# Success is 1 if sequence reward > 0, else 0
successes = (seq_rewards > 0).astype(float)

# Deduplicate: batch was repeated n times (interleaved), so we need to aggregate
unique_indices, inverse_indices = np.unique(dataset_indices, return_inverse=True)

assert len(unique_indices) > 0, "No unique samples found in batch. Check data pipeline configuration."
# Aggregate successes: take mean across rollouts for each sample
aggregated_successes = np.zeros(len(unique_indices), dtype=float)
for i, _ in enumerate(unique_indices):
mask = inverse_indices == i # boolean array to indicate positions of unique index i
aggregated_successes[i] = np.mean(successes[mask]) # take average success across rollouts for sample i

pass_rates = self.pass_rate_tracker.get_pass_rates()

# Log curriculum metrics BEFORE updating tracker
# Track improvement of hardest samples (across all samples, not just attempted)
metrics['curriculum/hardest_10pct_pass_rate'] = float(np.percentile(pass_rates, 10))
metrics['curriculum/hardest_25pct_pass_rate'] = float(np.percentile(pass_rates, 25))
metrics['curriculum/hardest_50pct_pass_rate'] = float(np.percentile(pass_rates, 50))
metrics['curriculum/hardest_75pct_pass_rate'] = float(np.percentile(pass_rates, 75))

# Batch-level statistics
metrics['curriculum/min_batch_pass_rate'] = float(np.min(aggregated_successes))
metrics['curriculum/mean_batch_pass_rate'] = float(np.mean(aggregated_successes))
metrics['curriculum/effective_batch_size'] = np.sum(aggregated_successes > 0)/len(unique_indices)

# Update tracker with current batch results
self.pass_rate_tracker.update(sample_indices=unique_indices.astype(int), batch_pass_rate=aggregated_successes)

if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
Expand Down Expand Up @@ -280,7 +326,6 @@ def fit(self):
# === Updating ===

batch.batch["response_mask"] = compute_response_mask(batch)

# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
Expand Down Expand Up @@ -342,6 +387,7 @@ def fit(self):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
print("in critic warmup loop")

# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
Expand Down Expand Up @@ -430,6 +476,31 @@ def _to_sequence(value):
num_total_prompts = 0
num_gen_batches = 0

# Add curriculum learning metrics to W&B
if isinstance(self.data_sampler, PassRateWeightedSampler):
# Add 3D plot data for weight and count distributions (percentile-based)
try:
import wandb
import pandas as pd

weight_3d_data = self.data_sampler.get_wandb_3d_plot_data(metric_type='weight')
count_3d_data = self.data_sampler.get_wandb_3d_plot_data(metric_type='count')

# Add step to each data point for 3D visualization
for point in weight_3d_data:
point['step'] = self.global_steps
for point in count_3d_data:
point['step'] = self.global_steps

metrics['curriculum/weight_distribution_3d'] = wandb.Table(
dataframe=pd.DataFrame(weight_3d_data)
) if weight_3d_data else None
metrics['curriculum/count_distribution_3d'] = wandb.Table(
dataframe=pd.DataFrame(count_3d_data)
) if count_3d_data else None
except ImportError:
pass # wandb or pandas not available

# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)

Expand Down
Loading