Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions examples/fully_async/run_qwen3_30b_a3b_fully_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from dataclasses import dataclass
from typing import Literal
Comment on lines +1 to +2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Move the os import to the top level and include field from dataclasses to support dynamic default values for dataclass fields.

Suggested change
from dataclasses import dataclass
from typing import Literal
from dataclasses import dataclass, field
import os
from typing import Literal
References
  1. PEP 8: Imports should be at the top of the file, before any other code except module docstrings. (link)


import typer

import miles.utils.external_utils.command_utils as U

# in_place + broadcast
#python run_qwen3_30b_a3b_fully_async.py

# retract + p2p
#python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode p2p

# retract + broadcast
#python run_qwen3_30b_a3b_fully_async.py --pause-generation-mode retract --update-weight-transfer-mode broadcast

@dataclass
class ScriptArgs(U.ExecuteTrainConfig):
mode: Literal["normal", "debug_minimal"] = "normal"
run_id: str = U.create_run_id()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In Python dataclasses, dynamic default values should be defined using field(default_factory=...). Using a function call directly in the class definition assigns the result of that call at module load time, meaning all instances of ScriptArgs will share the same run_id generated when the script is first imported. Using a factory ensures a fresh ID is generated upon instantiation.

Suggested change
run_id: str = U.create_run_id()
run_id: str = field(default_factory=U.create_run_id)

model_name: str = "Qwen3-30B-A3B"
megatron_model_type: str = "qwen3-30B-A3B"
num_gpus_per_node: int = 8
data_dir: str = "/root/datasets"
model_dir: str = "/root/models"
megatron_path: str = "/root/Megatron-LM"
pause_generation_mode: Literal["in_place", "retract"] = "in_place"
update_weight_transfer_mode: Literal["broadcast", "p2p"] = "broadcast"
extra_args: str = ""


def prepare(args: ScriptArgs):
U.exec_command(f"mkdir -p {args.model_dir} {args.data_dir}")
U.exec_command(f"hf download Qwen/{args.model_name} --local-dir {args.model_dir}/{args.model_name}")
U.hf_download_dataset("zhuzilin/dapo-math-17k", data_dir=args.data_dir)
U.convert_checkpoint(
model_name=args.model_name,
megatron_model_type=args.megatron_model_type,
num_gpus_per_node=args.num_gpus_per_node,
dir_dst=args.model_dir,
hf_checkpoint=f"{args.model_dir}/{args.model_name}",
megatron_path=args.megatron_path,
)


def execute(args: ScriptArgs):
if args.pause_generation_mode == "in_place" and args.update_weight_transfer_mode == "p2p":
raise ValueError(
"in_place + p2p is not supported: P2P transfer engine conflicts with "
"active NCCL inference. Use broadcast with in_place, or retract with p2p."
)

ref_load_path = f"{args.model_dir}/{args.model_name}_torch_dist"
load_save_path = f"{args.output_dir}/{args.run_id}/checkpoints"

ckpt_args = (
f"--hf-checkpoint {args.model_dir}/{args.model_name}/ "
f"--ref-load {ref_load_path} "
f"--load {load_save_path} "
)

rollout_args = (
"--rollout-function-path fully_async_rollout.generate_rollout_fully_async "
f"--prompt-data {args.data_dir}/dapo-math-17k/dapo-math-17k.jsonl "
"--input-key prompt "
"--label-key label "
"--apply-chat-template "
"--rollout-shuffle "
"--rm-type dapo "
"--reward-key score "
"--num-rollout 3000 "
"--rollout-batch-size 32 "
"--n-samples-per-prompt 8 "
f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} "
"--rollout-temperature 1 "
"--global-batch-size 256 "
"--balance-data "
f"--pause-generation-mode {args.pause_generation_mode} "
)

