[rollout] feat: trainer-side FP8 weight quantization for colocated and disaggregated modes#5976
[rollout] feat: trainer-side FP8 weight quantization for colocated and disaggregated modes#5976yxs wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements trainer-side FP8 quantization for model weights in engine_workers.py and fsdp_workers.py before they are transmitted to the rollout worker. The sglang_rollout logic is updated to recognize pre-quantized weights and skip redundant quantization steps. Feedback highlights issues with hardcoded quantization parameters and data types, as well as opportunities to optimize the code by reusing quantizer instances and avoiding redundant imports.
2daee1f to
b9e2ecb
Compare
|
Support trainer side fp8 quantization in disaggregate mode is higher priority than colocate mode. cc @sophiayyya #5972 |
b9e2ecb to
516657f
Compare
…d disaggregated modes
|
A few comments:
|
| # If quantization fails, use original weights | ||
| yield (k, v) | ||
|
|
||
| def quant_weights_by_name_sync(self, weights, dtype=torch.bfloat16): |
There was a problem hiding this comment.
We can use ensure_async_iterator in Checkpoint engines to async for iterate weights, and eliminate this sync version.
What does this PR do?
Move FP8 blockwise weight quantization from rollout GPU to trainer GPU in the weight sync path, controlled by a new config flag
trainer_quantize_fp8. This halves transfer bandwidth in disaggregated mode (1 byte vs 2 bytes per param) and reduces peak memory during bucketed transfer in colocated mode.Supports both colocated (async quantization via
quant_weights_by_name) and disaggregated (sync quantization viaquant_weights_by_name_syncfor NCCL/NIXL/HCCL checkpoint engines) modes.Related: #5836 (Q2 roadmap → Weight refit optimization → "fp8 rollout: quantize weights on trainer side")
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Environment: 8×H100, Qwen2.5-7B, real GSM8K (7473 samples), SGLang rollout, FSDP trainer
1. Disaggregated mode (NCCL checkpoint engine, 4 trainer + 4 rollout GPUs)
Trainer-side FP8 quantization + NCCL weight transfer verified:
FP8 trainer-side quantization enabled (disaggregated)Skipping FP8 quantization, weights pre-quantized on trainer sideRank 0 send weights done, time cost: 0.42s2. Convergence comparison (127 global steps, disaggregated mode)
Both converge to ~0.73-0.75. No convergence regression.
3. MoE model verification (Qwen3-30B-A3B, 128 experts)
send weights done, time cost: 6.42s)mlp.gate.weight) correctly skipped by quantization (unit test: 25/25 PASS)4. Selective quantization verification
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always