diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 91c0d21f2a3..dcc710d2270 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,7 +6,7 @@ - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` + - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `vllm_omni`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. diff --git a/.github/workflows/vllm_omni.yml b/.github/workflows/vllm_omni.yml index 7ad8254bd11..1cb34a22f21 100644 --- a/.github/workflows/vllm_omni.yml +++ b/.github/workflows/vllm_omni.yml @@ -61,7 +61,12 @@ on: - ".github/workflows/vllm_omni.yml" - "tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py" - "tests/experimental/agent_loop/test_diffusion_agent_loop.py" + - "tests/special_e2e/run_flowgrpo_trainer_diffusers.sh" + - "tests/special_e2e/create_dummy_diffusion_data.py" - "verl/workers/rollout/vllm_rollout/vllm_omni_async_server.py" + - "verl/trainer/diffusion/ray_diffusion_trainer.py" + - "verl/trainer/main_flowgrpo.py" + - "verl/trainer/config/diffusion_trainer.yaml" # Cancel jobs on the same ref if a new one is triggered concurrency: @@ -95,7 +100,7 @@ jobs: vllm_omni: needs: setup runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] - timeout-minutes: 35 # Increase this timeout value as needed + timeout-minutes: 50 # Increased to accommodate e2e FlowGRPO training env: HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" @@ -120,6 +125,16 @@ jobs: run: | ray stop --force pytest tests/experimental/agent_loop/test_diffusion_agent_loop.py -v -s + - name: Install diffusers for e2e training + run: | + pip3 install diffusers==0.37.0 + - name: Prepare dummy diffusion dataset + run: | + python3 tests/special_e2e/create_dummy_diffusion_data.py + - name: E2E FlowGRPO diffusion training + run: | + ray stop --force + bash tests/special_e2e/run_flowgrpo_trainer_diffusers.sh cleanup: runs-on: ubuntu-latest diff --git a/examples/flowgrpo_trainer/data_process/qwenimage_ocr.py b/examples/flowgrpo_trainer/data_process/qwenimage_ocr.py new file mode 100644 index 00000000000..4df1a055072 --- /dev/null +++ b/examples/flowgrpo_trainer/data_process/qwenimage_ocr.py @@ -0,0 +1,104 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the OCR dataset to parquet format (for Qwen-Image training). +You can obtain the raw dataset from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr +""" + +import argparse +import os + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + + +def extract_solution(solution_str): + # The solution is stored in the format: 'The image displays "xxx".' + return solution_str.split('"')[1] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default=None) + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument( + "--local_dataset_path", default="~/dataset/ocr/", help="The local path to the raw dataset, if it exists." + ) + parser.add_argument( + "--local_save_dir", default="~/data/ocr", help="The save directory for the preprocessed dataset." + ) + + args = parser.parse_args() + if args.local_dataset_path is not None: + local_dataset_path = os.path.expanduser(args.local_dataset_path) + + data_source = "flow_grpo/ocr" + + if local_dataset_path is not None: + dataset = datasets.load_dataset(local_dataset_path) + else: + raise NotImplementedError( + "It is not existed in huggingface hub. " + "Please get dataset from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr" + ) + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + system_prompt = ( + "Describe the image by detailing the color, shape, size, " + "texture, quantity, text, spatial relationships of the objects and background:" + ) + negative_user_prompt = " " + + def make_map_fn(split): + def process_fn(example, idx): + text = example.pop("text") + solution = extract_solution(text) + data = { + "data_source": data_source, + "prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": text}, + ], + "negative_prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": negative_user_prompt}, + ], + "ability": "ocr", + "reward_model": {"style": "model", "ground_truth": solution}, + "extra_info": {"split": split, "index": idx}, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + hdfs_dir = args.hdfs_dir + local_save_dir = args.local_dir + if local_save_dir is not None: + print("Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.") + else: + local_save_dir = args.local_save_dir + + local_save_dir = os.path.expanduser(local_save_dir) + train_dataset.to_parquet(os.path.join(local_save_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_save_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_save_dir, dst=hdfs_dir) diff --git a/examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh b/examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh new file mode 100644 index 00000000000..e6f9d08063d --- /dev/null +++ b/examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh @@ -0,0 +1,71 @@ +# Qwen-Image lora RL, vllm_omni rollout +set -x + +ocr_train_path=$HOME/data/ocr/train.parquet +ocr_test_path=$HOME/data/ocr/test.parquet + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + +reward_path=examples/flowgrpo_trainer/reward_fn.py +reward_model_name=$HOME/models/Qwen/Qwen3-VL-8B-Instruct + + +python3 -m verl.trainer.main_flowgrpo \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=32 \ + data.max_prompt_length=256 \ + actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen-Image \ + actor_rollout_ref.model.tokenizer_path=$HOME/models/Qwen/Qwen-Image/tokenizer \ + actor_rollout_ref.model.external_lib="examples.flowgrpo_trainer.diffusers.qwen_image" \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=128 \ + actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \ + actor_rollout_ref.actor.optim.lr=3e-4 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ + actor_rollout_ref.rollout.agent.num_workers=4 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=50 \ + +actor_rollout_ref.rollout.extra_configs.true_cfg_scale=4.0 \ + +actor_rollout_ref.rollout.extra_configs.noise_level=1.2 \ + +actor_rollout_ref.rollout.extra_configs.sde_type="sde" \ + +actor_rollout_ref.rollout.extra_configs.sde_window_size=2 \ + +actor_rollout_ref.rollout.extra_configs.sde_window_range="[0,5]" \ + +actor_rollout_ref.rollout.extra_configs.max_sequence_length=256 \ + +actor_rollout_ref.rollout.val_kwargs.extra_configs.noise_level=0.0 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=examples.flowgrpo_trainer.vllm_omni.pipeline_qwenimage.QwenImagePipelineWithLogProb \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + reward.num_workers=4 \ + reward.reward_manager.name=visual \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$reward_model_name \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.rollout.tensor_model_parallel_size=4 \ + reward.custom_reward_function.path=$reward_path \ + reward.custom_reward_function.name=compute_score_ocr \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=flow_grpo \ + trainer.experiment_name=qwen_image_ocr_lora \ + trainer.log_val_generations=8 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=30 \ + trainer.test_freq=30 \ + trainer.total_epochs=15 \ + trainer.total_training_steps=300 $@ diff --git a/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py b/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py index 0b43c7c4937..8a4b9bf7d11 100644 --- a/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py +++ b/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py @@ -29,6 +29,10 @@ def _maybe_to_cpu(v): return v +def _coalesce_not_none(value, default): + return default if value is None else value + + # Custom pipeline class for QwenImage that returns log probabilities during the diffusion process. # This is compatible with API of vllm-omni custom pipeline class QwenImagePipelineWithLogProb(QwenImagePipeline): @@ -188,6 +192,7 @@ def diffuse( generator=generator, noise_level=cur_noise_level, sde_type=sde_type, + return_logprobs=logprobs, return_dict=False, ) @@ -252,16 +257,16 @@ def forward( num_inference_steps = sp.num_inference_steps or num_inference_steps max_sequence_length = sp.max_sequence_length or max_sequence_length - noise_level = sp.extra_args.get("noise_level", None) or noise_level - sde_window_size = sp.extra_args.get("sde_window_size", None) or sde_window_size - sde_window_range = sp.extra_args.get("sde_window_range", None) or sde_window_range - sde_type = sp.extra_args.get("sde_type", None) or sde_type - logprobs = sp.extra_args.get("logprobs", None) + noise_level = _coalesce_not_none(sp.extra_args.get("noise_level", None), noise_level) + sde_window_size = _coalesce_not_none(sp.extra_args.get("sde_window_size", None), sde_window_size) + sde_window_range = _coalesce_not_none(sp.extra_args.get("sde_window_range", None), sde_window_range) + sde_type = _coalesce_not_none(sp.extra_args.get("sde_type", None), sde_type) + logprobs = _coalesce_not_none(sp.extra_args.get("logprobs", None), logprobs) generator = sp.generator or generator if generator is None and sp.seed is not None: generator = torch.Generator(device=self.device).manual_seed(sp.seed) - true_cfg_scale = sp.true_cfg_scale or true_cfg_scale + true_cfg_scale = _coalesce_not_none(sp.true_cfg_scale, true_cfg_scale) req_num_outputs = getattr(sp, "num_outputs_per_prompt", None) if req_num_outputs and req_num_outputs > 0: num_images_per_prompt = req_num_outputs diff --git a/tests/experimental/agent_loop/test_diffusion_agent_loop.py b/tests/experimental/agent_loop/test_diffusion_agent_loop.py index 5dfde636dd3..17c1044b8ca 100644 --- a/tests/experimental/agent_loop/test_diffusion_agent_loop.py +++ b/tests/experimental/agent_loop/test_diffusion_agent_loop.py @@ -79,12 +79,12 @@ def init_config() -> DictConfig: prompt_template_encode_start_idx = 34 max_length = tokenizer_max_length + prompt_template_encode_start_idx - with open_dict(config.actor_rollout_ref.model.extra_configs): - config.actor_rollout_ref.model.extra_configs.true_cfg_scale = 4.0 - config.actor_rollout_ref.model.extra_configs.max_sequence_length = max_length - config.actor_rollout_ref.model.extra_configs.noise_level = 1.0 - config.actor_rollout_ref.model.extra_configs.sde_window_size = 2 - config.actor_rollout_ref.model.extra_configs.sde_window_range = [0, 5] + with open_dict(config.actor_rollout_ref.rollout.extra_configs): + config.actor_rollout_ref.rollout.extra_configs.true_cfg_scale = 4.0 + config.actor_rollout_ref.rollout.extra_configs.max_sequence_length = max_length + config.actor_rollout_ref.rollout.extra_configs.noise_level = 1.0 + config.actor_rollout_ref.rollout.extra_configs.sde_window_size = 2 + config.actor_rollout_ref.rollout.extra_configs.sde_window_range = [0, 5] config.actor_rollout_ref.rollout.nnodes = 1 diff --git a/tests/models/test_diffusers_fsdp_engine.py b/tests/models/test_diffusers_fsdp_engine.py index b7e7681def5..547a099413e 100644 --- a/tests/models/test_diffusers_fsdp_engine.py +++ b/tests/models/test_diffusers_fsdp_engine.py @@ -140,7 +140,6 @@ def create_data_samples(num_device: int, model_config: DiffusionModelConfig) -> data.meta_info["height"] = height data.meta_info["width"] = width data.meta_info["vae_scale_factor"] = vae_scale_factor - data.meta_info["gradient_accumulation_steps"] = 1 return data diff --git a/tests/special_e2e/create_dummy_diffusion_data.py b/tests/special_e2e/create_dummy_diffusion_data.py new file mode 100644 index 00000000000..fd962a61fb8 --- /dev/null +++ b/tests/special_e2e/create_dummy_diffusion_data.py @@ -0,0 +1,93 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Create a small synthetic parquet dataset for FlowGRPO diffusion e2e testing. + +The dataset uses the jpeg_compressibility reward (a self-contained rule-based +reward that needs no external reward model) so the e2e test can run without +spinning up a separate vLLM reward server. +""" + +import argparse +import os + +import pandas as pd + +SYSTEM_PROMPT = ( + "Describe the image by detailing the color, shape, size, " + "texture, quantity, text, spatial relationships of the objects and background:" +) + +USER_PROMPTS = [ + "A red circle on a white background", + "A blue square on a black background", + "A green triangle next to an orange rectangle", + "The word HELLO written in bold letters", + "A yellow star above a purple crescent moon", + "Two overlapping circles, one red and one blue", + "A gradient from dark blue to light blue", + "A checkerboard pattern of black and white squares", +] + + +def build_rows(split: str, n: int): + rows = [] + for i in range(n): + prompt_text = USER_PROMPTS[i % len(USER_PROMPTS)] + rows.append( + { + "data_source": "jpeg_compressibility", + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt_text}, + ], + "negative_prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": " "}, + ], + "reward_model": {"style": "rule", "ground_truth": ""}, + "extra_info": {"split": split, "index": i}, + } + ) + return rows + + +def main(): + parser = argparse.ArgumentParser(description="Generate dummy diffusion parquet data for e2e testing") + parser.add_argument( + "--local_save_dir", + default=os.path.expanduser("~/data/dummy_diffusion"), + help="Directory to write train.parquet and test.parquet", + ) + parser.add_argument("--train_size", type=int, default=32, help="Number of training samples") + parser.add_argument("--val_size", type=int, default=8, help="Number of validation samples") + args = parser.parse_args() + + os.makedirs(args.local_save_dir, exist_ok=True) + + train_df = pd.DataFrame(build_rows("train", args.train_size)) + val_df = pd.DataFrame(build_rows("test", args.val_size)) + + train_path = os.path.join(args.local_save_dir, "train.parquet") + val_path = os.path.join(args.local_save_dir, "test.parquet") + + train_df.to_parquet(train_path) + val_df.to_parquet(val_path) + + print(f"Wrote {len(train_df)} train samples to {train_path}") + print(f"Wrote {len(val_df)} val samples to {val_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/special_e2e/run_flowgrpo_trainer_diffusers.sh b/tests/special_e2e/run_flowgrpo_trainer_diffusers.sh new file mode 100644 index 00000000000..f6f1df0a746 --- /dev/null +++ b/tests/special_e2e/run_flowgrpo_trainer_diffusers.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# FlowGRPO diffusion e2e smoke test (minimal runtime), vllm_omni rollout. +# +# Exercises: parquet load -> vllm_omni rollout -> visual reward (jpeg_compressibility, +# no reward model) -> flow_grpo -> FSDP LoRA -> sync. +# +# Requires: vllm-omni, diffusers>=0.37, tiny Qwen-Image at ~/models/tiny-random/Qwen-Image +set -xeuo pipefail + +# Override via env: NUM_GPUS, MODEL_PATH, DATA_DIR, TOTAL_TRAIN_STEPS, TRAIN_FILES, VAL_FILES +NUM_GPUS=${NUM_GPUS:-4} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/tiny-random/Qwen-Image} +TOKENIZER_PATH=${TOKENIZER_PATH:-${MODEL_PATH}/tokenizer} +DATA_DIR=${DATA_DIR:-${HOME}/data/dummy_diffusion} +dummy_train_path=${TRAIN_FILES:-${DATA_DIR}/train.parquet} +dummy_test_path=${VAL_FILES:-${DATA_DIR}/test.parquet} +TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} + +ENGINE=vllm_omni +max_prompt_length=256 + +if [ ! -f "${dummy_train_path}" ] || [ ! -f "${dummy_test_path}" ]; then + python3 tests/special_e2e/create_dummy_diffusion_data.py \ + --local_save_dir "${DATA_DIR}" \ + --train_size 8 \ + --val_size 4 +fi + +n_resp_per_prompt=2 +micro_bsz_per_gpu=1 +micro_bsz=$((micro_bsz_per_gpu * NUM_GPUS)) +mini_bsz=${micro_bsz} +train_batch_size=$((mini_bsz * n_resp_per_prompt)) + +python3 -m verl.trainer.main_flowgrpo \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=${dummy_train_path} \ + data.val_files=${dummy_test_path} \ + data.train_batch_size=${train_batch_size} \ + data.max_prompt_length=${max_prompt_length} \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.model.tokenizer_path=${TOKENIZER_PATH} \ + actor_rollout_ref.model.external_lib="examples.flowgrpo_trainer.diffusers.qwen_image" \ + actor_rollout_ref.model.lora_rank=8 \ + actor_rollout_ref.model.lora_alpha=16 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=1e-4 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${micro_bsz_per_gpu} \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.04 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${micro_bsz_per_gpu} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=${ENGINE} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ + actor_rollout_ref.rollout.agent.num_workers=1 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.num_inference_steps=4 \ + actor_rollout_ref.rollout.height=256 \ + actor_rollout_ref.rollout.width=256 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=4 \ + +actor_rollout_ref.rollout.extra_configs.true_cfg_scale=4.0 \ + +actor_rollout_ref.rollout.extra_configs.noise_level=1.0 \ + +actor_rollout_ref.rollout.extra_configs.sde_type="sde" \ + +actor_rollout_ref.rollout.extra_configs.sde_window_size=2 \ + +actor_rollout_ref.rollout.extra_configs.sde_window_range="[0,4]" \ + +actor_rollout_ref.rollout.extra_configs.max_sequence_length=${max_prompt_length} \ + +actor_rollout_ref.rollout.val_kwargs.extra_configs.noise_level=0.0 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=examples.flowgrpo_trainer.vllm_omni.pipeline_qwenimage.QwenImagePipelineWithLogProb \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${micro_bsz_per_gpu} \ + reward.num_workers=1 \ + reward.reward_manager.name=visual \ + reward.reward_model.enable=False \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger=console \ + trainer.project_name=verl-test \ + trainer.experiment_name=flowgrpo-diffusion-e2e \ + trainer.log_val_generations=0 \ + trainer.n_gpus_per_node=${NUM_GPUS} \ + trainer.nnodes=1 \ + trainer.val_before_train=False \ + trainer.test_freq=-1 \ + trainer.save_freq=-1 \ + trainer.resume_mode=disable \ + trainer.total_training_steps=${TOTAL_TRAIN_STEPS} \ + "$@" + +echo "FlowGRPO diffusion e2e test passed (training completed successfully)." diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index 3f285acc8f5..31d82100f5c 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -33,6 +33,7 @@ "verl/utils/rendezvous/ray_backend.py", # appear in cupy importance "verl/single_controller/ray/base.py", # appear in default device_name "verl/trainer/ppo/ray_trainer.py", # appear in default device_name + "verl/trainer/diffusion/ray_diffusion_trainer.py", # appear in default device_name "verl/experimental/transfer_queue/ray_trainer.py", # appear in docstring as default device_name "verl/experimental/one_step_off_policy/ray_trainer.py", # appear in docstring as default device_name "verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type diff --git a/tests/special_sanity/check_pr_title.py b/tests/special_sanity/check_pr_title.py index df316d3d080..19bb95df63d 100644 --- a/tests/special_sanity/check_pr_title.py +++ b/tests/special_sanity/check_pr_title.py @@ -19,7 +19,7 @@ pr_title = os.environ.get("PR_TITLE", "").strip() # Define rules -allowed_modules = ["fsdp", "megatron", "veomni", "sglang", "vllm", "trtllm", "rollout", "trainer"] +allowed_modules = ["fsdp", "megatron", "veomni", "sglang", "vllm", "vllm_omni", "trtllm", "rollout", "trainer"] allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"] allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"] allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg", "reward"] diff --git a/tests/trainer/diffusion/__init__.py b/tests/trainer/diffusion/__init__.py new file mode 100644 index 00000000000..d828409b82e --- /dev/null +++ b/tests/trainer/diffusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py b/tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py new file mode 100644 index 00000000000..1cf23521e4c --- /dev/null +++ b/tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py @@ -0,0 +1,81 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import torch + +from verl.trainer.diffusion import diffusion_algos + + +@pytest.mark.parametrize("norm_adv_by_std_in_grpo", [True, False]) +@pytest.mark.parametrize("global_std", [True, False]) +def test_flow_grpo_advantage_return(norm_adv_by_std_in_grpo: bool, global_std: bool) -> None: + batch_size = 8 + steps = 10 + sample_level_rewards = torch.randn((batch_size, 1), dtype=torch.float32) + response_mask = torch.ones((batch_size, steps), dtype=torch.int32) + uid = np.array([f"uid-{idx}" for idx in range(batch_size)], dtype=object) + + advantages, returns = diffusion_algos.compute_flow_grpo_outcome_advantage( + sample_level_rewards=sample_level_rewards, + response_mask=response_mask, + index=uid, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + global_std=global_std, + ) + + assert advantages.shape == returns.shape == (batch_size, steps) + + +def test_compute_policy_loss_flow_grpo() -> None: + from hydra import compose, initialize_config_dir + + from verl.utils.config import omega_conf_to_dataclass + from verl.workers.config.actor import FSDPActorConfig + + batch_size = 8 + steps = 10 + rollout_log_probs = torch.randn((batch_size, steps), dtype=torch.float32) + current_log_probs = torch.randn((batch_size, steps), dtype=torch.float32) + advantages = torch.randn((batch_size, steps), dtype=torch.float32) + response_mask = torch.ones((batch_size, steps), dtype=torch.int32) + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor"), version_base=None): + cfg = compose( + config_name="dp_actor", + overrides=[ + "strategy=fsdp", + "clip_ratio=0.0001", + "clip_ratio_high=5.0", + "ppo_micro_batch_size_per_gpu=8", + ], + ) + actor_config: FSDPActorConfig = omega_conf_to_dataclass(cfg) + + for step in range(steps): + pg_loss, pg_metrics = diffusion_algos.compute_policy_loss_flow_grpo( + old_log_prob=rollout_log_probs[:, step], + log_prob=current_log_probs[:, step], + advantages=advantages[:, step], + response_mask=response_mask[:, step], + loss_agg_mode="token-mean", + config=actor_config, + ) + + assert pg_loss.shape == () + assert isinstance(pg_loss.item(), float) + assert "actor/ppo_kl" in pg_metrics diff --git a/tests/trainer/ppo/test_core_algos_on_cpu.py b/tests/trainer/ppo/test_core_algos_on_cpu.py index 7f99475a738..288f28e6398 100644 --- a/tests/trainer/ppo/test_core_algos_on_cpu.py +++ b/tests/trainer/ppo/test_core_algos_on_cpu.py @@ -313,50 +313,5 @@ def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_grou assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6) -def test_compute_policy_loss_flow_grpo() -> None: - """Test flow-GRPO policy loss computation.""" - - # prepare input - batch_size = 8 - steps = 10 - rollout_log_probs = torch.randn((batch_size, steps), dtype=torch.float32) - current_log_probs = torch.randn((batch_size, steps), dtype=torch.float32) - advantages = torch.randn((batch_size, steps), dtype=torch.float32) - response_mask = torch.ones((batch_size, steps), dtype=torch.int32) - import os - - from hydra import compose, initialize_config_dir - - from verl.trainer.ppo.diffusion_algos import compute_policy_loss_flow_grpo - from verl.utils.config import omega_conf_to_dataclass - from verl.workers.config.actor import FSDPActorConfig - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config/actor")): - cfg = compose( - config_name="dp_actor", - overrides=[ - "strategy=fsdp", - "clip_ratio=0.0001", - "clip_ratio_high=5.0", - "ppo_micro_batch_size_per_gpu=8", - ], - ) - actor_config: FSDPActorConfig = omega_conf_to_dataclass(cfg) - - for step in range(steps): - pg_loss, pg_metrics = compute_policy_loss_flow_grpo( - old_log_prob=rollout_log_probs[:, step], - log_prob=current_log_probs[:, step], - advantages=advantages[:, step], - response_mask=response_mask[:, step], - loss_agg_mode="token-mean", - config=actor_config, - ) - - assert pg_loss.shape == () - assert isinstance(pg_loss.item(), float) - assert "actor/ppo_kl" in pg_metrics.keys() - - if __name__ == "__main__": unittest.main() diff --git a/verl/experimental/agent_loop/diffusion_agent_loop.py b/verl/experimental/agent_loop/diffusion_agent_loop.py index 1d91c6f338c..0882f444445 100644 --- a/verl/experimental/agent_loop/diffusion_agent_loop.py +++ b/verl/experimental/agent_loop/diffusion_agent_loop.py @@ -49,8 +49,6 @@ class DiffusionAgentLoopOutput(BaseModel): """Response diffusion output (torch.Tensor): image tensor (CHW) / video tensor (TCHW).""" response_logprobs: Optional[Any] = None """Log probabilities for the response tokens. (torch.Tensor)""" - multi_modal_data: Optional[dict[str, Any]] = None - """Multi-modal data for multi-modal tools.""" reward_score: Optional[float] = None """Reward score for the trajectory.""" num_turns: int = 0 @@ -75,9 +73,7 @@ class _InternalDiffusionAgentLoopOutput(DiffusionAgentLoopOutput): attention_mask: torch.Tensor """Padded attention mask.""" response_logprobs: Optional[torch.Tensor] = None - """Log probabilities for the response tokens.""" - multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None - """Multi-modal inputs for processors (e.g. pixel_values, image_grid_thw, video_grid_thw).""" + """Log probabilities over denoising timesteps.""" extra_fields: dict[str, Any] = {} """Extra fields for dynamic addition.""" @@ -121,7 +117,7 @@ def __init__( self.tokenizer = self.model_config.tokenizer self.processor = self.model_config.processor - self.max_prompt_embed_length = self.model_config.extra_configs.get( + self.max_prompt_embed_length = self.rollout_config.extra_configs.get( "max_sequence_length", self.rollout_config.prompt_length ) @@ -154,7 +150,7 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: """ config = self.rollout_config - sampling_params = dict(self.model_config.extra_configs) + sampling_params = dict(config.extra_configs) sampling_params.update( height=config.height, width=config.width, @@ -164,9 +160,9 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: # override sampling params for validation if batch.meta_info.get("validate", False): + sampling_params.update(config.val_kwargs.extra_configs) sampling_params["num_inference_steps"] = config.val_kwargs.num_inference_steps sampling_params["seed"] = config.val_kwargs.seed - sampling_params["noise_level"] = config.val_kwargs.noise_level # by default, we assume it's a single turn agent if "agent_name" not in batch.non_tensor_batch: @@ -209,12 +205,10 @@ async def _run_agent_loop( async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionAgentLoopOutput: """Perform post-processing operations on the output of each individual agent loop.""" - # handling extra tensor ouputs from vllm-omni, like prompt embedding, etc. + # Pad extra tensor outputs from vllm-omni (e.g. prompt embeddings). extra_fields = {} for k, v in output.extra_fields.items(): if isinstance(v, torch.Tensor): - # handle prompt embedding padding - # TODO (andy): reduce padding length for more effiency if k in ["prompt_embeds", "negative_prompt_embeds"]: pad_tuple = (0, 0, 0, self.max_prompt_embed_length - v.shape[0]) v = F.pad(v, pad_tuple, value=0) @@ -227,7 +221,6 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA extra_fields["raw_prompt"] = kwargs["raw_prompt"] - self.tokenizer.padding_side = "left" prompt_output = self.tokenizer.pad( {"input_ids": output.prompt_ids}, padding="max_length", @@ -266,7 +259,6 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA input_ids=input_ids, attention_mask=attention_mask, response_logprobs=response_logprobs, - multi_modal_data=output.multi_modal_data, reward_score=output.reward_score, num_turns=output.num_turns, metrics=output.metrics, @@ -352,11 +344,6 @@ def _postprocess( for key in reward_extra_keys: non_tensor_batch[key] = np.array([info[key] for info in reward_extra_infos]) - # Add multi_modal_inputs to non_tensor_batch if any samples have them - multi_modal_inputs_list = [input.multi_modal_inputs for input in inputs] - if any(mmi is not None for mmi in multi_modal_inputs_list): - non_tensor_batch["multi_modal_inputs"] = np.array(multi_modal_inputs_list, dtype=object) - metrics = [input.metrics.model_dump() for input in inputs] # Collect extra fields from all inputs and convert them to np.ndarray extra_fields = {} diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py index 27c0d46a43c..4ce42359ce6 100644 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -94,19 +94,6 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu class DiffusionSingleTurnAgentLoop(AgentLoopBase): """Agent loop for diffusion model serving.""" - # Keys from non_tensor_batch that are pipeline/dataset metadata and must - # NOT be forwarded to server_manager.generate() (which passes **kwargs - # down to the vllm-omni server that has a fixed signature). - _KEYS_EXCLUDED_FROM_GENERATE = frozenset( - { - "raw_prompt", - "raw_negative_prompt", - "data_source", - "reward_model", - "index", - } - ) - async def apply_chat_template( self, messages: list[dict], @@ -152,10 +139,8 @@ async def apply_chat_template( ) async def run(self, sampling_params: dict[str, Any], **kwargs) -> DiffusionAgentLoopOutput: - raw_prompt = kwargs.pop("raw_prompt") - raw_negative_prompt = kwargs.pop("raw_negative_prompt", None) - for key in self._KEYS_EXCLUDED_FROM_GENERATE: - kwargs.pop(key, None) + raw_prompt = kwargs["raw_prompt"] + raw_negative_prompt = kwargs.get("raw_negative_prompt") # 1. extract images and videos from messages multi_modal_data = await self.process_vision_info(raw_prompt) @@ -180,7 +165,6 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> DiffusionAgent image_data=images, video_data=videos, negative_prompt_ids=negative_prompt_ids, - **kwargs, ) if metrics.get("num_preempted") is None: metrics["num_preempted"] = output.num_preempted if output.num_preempted is not None else -1 @@ -189,7 +173,6 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> DiffusionAgent prompt_ids=prompt_ids, response_diffusion_output=output.diffusion_output, response_logprobs=output.log_probs, - multi_modal_data=multi_modal_data, num_turns=2, metrics=metrics, extra_fields=output.extra_fields, diff --git a/verl/trainer/config/_generated_diffusion_trainer.yaml b/verl/trainer/config/_generated_diffusion_trainer.yaml index 3f9ef3e41c8..986abc915ca 100644 --- a/verl/trainer/config/_generated_diffusion_trainer.yaml +++ b/verl/trainer/config/_generated_diffusion_trainer.yaml @@ -274,8 +274,8 @@ actor_rollout_ref: 'n': 1 do_sample: false num_inference_steps: 40 - noise_level: 0.0 seed: 42 + extra_configs: {} multi_turn: _target_: verl.workers.config.MultiTurnConfig enable: false @@ -352,6 +352,7 @@ actor_rollout_ref: height: 512 width: 512 num_inference_steps: 10 + extra_configs: {} model: _target_: verl.workers.config.DiffusionModelConfig path: ~/models/Qwen/Qwen-Image @@ -371,7 +372,7 @@ actor_rollout_ref: height: ${oc.select:actor_rollout_ref.rollout.height,512} width: ${oc.select:actor_rollout_ref.rollout.width,512} num_inference_steps: ${oc.select:actor_rollout_ref.rollout.num_inference_steps,10} - extra_configs: {} + extra_configs: ${oc.select:actor_rollout_ref.rollout.extra_configs,{}} model_type: diffusion_model hybrid_engine: true nccl_timeout: 600 @@ -415,7 +416,6 @@ data: path: null name: null apply_chat_template_kwargs: {} - data_source: prompt critic: optim: _target_: verl.workers.config.FSDPOptimizerConfig @@ -619,8 +619,10 @@ algorithm: bypass_mode: false loss_type: ppo_clip rollout_is_batch_normalize: false - _target_: verl.trainer.config.AlgoConfig + _target_: verl.trainer.config.DiffusionAlgoConfig + adv_estimator: flow_grpo norm_adv_by_std_in_grpo: true + global_std: true trainer: balance_batch: true total_epochs: 30 @@ -650,7 +652,7 @@ trainer: max_critic_ckpt_to_keep: null ray_wait_register_center_timeout: 300 device: cuda - use_legacy_worker_impl: auto + use_legacy_worker_impl: disable ray_kwargs: ray_init: num_cpus: null diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index a150ee63394..99fb6f98b1d 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -17,7 +17,7 @@ from verl.base_config import BaseConfig -__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig", "RolloutCorrectionConfig"] +__all__ = ["AlgoConfig", "DiffusionAlgoConfig", "FilterGroupsConfig", "KLControlConfig", "RolloutCorrectionConfig"] @dataclass @@ -667,3 +667,13 @@ class AlgoConfig(BaseConfig): # gdpo_reward_weights: per-dimension weights for aggregation (default: equal weights). gdpo_reward_keys: Optional[list[str]] = None gdpo_reward_weights: Optional[list[float]] = None + + +@dataclass +class DiffusionAlgoConfig(BaseConfig): + """Diffusion-specific algorithm config.""" + + adv_estimator: str = "flow_grpo" + norm_adv_by_std_in_grpo: bool = True + rollout_correction: Optional[RolloutCorrectionConfig] = None + global_std: bool = True diff --git a/verl/trainer/config/diffusion_trainer.yaml b/verl/trainer/config/diffusion_trainer.yaml index 78bd357f573..656a8172ecd 100644 --- a/verl/trainer/config/diffusion_trainer.yaml +++ b/verl/trainer/config/diffusion_trainer.yaml @@ -15,7 +15,7 @@ defaults: # data: trainer/config/data/legacy_data.yaml - data@data: legacy_data - # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + # Reference model will be enabled when actor.use_kl_loss is True. - ref@actor_rollout_ref.ref: ${model_engine}_ref # Rollout model config (vLLM-Omni diffusion rollout). @@ -40,12 +40,6 @@ defaults: # self config override anything above - _self_ -# Dataset config (merges with legacy_data from defaults) -data: - - # get ground-truth based on data_source, now support ["ocr", "prompt"] - data_source: "prompt" - # config for actor, rollout and reference model actor_rollout_ref: @@ -80,11 +74,17 @@ actor_rollout_ref: algorithm: # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs - _target_: verl.trainer.config.AlgoConfig + _target_: verl.trainer.config.DiffusionAlgoConfig + + # Advantage estimator type: "flow_grpo" + adv_estimator: flow_grpo # Whether to normalize advantages by std (specific to GRPO) norm_adv_by_std_in_grpo: True + # Whether to normalize advantages using global standard deviation + global_std: True + # config for the trainer trainer: @@ -175,7 +175,7 @@ trainer: # whether to use legacy worker implementation # mode: "auto", "enable", or "disable" - use_legacy_worker_impl: auto + use_legacy_worker_impl: disable # configs related to ray ray_kwargs: diff --git a/verl/trainer/config/model/diffusion_model.yaml b/verl/trainer/config/model/diffusion_model.yaml index 33ed0894ac3..7544b913d22 100644 --- a/verl/trainer/config/model/diffusion_model.yaml +++ b/verl/trainer/config/model/diffusion_model.yaml @@ -58,8 +58,4 @@ width: ${oc.select:actor_rollout_ref.rollout.width,512} num_inference_steps: ${oc.select:actor_rollout_ref.rollout.num_inference_steps,10} # extra configs for algorithm specific features. -# Model-specific diffusion sampling params (e.g. true_cfg_scale, guidance_scale, -# max_sequence_length, noise_level) should be placed here so the agent loop stays -# backend-neutral. The rollout server's backend translation layer will promote -# matching keys to direct OmniDiffusionSamplingParams fields automatically. -extra_configs: {} +extra_configs: ${oc.select:actor_rollout_ref.rollout.extra_configs,{}} diff --git a/verl/trainer/config/rollout/diffusion_rollout.yaml b/verl/trainer/config/rollout/diffusion_rollout.yaml index 4e493abcaec..1b8d9d2292b 100644 --- a/verl/trainer/config/rollout/diffusion_rollout.yaml +++ b/verl/trainer/config/rollout/diffusion_rollout.yaml @@ -23,6 +23,11 @@ width: 512 # number of inference steps for diffusion model rollout num_inference_steps: 10 +# extra configs for algorithm specific features. +# Model-specific diffusion sampling params (e.g. true_cfg_scale, guidance_scale, +# max_sequence_length, noise_level) +extra_configs: {} + # Extra inference engine arguments: add vllm_omni for diffusion engine_kwargs: @@ -38,14 +43,11 @@ val_kwargs: # whether to repeat n times for validation n: 1 - # Whether to sample during training rollout. False uses greedy sampling. - do_sample: False - # number of inference steps for diffusion model rollout num_inference_steps: 40 - # noise level for diffusion model rollout - noise_level: 0.0 - # random seed for validation seed: 42 + + # extra configs for algorithm specific features during validation. + extra_configs: {} diff --git a/verl/trainer/diffusion/__init__.py b/verl/trainer/diffusion/__init__.py new file mode 100644 index 00000000000..d828409b82e --- /dev/null +++ b/verl/trainer/diffusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/trainer/ppo/diffusion_algos.py b/verl/trainer/diffusion/diffusion_algos.py similarity index 53% rename from verl/trainer/ppo/diffusion_algos.py rename to verl/trainer/diffusion/diffusion_algos.py index 5992a63d949..f2d49ac9cee 100644 --- a/verl/trainer/ppo/diffusion_algos.py +++ b/verl/trainer/diffusion/diffusion_algos.py @@ -13,15 +13,107 @@ # limitations under the License. """Diffusion-specific policy loss functions and KL penalties.""" +from collections import defaultdict +from enum import Enum from typing import Any, Optional +import numpy as np import torch from omegaconf import DictConfig -from verl.trainer.ppo.core_algos import register_policy_loss +from verl.trainer.ppo.core_algos import register_adv_est, register_policy_loss from verl.workers.config import ActorConfig +class DiffusionAdvantageEstimator(str, Enum): + """Advantage estimators specific to diffusion-based policy training.""" + + FLOW_GRPO = "flow_grpo" + + +@register_adv_est(DiffusionAdvantageEstimator.FLOW_GRPO) +def compute_flow_grpo_outcome_advantage( + sample_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-4, + norm_adv_by_std_in_grpo: bool = True, + global_std: bool = True, + config: Optional[DictConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + sample_level_rewards: `(torch.Tensor)` + shape is (bs, ), (bs, 1) or (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length) + index: `(np.ndarray)` + index array for grouping + epsilon: `(float)` + small value to avoid division by zero + norm_adv_by_std_in_grpo: `(bool)` + whether to scale the GRPO advantage + global_std: `(bool)` + whether to use global std for advantage normalization + config: `(Optional[DictConfig])` + algorithm configuration object + + Note: + If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + + Returns: + advantages: `(torch.Tensor)` + shape is (bs, response_length) + Returns: `(torch.Tensor)` + shape is (bs, response_length) + """ + scores = sample_level_rewards + if scores.ndim == 1: + scores = scores.unsqueeze(-1) + scores = scores.expand_as(response_mask).clone() + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + if global_std: + batch_std = torch.std(scores) + else: + batch_std = None + + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = id2score[idx][0] + if global_std: + id2std[idx] = batch_std + else: + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + scores_tensor = torch.stack(id2score[idx]) + id2mean[idx] = torch.mean(scores_tensor) + if global_std: + id2std[idx] = batch_std + else: + id2std[idx] = torch.std(scores_tensor) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + if norm_adv_by_std_in_grpo: + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + else: + scores[i] = scores[i] - id2mean[index[i]] + + return scores, scores + + @register_policy_loss("flow_grpo") def compute_policy_loss_flow_grpo( old_log_prob: torch.Tensor, diff --git a/verl/trainer/diffusion/diffusion_metric_utils.py b/verl/trainer/diffusion/diffusion_metric_utils.py new file mode 100644 index 00000000000..257060bd0d9 --- /dev/null +++ b/verl/trainer/diffusion/diffusion_metric_utils.py @@ -0,0 +1,137 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Metrics for diffusion (image generation) training. +""" + +from typing import Any + +import numpy as np +import torch + +from verl import DataProto + + +def compute_data_metrics_diffusion(batch: DataProto) -> dict[str, Any]: + """ + Computes various metrics from a diffusion training batch. + + For diffusion (image generation) models, rewards and advantages are + indexed over denoising timesteps rather than output tokens. + + Args: + batch: A DataProto object containing diffusion batch data with + sample_level_rewards [B, 1], advantages [B, T], returns [B, T]. + + Returns: + A dictionary of metrics including: + - critic/rewards/mean, max, min: Per-image reward statistics + - critic/rewards/zero_std_ratio: Fraction of prompt groups whose reward std is zero + - critic/rewards/std_mean: Mean per-prompt reward standard deviation + - critic/rewards/group_size: Average number of images sampled per unique prompt + - critic/advantages/mean, max, min: Element-wise advantage statistics over B*T + - critic/returns/mean, max, min: Element-wise return statistics over B*T + """ + sequence_reward = batch.batch["sample_level_rewards"].squeeze(-1) # [B] + + # Flatten [B, T] tensors for aggregate statistics across timesteps + advantages = batch.batch["advantages"].flatten() # [B*T] + returns = batch.batch["returns"].flatten() # [B*T] + + reward_mean = torch.mean(sequence_reward).detach().item() + reward_max = torch.max(sequence_reward).detach().item() + reward_min = torch.min(sequence_reward).detach().item() + + metrics = { + # reward + "critic/rewards/mean": reward_mean, + "critic/rewards/max": reward_max, + "critic/rewards/min": reward_min, + # adv + "critic/advantages/mean": torch.mean(advantages).detach().item(), + "critic/advantages/max": torch.max(advantages).detach().item(), + "critic/advantages/min": torch.min(advantages).detach().item(), + # returns + "critic/returns/mean": torch.mean(returns).detach().item(), + "critic/returns/max": torch.max(returns).detach().item(), + "critic/returns/min": torch.min(returns).detach().item(), + } + + if "uid" in batch.non_tensor_batch: + rewards_np = sequence_reward.cpu().float().numpy() + uid_array = np.array(batch.non_tensor_batch["uid"]) + unique_uids = np.unique(uid_array) + + per_prompt_stds = np.array([np.std(rewards_np[uid_array == uid]) for uid in unique_uids]) + + metrics["critic/rewards/zero_std_ratio"] = float(np.mean(per_prompt_stds == 0)) + metrics["critic/rewards/std_mean"] = float(np.mean(per_prompt_stds)) + metrics["critic/rewards/group_size"] = float(len(rewards_np) / len(unique_uids)) + + return metrics + + +def compute_timing_metrics_diffusion(timing_raw: dict[str, float], num_images: int) -> dict[str, Any]: + """ + Computes timing metrics for diffusion training. + + Args: + timing_raw: A dictionary mapping stage names to their execution times in seconds. + num_images: Total number of images processed in the batch, used to compute per-image timing. + + Returns: + A dictionary containing: + - timing_s/{name}: Raw timing in seconds for each stage + - timing_per_image_ms/{name}: Per-image timing in milliseconds for core compute stages + (gen, ref, old_log_prob, adv, update_actor). Non-compute stages such as + save_checkpoint, update_weights, and testing are excluded. + """ + num_images_of_section = {name: num_images for name in ["gen", "ref", "old_log_prob", "adv", "update_actor"]} + + return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, + **{ + f"timing_per_image_ms/{name}": timing_raw[name] * 1000 / num_images_of_section[name] + for name in set(num_images_of_section.keys()) & set(timing_raw.keys()) + }, + } + + +def compute_throughput_metrics_diffusion(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: + """ + Computes throughput metrics for diffusion (image/video generation) training. + + Unlike language model training where throughput is measured in tokens/sec, + diffusion training generates images, so throughput is reported as images + per second. + + Args: + batch: A DataProto object containing diffusion batch data. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + Must contain a "step" key with the total step time. + n_gpus: Number of GPUs used for training. + + Returns: + A dictionary containing: + - perf/total_num_images: Number of images processed in the batch + - perf/time_per_step: Time taken for the step in seconds + - perf/throughput: Images generated per second per GPU + """ + batch_size = batch.batch["advantages"].shape[0] + time = timing_raw["step"] + return { + "perf/total_num_images": batch_size, + "perf/time_per_step": time, + "perf/throughput": batch_size / (time * n_gpus), + } diff --git a/verl/trainer/diffusion/ray_diffusion_trainer.py b/verl/trainer/diffusion/ray_diffusion_trainer.py new file mode 100644 index 00000000000..fe0be9d15c1 --- /dev/null +++ b/verl/trainer/diffusion/ray_diffusion_trainer.py @@ -0,0 +1,1085 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Flow-GRPO / diffusion trainer with a Ray-based single controller. +This trainer supports model-agnostic model initialization with Hugging Face. +""" + +import json +import os +import uuid +from collections import defaultdict +from pprint import pprint +from typing import Any, Optional + +import numpy as np +import torch +from omegaconf import OmegaConf, open_dict +from PIL import Image +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.checkpoint_engine import CheckpointEngineManager +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import DiffusionAlgoConfig +from verl.trainer.diffusion.diffusion_algos import DiffusionAdvantageEstimator +from verl.trainer.diffusion.diffusion_metric_utils import ( + compute_data_metrics_diffusion, + compute_throughput_metrics_diffusion, + compute_timing_metrics_diffusion, +) +from verl.trainer.ppo.core_algos import get_adv_estimator_fn +from verl.trainer.ppo.metric_utils import compute_variance_proxy_metrics, process_validation_metrics +from verl.trainer.ppo.reward import extract_reward +from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model +from verl.utils import tensordict_utils as tu +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.import_utils import load_class_from_fqn +from verl.utils.metric import reduce_metrics +from verl.utils.py_functional import rename_dict +from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.utils.padding import embeds_padding_2_no_padding + + +def compute_response_mask(data: DataProto): + """Compute the valid-step mask for diffusion latents. + + For diffusion models, every denoising timestep is a valid optimization step, + so the returned mask is all-ones covering all timesteps. + + Args: + data (DataProto): The data containing batched diffusion model outputs, including ``all_latents``. + + Returns: + torch.Tensor: An all-ones int32 mask of shape ``[batch, num_timesteps]``. + """ + all_latents = data.batch["all_latents"] + b, t = all_latents.shape[:2] + return torch.ones((b, t), dtype=torch.int32, device=all_latents.device) + + +def compute_advantage( + data: DataProto, + adv_estimator: str, + norm_adv_by_std_in_grpo: bool = True, + global_std: bool = True, + config: Optional[DiffusionAlgoConfig] = None, +) -> DataProto: + """Compute advantage estimates for diffusion policy optimization. + + This function computes advantage estimates for diffusion models using the registered + advantage estimator (e.g., Flow-GRPO). The advantage estimates are used to guide + policy optimization across denoising timesteps. + + Args: + data (DataProto): The data containing batched diffusion model outputs and inputs. + adv_estimator (str): Name of the advantage estimator to use (e.g., Flow-GRPO). + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard + deviation in GRPO. Defaults to True. + global_std (bool, optional): Whether to use global standard deviation for normalization. + Defaults to True. + config (DiffusionAlgoConfig, optional): Configuration object for algorithm settings. + Defaults to None. + + Returns: + DataProto: The updated data with computed ``advantages`` and ``returns`` in its batch. + """ + if "response_mask" not in data.batch.keys(): + data.batch["response_mask"] = compute_response_mask(data) + + adv_kwargs = { + "sample_level_rewards": data.batch["sample_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + + adv_estimator_fn = get_adv_estimator_fn(adv_estimator) + if adv_estimator == DiffusionAdvantageEstimator.FLOW_GRPO: + adv_kwargs["norm_adv_by_std_in_grpo"] = norm_adv_by_std_in_grpo + adv_kwargs["global_std"] = global_std + advantages, returns = adv_estimator_fn(**adv_kwargs) + + data.batch["advantages"] = advantages + data.batch["returns"] = returns + return data + + +class RayFlowGRPOTrainer: + """Distributed Flow-GRPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration. + """ + + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, + processor=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping or Role.ActorRolloutRef in role_worker_mapping, ( + f"{role_worker_mapping.keys()=}" + ) + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(self.config) + + self.use_rm = need_reward_model(self.config) + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + # if ref_in_actor is True, the reference policy will be actor without lora applied + lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + self.checkpoint_manager = None + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("train_max_samples", -1), + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("val_max_samples", -1), + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + os.makedirs(dump_path, exist_ok=True) + + visual_folder = os.path.join(dump_path, f"{self.global_steps}") + os.makedirs(visual_folder, exist_ok=True) + + output_paths = [] + images_pil = outputs.cpu().float().permute(0, 2, 3, 1).numpy() + images_pil = (images_pil * 255).round().clip(0, 255).astype("uint8") + for i, image in enumerate(images_pil): + image_path = os.path.join(visual_folder, f"{i}.jpg") + Image.fromarray(image).save(image_path) + output_paths.append(image_path) + + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": output_paths, + "gts": gts, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _log_rollout_data( + self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + ): + """Log rollout data to disk. + Args: + batch (DataProto): The batch containing rollout data + reward_extra_infos_dict (dict): Additional reward information to log + timing_raw (dict): Timing information for profiling + rollout_data_dir (str): Directory path to save the rollout data + """ + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["input_ids"], skip_special_tokens=True) + outputs = batch.batch["responses"] + scores = batch.batch["sample_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + reward_extra_infos_to_dump = reward_extra_infos_dict.copy() + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_to_dump.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_to_dump, + dump_path=rollout_data_dir, + ) + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + if "wandb" in self.config.trainer.logger: + import wandb + + outputs = [wandb.Image(image.float(), file_type="jpg") for image in outputs] + samples = list(zip(inputs, outputs, scores, strict=True)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _get_gen_batch(self, batch: DataProto) -> DataProto: + reward_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = [] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + + def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: + """ + compute reward use colocate reward model + """ + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + batch_reward = self.reward_loop_manager.compute_rm_score(batch) + return batch_reward + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_gts = [] + sample_scores = [] + sample_turns = [] + sample_uids = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + ground_truths = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + sample_gts.extend(ground_truths) + + test_gen_batch = self._get_gen_batch(test_batch) + test_gen_batch.meta_info = { + "recompute_log_prob": False, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = self.config.actor_rollout_ref.rollout.agent.num_workers + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + if self.use_rm and "rm_scores" not in test_output_gen_batch_padded.batch.keys(): + # for colocate reward models, we need to sleep rollout model + # to spare GPU memory for reward model + self.checkpoint_manager.sleep_replicas() + batch_reward = self._compute_reward_colocate(test_output_gen_batch_padded) + test_output_gen_batch_padded = test_output_gen_batch_padded.union(batch_reward) + # wake up rollout model + # replace with wake_up method once supported + self.checkpoint_manager.update_weights(self.global_steps) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_images = test_output_gen_batch.batch["responses"] + sample_outputs.append(output_images) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # Store original inputs + input_ids = test_batch.batch["prompts"] + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) + + # evaluate using reward_function + reward_tensor, reward_extra_info = extract_reward(test_batch) + + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + for key, values in reward_extra_info.items(): + if key not in reward_extra_infos_dict: + reward_extra_infos_dict[key] = [] + if isinstance(values, np.ndarray): + reward_extra_infos_dict[key].extend(values.tolist()) + else: + reward_extra_infos_dict[key].extend(values if isinstance(values, list) else [values]) + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + sample_outputs = torch.cat(sample_outputs, dim=0) + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + return self._val_metrics_update(data_sources, sample_uids, reward_extra_infos_dict, sample_turns) + + def _val_metrics_update(self, data_sources, sample_uids, reward_extra_infos_dict, sample_turns): + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + return metric_dict + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each enabled role (actor/ref/reward) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout + if self.hybrid_engine: + actor_rollout_resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[actor_role], + config=self.config.actor_rollout_ref, + role=str(actor_role), + ) + self.resource_pool_to_cls[actor_rollout_resource_pool][str(actor_role)] = actor_rollout_cls + else: + raise NotImplementedError + + # create reference policy if needed + if self.use_reference_policy and Role.RefPolicy in self.role_worker_mapping: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + if not class_dict: + continue + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_reference_policy and not self.ref_in_actor: + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg[str(actor_role)] + self.actor_rollout_wg.init_model() + + if self.ref_in_actor: + self.ref_policy_wg = self.actor_rollout_wg + + # create reward loop manager + from verl.experimental.reward_loop import RewardLoopManager + + # initalize reward loop manager + # reward model (colocate or standalone): get resource_pool + # no reward model: resource_pool = None + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None + self.reward_loop_manager = RewardLoopManager( + config=self.config, + rm_resource_pool=resource_pool, + ) + + # create async rollout manager and request scheduler + # Note: mode is always "async" since sync mode is deprecated + self.async_rollout_mode = True + + # Support custom AgentLoopManager via config + manager_class_fqn = self.config.actor_rollout_ref.rollout.get("agent", {}).get("agent_loop_manager_class") + if manager_class_fqn: + AgentLoopManager = load_class_from_fqn(manager_class_fqn, "AgentLoopManager") + else: + from verl.experimental.agent_loop import AgentLoopManager + + # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design + # agent_reward_loop: streaming reward computation with actor rollout + # two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool + enable_agent_reward_loop = not self.use_rm or self.config.reward.reward_model.enable_resource_pool + + # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager + # to stream reward computation with actor rollout + reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None + self.async_rollout_manager = AgentLoopManager.create( + config=self.config, + worker_group=self.actor_rollout_wg, + rollout_resource_pool=actor_rollout_resource_pool, + reward_loop_worker_handles=reward_loop_worker_handles, + ) + + checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine) + self.checkpoint_manager = CheckpointEngineManager( + config=checkpoint_engine_config, + trainer=self.actor_rollout_wg, + replicas=self.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + self.checkpoint_manager.sleep_replicas() + + def _save_checkpoint(self): + from verl.utils.fs import local_mkdir_safe + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print("Warning: remove_previous_ckpt_in_save is deprecated," + " set max_actor_ckpt_to_keep=1 instead") + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # latest checkpointed iteration tracker (for atomic usage) + if ( + hasattr(self.config.actor_rollout_ref.actor.checkpoint, "async_save") + and self.config.actor_rollout_ref.actor.checkpoint.async_save + ) or ( + "async_save" in self.config.actor_rollout_ref.actor.checkpoint + and self.config.actor_rollout_ref.actor.checkpoint["async_save"] + ): + print("skip write latest_checkpointed_iteration.txt when async_save is True") + return + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + + # load dataloader, + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _compute_ref_log_prob(self, batch: DataProto) -> DataProto: + batch_td = batch.to_tensordict() + batch_td = embeds_padding_2_no_padding(batch_td) + tu.pop(batch_td, key="input_ids") + metadata = { + "compute_loss": False, + "height": self.config.actor_rollout_ref.model.height, + "width": self.config.actor_rollout_ref.model.width, + "vae_scale_factor": self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), + } + if self.ref_in_actor: + metadata["no_lora_adapter"] = True + tu.assign_non_tensor(batch_td, **metadata) + if self.ref_in_actor: + output = self.actor_rollout_wg.compute_log_prob(batch_td) + else: + output = self.ref_policy_wg.compute_ref_log_prob(batch_td) + # gather output + log_probs = tu.get(output, "log_probs") + prev_sample_mean = tu.get(output, "prev_sample_mean") + ref_log_prob = tu.get_tensordict( + {"ref_log_prob": log_probs.float(), "ref_prev_sample_mean": prev_sample_mean.float()} + ) + return DataProto.from_tensordict(ref_log_prob) + + def _compute_old_log_prob(self, batch: DataProto): + batch_td = batch.to_tensordict() + batch_td = embeds_padding_2_no_padding(batch_td) + tu.pop(batch_td, key="input_ids") + tu.assign_non_tensor( + batch_td, + compute_loss=False, + height=self.config.actor_rollout_ref.model.height, + width=self.config.actor_rollout_ref.model.width, + vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), + ) + output = self.actor_rollout_wg.compute_log_prob(batch_td) + log_probs = tu.get(output, "log_probs") + old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float()}) + return DataProto.from_tensordict(old_log_prob) + + def _update_actor(self, batch: DataProto) -> DataProto: + rollout_config = self.config.actor_rollout_ref.rollout + batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable + # update actor + batch_td = batch.to_tensordict() + # step 2: convert from padding to no-padding + batch_td = embeds_padding_2_no_padding(batch_td) + tu.pop(batch_td, key="input_ids") + ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs + seed = self.config.actor_rollout_ref.actor.data_loader_seed + shuffle = self.config.actor_rollout_ref.actor.shuffle + tu.assign_non_tensor( + batch_td, + global_batch_size=ppo_mini_batch_size, + mini_batch_size=ppo_mini_batch_size, + epochs=ppo_epochs, + seed=seed, + dataloader_kwargs={"shuffle": shuffle}, + height=self.config.actor_rollout_ref.model.height, + width=self.config.actor_rollout_ref.model.width, + vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), + ) + + actor_output = self.actor_rollout_wg.update_actor(batch_td) + actor_output = tu.get(actor_output, "metrics") + actor_output = rename_dict(actor_output, "actor/") + return DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output}) + + def fit(self): + """ + The training loop of FlowGRPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint and update weights before doing anything + self._load_checkpoint() + self.checkpoint_manager.update_weights(self.global_steps) + + current_epoch = self.global_steps // len(self.train_dataloader) + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + for epoch in range(current_epoch, self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) + metrics = {} + timing_raw = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + self.checkpoint_manager.sleep_replicas() + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + batch_reward = self._compute_reward_colocate(batch) + batch = batch.union(batch_reward) + + # extract reward_tensor and reward_extra_infos_dict for training + reward_tensor, reward_extra_infos_dict = extract_reward(batch) + + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self._compute_old_log_prob(batch) + batch = batch.union(old_log_prob) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + ref_log_prob = self._compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + batch.batch["sample_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + batch.batch["sample_level_rewards"] = batch.batch["sample_level_scores"] + + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + global_std=self.config.algorithm.global_std, + config=self.config.algorithm, + ) + + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + actor_output = self._update_actor(batch) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 + or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights(self.global_steps) + + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if self.config.trainer.test_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.test_freq == 0 + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics_diffusion(batch=batch)) + n_gpus = self.resource_pool_manager.get_n_gpus() + num_images = batch.batch["advantages"].shape[0] + metrics.update(compute_timing_metrics_diffusion(timing_raw=timing_raw, num_images=num_images)) + metrics.update(compute_throughput_metrics_diffusion(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # compute variance proxy metrics + gradient_norm = metrics.get("actor/grad_norm", None) + metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm)) + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/verl/trainer/main_flowgrpo.py b/verl/trainer/main_flowgrpo.py new file mode 100644 index 00000000000..33c7e66a0b3 --- /dev/null +++ b/verl/trainer/main_flowgrpo.py @@ -0,0 +1,266 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Entrypoint for FlowGRPO / diffusion model training. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.experimental.reward_loop import migrate_legacy_reward_impl +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.diffusion.ray_diffusion_trainer import RayFlowGRPOTrainer +from verl.trainer.ppo.utils import need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import auto_set_device + + +@hydra.main(config_path="config", config_name="diffusion_trainer", version_base=None) +def main(config): + """Main entry point for FlowGRPO / diffusion model training with Hydra configuration management. + + Args: + config: Hydra configuration dictionary containing training parameters. + """ + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + config = migrate_legacy_reward_impl(config) + run_flowgrpo(config) + + +def run_flowgrpo(config, task_runner_class=None) -> None: + """Initialize Ray cluster and run distributed FlowGRPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed FlowGRPO training including Ray initialization settings, + model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + runner = task_runner_class.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +class TaskRunner: + """Ray remote class for executing distributed FlowGRPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role + + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl != "disable": + raise NotImplementedError( + "FlowGRPO only supports the new engine path (trainer.use_legacy_worker_impl=disable)." + ) + + from verl.workers.engine_workers import ActorRolloutRefWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + if need_reference_policy(config) and not ref_in_actor: + role = Role.ActorRolloutRef + else: + role = Role.ActorRollout + self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) + self.mapping[role] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + def init_resource_pool_mgr(self, config): + """Initialize resource pool manager.""" + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + + if config.reward.reward_model.enable_resource_pool: + if config.reward.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward.reward_model.n_gpus_per_node must be greater than 0") + if config.reward.reward_model.nnodes <= 0: + raise ValueError("config.reward.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward.reward_model.n_gpus_per_node] * config.reward.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + else: + config.reward.reward_model.nnodes = config.trainer.nnodes + config.reward.reward_model.n_gpus_per_node = config.trainer.n_gpus_per_node + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_resource_pool(self, config): + """Add reward model worker if enabled.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.reward.reward_model.enable: + # we do not use reward model workers, so we only register reward model in resource pool + # without continue to register reward model worker in role mapping + if config.reward.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Add reference policy worker if KL loss or KL reward is used.""" + # Ref policy has been fused into ActorRolloutRefWorker in new model engine. + # we don't need to add a separate ref policy worker group. + return + + def run(self, config): + """Execute the main FlowGRPO training workflow. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the FlowGRPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + + self.add_reward_model_resource_pool(config) + + # Add a reference policy worker if KL loss is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(config), + use_critic=False, + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(config.actor_rollout_ref.model.tokenizer_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + if os.path.exists(os.path.join(local_path, "processor")): + processor_path = os.path.join(local_path, "processor") + else: + processor_path = local_path + processor = hf_processor(processor_path, trust_remote_code=trust_remote_code, use_fast=True) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the FlowGRPO trainer. + trainer = RayFlowGRPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + # Start the training process. + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/verl/trainer/ppo/utils.py b/verl/trainer/ppo/utils.py index 315c7e548b3..876eaad37ae 100644 --- a/verl/trainer/ppo/utils.py +++ b/verl/trainer/ppo/utils.py @@ -76,7 +76,7 @@ def need_reference_policy( config: DictConfig, ) -> bool: """Given the config, do we need ref policy.""" - return config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss + return config.algorithm.get("use_kl_in_reward", False) or config.actor_rollout_ref.actor.use_kl_loss def need_teacher_policy( diff --git a/verl/utils/config.py b/verl/utils/config.py index ccac5f1764f..6aca40f3c52 100644 --- a/verl/utils/config.py +++ b/verl/utils/config.py @@ -166,7 +166,7 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): "actor_rollout_ref.rollout", ) - if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + if config.algorithm.get("use_kl_in_reward", False) and config.actor_rollout_ref.actor.use_kl_loss: print("NOTICE: You have both enabled in-reward kl and kl loss.") # critic diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 9e38a24fdf3..dc5a73d51e3 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -141,6 +141,9 @@ def __init__( self.shuffle = config.get("shuffle", False) self.seed = config.get("seed") + # For diffusion model training only + self.negative_prompt_key = config.get("negative_prompt_key", "negative_prompt") + self._download() self._read_files_and_tokenize() @@ -193,7 +196,7 @@ def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None): def doc2len(doc) -> int: try: - messages = self._build_messages(doc) + messages = self._build_messages(doc, key=self.prompt_key) # pass tool schemas if available so the processor can format prompts apply_kwargs = dict(**self.apply_chat_template_kwargs) if self.tool_schemas is not None: @@ -300,7 +303,7 @@ def __getstate__(self): def __len__(self): return len(self.dataframe) - def _build_messages(self, example: dict): + def _build_messages(self, example: dict, key: str): """Replace and