diff --git a/examples/fully_async/run_qwen3_30b_a3b_fully_async.py b/examples/fully_async/run_qwen3_30b_a3b_fully_async.py new file mode 100644 index 000000000..0a2dbca92 --- /dev/null +++ b/examples/fully_async/run_qwen3_30b_a3b_fully_async.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import Literal + +import typer + +import miles.utils.external_utils.command_utils as U + +# in_place + broadcast +# python run_qwen3_30b_a3b_fully_async.py + +# retract + p2p +# python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode p2p + +# retract + broadcast +# python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode broadcast + + +@dataclass +class ScriptArgs(U.ExecuteTrainConfig): + mode: Literal["normal", "debug_minimal"] = "normal" + run_id: str = U.create_run_id() + model_name: str = "Qwen3-30B-A3B" + megatron_model_type: str = "qwen3-30B-A3B" + num_gpus_per_node: int = 8 + data_dir: str = "/root/datasets" + model_dir: str = "/root/models" + megatron_path: str = "/root/Megatron-LM" + pause_generation_mode: Literal["in_place", "retract"] = "in_place" + update_weight_transfer_mode: Literal["broadcast", "p2p"] = "broadcast" + extra_args: str = "" + + +def prepare(args: ScriptArgs): + U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") + U.exec_command(f"hf download Qwen/{args.model_name} --local-dir {args.model_dir}/{args.model_name}") + U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) + U.convert_checkpoint( + model_name=args.model_name, + megatron_model_type=args.megatron_model_type, + num_gpus_per_node=args.num_gpus_per_node, + dir_dst=args.model_dir, + hf_checkpoint=f"{args.model_dir}/{args.model_name}", + megatron_path=args.megatron_path, + ) + + +def execute(args: ScriptArgs): + if args.pause_generation_mode == "in_place" and args.update_weight_transfer_mode == "p2p": + raise ValueError( + "in_place + p2p is not supported: P2P transfer engine conflicts with " + "active NCCL inference. Use broadcast with in_place, or retract with p2p." + ) + + ref_load_path = f"{args.model_dir}/{args.model_name}_torch_dist" + load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints" + + ckpt_args = ( + f"--hf-checkpoint {args.model_dir}/{args.model_name}/ " + f"--ref-load {ref_load_path} " + f"--load {load_save_path} " + ) + + rollout_args = ( + "--rollout-function-path fully_async_rollout.generate_rollout_fully_async " + f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type dapo " + "--reward-key score " + "--num-rollout 3000 " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " + "--rollout-temperature 1 " + "--global-batch-size 256 " + "--balance-data " + f"--pause-generation-mode {args.pause_generation_mode} " + ) + + perf_args = ( + "--tensor-model-parallel-size 8 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 8 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + "--use-tis " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_extra = "" + if args.update_weight_transfer_mode == "p2p": + sglang_extra = "--sglang-remote-instance-weight-loader-start-seed-via-transfer-engine " + + sglang_args = ( + "--rollout-num-gpus-per-engine 8 " + f"--sglang-mem-fraction-static 0.7 {sglang_extra}" + "--sglang-cuda-graph-max-bs 512 " + ) + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + f"--attention-backend flash --update-weight-transfer-mode {args.update_weight_transfer_mode} " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {args.num_gpus_per_node} " + f"--num-gpus-per-node {args.num_gpus_per_node} " + f"--rollout-num-gpus {args.num_gpus_per_node} " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " + f"{perf_args} " + f"{sglang_args} " + f"{misc_args} " + f"{args.extra_args} " + ) + + import os + + fully_async_dir = os.path.join(os.path.dirname(os.path.abspath(__file__))) + U.execute_train( + train_args=train_args, + num_gpus_per_node=args.num_gpus_per_node, + megatron_model_type=args.megatron_model_type, + train_script="train_async.py", + megatron_path=args.megatron_path, + extra_env_vars={ + "FLASHINFER_DISABLE_VERSION_CHECK": "1", + "PYTHONPATH": f"{args.megatron_path}:{fully_async_dir}", + }, + ) + + +@U.dataclass_cli +def main(args: ScriptArgs): + prepare(args) + execute(args) + + +if __name__ == "__main__": + typer.run(main) diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py index 006f15516..670799382 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed/mixin.py @@ -141,7 +141,8 @@ def _pause_and_prepare_engines(self) -> None: if dist.get_rank() == 0: mode = self.args.pause_generation_mode ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines]) - ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + if mode not in ("in_place"): + ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) # int4/fp4 pre_process if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: