-
Notifications
You must be signed in to change notification settings - Fork 221
add prm #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add prm #399
Changes from 2 commits
aba4313
0572275
32a84c6
d827ce8
4c0d80e
da790d6
5937c04
54ef0cb
e2c6c9b
f4284d7
8ecbe5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |||||
|
|
||||||
| logger = logging.getLogger("Launcher Utils") | ||||||
|
|
||||||
| LOCAL_CACHE_DIR = "/tmp/areal" | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should revert. |
||||||
| LOCAL_CACHE_DIR = "/data/yl/AReaL/tmp/areal" | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| PYTORCH_KERNEL_CACHE_PATH = ( | ||||||
| f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/" | ||||||
| ) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| import asyncio | ||
| from dataclasses import asdict | ||
| import os | ||
| import uuid | ||
|
|
||
| import aiofiles | ||
| import aiofiles.os | ||
| import colorama | ||
| import torch | ||
| from transformers import PreTrainedTokenizerFast, PreTrainedModel | ||
|
|
||
| from areal.api.cli_args import GenerationHyperparameters, PRMRewardHyperparameters | ||
| from areal.api.engine_api import InferenceEngine | ||
| from areal.api.io_struct import ModelRequest | ||
| from areal.api.reward_api import AsyncRewardWrapper | ||
| from areal.api.workflow_api import RolloutWorkflow | ||
| from areal.utils import logging, stats_tracker | ||
| from areal.utils.data import concat_padded_tensors | ||
|
|
||
| logger = logging.getLogger("RLVR workflow") | ||
|
|
||
|
|
||
| class PRMRLVRWorkflow(RolloutWorkflow): | ||
| def __init__( | ||
| self, | ||
| reward_fn, | ||
| reward_fn_prm, | ||
| gconfig: GenerationHyperparameters, | ||
| prmconfig: PRMRewardHyperparameters, | ||
| tokenizer: PreTrainedTokenizerFast, | ||
| # prm_model: PreTrainedModel, | ||
| # prm_tokenizer: PreTrainedTokenizerFast, | ||
|
Comment on lines
+32
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| enable_thinking: bool, | ||
| rollout_stat_scope: bool = "rollout", | ||
| dump_dir: str | None = None, | ||
| ): | ||
| self.reward_fn = reward_fn | ||
| self.gconfig = gconfig | ||
| self.prmconfig = prmconfig | ||
| self.tokenizer = tokenizer | ||
| # self.prm_model = prm_model | ||
| # self.prm_tokenizer = prm_tokenizer | ||
| self.enable_thinking = enable_thinking | ||
| self.dump_dir = dump_dir | ||
| self.rollout_stat_scope = rollout_stat_scope | ||
| self.async_reward_fn = AsyncRewardWrapper(reward_fn) | ||
| self.async_reward_fn_prm = AsyncRewardWrapper(reward_fn_prm) | ||
| if self.dump_dir is not None and not os.path.exists(self.dump_dir): | ||
| os.makedirs(self.dump_dir, exist_ok=True) | ||
|
|
||
| async def arun_episode(self, engine: InferenceEngine, data): | ||
| input_ids = self.tokenizer.apply_chat_template( | ||
| data["messages"], | ||
| tokenize=True, | ||
| add_generation_prompt=True, | ||
| enable_thinking=self.enable_thinking, | ||
| ) | ||
|
|
||
| n_samples = self.gconfig.n_samples | ||
| req = ModelRequest( | ||
| rid=uuid.uuid4().hex, | ||
| input_ids=input_ids, | ||
| gconfig=self.gconfig.new(n_samples=1), | ||
| tokenizer=self.tokenizer, | ||
| ) | ||
| resps = await asyncio.gather(*[engine.agenerate(req) for _ in range(n_samples)]) | ||
|
|
||
| version = engine.get_version() | ||
| prompt_strs = [] | ||
| completions_strs = [] | ||
| rewards = [] | ||
| prm_rewards = [] | ||
| seqlens = [] | ||
|
|
||
| results = [] | ||
| for resp in resps: | ||
| seq = resp.input_tokens + resp.output_tokens | ||
| logprobs = [0.0] * resp.input_len + resp.output_logprobs | ||
| loss_mask = [0] * resp.input_len + [1] * resp.output_len | ||
| versions = [-1] * resp.input_len + resp.output_versions | ||
|
|
||
| prompt_str = self.tokenizer.decode(input_ids) | ||
| completions_str = self.tokenizer.decode(resp.output_tokens) | ||
| prompt_strs.append(prompt_str) | ||
| completions_strs.append(completions_str) | ||
| seqlens.append(len(seq)) | ||
| result_reward = await self.async_reward_fn( | ||
| prompt_str, | ||
| completions_str, | ||
| resp.input_tokens, | ||
| resp.output_tokens, | ||
| **data, | ||
| ) | ||
| prm_reward = await self.async_reward_fn( | ||
| prompt_str, | ||
| completions_str, | ||
| resp.input_tokens, | ||
| resp.output_tokens, | ||
| # self.prm_model, | ||
| # self.prm_tokenizer, | ||
| **data, | ||
| ) | ||
garrett4wade marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| reward = self.prmconfig.reward_shaping_alpha * prm_reward + result_reward | ||
|
|
||
| # Log reward. | ||
| stats_tracker.get(self.rollout_stat_scope).scalar(reward=reward) | ||
|
|
||
| rewards.append(reward) | ||
| prm_rewards.append(prm_reward) | ||
|
|
||
| res = dict( | ||
| # unsqueeze to add an additional batch dimension | ||
| input_ids=torch.tensor(seq).unsqueeze(0), | ||
| loss_mask=torch.tensor(loss_mask).unsqueeze(0), | ||
| logprobs=torch.tensor(logprobs).unsqueeze(0), | ||
| versions=torch.tensor(versions).unsqueeze(0), | ||
| attention_mask=torch.ones(len(seq), dtype=torch.bool).unsqueeze(0), | ||
| # reward | ||
| rewards=torch.tensor([float(reward)]), | ||
| ) | ||
| results.append(res) | ||
|
|
||
| # clip mechanism | ||
| avg_prm_reward = sum(prm_rewards) / len(prm_rewards) | ||
| for i, val in enumerate(prm_rewards): | ||
| if val > avg_prm_reward: | ||
| rewards[i] = 0 | ||
| for res, r in zip(results, rewards): | ||
| res["rewards"] = torch.tensor([float(r)]) | ||
|
||
|
|
||
| if self.dump_dir is not None: | ||
| dump_path = os.path.join(self.dump_dir, str(version)) | ||
| await aiofiles.os.makedirs(dump_path, exist_ok=True) | ||
| # Get the unique identifier for this prompt | ||
| qid = None | ||
| for key in ["query_id", "id", "qid"]: | ||
| qid = data.get(key, None) | ||
| if qid is not None: | ||
| break | ||
| qid = qid or uuid.uuid4().hex | ||
|
|
||
| # Dump rollout to file | ||
| file_path = os.path.join(dump_path, f"{qid}.txt") | ||
| async with aiofiles.open(file_path, "a") as f: | ||
| n_samples = self.gconfig.n_samples | ||
| for i, (p, c, r, sl) in enumerate( | ||
| zip(prompt_strs, completions_strs, rewards, seqlens) | ||
| ): | ||
| info = "\n".join( | ||
| [ | ||
| f"idx: {i + 1} / {n_samples}, seqlen: {sl}, reward is {r}.", | ||
| f"prompt is \n{colorama.Fore.YELLOW + colorama.Style.DIM}{p}{colorama.Style.RESET_ALL}", | ||
| f"sequence is: \n{colorama.Fore.YELLOW + colorama.Style.DIM}{c}{colorama.Style.RESET_ALL}", | ||
| ] | ||
| ) | ||
| await f.write(info + "\n") | ||
|
|
||
| return concat_padded_tensors(results) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that we can just inheirt
GRPOConfigand add two new fieldsprm_pathandreward_shaping_alpha? BTW if you refer to reward scaling, you can useactor.reward_scalingrather than creating a new field.