Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
def quantize_params_mxfp8(args, megatron_name, converted_named_params, quantization_config):
assert quantization_config["quant_method"] == "mxfp8"

if getattr(args, "extra_high_precision_layers_megatron", False):
for layer_name in getattr(args, "extra_high_precision_layers_megatron", ()):
if layer_name in megatron_name:
return converted_named_params

decoder_layers_pattern = r"decoder\.layers\.(\d+)\.(.+)"
match = re.search(decoder_layers_pattern, megatron_name)

Expand All @@ -23,6 +28,16 @@ def quantize_params_mxfp8(args, megatron_name, converted_named_params, quantizat
else:
layer_idx, rest = match.groups()

# Skip quantization for BF16 tail of main decoder layers.
if getattr(args, "first_last_layers_bf16", False):
num_layers = int(args.num_layers)
num_layers_at_start_in_bf16 = int(getattr(args, "num_layers_at_start_in_bf16", 0))
num_layers_at_end_in_bf16 = int(getattr(args, "num_layers_at_end_in_bf16", 0))
head_end_idx = max(0, num_layers_at_start_in_bf16)
tail_start_idx = max(0, num_layers - num_layers_at_end_in_bf16)
if int(layer_idx) < head_end_idx or int(layer_idx) >= tail_start_idx:
return converted_named_params

# experts
expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
match = re.match(expert_pattern, rest)
Expand Down
17 changes: 17 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,23 @@ def add_train_arguments(parser):
default="raw",
help="The method to convert megatron weights to hugging face weights for SGLang.",
)
parser.add_argument(
"--extra-high-precision-layers-hf",
type=str,
nargs="*",
default=(),
help=("Extra substrings for HF weight names to skip quantization " "(e.g. .kv_b_proj.)."),
)
parser.add_argument(
"--extra-high-precision-layers-megatron",
type=str,
nargs="*",
default=(),
help=(
"Extra substrings for Megatron weight names to skip quantization in Megatron-to-HF paths "
"(e.g. .linear_kv_up_proj.)."
),
)
parser.add_argument(
"--custom-model-provider-path",
type=str,
Expand Down
4 changes: 3 additions & 1 deletion scripts/run_qwen3_30b_a3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def prepare(args: ScriptArgs):

if args.rollout_mxfp8:
U.exec_command(
f"python tools/convert_hf_to_mxfp8.py --model-dir {args.model_dir}/{args.model_name} --save-dir {args.model_dir}/{args.model_name}-MXFP8"
f"python tools/convert_hf_to_mxfp8.py --model-dir {args.model_dir}/{args.model_name} "
f"--save-dir {args.model_dir}/{args.model_name}-MXFP8 "
f"{args.extra_args} "
)

if args.rollout_int4:
Expand Down
Loading