Skip to content

Conversation

@lxd-cumt
Copy link
Collaborator

@lxd-cumt lxd-cumt commented Nov 11, 2025

PR Category

Train

PR Types

New Features

PR Description

Integrate TransformerEngine-FL to Megatron-LM for unified training backend
TransformerEngine-FL

Background:
Megatron-LM natively relies on NVIDIA TransformerEngine for distributed training, wherein its core operators—including GEMM, LayerNorm, Attention, and communication primitives—are monolithically encapsulated within the proprietary NCCL+cuBLAS stack. This initiative refactor the TransformerEngine, to construct a unified distributed training backend, leveraging FlagGems and FlagCX as foundational components.

Primary Roadmap:

  • Initial Development: Leveraging FlagOS (FlagGems & FlagCX), implement core operators—including Linear (Column-wise/Row-wise Parallel), DotProductAttention (FlashAttn), RMSNorm, to facilitate end-to-end training of models such as Qwen3, ensuring correct convergence and performance alignment with expectations.
  • Performance Optimization: Establish and realize diverse Computation-Communication (Comp/Comm) overlap optimization schemes, such as GEMM+SP Comm Overlap and FlashAttn+CP, for iterative performance refinement of TransformerEngine-FL.
  • Hardware Ecosystem Compatibility: Enable adaptation across multiple hardware vendors and execute architecture-specific operator optimizations to enhance end-to-end model training performance.

Current Progress:

  • Development of Linear, DotProductAttention, RMSNorm, AdamW operators has been completed.
  • End-to-end convergence, as illustrated in the figure below, demonstrates alignment with TransformerEngine.
