[feat] Expand mxfp8 quantization to support fine-grained layer precision config#614
[feat] Expand mxfp8 quantization to support fine-grained layer precision config#614zianglih wants to merge 15 commits intoradixark:mainfrom
Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the model quantization pipeline by introducing the ability to perform partial MXFP8 quantization. Specifically, it allows users to designate a certain number of the last decoder layers to retain their original BF16 precision, while the rest of the model is converted to MXFP8. This feature provides finer-grained control over model precision, potentially balancing performance and accuracy by preserving higher precision in critical layers. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a feature to skip quantization for the last N layers of a model, keeping them in BF16 format. The changes span across model conversion scripts and training execution scripts to support this new functionality, primarily for MXFP8 quantization. My review identifies a logical contradiction in the feature's activation condition, a magic number that hurts maintainability, and some redundant and non-robust file handling. Addressing these points will improve the code's correctness and quality.
scripts/run_qwen3_30b_a3b.py
Outdated
| if (args.train_fp8 or args.train_mxfp8) and args.num_layers_at_end_in_bf16 > 0: | ||
| misc_args += ( | ||
| "--first-last-layers-bf16 " | ||
| "--num-layers-at-start-in-bf16 0 " | ||
| f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16} " | ||
| ) |
There was a problem hiding this comment.
There's a logical contradiction regarding when this feature is enabled. The __post_init__ check on line 40 asserts that num_layers_at_end_in_bf16 is only supported when rollout_mxfp8 is enabled. However, this block enables the feature for both train_fp8 and train_mxfp8. If train_fp8 is used (which usually implies rollout_fp8 and not rollout_mxfp8), the assertion on line 40 will fail, making the feature unusable with train_fp8.
Given that other changes in this PR are specific to mxfp8, it seems this feature is intended only for mxfp8. If so, the condition should be narrowed to resolve the contradiction.
| if (args.train_fp8 or args.train_mxfp8) and args.num_layers_at_end_in_bf16 > 0: | |
| misc_args += ( | |
| "--first-last-layers-bf16 " | |
| "--num-layers-at-start-in-bf16 0 " | |
| f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16} " | |
| ) | |
| if args.train_mxfp8 and args.num_layers_at_end_in_bf16 > 0: | |
| misc_args += ( | |
| "--first-last-layers-bf16 " | |
| "--num-layers-at-start-in-bf16 0 " | |
| f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16} " | |
| ) | |
tools/convert_hf_to_mxfp8.py
Outdated
| num_maybe_mtp_layers = 1 | ||
| dynamic_skip_layer_prefixes: set[str] = { | ||
| f"model.layers.{i}." for i in range(tail_start_idx, num_hidden_layers + num_maybe_mtp_layers) | ||
| } |
There was a problem hiding this comment.
The use of the magic number 1 for num_maybe_mtp_layers makes the code less readable and harder to maintain. It's not immediately clear why this value is 1 and if it's model-specific.
To improve clarity and maintainability, please define this as a named constant with a comment explaining its purpose. For example:
# Number of MTP (Mixture of Transformer Parallel) layers to account for, which might not be included in `num_hidden_layers`.
# This can be model-specific.
NUM_MAYBE_MTP_LAYERS = 1
# ...
# ... range(tail_start_idx, num_hidden_layers + NUM_MAYBE_MTP_LAYERS)d5236a8 to
56515e8
Compare
ae984d3 to
f5790bb
Compare
This reverts commit f3965d8.
@HumansAnd
Expand mxfp8 quantization utils to support fine-grained layer precision config:
--num-layers-at-start-in-bf16,--num-layers-at-end-in-bf16--extra-high-precision-layers-hf,--extra-high-precision-layers-megatronused by weight conversion--te-precision-config-filetools/convert_hf_to_mxfp8.pyUse FlashInfer mxfp8 quantizer for faster weight sync, fall back to Triton if unavailableThis cannot be merged untill SGLang v0.6.10 bump with the following PRs;
flashinfer_trtllm_routedmoe backend sgl-project/sglang#20214