-
Notifications
You must be signed in to change notification settings - Fork 125
[FlagOS][Train] Integrate TransformerEngine-FL to Megatron-LM for unified training backend #917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant changes related to integrating a 'FlagEngine' and 'Flag Gems' into the training framework, particularly within the TransformerEngine backend. It includes new configuration files, modifications to Megatron-LM components, and new implementations for attention, GEMM, and RMSNorm operations. While the intent appears to be performance optimization and feature expansion, several critical and high-severity issues have been identified. These include hardcoded paths, potential security risks from changing submodule sources, strong coupling of new features, and most importantly, known correctness issues in the newly introduced gems_rms_norm implementation, which is then used in other core components. The hardcoding of use_gems_flash_attention = True also raises concerns about flexibility and potential regressions. Addressing these issues is crucial for the stability, correctness, and maintainability of the codebase.
| + # d_out is expected to be in FP8 if is_output_fp8=True, | ||
| + # but in the case it's not, convert it to FP8 before any operation | ||
| + assert not ctx.fp8, "Gems Flash Attention do not support fp8 now" | ||
| + assert not ctx.use_FAv2_bwd, "do not support use other flash attention kernels now" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.gitmodules
Outdated
| [submodule "third_party/TransformerEngine"] | ||
| path = third_party/TransformerEngine | ||
| url = https://gitee.com/lxdcumt/TransformerEngine.git |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the Megatron-LM change, using a personal Gitee fork for TransformerEngine introduces potential security and maintainability risks. Relying on unofficial forks can lead to issues with updates, bug fixes, and overall code integrity. Please ensure this fork is actively maintained and secure, or consider using the official repository if possible.
| + # TODO(lixianduo): polish | ||
| # Compute RMSNorm backward pass | ||
| - dx, dw = rmsnorm_bwd( | ||
| + # dx, dw = rmsnorm_bwd( | ||
| + # dy, | ||
| + # x, | ||
| + # rstdevs, | ||
| + # w, | ||
| + # self._sm_margins["backward"], | ||
| + # self.zero_centered_gamma, | ||
| + # ) | ||
| + | ||
| + dx, dw = rms_norm_backward( | ||
| dy, | ||
| x, | ||
| rstdevs, | ||
| + [x.shape[-1]], | ||
| w, | ||
| - self._sm_margins["backward"], | ||
| - self.zero_centered_gamma, | ||
| + self.eps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out rmsnorm_bwd call and its replacement with rms_norm_backward from gems_rms_norm is a critical issue. As noted in gems_rms_norm.py, rms_norm_backward is explicitly marked as 'wrong' and 'needs to be fixed'. Using a known incorrect backward pass will lead to training instability and incorrect model updates.
| + # TODO(lixianduo): polish | ||
| + # dgrad, dgamma = tex.rmsnorm_bwd( | ||
| + # dgrad, | ||
| + # inputmat, | ||
| + # rsigma, | ||
| + # ln_weight, | ||
| + # ctx.bwd_ln_sm_margin, | ||
| + # ctx.zero_centered_gamma, | ||
| + # ) | ||
| + | ||
| + dgrad, dgamma = rms_norm_backward( | ||
| dgrad, | ||
| inputmat, | ||
| rsigma, | ||
| + [inputmat.shape[-1]], | ||
| ln_weight, | ||
| - ctx.bwd_ln_sm_margin, | ||
| - ctx.zero_centered_gamma, | ||
| + ctx.eps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out tex.rmsnorm_bwd call and its replacement with rms_norm_backward from gems_rms_norm is a critical issue. As noted in gems_rms_norm.py, rms_norm_backward is explicitly marked as 'wrong' and 'needs to be fixed'. Using a known incorrect backward pass will lead to training instability and incorrect model updates.
| + # TODO(lixianduo): polish | ||
| + # ln_out, mu, rsigma = apply_normalization( | ||
| + # inputmat, | ||
| + # None, # ln_out | ||
| + # ln_weight, | ||
| + # ln_bias, | ||
| + # eps, | ||
| + # input_quantizer if with_quantized_norm else None, | ||
| + # inputmat.dtype, | ||
| + # normalization, | ||
| + # fwd_ln_sm_margin, | ||
| + # zero_centered_gamma, | ||
| + # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out apply_normalization call and its replacement with rms_norm_forward (from gems_rms_norm) is concerning, especially since rms_norm_backward in gems_rms_norm.py is marked as 'wrong'. This direct replacement without fully functional backward pass could lead to incorrect gradients and training failures. The TODO(lixianduo): polish also indicates incomplete work.
| - dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) | ||
| + # TODO(lixianduo): polish | ||
| + # dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) | ||
| + dw, db, *_ = gems_general_gemm(x, dy, **wgrad_gemm_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| + # TODO(lixianduo): polish | ||
| + # y, _, rstdevs = rmsnorm_fwd( | ||
| + # x, | ||
| + # w, | ||
| + # self.eps, | ||
| + # None, | ||
| + # next_op_input_quantizer, | ||
| + # TE_DType[dtype], | ||
| + # sm_margin, | ||
| + # self.zero_centered_gamma, | ||
| + # ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| data_path: /share/project/lixianduo/demo_data/pile/pile_wikipedia_demo | ||
| split: 1 | ||
| no_mmap_bin_files: true | ||
| tokenizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data_path is hardcoded to a specific user's directory (/share/project/lixianduo/demo_data/pile/pile_wikipedia_demo). This makes the configuration non-portable and difficult to use in different environments. It should be made configurable (e.g., using environment variables or relative paths) or moved to a more generic location.
flagscale/runner/runner_train.py
Outdated
| host, master_addr, master_port, nnodes, node_rank, nproc_per_node, self.config | ||
| ) | ||
|
|
||
| nsys_cmd = "/share/project/lixianduo/envs/nsys/nsight-system/2025.5.1/bin/nsys profile -s none -t nvtx,cuda,osrt -o /share/project/lixianduo/scale_gems_cx/nsys_reps/$HOSTNAME.nsys-rep --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop".split(" ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nsys_cmd contains a hardcoded path /share/project/lixianduo/envs/nsys/nsight-system/2025.5.1/bin/nsys. This path is specific to a particular environment and Nsight Systems installation. It should be made configurable (e.g., via an environment variable or a configuration parameter) to ensure portability across different systems.
| print(f"use flag_gems for training") | ||
| print(f"{args.flag_gems_log_path=}") | ||
| print(f"{args.flag_gems_unused=}") | ||
| flag_gems.enable(record=True, once=True, unused=args.flag_gems_unused, path=args.flag_gems_log_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The flag_gems.enable call uses once=True. If flag_gems is intended to be enabled for the entire training run, once=True might prevent it from being re-enabled if the training loop restarts or if there are multiple phases where it needs to be active. Please clarify if this is the intended behavior or if once=False should be used for continuous enablement.
|
|
PR Category
Train
PR Types
New Features
PR Description
Integrate TransformerEngine-FL to Megatron-LM for unified training backend
TransformerEngine-FL
Background:
Megatron-LM natively relies on NVIDIA TransformerEngine for distributed training, wherein its core operators—including GEMM, LayerNorm, Attention, and communication primitives—are monolithically encapsulated within the proprietary NCCL+cuBLAS stack. This initiative refactor the TransformerEngine, to construct a unified distributed training backend, leveraging FlagGems and FlagCX as foundational components.
Primary Roadmap:
Current Progress:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name