截屏2025-11-25 14 21 16
  • The operator inventory is enumerated below.

Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name


 23.8      73434757203     123324     595462.0    565601.0    172896    4311494     459319.4  mm_kernel_general
 20.7      63740396844      64878     982465.5    366879.5     97248  425961771    5327781.4  ncclDevKernel_AllGather_RING_LL(ncclDevComm *, unsigned long, ncclWork *)
 18.1      55932338886      41336    1353114.4    365569.0    304385  496029996    8331941.6  ncclDevKernel_ReduceScatter_Sum_bf16_RING_LL(ncclDevComm *, unsigned long, ncclWork *)
  9.3      28677100782       2720   10543051.8   8746512.5   5851277   72716183    5706959.2  ncclDevKernel_ReduceScatter_Sum_f32_RING_LL(ncclDevComm *, unsigned long, ncclWork *)
  6.9      21159530707      10240    2066360.4   2044545.0   2038149    2157983      38519.5  _attn_bwd
  6.7      20597525353        160  128734533.5  18290506.5     45856  529057641  180623640.3  ncclDevKernel_AllReduce_Sum_u8_TREE_LL(ncclDevComm *, unsigned long, ncclWork *)
  2.6       8102566635        160   50641041.5  33241872.5     49408   99884105   32285638.0  ncclDevKernel_AllReduce_Sum_f32_TREE_LL(ncclDevComm *, unsigned long, ncclWork *)
  2.1       6385626731         80   79820334.1  74899619.5     22688  316885564   76672028.6  ncclDevKernel_AllReduce_Sum_u32_RING_LL(ncclDevComm *, unsigned long, ncclWork *)
  2.0       6155190811     228706      26913.1      7424.0      1344     585313      52799.6  add_func_kernel_rank_1
  1.2       3586448520      10108     354812.9    353185.0    351361     370112       3977.2  _attn_fwd
  1.0       3040437164     156992      19366.8      3200.0      1312     530017      43798.4  mul_func_scalar_kernel_rank_1
  0.7       2193838986       1520    1443315.1     42928.0     13985   34803310    5967409.2  ncclDevKernel_AllReduce_Sum_f32_RING_LL(ncclDevComm *, unsigned long, ncclWork *)
  0.5       1529975869     111980      13662.9     10816.0      4255      33856       7757.2  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl_nocast<at::n…
  0.3        862768767      26100      33056.3     16064.0      2048     297248      47648.7  l2_norm_kernel_1
  0.3        848179397      10240      82830.0     82304.0     80992      92736       1810.9  triton_poi_fused_cat_0
  0.3        841353024      41280      20381.6     15520.0      8096      48288      11868.9  rms_norm_grad_dw_kernel
  0.3        838766524      42080      19932.7      2048.0      1312     231072      30059.9  true_div_func_tensor_scalar_kernel_rank_1
  0.2        721676008      41280      17482.5     15328.0      6080      42560       9245.4  rms_norm_grad_dx_kernel
  0.2        690642939      61792      11176.9      2880.0      1536     345825      33153.8  void at::native::vectorized_elementwise_kernel<(int)4, at::native::bfloat16_copy_kernel_cuda(at::Te…
  0.2        662149297       1570     421751.1     26784.0      9504    6105775    1036900.3  ncclDevKernel_Broadcast_RING_LL(ncclDevComm *, unsigned long, ncclWork *)
  0.2        656237310      40696      16125.4     15328.0     14720      24096       1740.9  mul_func_kernel_rank_4
  0.2        614338150      21440      28653.8      1984.0      1344     332161      44474.2  true_div_func_kernel_rank_1
  0.2        591549664      62776       9423.2      3424.0      1312     229856      19177.0  mul_func_kernel_rank_1
  0.2        562876515      40752      13812.2      8768.0      5600      43200       9282.8  rms_norm_kernel
  0.2        526119217         80    6576490.2   7069194.0     24640    9801797    2067934.3  ncclDevKernel_AllReduce_Sum_u32_TREE_LL(ncclDevComm *, unsigned long, ncclWork *)
  0.2        512128718      74112       6910.2      6720.0      2208      24544       4225.8  cat_copy_func_kernel_4
  0.2        494684161      61840       7999.4      4608.0      2624      26656       5036.1  sum_dim_kernel_non_inner
  0.1        423398358      10114      41862.6     40608.0     39904      50048       2419.5  triton_poi_fused_mul_silu_0
  0.1        395244884      20720      19075.5      7136.0      1344     229152      30226.3  sqrt_func_kernel_rank_1
  0.1        389856610      20720      18815.5      6688.0      1344     229761      30029.7  add_func_tensor_scalar_kernel_rank_1
  0.1        374265252         80    4678315.7   4669215.0   4654854    5058286      60766.1  fill_scalar_func_kernel_rank_1
  0.1        267491494        320     835910.9    873089.5    311104    1093409     197953.7  embedding_backward_kernel
  0.1        252687527      11690      21615.7     16000.0      2208     134912      18665.5  count_nonzero_kernel_1
  0.1        248953630      40696       6117.4      7904.0      2624      15776       3187.6  neg_func_kernel_rank_4
  0.1        199178669      10240      19451.0     18464.0     17664      28512       2021.3  _attn_bwd_preprocess
  0.1        178310444      33130       5382.1      2240.0      1215     223393      19263.9  zeros_kernel
  0.1        177752928      20222       8790.1      7776.0      6976      17824       1846.9  triton_poi_fused_add_0
  0.1        172528842      10240      16848.5     15968.0     15616      24640       1773.1  add_func_kernel_rank_4
  0.0        129051736        960     134428.9      4832.0      1984     397346     185091.0  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
  0.0        128656001        320     402050.0    402481.0    388192     410145       3474.4  true_div_func_kernel_rank_3
  0.0        128654046        320     402043.9    401905.0    389121     415617       4864.4  sub_func_kernel_rank_3
  0.0        126241110        320     394503.5    394496.0    391328     397889       1096.0  mul_func_kernel_rank_3
  0.0        117415251        320     366922.7    366832.0    361025     371265       1334.3  void at::native::vectorized_elementwise_kernel<(int)4, at::native::exp_kernel_cuda(at::TensorIterat…
  0.0        115360242      20216       5706.4      5039.5      4544      12288       1565.5  cos_func_kernel_rank_1
  0.0         90851159      20216       4494.0      3872.0      3680      11552       1483.1  sin_func_kernel_rank_1
  0.0         69295147      26100       2655.0      2080.0      1376      12416       1516.5  l2_norm_kernel_2
  0.0         59334903        320     185421.6    185407.5    180864     191808       1901.7  max_kernel
  0.0         58171124        320     181784.8    181728.5    179232     184672        969.6  sum_dim_kernel_inner
  0.0         51856723        634      81792.9     80448.0     80064      88064       2535.4  masked_fill_kernel_kernel_rank_3
  0.0         19589908      12010       1631.1      1568.0      1312       2560        203.9  sub_func_scalar_tensor_kernel_rank_1
  0.0          8699435       3840       2265.5      2304.0      2080       2560         60.4  nonzero_kernel
  0.0          8570665        314      27295.1     26336.0     23616      39840       3406.7  embedding_kernel
  0.0          7695050       3040       2531.3      1472.0      1344      10496       2085.6  isinf_func_kernel_rank_1
  0.0          7239914       3040       2381.6      2016.0      1536       8096        994.7  isnan_func_kernel_rank_1
  0.0          6594691        640      10304.2     10416.0      9280      11616        636.8  void at::native::index_elementwise_kernel<(int)128, (int)4, void at::native::gpu_index_kernel<void …
  0.0          5971307       3840       1555.0      1568.0      1439       2048         53.2  reduce_then_scan_root_scan_kernel_row
  0.0          5650567       3840       1471.5      1472.0      1344       1728         44.1  gt_func_scalar_kernel_rank_1
  0.0          2415396        320       7548.1      7552.0      7328       7712         57.3  void at::native::index_elementwise_kernel<(int)128, (int)4, void at::native::gpu_index_kernel<void …
  0.0          2041922        960       2127.0      2240.0      1376       2784        336.3  sum_kernel_1
  0.0          2033763        954       2131.8      2144.0      1952       2464         57.3  masked_fill_kernel_kernel_rank_1
  0.0          1907392        634       3008.5      2688.0      2272       8000       1012.8  lt_func_scalar_kernel_rank_1
  0.0          1642143        960       1710.6      1728.0      1376       1984         90.8  sum_kernel_2
  0.0          1623040        634       2560.0      2144.0      2080       6752       1133.9  ge_func_scalar_kernel_rank_1
  0.0          1422273        634       2243.3      1952.0      1856       6144        893.7  bitwise_or_func_kernel_rank_1
  0.0          1356771        640       2120.0      2112.0      1952       2464        116.1  sub_func_kernel_rank_1
  0.0          1327523        634       2093.9      2080.0      1952       2272         47.7  sub_func_tensor_scalar_kernel_rank_1
  0.0          1131074        640       1767.3      1760.0      1728       1824         28.5  arange_func
  0.0           754400        320       2357.5      2336.0      2016       2816        114.7  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
  0.0           736416        320       2301.3      2304.0      2240       2496         51.9  log_func_kernel_rank_1
  0.0           478783        320       1496.2      1504.0      1344       1536         29.7  clamp_func_min_kernel_rank_1
  0.0           437536        320       1367.3      1376.0      1280       1440         48.7  ones_kernel
  0.0           182208         80       2277.6      2272.0      2112       2464         78.2  vstack_kernel
  0.0           182080         80       2276.0      2336.0      1824       2688        203.8  pow_func_tensor_scalar_kernel_rank_1

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant changes related to integrating a 'FlagEngine' and 'Flag Gems' into the training framework, particularly within the TransformerEngine backend. It includes new configuration files, modifications to Megatron-LM components, and new implementations for attention, GEMM, and RMSNorm operations. While the intent appears to be performance optimization and feature expansion, several critical and high-severity issues have been identified. These include hardcoded paths, potential security risks from changing submodule sources, strong coupling of new features, and most importantly, known correctness issues in the newly introduced gems_rms_norm implementation, which is then used in other core components. The hardcoding of use_gems_flash_attention = True also raises concerns about flexibility and potential regressions. Addressing these issues is crucial for the stability, correctness, and maintainability of the codebase.

+ # d_out is expected to be in FP8 if is_output_fp8=True,
+ # but in the case it's not, convert it to FP8 before any operation
+ assert not ctx.fp8, "Gems Flash Attention do not support fp8 now"
+ assert not ctx.use_FAv2_bwd, "do not support use other flash attention kernels now"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion assert not ctx.fp8, "Gems Flash Attention do not support fp8 now" in the backward pass reiterates the critical limitation regarding FP8 support. This is a major drawback for a component within TransformerEngine, which is designed to leverage FP8 for performance.

.gitmodules Outdated
Comment on lines 28 to 30
[submodule "third_party/TransformerEngine"]
path = third_party/TransformerEngine
url = https://gitee.com/lxdcumt/TransformerEngine.git
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the Megatron-LM change, using a personal Gitee fork for TransformerEngine introduces potential security and maintainability risks. Relying on unofficial forks can lead to issues with updates, bug fixes, and overall code integrity. Please ensure this fork is actively maintained and secure, or consider using the official repository if possible.

Comment on lines 47 to 67
+ # TODO(lixianduo): polish
# Compute RMSNorm backward pass
- dx, dw = rmsnorm_bwd(
+ # dx, dw = rmsnorm_bwd(
+ # dy,
+ # x,
+ # rstdevs,
+ # w,
+ # self._sm_margins["backward"],
+ # self.zero_centered_gamma,
+ # )
+
+ dx, dw = rms_norm_backward(
dy,
x,
rstdevs,
+ [x.shape[-1]],
w,
- self._sm_margins["backward"],
- self.zero_centered_gamma,
+ self.eps,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The commented-out rmsnorm_bwd call and its replacement with rms_norm_backward from gems_rms_norm is a critical issue. As noted in gems_rms_norm.py, rms_norm_backward is explicitly marked as 'wrong' and 'needs to be fixed'. Using a known incorrect backward pass will lead to training instability and incorrect model updates.

Comment on lines 127 to 145
+ # TODO(lixianduo): polish
+ # dgrad, dgamma = tex.rmsnorm_bwd(
+ # dgrad,
+ # inputmat,
+ # rsigma,
+ # ln_weight,
+ # ctx.bwd_ln_sm_margin,
+ # ctx.zero_centered_gamma,
+ # )
+
+ dgrad, dgamma = rms_norm_backward(
dgrad,
inputmat,
rsigma,
+ [inputmat.shape[-1]],
ln_weight,
- ctx.bwd_ln_sm_margin,
- ctx.zero_centered_gamma,
+ ctx.eps,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The commented-out tex.rmsnorm_bwd call and its replacement with rms_norm_backward from gems_rms_norm is a critical issue. As noted in gems_rms_norm.py, rms_norm_backward is explicitly marked as 'wrong' and 'needs to be fixed'. Using a known incorrect backward pass will lead to training instability and incorrect model updates.

Comment on lines 23 to 35
+ # TODO(lixianduo): polish
+ # ln_out, mu, rsigma = apply_normalization(
+ # inputmat,
+ # None, # ln_out
+ # ln_weight,
+ # ln_bias,
+ # eps,
+ # input_quantizer if with_quantized_norm else None,
+ # inputmat.dtype,
+ # normalization,
+ # fwd_ln_sm_margin,
+ # zero_centered_gamma,
+ # )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The commented-out apply_normalization call and its replacement with rms_norm_forward (from gems_rms_norm) is concerning, especially since rms_norm_backward in gems_rms_norm.py is marked as 'wrong'. This direct replacement without fully functional backward pass could lead to incorrect gradients and training failures. The TODO(lixianduo): polish also indicates incomplete work.

Comment on lines 64 to 67
- dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
+ # TODO(lixianduo): polish
+ # dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
+ dw, db, *_ = gems_general_gemm(x, dy, **wgrad_gemm_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The commented-out general_gemm call and its replacement with gems_general_gemm for the wgrad GEMM operation carries the same risks as noted for dgrad. Consistency and correctness of gems_general_gemm are crucial here.

Comment on lines 19 to 29
+ # TODO(lixianduo): polish
+ # y, _, rstdevs = rmsnorm_fwd(
+ # x,
+ # w,
+ # self.eps,
+ # None,
+ # next_op_input_quantizer,
+ # TE_DType[dtype],
+ # sm_margin,
+ # self.zero_centered_gamma,
+ # )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The commented-out rmsnorm_fwd call and its replacement with rms_norm_forward from gems_rms_norm is concerning. While rms_norm_forward itself might be correct, the overall gems_rms_norm module has a known issue with its backward pass. The TODO(lixianduo): polish also indicates incomplete work.

data_path: /share/project/lixianduo/demo_data/pile/pile_wikipedia_demo
split: 1
no_mmap_bin_files: true
tokenizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The data_path is hardcoded to a specific user's directory (/share/project/lixianduo/demo_data/pile/pile_wikipedia_demo). This makes the configuration non-portable and difficult to use in different environments. It should be made configurable (e.g., using environment variables or relative paths) or moved to a more generic location.

host, master_addr, master_port, nnodes, node_rank, nproc_per_node, self.config
)

nsys_cmd = "/share/project/lixianduo/envs/nsys/nsight-system/2025.5.1/bin/nsys profile -s none -t nvtx,cuda,osrt -o /share/project/lixianduo/scale_gems_cx/nsys_reps/$HOSTNAME.nsys-rep --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop".split(" ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nsys_cmd contains a hardcoded path /share/project/lixianduo/envs/nsys/nsight-system/2025.5.1/bin/nsys. This path is specific to a particular environment and Nsight Systems installation. It should be made configurable (e.g., via an environment variable or a configuration parameter) to ensure portability across different systems.

print(f"use flag_gems for training")
print(f"{args.flag_gems_log_path=}")
print(f"{args.flag_gems_unused=}")
flag_gems.enable(record=True, once=True, unused=args.flag_gems_unused, path=args.flag_gems_log_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The flag_gems.enable call uses once=True. If flag_gems is intended to be enabled for the entire training run, once=True might prevent it from being re-enabled if the training loop restarts or if there are multiple phases where it needs to be active. Please clarify if this is the intended behavior or if once=False should be used for continuous enablement.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@lxd-cumt lxd-cumt changed the title [Train] FlagEngine: Initialize FlagEngine [Train] Initialize TransformerEngine-FlagOS Nov 21, 2025
@lxd-cumt lxd-cumt changed the title [Train] Initialize TransformerEngine-FlagOS [Train] Refactor TransformerEngine for FlagOS Nov 21, 2025
@lxd-cumt lxd-cumt changed the title [Train] Refactor TransformerEngine for FlagOS [FlagOS][Train] Refactor TransformerEngine for FlagOS Nov 21, 2025
@lxd-cumt lxd-cumt changed the title [FlagOS][Train] Refactor TransformerEngine for FlagOS [FlagOS][Train] Integrate TransformerEngine-FL to Megatron-LM for unified training backend Nov 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants