Skip to content
Open
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off`
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `vllm_omni`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off`
- If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title.
Expand Down
17 changes: 16 additions & 1 deletion .github/workflows/vllm_omni.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ on:
- ".github/workflows/vllm_omni.yml"
- "tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py"
- "tests/experimental/agent_loop/test_diffusion_agent_loop.py"
- "tests/special_e2e/run_flowgrpo_trainer_diffusers.sh"
- "tests/special_e2e/create_dummy_diffusion_data.py"
- "verl/workers/rollout/vllm_rollout/vllm_omni_async_server.py"
- "verl/trainer/diffusion/ray_diffusion_trainer.py"
- "verl/trainer/main_flowgrpo.py"
- "verl/trainer/config/diffusion_trainer.yaml"

# Cancel jobs on the same ref if a new one is triggered
concurrency:
Expand Down Expand Up @@ -95,7 +100,7 @@ jobs:
vllm_omni:
needs: setup
runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"]
timeout-minutes: 35 # Increase this timeout value as needed
timeout-minutes: 50 # Increased to accommodate e2e FlowGRPO training
env:
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
Expand All @@ -120,6 +125,16 @@ jobs:
run: |
ray stop --force
pytest tests/experimental/agent_loop/test_diffusion_agent_loop.py -v -s
- name: Install diffusers for e2e training
run: |
pip3 install diffusers==0.37.0
- name: Prepare dummy diffusion dataset
run: |
python3 tests/special_e2e/create_dummy_diffusion_data.py
- name: E2E FlowGRPO diffusion training
run: |
ray stop --force
bash tests/special_e2e/run_flowgrpo_trainer_diffusers.sh