perf_args = (
"--tensor-model-parallel-size 8 "
"--sequence-parallel "
"--pipeline-model-parallel-size 1 "
"--context-parallel-size 1 "
"--expert-model-parallel-size 8 "
"--expert-tensor-parallel-size 1 "
"--recompute-granularity full "
"--recompute-method uniform "
"--recompute-num-layers 1 "
"--use-dynamic-batch-size "
"--max-tokens-per-gpu 9216 "
"--optimizer-cpu-offload "
"--overlap-cpu-optimizer-d2h-h2d "
"--use-precision-aware-optimizer "
)

grpo_args = (
"--advantage-estimator grpo "
"--use-kl-loss "
"--kl-loss-coef 0.00 "
"--kl-loss-type low_var_kl "
"--entropy-coef 0.00 "
"--eps-clip 0.2 "
"--eps-clip-high 0.28 "
"--use-tis "
)

optimizer_args = (
"--optimizer adam "
"--lr 1e-6 "
"--lr-decay-style constant "
"--weight-decay 0.1 "
"--adam-beta1 0.9 "
"--adam-beta2 0.98 "
)

sglang_extra = ""
if args.update_weight_transfer_mode == "p2p":
sglang_extra = "--sglang-remote-instance-weight-loader-start-seed-via-transfer-engine "

sglang_args = (
"--rollout-num-gpus-per-engine 8 "
f"--sglang-mem-fraction-static 0.7 {sglang_extra}"
"--sglang-cuda-graph-max-bs 512 "
)

misc_args = (
"--attention-dropout 0.0 "
"--hidden-dropout 0.0 "
"--accumulate-allreduce-grads-in-fp32 "
"--attention-softmax-in-fp32 "
f"--attention-backend flash --update-weight-transfer-mode {args.update_weight_transfer_mode} "
"--actor-num-nodes 1 "
f"--actor-num-gpus-per-node {args.num_gpus_per_node} "
f"--num-gpus-per-node {args.num_gpus_per_node} "
f"--rollout-num-gpus {args.num_gpus_per_node} "
)

train_args = (
f"{ckpt_args} "
f"{rollout_args} "
f"{optimizer_args} "
f"{grpo_args} "
f"{U.get_default_wandb_args(__file__, run_id=args.run_id)} "
f"{perf_args} "
f"{sglang_args} "
f"{misc_args} "
f"{args.extra_args} "
)

import os

fully_async_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
Comment on lines +153 to +155
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The local import of os is unnecessary if moved to the top level, and os.path.join is redundant when called with a single argument.

Suggested change
import os
fully_async_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
fully_async_dir = os.path.dirname(os.path.abspath(__file__))

U.execute_train(
train_args=train_args,
num_gpus_per_node=args.num_gpus_per_node,
megatron_model_type=args.megatron_model_type,
train_script="train_async.py",
megatron_path=args.megatron_path,
extra_env_vars={
"FLASHINFER_DISABLE_VERSION_CHECK": "1",
"PYTHONPATH": f"{args.megatron_path}:{fully_async_dir}",
},
)


@U.dataclass_cli
def main(args: ScriptArgs):
prepare(args)
execute(args)


if __name__ == "__main__":
typer.run(main)
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def _pause_and_prepare_engines(self) -> None:
if dist.get_rank() == 0:
mode = self.args.pause_generation_mode
ray.get([engine.pause_generation.remote(mode=mode) for engine in self.rollout_engines])
ray.get([engine.flush_cache.remote() for engine in self.rollout_engines])
if mode not in ("in_place"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression mode not in ("in_place") is evaluated as a substring check because ("in_place") is a string literal, not a tuple. While this works for the current literal values, it is non-idiomatic and potentially confusing. Using a direct inequality check is clearer and more robust.

Suggested change
if mode not in ("in_place"):
if mode != "in_place":

ray.get([engine.flush_cache.remote() for engine in self.rollout_engines])

# int4/fp4 pre_process
if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]:
Expand Down
Loading