-
Notifications
You must be signed in to change notification settings - Fork 150
[fix] Skip flush_cache in in_place mode and add fully async example #974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,175 @@ | ||||||
| from dataclasses import dataclass | ||||||
| from typing import Literal | ||||||
|
|
||||||
| import typer | ||||||
|
|
||||||
| import miles.utils.external_utils.command_utils as U | ||||||
|
|
||||||
| # in_place + broadcast | ||||||
| #python run_qwen3_30b_a3b_fully_async.py | ||||||
|
|
||||||
| # retract + p2p | ||||||
| #python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode p2p | ||||||
|
|
||||||
| # retract + broadcast | ||||||
| #python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode broadcast | ||||||
|
|
||||||
| @dataclass | ||||||
| class ScriptArgs(U.ExecuteTrainConfig): | ||||||
| mode: Literal["normal", "debug_minimal"] = "normal" | ||||||
| run_id: str = U.create_run_id() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Python dataclasses, dynamic default values should be defined using
Suggested change
|
||||||
| model_name: str = "Qwen3-30B-A3B" | ||||||
| megatron_model_type: str = "qwen3-30B-A3B" | ||||||
| num_gpus_per_node: int = 8 | ||||||
| data_dir: str = "/root/datasets" | ||||||
| model_dir: str = "/root/models" | ||||||
| megatron_path: str = "/root/Megatron-LM" | ||||||
| pause_generation_mode: Literal["in_place", "retract"] = "in_place" | ||||||
| update_weight_transfer_mode: Literal["broadcast", "p2p"] = "broadcast" | ||||||
| extra_args: str = "" | ||||||
|
|
||||||
|
|
||||||
| def prepare(args: ScriptArgs): | ||||||
| U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}") | ||||||
| U.exec_command(f"hf download Qwen/{args.model_name} --local-dir {args.model_dir}/{args.model_name}") | ||||||
| U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir) | ||||||
| U.convert_checkpoint( | ||||||
| model_name=args.model_name, | ||||||
| megatron_model_type=args.megatron_model_type, | ||||||
| num_gpus_per_node=args.num_gpus_per_node, | ||||||
| dir_dst=args.model_dir, | ||||||
| hf_checkpoint=f"{args.model_dir}/{args.model_name}", | ||||||
| megatron_path=args.megatron_path, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def execute(args: ScriptArgs): | ||||||
| if args.pause_generation_mode == "in_place" and args.update_weight_transfer_mode == "p2p": | ||||||
| raise ValueError( | ||||||
| "in_place + p2p is not supported: P2P transfer engine conflicts with " | ||||||
| "active NCCL inference. Use broadcast with in_place, or retract with p2p." | ||||||
| ) | ||||||
|
|
||||||
| ref_load_path = f"{args.model_dir}/{args.model_name}_torch_dist" | ||||||
| load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints" | ||||||
|
|
||||||
| ckpt_args = ( | ||||||
| f"--hf-checkpoint {args.model_dir}/{args.model_name}/ " | ||||||
| f"--ref-load {ref_load_path} " | ||||||
| f"--load {load_save_path} " | ||||||
| ) | ||||||
|
|
||||||
| rollout_args = ( | ||||||
| "--rollout-function-path fully_async_rollout.generate_rollout_fully_async " | ||||||
| f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl " | ||||||
| "--input-key prompt " | ||||||
| "--label-key label " | ||||||
| "--apply-chat-template " | ||||||
| "--rollout-shuffle " | ||||||
| "--rm-type dapo " | ||||||
| "--reward-key score " | ||||||
| "--num-rollout 3000 " | ||||||
| "--rollout-batch-size 32 " | ||||||
| "--n-samples-per-prompt 8 " | ||||||
| f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " | ||||||
| "--rollout-temperature 1 " | ||||||
| "--global-batch-size 256 " | ||||||
| "--balance-data " | ||||||
| f"--pause-generation-mode {args.pause_generation_mode} " | ||||||
| ) | ||||||
|
|
||||||
| perf_args = ( | ||||||
| "--tensor-model-parallel-size 8 " | ||||||
| "--sequence-parallel " | ||||||
| "--pipeline-model-parallel-size 1 " | ||||||
| "--context-parallel-size 1 " | ||||||
| "--expert-model-parallel-size 8 " | ||||||
| "--expert-tensor-parallel-size 1 " | ||||||
| "--recompute-granularity full " | ||||||
| "--recompute-method uniform " | ||||||
| "--recompute-num-layers 1 " | ||||||
| "--use-dynamic-batch-size " | ||||||
| "--max-tokens-per-gpu 9216 " | ||||||
| "--optimizer-cpu-offload " | ||||||
| "--overlap-cpu-optimizer-d2h-h2d " | ||||||
| "--use-precision-aware-optimizer " | ||||||
| ) | ||||||
|
|
||||||
| grpo_args = ( | ||||||
| "--advantage-estimator grpo " | ||||||
| "--use-kl-loss " | ||||||
| "--kl-loss-coef 0.00 " | ||||||
| "--kl-loss-type low_var_kl " | ||||||
| "--entropy-coef 0.00 " | ||||||
| "--eps-clip 0.2 " | ||||||
| "--eps-clip-high 0.28 " | ||||||
| "--use-tis " | ||||||
| ) | ||||||
|
|
||||||
| optimizer_args = ( | ||||||
| "--optimizer adam " | ||||||
| "--lr 1e-6 " | ||||||
| "--lr-decay-style constant " | ||||||
| "--weight-decay 0.1 " | ||||||
| "--adam-beta1 0.9 " | ||||||
| "--adam-beta2 0.98 " | ||||||
| ) | ||||||
|
|
||||||
| sglang_extra = "" | ||||||
| if args.update_weight_transfer_mode == "p2p": | ||||||
| sglang_extra = "--sglang-remote-instance-weight-loader-start-seed-via-transfer-engine " | ||||||
|
|
||||||
| sglang_args = ( | ||||||
| "--rollout-num-gpus-per-engine 8 " | ||||||
| f"--sglang-mem-fraction-static 0.7 {sglang_extra}" | ||||||
| "--sglang-cuda-graph-max-bs 512 " | ||||||
| ) | ||||||
|
|
||||||
| misc_args = ( | ||||||
| "--attention-dropout 0.0 " | ||||||
| "--hidden-dropout 0.0 " | ||||||
| "--accumulate-allreduce-grads-in-fp32 " | ||||||
| "--attention-softmax-in-fp32 " | ||||||
| f"--attention-backend flash --update-weight-transfer-mode {args.update_weight_transfer_mode} " | ||||||
| "--actor-num-nodes 1 " | ||||||
| f"--actor-num-gpus-per-node {args.num_gpus_per_node} " | ||||||
| f"--num-gpus-per-node {args.num_gpus_per_node} " | ||||||
| f"--rollout-num-gpus {args.num_gpus_per_node} " | ||||||
| ) | ||||||
|
|
||||||
| train_args = ( | ||||||
| f"{ckpt_args} " | ||||||
| f"{rollout_args} " | ||||||
| f"{optimizer_args} " | ||||||
| f"{grpo_args} " | ||||||
| f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} " | ||||||
| f"{perf_args} " | ||||||
| f"{sglang_args} " | ||||||
| f"{misc_args} " | ||||||
| f"{args.extra_args} " | ||||||
| ) | ||||||
|
|
||||||
| import os | ||||||
|
|
||||||
| fully_async_dir = os.path.join(os.path.dirname(os.path.abspath(__file__))) | ||||||
|
Comment on lines
+153
to
+155
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| U.execute_train( | ||||||
| train_args=train_args, | ||||||
| num_gpus_per_node=args.num_gpus_per_node, | ||||||
| megatron_model_type=args.megatron_model_type, | ||||||
| train_script="train_async.py", | ||||||
| megatron_path=args.megatron_path, | ||||||
| extra_env_vars={ | ||||||
| "FLASHINFER_DISABLE_VERSION_CHECK": "1", | ||||||
| "PYTHONPATH": f"{args.megatron_path}:{fully_async_dir}", | ||||||
| }, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @U.dataclass_cli | ||||||
| def main(args: ScriptArgs): | ||||||
| prepare(args) | ||||||
| execute(args) | ||||||
|
|
||||||
|
|
||||||
| if __name__ == "__main__": | ||||||
| typer.run(main) | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -141,7 +141,8 @@ def _pause_and_prepare_engines(self) -> None: | |||||
| if dist.get_rank() == 0: | ||||||
| mode = self.args.pause_generation_mode | ||||||
| ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines]) | ||||||
| ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) | ||||||
| if mode not in ("in_place"): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expression
Suggested change
|
||||||
| ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) | ||||||
|
|
||||||
| # int4/fp4 pre_process | ||||||
| if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the
osimport to the top level and includefieldfromdataclassesto support dynamic default values for dataclass fields.References