cleanup:
runs-on: ubuntu-latest
Expand Down
104 changes: 104 additions & 0 deletions examples/flowgrpo_trainer/data_process/qwenimage_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the OCR dataset to parquet format (for Qwen-Image training).
You can obtain the raw dataset from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr
"""

import argparse
import os

import datasets

from verl.utils.hdfs_io import copy, makedirs


def extract_solution(solution_str):
# The solution is stored in the format: 'The image displays "xxx".'
return solution_str.split('"')[1]


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default=None)
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument(
"--local_dataset_path", default="~/dataset/ocr/", help="The local path to the raw dataset, if it exists."
)
parser.add_argument(
"--local_save_dir", default="~/data/ocr", help="The save directory for the preprocessed dataset."
)

args = parser.parse_args()
if args.local_dataset_path is not None:
local_dataset_path = os.path.expanduser(args.local_dataset_path)

data_source = "flow_grpo/ocr"

if local_dataset_path is not None:
dataset = datasets.load_dataset(local_dataset_path)
else:
raise NotImplementedError(
"It is not existed in huggingface hub. "
"Please get dataset from https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr"
)

train_dataset = dataset["train"]
test_dataset = dataset["test"]

system_prompt = (
"Describe the image by detailing the color, shape, size, "
"texture, quantity, text, spatial relationships of the objects and background:"
)
negative_user_prompt = " "

def make_map_fn(split):
def process_fn(example, idx):
text = example.pop("text")
solution = extract_solution(text)
data = {
"data_source": data_source,
"prompt": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
"negative_prompt": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": negative_user_prompt},
],
"ability": "ocr",
"reward_model": {"style": "model", "ground_truth": solution},
"extra_info": {"split": split, "index": idx},
}
return data

return process_fn

train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)

hdfs_dir = args.hdfs_dir
local_save_dir = args.local_dir
if local_save_dir is not None:
print("Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.")
else:
local_save_dir = args.local_save_dir

local_save_dir = os.path.expanduser(local_save_dir)
train_dataset.to_parquet(os.path.join(local_save_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_save_dir, "test.parquet"))

if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_save_dir, dst=hdfs_dir)
71 changes: 71 additions & 0 deletions examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Qwen-Image lora RL, vllm_omni rollout
set -x

ocr_train_path=$HOME/data/ocr/train.parquet
ocr_test_path=$HOME/data/ocr/test.parquet

ENGINE=vllm_omni
REWARD_ENGINE=vllm

reward_path=examples/flowgrpo_trainer/reward_fn.py
reward_model_name=$HOME/models/Qwen/Qwen3-VL-8B-Instruct


python3 -m verl.trainer.main_flowgrpo \
algorithm.adv_estimator=flow_grpo \
data.train_files=$ocr_train_path \
data.val_files=$ocr_test_path \
data.train_batch_size=32 \
data.max_prompt_length=256 \
actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen-Image \
actor_rollout_ref.model.tokenizer_path=$HOME/models/Qwen/Qwen-Image/tokenizer \
actor_rollout_ref.model.external_lib="examples.flowgrpo_trainer.diffusers.qwen_image" \
actor_rollout_ref.model.lora_rank=64 \
actor_rollout_ref.model.lora_alpha=128 \
actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \
actor_rollout_ref.actor.optim.lr=3e-4 \
actor_rollout_ref.actor.optim.weight_decay=0.0001 \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=$ENGINE \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \
actor_rollout_ref.rollout.agent.num_workers=4 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.rollout.val_kwargs.num_inference_steps=50 \
+actor_rollout_ref.rollout.extra_configs.true_cfg_scale=4.0 \
+actor_rollout_ref.rollout.extra_configs.noise_level=1.2 \
+actor_rollout_ref.rollout.extra_configs.sde_type="sde" \
+actor_rollout_ref.rollout.extra_configs.sde_window_size=2 \
+actor_rollout_ref.rollout.extra_configs.sde_window_range="[0,5]" \
+actor_rollout_ref.rollout.extra_configs.max_sequence_length=256 \
+actor_rollout_ref.rollout.val_kwargs.extra_configs.noise_level=0.0 \
+actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=examples.flowgrpo_trainer.vllm_omni.pipeline_qwenimage.QwenImagePipelineWithLogProb \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
reward.num_workers=4 \
reward.reward_manager.name=visual \
reward.reward_model.enable=True \
reward.reward_model.model_path=$reward_model_name \
reward.reward_model.rollout.name=$REWARD_ENGINE \
reward.reward_model.rollout.tensor_model_parallel_size=4 \
reward.custom_reward_function.path=$reward_path \
reward.custom_reward_function.name=compute_score_ocr \
trainer.use_legacy_worker_impl=disable \
trainer.logger='["console", "wandb"]' \
trainer.project_name=flow_grpo \
trainer.experiment_name=qwen_image_ocr_lora \
trainer.log_val_generations=8 \
trainer.val_before_train=False \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=30 \
trainer.test_freq=30 \
trainer.total_epochs=15 \
trainer.total_training_steps=300 $@
17 changes: 11 additions & 6 deletions examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def _maybe_to_cpu(v):
return v


def _coalesce_not_none(value, default):
return default if value is None else value


# Custom pipeline class for QwenImage that returns log probabilities during the diffusion process.
# This is compatible with API of vllm-omni custom pipeline
class QwenImagePipelineWithLogProb(QwenImagePipeline):
Expand Down Expand Up @@ -188,6 +192,7 @@ def diffuse(
generator=generator,
noise_level=cur_noise_level,
sde_type=sde_type,
return_logprobs=logprobs,
return_dict=False,
)

Expand Down Expand Up @@ -252,16 +257,16 @@ def forward(
num_inference_steps = sp.num_inference_steps or num_inference_steps
max_sequence_length = sp.max_sequence_length or max_sequence_length

noise_level = sp.extra_args.get("noise_level", None) or noise_level
sde_window_size = sp.extra_args.get("sde_window_size", None) or sde_window_size
sde_window_range = sp.extra_args.get("sde_window_range", None) or sde_window_range
sde_type = sp.extra_args.get("sde_type", None) or sde_type
logprobs = sp.extra_args.get("logprobs", None)
noise_level = _coalesce_not_none(sp.extra_args.get("noise_level", None), noise_level)
sde_window_size = _coalesce_not_none(sp.extra_args.get("sde_window_size", None), sde_window_size)
sde_window_range = _coalesce_not_none(sp.extra_args.get("sde_window_range", None), sde_window_range)
sde_type = _coalesce_not_none(sp.extra_args.get("sde_type", None), sde_type)
logprobs = _coalesce_not_none(sp.extra_args.get("logprobs", None), logprobs)

generator = sp.generator or generator
if generator is None and sp.seed is not None:
generator = torch.Generator(device=self.device).manual_seed(sp.seed)
true_cfg_scale = sp.true_cfg_scale or true_cfg_scale
true_cfg_scale = _coalesce_not_none(sp.true_cfg_scale, true_cfg_scale)
req_num_outputs = getattr(sp, "num_outputs_per_prompt", None)
if req_num_outputs and req_num_outputs > 0:
num_images_per_prompt = req_num_outputs
Expand Down
12 changes: 6 additions & 6 deletions tests/experimental/agent_loop/test_diffusion_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def init_config() -> DictConfig:
prompt_template_encode_start_idx = 34
max_length = tokenizer_max_length + prompt_template_encode_start_idx

with open_dict(config.actor_rollout_ref.model.extra_configs):
config.actor_rollout_ref.model.extra_configs.true_cfg_scale = 4.0
config.actor_rollout_ref.model.extra_configs.max_sequence_length = max_length
config.actor_rollout_ref.model.extra_configs.noise_level = 1.0
config.actor_rollout_ref.model.extra_configs.sde_window_size = 2
config.actor_rollout_ref.model.extra_configs.sde_window_range = [0, 5]
with open_dict(config.actor_rollout_ref.rollout.extra_configs):
config.actor_rollout_ref.rollout.extra_configs.true_cfg_scale = 4.0
config.actor_rollout_ref.rollout.extra_configs.max_sequence_length = max_length
config.actor_rollout_ref.rollout.extra_configs.noise_level = 1.0
config.actor_rollout_ref.rollout.extra_configs.sde_window_size = 2
config.actor_rollout_ref.rollout.extra_configs.sde_window_range = [0, 5]

config.actor_rollout_ref.rollout.nnodes = 1

Expand Down
1 change: 0 additions & 1 deletion tests/models/test_diffusers_fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def create_data_samples(num_device: int, model_config: DiffusionModelConfig) ->
data.meta_info["height"] = height
data.meta_info["width"] = width
data.meta_info["vae_scale_factor"] = vae_scale_factor
data.meta_info["gradient_accumulation_steps"] = 1

return data

Expand Down
93 changes: 93 additions & 0 deletions tests/special_e2e/create_dummy_diffusion_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create a small synthetic parquet dataset for FlowGRPO diffusion e2e testing.

The dataset uses the jpeg_compressibility reward (a self-contained rule-based
reward that needs no external reward model) so the e2e test can run without
spinning up a separate vLLM reward server.
"""

import argparse
import os

import pandas as pd

SYSTEM_PROMPT = (
"Describe the image by detailing the color, shape, size, "
"texture, quantity, text, spatial relationships of the objects and background:"
)

USER_PROMPTS = [
"A red circle on a white background",
"A blue square on a black background",
"A green triangle next to an orange rectangle",
"The word HELLO written in bold letters",
"A yellow star above a purple crescent moon",
"Two overlapping circles, one red and one blue",
"A gradient from dark blue to light blue",
"A checkerboard pattern of black and white squares",
]


def build_rows(split: str, n: int):
rows = []
for i in range(n):
prompt_text = USER_PROMPTS[i % len(USER_PROMPTS)]
rows.append(
{
"data_source": "jpeg_compressibility",
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt_text},
],
"negative_prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": " "},
],
"reward_model": {"style": "rule", "ground_truth": ""},
"extra_info": {"split": split, "index": i},
}
)
return rows


def main():
parser = argparse.ArgumentParser(description="Generate dummy diffusion parquet data for e2e testing")
parser.add_argument(
"--local_save_dir",
default=os.path.expanduser("~/data/dummy_diffusion"),
help="Directory to write train.parquet and test.parquet",
)
parser.add_argument("--train_size", type=int, default=32, help="Number of training samples")
parser.add_argument("--val_size", type=int, default=8, help="Number of validation samples")
args = parser.parse_args()

os.makedirs(args.local_save_dir, exist_ok=True)

train_df = pd.DataFrame(build_rows("train", args.train_size))
val_df = pd.DataFrame(build_rows("test", args.val_size))

train_path = os.path.join(args.local_save_dir, "train.parquet")
val_path = os.path.join(args.local_save_dir, "test.parquet")

train_df.to_parquet(train_path)
val_df.to_parquet(val_path)

print(f"Wrote {len(train_df)} train samples to {train_path}")
print(f"Wrote {len(val_df)} val samples to {val_path}")


if __name__ == "__main__":
main()
Loading
Loading