[megatron] fix: add missing FP8 padding for router replay#5989
[megatron] fix: add missing FP8 padding for router replay#5989eternally-z wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request integrates FP8 padding support into the router replay utilities by passing a use_fp8_padding flag to the preprocessing functions. Feedback identifies a potential AttributeError when accessing tf_config.fp8 directly and suggests using getattr for safety. Additionally, it is recommended to set pre_process=False in merge_router_topk_indices to avoid unnecessary memory allocation and computation for tensors that are not utilized.
| fp8 = tf_config.fp8 | ||
| use_fp8_padding = fp8 in ["e4m3", "hybrid"] | ||
|
|
||
| if input_ids.is_nested: | ||
| batch_size = input_ids.shape[0] | ||
| _, packed_seq_params, _ = preprocess_thd_engine(input_ids, pre_process=True) | ||
| _, packed_seq_params, _ = preprocess_thd_engine( | ||
| input_ids, pre_process=True, use_fp8_padding=use_fp8_padding | ||
| ) | ||
| layers_topk_idx = postprocess_thd_engine( | ||
| layers_topk_idx, packed_seq_params, input_ids, batch_size, post_process=True | ||
| ) | ||
| else: | ||
| batch_size, seq_len = attention_mask.shape[:2] | ||
| _, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) | ||
| _, packed_seq_params = preprocess_packed_seqs( | ||
| input_ids, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding | ||
| ) |
There was a problem hiding this comment.
Accessing tf_config.fp8 directly can lead to an AttributeError if the attribute is missing in certain Megatron versions or custom configurations. It is safer to use getattr(tf_config, 'fp8', None).
Additionally, in merge_router_topk_indices, the calls to preprocess_thd_engine and preprocess_packed_seqs only require the packed_seq_params for the subsequent post-processing step. Setting pre_process=False avoids redundant memory allocation and computation for the processed tensor which is currently discarded.
use_fp8_padding = getattr(tf_config, 'fp8', None) in ["e4m3", "hybrid"]
if input_ids.is_nested:
batch_size = input_ids.shape[0]
_, packed_seq_params, _ = preprocess_thd_engine(
input_ids, pre_process=False, use_fp8_padding=use_fp8_padding
)
layers_topk_idx = postprocess_thd_engine(
layers_topk_idx, packed_seq_params, input_ids, batch_size, post_process=True
)
else:
batch_size, seq_len = attention_mask.shape[:2]
_, packed_seq_params = preprocess_packed_seqs(
input_ids, attention_mask, pre_process=False, use_fp8_padding=use_fp8_padding
)| fp8 = tf_config.fp8 | ||
| use_fp8_padding = fp8 in ["e4m3", "hybrid"] |
There was a problem hiding this comment.
Use getattr(tf_config, 'fp8', None) to safely check for FP8 configuration and avoid potential AttributeError if the attribute is not present in the configuration object.
| fp8 = tf_config.fp8 | |
| use_fp8_padding = fp8 in ["e4m3", "hybrid"] | |
| use_fp8_padding = getattr(tf_config, 'fp8', None) in ["e4m3", "hybrid"] |
What does this PR do?
The router replay path lacks FP8 padding logic. Consequently, enabling router replay during FP8 training leads to incorrect training results. This PR adds the missing FP8 padding support.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.