diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 2cd776e31..7a13c6b9a 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -148,7 +148,20 @@ def new(self, **kwargs): args.update(kwargs) return GenerationHyperparameters(**args) - +@dataclass +class PRMRewardHyperparameters: + reward_shaping_alpha: float = field( + default=0.02, + metadata={"help": "reward shaping alpha"}, + ) + use_clip: bool = field( + default=True, + metadata={"help": "Whether to use clip mechanism."}, + ) + use_delta: bool = field( + default=True, + metadata={"help": "Whether to use delta mechanism."}, + ) # Train Engine Configs @@ -577,7 +590,7 @@ class SGLangConfig: max_lora_rank: int | None = None lora_target_modules: List[str] | None = None lora_paths: List[str] | None = None - max_loaded_loras: int = 1 + # max_loaded_loras: int = 1 max_loras_per_batch: int = 1 lora_backend: str = "triton" # logging @@ -1118,6 +1131,19 @@ class GRPOConfig(BaseExperimentConfig): actor: PPOActorConfig = field(default_factory=PPOActorConfig) ref: PPOActorConfig = field(default_factory=PPOActorConfig) +@dataclass +class PRMConfig(BaseExperimentConfig): + async_training: bool = field(default=True) + prm_path: str = field(default="") + gconfig: GenerationHyperparameters = field( + default_factory=GenerationHyperparameters + ) + prmconfig: PRMRewardHyperparameters = field( + default_factory=PRMRewardHyperparameters + ) + rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig) + actor: PPOActorConfig = field(default_factory=PPOActorConfig) + ref: PPOActorConfig = field(default_factory=PPOActorConfig) @dataclass class PPOConfig(GRPOConfig): diff --git a/areal/api/reward_api.py b/areal/api/reward_api.py index 5409a4a50..6a239feae 100644 --- a/areal/api/reward_api.py +++ b/areal/api/reward_api.py @@ -103,7 +103,7 @@ def _recreate_executor(cls, executor_key, max_workers): return cls._executors[executor_key] return None - async def __call__(self, *args, **kwargs) -> float: + async def __call__(self, *args, **kwargs): last_exception = None for attempt in range(self.max_retries + 1): diff --git a/areal/engine/ppo/actor.py b/areal/engine/ppo/actor.py index 6f8e82066..2f7600543 100644 --- a/areal/engine/ppo/actor.py +++ b/areal/engine/ppo/actor.py @@ -1,5 +1,6 @@ import functools -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple +import warnings import torch @@ -287,6 +288,224 @@ def compute_advantages(self, *args, **kwargs) -> None: def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]: return self.actor.ppo_update(*args, **kwargs) +class FSDPPPOActorDense(FSDPPPOActor): + + def __init__(self, config: PPOActorConfig): + super().__init__(config) + self.actor = DensePPOActor(config, self) + +class DensePPOActor(PPOActor): + def __init__(self, config: PPOActorConfig, engine: TrainEngine): + super().__init__(config, engine) + def compute_advantages(self, data: Dict[str, Any]) -> None: + bs = data["input_ids"].shape[0] + max_seqlen = data["input_ids"].shape[1] + batch_indices = torch.arange( + bs, device=data["input_ids"].device, dtype=torch.long + ) + + # Reward Penalty on length + if self.config.overlong_reward_penalty: + + overlong_tokens = self.config.overlong_tokens + overlong_penalty_factor = self.config.overlong_penalty_factor + + data = reward_overlong_penalty( + data, + overlong_tokens=overlong_tokens, + overlong_penalty_factor=overlong_penalty_factor, + max_response_length=self.config.max_new_tokens, + ) + + # Reward Scaling + reward_score = data["rewards"] + reward_score = (reward_score + self.reward_bias) * self.reward_scaling + reward_score = torch.clip( + reward_score, max=self.reward_clip, min=-self.reward_clip + ) + if self.reward_norm: + reward_score = self.reward_norm(reward_score) + + loss_mask = data["loss_mask"].float() + loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1) + # Apply the mask to log probabilities. + if not self.config.use_decoupled_loss and self.config.recompute_logprob: + # Overwrite logprobs produced by the inference engine + old_logp = data["logprobs"] = data["prox_logp"] + else: + old_logp = torch.roll(data["logprobs"], shifts=-1, dims=-1) + if not self.config.use_decoupled_loss: + # prox logp not available, use inferenced logp + data["prox_logp"] = old_logp + ref_logp = data.get("ref_logp", torch.zeros_like(old_logp)) + ref_logp *= loss_mask + old_logp *= loss_mask + + # Compute KL-regularized rewards. + attn_mask = data["attention_mask"] + seqlens = attn_mask.sum(-1).long() + seq_no_eos_mask = seqlens == attn_mask.shape[1] + rewards = -self.kl_ctl * (old_logp - ref_logp) + kl_rewards = rewards.clone() + # KL rewards at the next token after eos is zero. + rewards[batch_indices, seqlens - 1] = 0 + indices = torch.clip(seqlens - 2, min=0) + # print(f"reward_score: {reward_score.shape}, {reward_score}") + # print(f"rewards before: {rewards.shape}, {rewards}") + if self.mask_no_eos_with_zero: + rewards[batch_indices, :] += torch.where( + seq_no_eos_mask, 0, reward_score + ) + else: + rewards[batch_indices, :] += reward_score + # print(f"rewards after: {rewards}") + # Compute GAE. + if "values" not in data: + values = torch.zeros_like(rewards) + else: + values = data["values"] + advantages_reversed = [ + torch.zeros(bs, dtype=torch.float32, device=values.device) + ] + lastgaelam = 0 + nextvalues = values[:, max_seqlen - 1] * seq_no_eos_mask + for t in reversed(range(max_seqlen - 1)): + delta = rewards[:, t] + self.discount * nextvalues - values[:, t] + newgaelam = delta + self.discount * self.gae_lambda * lastgaelam + + # Skip tokens that do not contribute to the loss + mask = loss_mask[:, t] + nextvalues = nextvalues * (1 - mask) + values[:, t] * mask + lastgaelam = lastgaelam * (1 - mask) + newgaelam * mask + advantages_reversed.append(lastgaelam) + + advantages = torch.stack(advantages_reversed[::-1], dim=1) + data["returns"] = advantages + values + + # Optionally perform advantage normalization. + if self.adv_norm is not None: + advantages = self.adv_norm(advantages, loss_mask) + + # Store data in the dict. + data["advantages"] = advantages + data["kl_rewards"] = kl_rewards + data["tot_rewards"] = rewards + data["loss_mask"] = loss_mask + # because we have rolled old_logp by -1 + data["logprobs"] = old_logp + + def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]: + + if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0: + data, sampling_stat = dynamic_sampling_dense_reward(data, self.group_size) + + attn_mask = data["attention_mask"] + loss_mask = data["loss_mask"] + reward_score = data["rewards"] + seqlens = attn_mask.sum(-1) + + all_stats = [] + ########## Logging code starts ########## + result_denominators = { + "correct_n_seqs": (reward_score[:, -1] > 0).bool(), + "incorrect_n_seqs": (reward_score[:, -1] <= 0).bool(), + } + if self.config.log_agent_stats: + assert ( + "begin_of_trajectory" in data + ), "'begin_of_trajectory' is expected to log agent statistics" + assert ( + len(self.config.log_agent_stats_keys) > 0 + ), "`log_agent_stats_keys` should not be empty when log_agent_stats=True" + agent_denominator = (data["begin_of_trajectory"] > 0).bool() + result_denominators["agent"] = agent_denominator + global_denominators = dict( + n_seqs=torch.ones_like(reward_score[:, 0], dtype=torch.bool), + n_tokens=torch.ones_like(loss_mask, dtype=torch.bool), + n_valid_tokens=loss_mask.bool(), + **result_denominators, + ) + stats_tracker.denominator(**global_denominators) + stats_tracker.stat( + correct_seq_len=seqlens.float(), denominator="correct_n_seqs" + ) + stats_tracker.stat( + incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs" + ) + + stats = dict( + advantages=data["advantages"], + kl_rewards=data["kl_rewards"], + final_reward=data["tot_rewards"], + ) + stats_tracker.stat(**stats, denominator="n_valid_tokens") + + prompt_lens = [] + prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1) + seq_stats = dict( + no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(), + task_reward=reward_score[:, -2].float(), + prompt_len=prompt_lens.float(), + seq_len=seqlens.float(), + ) + stats_tracker.stat(**seq_stats, denominator="n_seqs") + scalars = dict( + mask_no_eos_with_zero=self.config.mask_no_eos_with_zero, + eps_clip=self.config.eps_clip, + ) + if self.config.c_clip is not None: + scalars["c_clip"] = self.config.c_clip + scalars["use_dual_clip"] = 1 + else: + scalars["use_dual_clip"] = 0 + if self.config.behav_imp_weight_cap is not None: + scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap + stats_tracker.scalar(**scalars) + + if self.config.log_agent_stats: + stats_tracker.stat( + **{k: data[k].float() for k in self.config.log_agent_stats_keys}, + denominator="agent", + ) + + global_stats = stats_tracker.export( + reduce_group=self.engine.data_parallel_group + ) + for k in global_denominators: + keys = list(global_stats.keys()) + for k2 in keys: + if k2.endswith(k): + global_stats.pop(k2) + ########## Logging code ends ########## + + for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]: + data.pop(key, None) + # NOTE: calling engine.train() is critical to enabling gradient checkpointing + self.engine.train() + mb_inputs = split_padded_tensor_dict_into_mb_list( + data, + mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), + ) + for mb in mb_inputs.mbs: + train_stat = self.engine.train_batch( + mb, + loss_fn=functools.partial( + grpo_loss_fn, + temperature=self.temperature, + eps_clip=self.config.eps_clip, + eps_clip_higher=self.config.eps_clip_higher, + c_clip=self.config.c_clip, + behav_imp_weight_cap=self.config.behav_imp_weight_cap, + ), + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) + stats_tracker.scalar(**train_stat) + all_stats.append( + stats_tracker.export(reduce_group=self.engine.data_parallel_group) + ) + all_stats[0].update(global_stats) + return all_stats + def grpo_loss_fn( logits: torch.Tensor, @@ -364,3 +583,63 @@ def grpo_loss_fn( denominator="clipped_tokens", ) return loss + +def dynamic_sampling_dense_reward( + data: Dict[str, Any], group_size: int +) -> Tuple[Dict[str, Any], Dict[str, int]]: + """Filter samples by group when all rewards in a group are equal. + + Assumes samples of the same group are adjacent in the batch. + + Returns a new dict containing only kept samples (mask applied on batch dim + for all tensor values whose first dimension equals batch size), and a small + stats dict. + """ + rewards = data["rewards"] + if not torch.is_tensor(rewards): + raise TypeError("data['rewards'] must be a torch.Tensor") + batch_size = rewards.shape[0] + + if group_size <= 0: + warnings.warn("group_size <= 0; returning original data") + return data, dict(n_group_kept=0, n_group_filtered=0) + + if batch_size % group_size != 0: + warnings.warn( + "The group size is not divisible by the batch size. Return the original data" + ) + return data, dict( + n_group_kept=batch_size // max(group_size, 1), n_group_filtered=0 + ) + + # Calculate number of groups (must be divisible) + num_groups = batch_size // group_size + + # Reshape rewards to (num_groups, group_size) for group-wise operations + rewards_reshaped = rewards.view(num_groups, group_size * rewards.shape[1]) + + # Check if all elements in each group are equal to the first element + all_equal = (rewards_reshaped == rewards_reshaped[:, 0:1]).all(dim=1) + + # Create mask for groups to keep (where not all rewards are equal) + valid_groups = ~all_equal + + # Expand the group mask to individual samples + mask = valid_groups.repeat_interleave(group_size) + + # In case all group is filtered out, return the original data (although not gradient in this case) + if not mask.any(): + return data, dict(n_group_kept=0, n_group_filtered=num_groups) + + n_group_kept = int(valid_groups.sum().item()) + n_group_filtered = int(num_groups - n_group_kept) + + # Apply mask row-wise across tensors that share the same batch dimension + filtered: Dict[str, Any] = {} + for k, v in data.items(): + if torch.is_tensor(v) and v.shape[:1] == (batch_size,): + filtered[k] = v[mask] + else: + # keep untouched (e.g., scalars, metadata); caller should ensure consistency + filtered[k] = v + return filtered, dict(n_group_kept=n_group_kept, n_group_filtered=n_group_filtered) \ No newline at end of file diff --git a/areal/utils/launcher.py b/areal/utils/launcher.py index 73010e1c8..5c24f0584 100644 --- a/areal/utils/launcher.py +++ b/areal/utils/launcher.py @@ -11,7 +11,7 @@ logger = logging.getLogger("Launcher Utils") -LOCAL_CACHE_DIR = "/tmp/areal" +LOCAL_CACHE_DIR = "/data/yl/AReaL/tmp/areal" PYTORCH_KERNEL_CACHE_PATH = ( f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/" ) diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index 321bc4991..e4800c028 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -20,11 +20,11 @@ class StatsLogger: def __init__(self, config: BaseExperimentConfig, ft_spec: FinetuneSpec): - if not isinstance(config, StatsLoggerConfig): - raise ValueError( - "Passing config.stats_logger as the config is deprecated. " - "Please pass the full config instead." - ) + # if not isinstance(config, StatsLoggerConfig): + # raise ValueError( + # "Passing config.stats_logger as the config is deprecated. " + # "Please pass the full config instead." + # ) self.exp_config = config self.config = config.stats_logger self.ft_spec = ft_spec diff --git a/areal/workflow/rlvr_prm.py b/areal/workflow/rlvr_prm.py new file mode 100644 index 000000000..1e504cc3e --- /dev/null +++ b/areal/workflow/rlvr_prm.py @@ -0,0 +1,237 @@ +import asyncio +from dataclasses import asdict +import os +import re +import uuid + +import aiofiles +import aiofiles.os +import colorama +import torch +from transformers import PreTrainedTokenizerFast, PreTrainedModel + +from areal.api.cli_args import GenerationHyperparameters, PRMRewardHyperparameters +from areal.api.engine_api import InferenceEngine +from areal.api.io_struct import ModelRequest +from areal.api.reward_api import AsyncRewardWrapper +from areal.api.workflow_api import RolloutWorkflow +from areal.utils import logging, stats_tracker +from areal.utils.data import concat_padded_tensors + +logger = logging.getLogger("RLVR workflow") + + +class PRMRLVRWorkflow(RolloutWorkflow): + def __init__( + self, + reward_fn, + reward_fn_prm, + gconfig: GenerationHyperparameters, + prmconfig: PRMRewardHyperparameters, + tokenizer: PreTrainedTokenizerFast, + # prm_model: PreTrainedModel, + # prm_tokenizer: PreTrainedTokenizerFast, + enable_thinking: bool, + rollout_stat_scope: bool = "rollout", + dump_dir: str | None = None, + ): + self.reward_fn = reward_fn + self.gconfig = gconfig + self.prmconfig = prmconfig + self.tokenizer = tokenizer + # self.prm_model = prm_model + # self.prm_tokenizer = prm_tokenizer + self.enable_thinking = enable_thinking + self.dump_dir = dump_dir + self.rollout_stat_scope = rollout_stat_scope + self.async_reward_fn = AsyncRewardWrapper(reward_fn) + self.async_reward_fn_prm = AsyncRewardWrapper(reward_fn_prm, timeout_seconds=100) + if self.dump_dir is not None and not os.path.exists(self.dump_dir): + os.makedirs(self.dump_dir, exist_ok=True) + + async def arun_episode(self, engine: InferenceEngine, data): + input_ids = self.tokenizer.apply_chat_template( + data["messages"], + tokenize=True, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + + n_samples = self.gconfig.n_samples + req = ModelRequest( + rid=uuid.uuid4().hex, + input_ids=input_ids, + gconfig=self.gconfig.new(n_samples=1), + tokenizer=self.tokenizer, + ) + resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)]) + + version = engine.get_version() + prompt_strs = [] + completions_strs = [] + rewards = [] + result_rewards = [] + prm_rewards = [] + reward_masks = [] + seqlens = [] + + results = [] + for resp in resps: + seq = resp.input_tokens + resp.output_tokens + logprobs = [0.0] * resp.input_len + resp.output_logprobs + loss_mask = [0] * resp.input_len + [1] * resp.output_len + versions = [-1] * resp.input_len + resp.output_versions + + prompt_str = self.tokenizer.decode(input_ids) + completions_str = self.tokenizer.decode(resp.output_tokens) + prompt_strs.append(prompt_str) + completions_strs.append(completions_str) + seqlens.append(len(seq)) + result_reward = await self.async_reward_fn( + prompt_str, + completions_str, + resp.input_tokens, + resp.output_tokens, + **data, + ) + + # separate steps + full_str = self.tokenizer.decode(resp.output_tokens, clean_up_tokenization_spaces=False) + raw_lines = full_str.split("\n") + lines = [line for line in raw_lines if line.strip() != ""] + ends = [] + pos = 0 + line_i = 0 + for raw_line in raw_lines: + if raw_line.strip() == "": + pos += len(raw_line) + 1 + continue + pos += len(raw_line) + ends.append(pos) + pos += 1 + line_i += 1 + last_indices = [None] * len(lines) + cur_len = 0 + seg_i = 0 + for idx, tok in enumerate(resp.output_tokens): + piece = self.tokenizer.decode([tok], clean_up_tokenization_spaces=False) + cur_len += len(piece) + while seg_i < len(ends) and cur_len >= ends[seg_i]: + last_indices[seg_i] = idx + seg_i += 1 + if seg_i >= len(ends): + break + if last_indices and last_indices[-1] != len(resp.output_tokens) - 2: + last_indices[-1] = len(resp.output_tokens) - 2 + + steps_str = "".join([line_text for line_text in lines]) + cr_pos = [resp.input_len+last_indice for last_indice in last_indices] + + prm_reward = await self.async_reward_fn_prm( + prompt_str, + steps_str, + resp.input_tokens, + resp.output_tokens, + # self.prm_model, + # self.prm_tokenizer, + **data, + ) + if not isinstance(prm_reward, list): + prm_reward = [prm_reward] * len(cr_pos) + if len(prm_reward) != len(cr_pos): + # print(f"Mismatch: prm_reward={len(prm_reward)}, cr_pos={len(cr_pos)}") + continue + # Log reward. + stats_tracker.get(self.rollout_stat_scope).scalar(reward=result_reward) + + rewards.append(prm_reward) + prm_rewards.append(prm_reward) + result_rewards.append(result_reward) + + # step reward + dense_reward = torch.zeros(len(seq), dtype=torch.float) + # print(f"cr_pos: {cr_pos}, len(seq): {len(seq)}") + assert len(prm_reward) == len(cr_pos), f"Mismatch: prm_reward={len(prm_reward)}, cr_pos={len(cr_pos)}, len(seq)={len(seq)}, {prm_reward}, \n{steps_str}" + dense_reward[cr_pos] = torch.tensor(prm_reward, dtype=torch.float) + reward_mask = torch.zeros(len(seq), dtype=torch.bool) + reward_mask[cr_pos] = True + reward_masks.append(reward_mask) + + res = dict( + # unsqueeze to add an additional batch dimension + input_ids=torch.tensor(seq).unsqueeze(0), + loss_mask=torch.tensor(loss_mask).unsqueeze(0), + logprobs=torch.tensor(logprobs).unsqueeze(0), + versions=torch.tensor(versions).unsqueeze(0), + attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), + # reward + rewards=dense_reward.unsqueeze(0), + ) + results.append(res) + # print(f"original rewards: {results[0]["rewards"]}") + # print(f"avg_prm_reward: {sum(prm_rewards[0]) / len(prm_rewards[0])}") + # print(f"prm_reward: {prm_rewards[0]}, result_reward: {result_rewards[0]}") + + # clip mechanism + if self.prmconfig.use_clip: + for res, reward_mask, prm_reward in zip(results, reward_masks, prm_rewards): + dense_reward = res["rewards"] + if isinstance(prm_reward, list): + avg_prm_reward = sum(prm_reward) / len(prm_reward) + else: + avg_prm_reward = prm_reward + gt_mean = (dense_reward > avg_prm_reward) & reward_mask + ls_mean = (dense_reward <= avg_prm_reward) & reward_mask + res["rewards"][gt_mean] = 0 + res["rewards"][ls_mean] -= avg_prm_reward + # print(f"rewards after clip: {results[0]["rewards"]}") + + # delta mechanism + if self.prmconfig.use_delta: + for res, reward_mask in zip(results, reward_masks): + rewards_1d = res["rewards"].squeeze(0) + valid = rewards_1d[reward_mask] + new_v = valid.clone() + K = new_v.numel() + new_v[-1] = 0 + if K > 1: + new_v[:-2] = valid[:-2] - valid[1:-1] + out = rewards_1d.clone() + out[reward_mask] = new_v + res["rewards"] = out.unsqueeze(0) + # print(f"rewards after delta: {results[0]["rewards"]}") + + # success reward + for res, result_reward in zip(results, result_rewards): + seq_len = res["rewards"].shape[1] + res["rewards"][:, seq_len-2] = result_reward + # print(f"rewards add success reward: {results[0]["rewards"]}") + + if self.dump_dir is not None: + dump_path = os.path.join(self.dump_dir, str(version)) + await aiofiles.os.makedirs(dump_path, exist_ok=True) + # Get the unique identifier for this prompt + qid = None + for key in ["query_id", "id", "qid"]: + qid = data.get(key, None) + if qid is not None: + break + qid = qid or uuid.uuid4().hex + + # Dump rollout to file + file_path = os.path.join(dump_path, f"{qid}.txt") + async with aiofiles.open(file_path, "a") as f: + n_samples = self.gconfig.n_samples + for i, (p, c, r, sl) in enumerate( + zip(prompt_strs, completions_strs, rewards, seqlens) + ): + info = "\n".join( + [ + f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.", + f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}", + f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}", + ] + ) + await f.write(info + "\n") + + return concat_padded_tensors(results) diff --git a/examples/experimental/dapo/greso_dapo.py b/examples/experimental/dapo/greso_dapo.py new file mode 100644 index 000000000..e7a61ce2d --- /dev/null +++ b/examples/experimental/dapo/greso_dapo.py @@ -0,0 +1,348 @@ +import os +import sys +from copy import deepcopy + +import torch.distributed as dist +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, load_expr_config +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.platforms import current_platform +from areal.utils import seeding, stats_tracker +from areal.utils.data import ( + broadcast_tensor_container, + cycle_dataloader, + tensor_container_to, +) +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger +from areal.workflow.rlvr import RLVRWorkflow + +from typing import TYPE_CHECKING, Optional +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +if TYPE_CHECKING: + from datasets import Dataset + from transformers.processing_utils import ProcessorMixin + from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + from areal.reward.math_parser import process_results + + return int(process_results(completions, answer)[0]) + +def load_greso_dataset( + path: str, + rank: int, + world_size: int, + type: str = "sft", + split: Optional[str] = None, + max_length: Optional[int] = None, + tokenizer: Optional["PreTrainedTokenizerFast"] = None, + processor: Optional["ProcessorMixin"] = None, + **kwargs, +) -> "Dataset": + dataset = load_dataset("parquet", data_dir=path, split=split) + + def process(sample): + return {"messages": sample["messages"], "answer": sample["answer"]} + + dataset = dataset.map(process) + + # Filter out sequences longer than max_length if tokenizer and max_length are provided + if max_length is not None: + + def filter_length(sample): + # Tokenize the user content to check length + content = sample["messages"][0]["content"] + tokens = tokenizer.encode(content) + return len(tokens) <= max_length + + dataset = dataset.filter(filter_length) + + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + return dataset + + +def main(args): + config, _ = load_expr_config(args, GRPOConfig) + config: GRPOConfig + + rank = int(os.getenv("RANK")) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + seeding.set_random_seed(config.seed, key=f"trainer{rank}") + allocation_mode = AllocationMode.from_str(config.allocation_mode) + parallel_strategy = allocation_mode.train + assert parallel_strategy is not None + + # Initialize train engine + actor = FSDPPPOActor(config=config.actor) + actor.create_process_group(parallel_strategy=parallel_strategy) + + train_dataset = load_greso_dataset( + path=config.train_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="train", + max_length=config.train_dataset.max_length, + type=config.train_dataset.type, + tokenizer=tokenizer, + ) + valid_dataset = load_greso_dataset( + path=config.valid_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="test", + max_length=config.valid_dataset.max_length, + type=config.valid_dataset.type, + tokenizer=tokenizer, + ) + + # Create dataset and dataloaders + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, + ) + valid_dataloader = StatefulDataLoader( + valid_dataset, + batch_size=config.valid_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.valid_dataset.shuffle, + num_workers=config.valid_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.valid_dataset.drop_last, + ) + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize inference engine + rollout = RemoteSGLangEngine(config.rollout) + rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) + eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) + # NOTE: eval does not have any offpolicyness control + eval_rollout.config.max_head_offpolicyness = int(1e12) + eval_rollout.initialize() + + actor.initialize(None, ft_spec) + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = FSDPPPOActor(config=config.ref) + ref.create_process_group(parallel_strategy=parallel_strategy) + ref.initialize(None, ft_spec) + + # NOTE: Weight update meta only requires address and free port of rank 0, + # but `WeightUpdateMeta.from_fsdp_xccl` has to be executed on all ranks + # due to `engine.get_param_specs()`. + # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. + weight_update_meta = [ + WeightUpdateMeta.from_fsdp_xccl( + AllocationMode.from_str(config.allocation_mode), actor + ) + ] + dist.broadcast_object_list(weight_update_meta, src=0) + weight_update_meta = weight_update_meta[0] + + # Create rollout workflow + if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) + if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) + workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), + ) + eval_workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig.new(temperature=0.6), + tokenizer=tokenizer, + enable_thinking=False, + rollout_stat_scope="eval-rollout", + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated-eval" + ), + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + batch = None + if actor.is_data_parallel_head(): + if config.async_training: + batch = rollout.prepare_batch( + train_dataloader, + workflow=workflow, + should_accept=lambda sample: True, + ) + else: + batch = rollout.rollout_batch( + next(data_generator), + workflow=workflow, + should_accept=lambda sample: True, + ) + batch = tensor_container_to(batch, actor.device) + batch = broadcast_tensor_container( + batch, + src_rank=actor.current_data_parallel_head(), + group=actor.context_and_model_parallel_group, + ) + # Create barrier to synchronize all rollout processes. + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with ( + stats_tracker.record_timing("train_step"), + stats_tracker.scope("grpo_actor"), + ): + stats = actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + if dist.get_rank() == 0: + future = rollout.update_weights(weight_update_meta) + actor.upload_weights(weight_update_meta) + if dist.get_rank() == 0: + future.result() + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + eval_rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + with stats_tracker.record_timing("eval"): + + def evaluate_fn(): + if actor.is_data_parallel_head(): + # Stats are logged in workflow + # and will be exported later + cnt = 0 + for data in valid_dataloader: + for item in data: + eval_rollout.submit(item, eval_workflow) + cnt += 1 + eval_rollout.wait(cnt, timeout=None) + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + evaluator.evaluate( + evaluate_fn, + epoch, + step, + global_step, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Upload statistics to the logger (e.g., wandb) + stats[0].update( + stats_tracker.export_all(reduce_group=actor.data_parallel_group) + ) + stats_logger.commit(epoch, step, global_step, stats) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Resume rollout + rollout.resume() + + stats_logger.close() + eval_rollout.destroy() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) \ No newline at end of file diff --git a/examples/prm/greso_dapo_prm.py b/examples/prm/greso_dapo_prm.py new file mode 100644 index 000000000..715997728 --- /dev/null +++ b/examples/prm/greso_dapo_prm.py @@ -0,0 +1,364 @@ +import os +import sys +from copy import deepcopy + +import torch.distributed as dist +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, load_expr_config, PRMConfig +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor, FSDPPPOActorDense +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.platforms import current_platform +from areal.utils import seeding, stats_tracker +from areal.utils.data import ( + broadcast_tensor_container, + cycle_dataloader, + tensor_container_to, +) +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger +from areal.workflow.rlvr import RLVRWorkflow +from areal.workflow.rlvr_prm import PRMRLVRWorkflow + +import requests +import aiohttp + +from typing import TYPE_CHECKING, Optional +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +if TYPE_CHECKING: + from datasets import Dataset + from transformers.processing_utils import ProcessorMixin + from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + from areal.reward.math_parser import process_results + + return int(process_results(completions, answer)[0]) + +def gsm8k_reward_fn_prm(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + conversation_str = f"{prompt}"[:-len("<|endoftext|>")] + f"<|im_start|>assistant\n{completions}"[:-len("<|im_end|>")] + "<|im_end|><|endoftext|>" + # print(f"conversation str: {conversation_str}") + resp = requests.post("http://localhost:8001/score", json={"text": conversation_str}) + # print(f"prm_reward: {resp.json()["reward"]}") + prm_reward = resp.json()["reward"] + return prm_reward + +def load_greso_dataset( + path: str, + rank: int, + world_size: int, + type: str = "sft", + split: Optional[str] = None, + max_length: Optional[int] = None, + tokenizer: Optional["PreTrainedTokenizerFast"] = None, + processor: Optional["ProcessorMixin"] = None, + **kwargs, +) -> "Dataset": + dataset = load_dataset("parquet", data_dir=path, split=split) + + def process(sample): + return {"messages": sample["messages"], "answer": sample["answer"]} + + dataset = dataset.map(process) + + # Filter out sequences longer than max_length if tokenizer and max_length are provided + if max_length is not None: + + def filter_length(sample): + # Tokenize the user content to check length + content = sample["messages"][0]["content"] + tokens = tokenizer.encode(content) + return len(tokens) <= max_length + + dataset = dataset.filter(filter_length) + + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + return dataset + + +def main(args): + config, _ = load_expr_config(args, PRMConfig) + config: PRMConfig + + rank = int(os.getenv("RANK")) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + seeding.set_random_seed(config.seed, key=f"trainer{rank}") + allocation_mode = AllocationMode.from_str(config.allocation_mode) + parallel_strategy = allocation_mode.train + assert parallel_strategy is not None + + # Initialize train engine + actor = FSDPPPOActorDense(config=config.actor) + actor.create_process_group(parallel_strategy=parallel_strategy) + + train_dataset = load_greso_dataset( + path=config.train_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="train", + max_length=config.train_dataset.max_length, + type=config.train_dataset.type, + tokenizer=tokenizer, + ) + valid_dataset = load_greso_dataset( + path=config.valid_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="test", + max_length=config.valid_dataset.max_length, + type=config.valid_dataset.type, + tokenizer=tokenizer, + ) + + # Create dataset and dataloaders + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, + ) + valid_dataloader = StatefulDataLoader( + valid_dataset, + batch_size=config.valid_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.valid_dataset.shuffle, + num_workers=config.valid_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.valid_dataset.drop_last, + ) + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize inference engine + rollout = RemoteSGLangEngine(config.rollout) + rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) + eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) + # NOTE: eval does not have any offpolicyness control + eval_rollout.config.max_head_offpolicyness = int(1e12) + eval_rollout.initialize() + + actor.initialize(None, ft_spec) + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = FSDPPPOActor(config=config.ref) + ref.create_process_group(parallel_strategy=parallel_strategy) + ref.initialize(None, ft_spec) + + # NOTE: Weight update meta only requires address and free port of rank 0, + # but `WeightUpdateMeta.from_fsdp_xccl` has to be executed on all ranks + # due to `engine.get_param_specs()`. + # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. + weight_update_meta = [ + WeightUpdateMeta.from_fsdp_xccl( + AllocationMode.from_str(config.allocation_mode), actor + ) + ] + dist.broadcast_object_list(weight_update_meta, src=0) + weight_update_meta = weight_update_meta[0] + + # Create rollout workflow + if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) + if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) + workflow = PRMRLVRWorkflow( + reward_fn=gsm8k_reward_fn, + reward_fn_prm=gsm8k_reward_fn_prm, + gconfig=config.gconfig, + prmconfig=config.prmconfig, + tokenizer=tokenizer, + # prm_model=prm_model, + # prm_tokenizer=prm_tokenizer, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), + ) + eval_workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig, + tokenizer=tokenizer, + enable_thinking=False, + rollout_stat_scope="eval-rollout", + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated-eval" + ), + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + batch = None + if actor.is_data_parallel_head(): + if config.async_training: + batch = rollout.prepare_batch( + train_dataloader, + workflow=workflow, + should_accept=lambda sample: True, + ) + else: + batch = rollout.rollout_batch( + next(data_generator), + workflow=workflow, + should_accept=lambda sample: True, + ) + batch = tensor_container_to(batch, actor.device) + batch = broadcast_tensor_container( + batch, + src_rank=actor.current_data_parallel_head(), + group=actor.context_and_model_parallel_group, + ) + # Create barrier to synchronize all rollout processes. + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with ( + stats_tracker.record_timing("train_step"), + stats_tracker.scope("grpo_actor"), + ): + stats = actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + if dist.get_rank() == 0: + future = rollout.update_weights(weight_update_meta) + actor.upload_weights(weight_update_meta) + if dist.get_rank() == 0: + future.result() + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + eval_rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + with stats_tracker.record_timing("eval"): + + def evaluate_fn(): + if actor.is_data_parallel_head(): + # Stats are logged in workflow + # and will be exported later + cnt = 0 + for data in valid_dataloader: + for item in data: + eval_rollout.submit(item, eval_workflow) + cnt += 1 + eval_rollout.wait(cnt, timeout=None) + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + evaluator.evaluate( + evaluate_fn, + epoch, + step, + global_step, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Upload statistics to the logger (e.g., wandb) + stats[0].update( + stats_tracker.export_all(reduce_group=actor.data_parallel_group) + ) + stats_logger.commit(epoch, step, global_step, stats) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Resume rollout + rollout.resume() + + stats_logger.close() + eval_rollout.destroy() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) \ No newline at end of file diff --git a/examples/prm/gsm8k_dapo_prm.py b/examples/prm/gsm8k_dapo_prm.py new file mode 100644 index 000000000..acef0d038 --- /dev/null +++ b/examples/prm/gsm8k_dapo_prm.py @@ -0,0 +1,345 @@ +import os +import sys +from copy import deepcopy + +import torch.distributed as dist +from torchdata.stateful_dataloader import StatefulDataLoader + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, load_expr_config, PRMConfig +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.ppo.prm import FSDPPPOPrm +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.platforms import current_platform +from areal.utils import seeding, stats_tracker +from areal.utils.data import ( + broadcast_tensor_container, + cycle_dataloader, + tensor_container_to, +) +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger +from areal.workflow.rlvr import RLVRWorkflow +from areal.workflow.rlvr_prm import PRMRLVRWorkflow + +# from transformers import AutoModel, AutoTokenizer +# import torch +# import torch.nn.functional as F +import requests + +def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + from areal.reward.math_parser import process_results + + return int(process_results(completions, answer)[0]) + +# def gsm8k_reward_fn_prm(prompt, completions, prompt_ids, completion_ids, answer, prm_model, prm_tokenizer, **kwargs): +def gsm8k_reward_fn_prm(prompt, completions, prompt_ids, completion_ids, answer, **kwargs): + conversation_str = f"{prompt}<|im_start|>assistant\n{completions}<|im_end|><|endoftext|>" + print(f"conversation str: {conversation_str}") + # prm_input_ids = prm_tokenizer.encode( + # conversation_str, + # return_tensors="pt", + # ).to(prm_model.device) + # prm_outputs = prm_model(input_ids=prm_input_ids) + # step_sep_id = prm_tokenizer.encode("")[0] + # token_masks = (prm_input_ids == step_sep_id) + # probabilities = F.softmax(prm_outputs[0], dim=-1)* token_masks.unsqueeze(-1) + # sample = probabilities[0] + # prm_reward = sample[sample != 0].view(-1, 2)[:, 1][0].item() + resp = requests.post("http://localhost:8001/score", json={"text": conversation_str}) + prm_reward = resp.json()["reward"] + print(f"prm_reward: {prm_reward}") + return prm_reward + +def main(args): + config, _ = load_expr_config(args, PRMConfig) + config: PRMConfig + + rank = int(os.getenv("RANK")) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + # prm + # prm_tokenizer = AutoTokenizer.from_pretrained(config.prm_path, local_files_only=True, trust_remote_code=True) + # prm_model = AutoModel.from_pretrained( + # config.prm_path, + # torch_dtype=torch.bfloat16, + # local_files_only=True, + # trust_remote_code=True, + # ).eval() + + seeding.set_random_seed(config.seed, key=f"trainer{rank}") + allocation_mode = AllocationMode.from_str(config.allocation_mode) + parallel_strategy = allocation_mode.train + + # Initialize train engine + actor = FSDPPPOActor(config=config.actor) + actor.create_process_group(parallel_strategy=parallel_strategy) + + train_dataset = get_custom_dataset( + path=config.train_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="train", + max_length=config.train_dataset.max_length, + type=config.train_dataset.type, + tokenizer=tokenizer, + ) + valid_dataset = get_custom_dataset( + path=config.valid_dataset.path, + rank=actor.data_parallel_rank, + world_size=actor.data_parallel_world_size, + split="test", + max_length=config.valid_dataset.max_length, + type=config.valid_dataset.type, + tokenizer=tokenizer, + ) + + # Create dataset and dataloaders + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.train_dataset.shuffle, + num_workers=config.train_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.train_dataset.drop_last, + ) + valid_dataloader = StatefulDataLoader( + valid_dataset, + batch_size=config.valid_dataset.batch_size // actor.data_parallel_world_size, + shuffle=config.valid_dataset.shuffle, + num_workers=config.valid_dataset.num_workers, + collate_fn=lambda x: x, + drop_last=config.valid_dataset.drop_last, + ) + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize inference engine + rollout = RemoteSGLangEngine(config.rollout) + rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) + eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) + # NOTE: eval does not have any offpolicyness control + eval_rollout.config.max_head_offpolicyness = int(1e12) + eval_rollout.initialize() + + actor.initialize(None, ft_spec) + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = FSDPPPOActor(config=config.ref) + ref.create_process_group(parallel_strategy=parallel_strategy) + ref.initialize(None, ft_spec) + + # NOTE: Weight update meta only requires address and free port of rank 0, + # but `WeightUpdateMeta.from_fsdp_nccl` has to be executed on all ranks + # due to `engine.get_param_specs()`. + # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0. + weight_update_meta = [ + WeightUpdateMeta.from_fsdp_nccl( + AllocationMode.from_str(config.allocation_mode), actor + ) + ] + dist.broadcast_object_list(weight_update_meta, src=0) + weight_update_meta = weight_update_meta[0] + + # Create rollout workflow + if tokenizer.pad_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.pad_token_id) + if tokenizer.eos_token_id not in config.gconfig.stop_token_ids: + config.gconfig.stop_token_ids.append(tokenizer.eos_token_id) + workflow = PRMRLVRWorkflow( + reward_fn=gsm8k_reward_fn, + reward_fn_prm=gsm8k_reward_fn_prm, + gconfig=config.gconfig, + prmconfig=config.prmconfig, + tokenizer=tokenizer, + # prm_model=prm_model, + # prm_tokenizer=prm_tokenizer, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), + ) + eval_workflow = RLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig.new(temperature=0.6), + tokenizer=tokenizer, + enable_thinking=False, + rollout_stat_scope="eval-rollout", + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated-eval" + ), + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config.stats_logger, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + batch = None + if actor.is_data_parallel_head(): + if config.async_training: + batch = rollout.prepare_batch( + train_dataloader, + workflow=workflow, + should_accept=lambda sample: True, + ) + else: + batch = rollout.rollout_batch( + next(data_generator), + workflow=workflow, + should_accept=lambda sample: True, + ) + batch = tensor_container_to(batch, actor.device) + batch = broadcast_tensor_container( + batch, + src_rank=actor.current_data_parallel_head(), + group=actor.context_and_model_parallel_group, + ) + # Create barrier to synchronize all rollout processes. + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with ( + stats_tracker.record_timing("train_step"), + stats_tracker.scope("grpo_actor"), + ): + stats = actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + if dist.get_rank() == 0: + future = rollout.update_weights(weight_update_meta) + actor.upload_weights(weight_update_meta) + if dist.get_rank() == 0: + future.result() + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + eval_rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + with stats_tracker.record_timing("eval"): + + def evaluate_fn(): + if actor.is_data_parallel_head(): + # Stats are logged in workflow + # and will be exported later + cnt = 0 + for data in valid_dataloader: + for item in data: + eval_rollout.submit(item, eval_workflow) + cnt += 1 + eval_rollout.wait(cnt, timeout=None) + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + evaluator.evaluate( + evaluate_fn, + epoch, + step, + global_step, + ) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Upload statistics to the logger (e.g., wandb) + stats[0].update( + stats_tracker.export_all(reduce_group=actor.data_parallel_group) + ) + stats_logger.commit(epoch, step, global_step, stats) + + dist.barrier(device_ids=[actor.device.index]) + current_platform.synchronize() + + # Resume rollout + rollout.resume() + + stats_logger.close() + eval_rollout.destroy() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/prm/prm_service.py b/examples/prm/prm_service.py new file mode 100644 index 000000000..76e81b6f1 --- /dev/null +++ b/examples/prm/prm_service.py @@ -0,0 +1,54 @@ +import torch +from transformers import AutoModel, AutoTokenizer +from fastapi import FastAPI +from pydantic import BaseModel +import uvicorn +import torch.nn.functional as F +# import sys, os +# logfile = open("prm_server.log", "a", buffering=1) +# sys.stdout = logfile +# sys.stderr = logfile + +# 配置 +MODEL_PATH = "/data/yanglu/model/Qwen/Qwen2.5-Math-PRM-7B" +DEVICE = "cuda:3" # 固定 PRM 用的卡 + +# 加载模型 +print("Loading PRM model...") +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True, trust_remote_code=True) +model = AutoModel.from_pretrained( + MODEL_PATH, + torch_dtype=torch.bfloat16, + local_files_only=True, + trust_remote_code=True +).to(DEVICE).eval() +max_pos = model.config.max_position_embeddings +end_tokens = ["", "<|im_end|>", "<|endoftext|>"] +end_ids = [tokenizer.convert_tokens_to_ids(t) for t in end_tokens] +allowed_txt_len = max_pos - len(end_ids) + +# 定义 API +app = FastAPI() + +class PRMRequest(BaseModel): + text: str + +@app.post("/score") +def score(req: PRMRequest): + # print(f"req.text: {req.text}") + input_ids = tokenizer.encode(req.text, return_tensors="pt").to(DEVICE) + if input_ids.shape[1] >= max_pos: + input_ids = input_ids.cpu().squeeze(0).tolist() + truncated_ids = input_ids[:allowed_txt_len] + input_ids = torch.tensor([truncated_ids+end_ids], device=DEVICE, dtype=torch.long) + with torch.no_grad(): + outputs = model(input_ids=input_ids) + step_sep_id = tokenizer.encode("")[0] + token_masks = (input_ids == step_sep_id) + probabilities = F.softmax(outputs[0], dim=-1)* token_masks.unsqueeze(-1) + sample = probabilities[0] + prm_reward = sample[sample != 0].view(-1, 2)[:, 1].cpu().tolist() # list + return {"reward": prm_reward} + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/scripts/dapo.sh b/scripts/dapo.sh new file mode 100644 index 000000000..5a8e7c655 --- /dev/null +++ b/scripts/dapo.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -euo pipefail +export CUDA_VISIBLE_DEVICES=0,1 +N_GPU=2 +EXP_NAME=greso-dapo +TRIAL_NAME=trial0 +FILE_ROOT=/data/yanglu/AReaL/tmp/areal/experiments +# ACTOR_PATH=/data/yanglu/model/Qwen/Qwen2.5-Math-7B +ACTOR_PATH=/data/yanglu/model/Qwen/Qwen2.5-1.5B-Instruct +TRAIN_DATASET_PATH=/data/yanglu/dataset/greso +VALID_DATASET_PATH=/data/yanglu/dataset/greso + +TOTAL_TRAIN_EPOCHS=1 + +python3 -m areal.launcher.local \ + examples/experimental/dapo/greso_dapo.py \ + --config examples/experimental/dapo/gsm8k_dapo.yaml \ + experiment_name="$EXP_NAME" \ + trial_name="$TRIAL_NAME" \ + total_train_epochs="$TOTAL_TRAIN_EPOCHS" \ + allocation_mode=sglang.d1p1t1+d1p1t1 \ + cluster.n_nodes=1 \ + cluster.n_gpus_per_node="$N_GPU" \ + cluster.fileroot="$FILE_ROOT" \ + +gconfig.top_p=0.7 \ + actor.path="$ACTOR_PATH" \ + actor.optimizer.lr=1e-6 \ + actor.optimizer.weight_decay=0.01 \ + actor.overlong_reward_penalty=false \ + actor.ppo_n_minibatches=64 \ + +actor.c_clip=10.0 \ + train_dataset.path="$TRAIN_DATASET_PATH" \ + valid_dataset.path="$VALID_DATASET_PATH" \ No newline at end of file diff --git a/scripts/dapo_prm.sh b/scripts/dapo_prm.sh new file mode 100644 index 000000000..2634ae75a --- /dev/null +++ b/scripts/dapo_prm.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -euo pipefail +export CUDA_VISIBLE_DEVICES=0,1 +N_GPU=2 +EXP_NAME=greso-dapo-clip-delta +TRIAL_NAME=trial0 +FILE_ROOT=/data/yanglu/AReaL/tmp/areal/experiments +# ACTOR_PATH=/data/yanglu/model/Qwen/Qwen2.5-Math-7B +ACTOR_PATH=/data/yanglu/model/Qwen/Qwen2.5-1.5B-Instruct +PRM_PATH=/data/yanglu/model/Qwen/Qwen2.5-Math-PRM-7B +TRAIN_DATASET_PATH=/data/yanglu/dataset/greso +VALID_DATASET_PATH=/data/yanglu/dataset/greso + +TOTAL_TRAIN_EPOCHS=1 + +python3 -m areal.launcher.local \ + examples/prm/greso_dapo_prm.py \ + --config examples/experimental/dapo/gsm8k_dapo.yaml \ + experiment_name="$EXP_NAME" \ + trial_name="$TRIAL_NAME" \ + tokenizer_path="$PRM_PATH" \ + +prm_path="$PRM_PATH" \ + +prmconfig.reward_shaping_alpha=0.02 \ + total_train_epochs="$TOTAL_TRAIN_EPOCHS" \ + allocation_mode=sglang.d1p1t1+d1p1t1 \ + cluster.n_nodes=1 \ + cluster.n_gpus_per_node="$N_GPU" \ + cluster.fileroot="$FILE_ROOT" \ + actor.path="$ACTOR_PATH" \ + actor.optimizer.lr=1e-6 \ + actor.optimizer.weight_decay=0.01 \ + actor.overlong_reward_penalty=false \ + actor.ppo_n_minibatches=64 \ + +actor.c_clip=10.0 \ + train_dataset.path="$TRAIN_DATASET_PATH" \ + valid_dataset.path="$VALID_DATASET_PATH" \ No newline at end of file diff --git a/scripts/dataset/greso/test/data.parquet b/scripts/dataset/greso/test/data.parquet new file mode 100644 index 000000000..aba706d44 Binary files /dev/null and b/scripts/dataset/greso/test/data.parquet differ diff --git a/scripts/dataset/greso/train/data.parquet b/scripts/dataset/greso/train/data.parquet new file mode 100644 index 000000000..9fa5efe26 Binary files /dev/null and b/scripts/dataset/greso/train/data.parquet differ