From c0d3d52957f20b304d0b93144b63dc1fce22a72c Mon Sep 17 00:00:00 2001 From: lxdcumt <1141051934@qq.com> Date: Fri, 28 Nov 2025 15:17:47 +0800 Subject: [PATCH 1/2] support te_fl --- .../qwen3/conf/train/32b_te_cx_gems_nsys.yaml | 98 ++++++++++++++++ .../qwen3/conf/train_te_cx_gems_nsys.yaml | 36 ++++++ .../core/tensor_parallel/layers.py.patch | 106 ++++++++++++++++++ .../transformer/transformer_config.py.patch | 16 ++- .../megatron/training/arguments.py.patch | 82 +++++++++----- .../megatron/training/checkpointing.py.patch | 26 +++-- flagscale/runner/runner_train.py | 12 +- flagscale/train/device_wrapper.py | 79 +++++++++++++ flagscale/train/train.py | 24 ++++ 9 files changed, 437 insertions(+), 42 deletions(-) create mode 100644 examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml create mode 100644 examples/qwen3/conf/train_te_cx_gems_nsys.yaml create mode 100644 flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch create mode 100644 flagscale/train/device_wrapper.py diff --git a/examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml b/examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml new file mode 100644 index 0000000000..ac78c0595c --- /dev/null +++ b/examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml @@ -0,0 +1,98 @@ +system: + no_shared_fs: ${experiment.runner.no_shared_fs} + num_workers: 2 + tensor_model_parallel_size: 8 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + overlap_grad_reduce: true + overlap_param_gather: true + # profiling + profile: true + profile_step_start: 5 + profile_step_end: 15 + profile_ranks: [0,7,8,15] + precision: + bf16: true + attention_softmax_in_fp32: true + accumulate_allreduce_grads_in_fp32: true + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: ${experiment.exp_name} + wandb_exp_name: ${experiment.exp_name} + log_timers_to_tensorboard: true + log_validation_ppl_to_tensorboard: true + log_throughput: true + log_params_norm: true + log_num_zeros_in_grad: true + log_memory_to_tensorboard: true + checkpoint: + save_interval: ${experiment.save_steps} + load: ${experiment.load} + ckpt_format: ${experiment.ckpt_format} + +model: + transformer_impl: transformer_engine + enable_transformer_engine_fl: true + flag_gems_log_path: /share/project/lixianduo/scale_gems_cx/gems_te_cx_gems_log + flag_gems_unused: ['index_put', 'index_put_'] + num_layers: 4 + hidden_size: 5120 + ffn_hidden_size: 25600 + num_attention_heads: 64 + kv_channels: 128 + group_query_attention: true + num_query_groups: 8 + seq_length: 4096 + max_position_embeddings: 40960 + norm_epsilon: 1e-6 + use_rotary_position_embeddings: true + rotary_base: 1000000 + swiglu: true + normalization: RMSNorm + qk_layernorm: true + init_method_std: 0.02 + attention_dropout: 0.0 + hidden_dropout: 0.0 + untie_embeddings_and_output_weights: true + no_position_embedding: true + no_rope_fusion: true + disable_bias_linear: true + + # training + seed: ${experiment.seed} + finetune: false + micro_batch_size: 1 + global_batch_size: 8 #2048 + eval_iters: 0 + train_iters: 102400 + + optimizer: + clip_grad: 1.0 + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + lr_scheduler: + lr: 3.0e-3 + min_lr: 3.0e-4 + lr_warmup_fraction: 0.1 + lr_decay_style: WSD + lr_wsd_decay_style: cosine + lr_wsd_decay_iters: 10 + +data: + reset_position_ids: True + reset_attention_mask: True + # data_path: /share/project/lixianduo/demo_data/pile/pile_wikipedia_demo + data_path: /share/project/lizhiyu/hetero_data/HQ_wo_fim/Nemotron-CC-high-actual-actual-high_text_document + split: 1 + no_mmap_bin_files: true + tokenizer: + legacy_tokenizer: true + tokenizer_type: QwenTokenizerFS + tokenizer_path: /share/project/lixianduo/qwentokenizer + vocab_size: 151851 + make_vocab_size_divisible_by: 64 diff --git a/examples/qwen3/conf/train_te_cx_gems_nsys.yaml b/examples/qwen3/conf/train_te_cx_gems_nsys.yaml new file mode 100644 index 0000000000..10e4e8c0e8 --- /dev/null +++ b/examples/qwen3/conf/train_te_cx_gems_nsys.yaml @@ -0,0 +1,36 @@ +defaults: + - _self_ + # - train: 30b_a3b + - train: 32b_te_cx_gems_nsys + +experiment: + # exp_name: Qwen3-30b-a3b-Train + exp_name: Qwen3-32b-Train + seed: 42 + save_steps: 10000 + load: None + exp_dir: /share/project/lixianduo/scale_gems_cx/experiments_te_cx_gems_nsys/${experiment.exp_name} + ckpt_format: torch + task: + type: train + backend: megatron + entrypoint: flagscale/train/train_gpt.py + runner: + per_node_task: false + no_shared_fs: false + rdzv_backend: static + hostfile: /share/project/lixianduo/scale_gems_cx/1node_hostfile + ssh_port: 7878 + cmds: + before_start: ulimit -n 1048576 && source /root/miniconda3/bin/activate /share/project/lixianduo/envs/flagscale-train-gems-cx + envs: + LOGLEVEL: "INFO" + CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7" + CUDA_DEVICE_MAX_CONNECTIONS: 1 + USE_NSYS_PROFILE: True + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch b/flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch new file mode 100644 index 0000000000..bb1d2abdbf --- /dev/null +++ b/flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch @@ -0,0 +1,106 @@ +diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py +index 5c6f34b70..669a7e394 100644 +--- a/megatron/core/tensor_parallel/layers.py ++++ b/megatron/core/tensor_parallel/layers.py +@@ -448,6 +448,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + grad_output_buffer, + wgrad_deferral_limit, + tp_group, ++ use_transformer_engine_fl, + ): + """Forward.""" + if gradient_accumulation_fusion and hasattr(weight, "main_grad"): +@@ -466,6 +467,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + ctx.wgrad_deferral_limit = wgrad_deferral_limit + ctx.grad_output_buffer = grad_output_buffer + ctx.tp_group = tp_group ++ ctx.use_transformer_engine_fl = use_transformer_engine_fl + + if sequence_parallel: + dim_size = list(input.size()) +@@ -556,16 +558,23 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + if hasattr(weight, "__fsdp_param__"): + weight.main_grad = weight.get_main_grad() + +- if weight.main_grad.dtype == torch.float32: +- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( +- total_input, grad_output, weight.main_grad +- ) +- elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): +- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( +- total_input, grad_output, weight.main_grad +- ) ++ if not ctx.use_transformer_engine_fl: ++ if weight.main_grad.dtype == torch.float32: ++ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( ++ total_input, grad_output, weight.main_grad ++ ) ++ elif weight.main_grad.dtype in (torch.float16, torch.bfloat16): ++ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( ++ total_input, grad_output, weight.main_grad ++ ) ++ else: ++ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: +- raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") ++ if weight.main_grad.dtype in (torch.float32, torch.float16, torch.bfloat16): ++ grad_weight = torch.matmul(grad_output.t(), total_input) ++ weight.main_grad += grad_weight.view_as(weight.main_grad) ++ else: ++ raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + + if hasattr(weight, "grad_added_to_main_grad"): + # When overlap_grad_reduce is True, need to ensure that backward hooks +@@ -607,12 +616,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): + handle.wait() + # Need to return None's as gradient has to flow for all the input arguments + # provided during forward +- return (sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None) ++ return (sub_grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None) + + if ctx.allreduce_dgrad: + handle.wait() + +- return grad_input, grad_weight, grad_bias, None, None, None, None, None, None ++ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None + + + def linear_with_grad_accumulation_and_async_allreduce( +@@ -626,6 +635,7 @@ def linear_with_grad_accumulation_and_async_allreduce( + wgrad_deferral_limit: Optional[int] = 0, + async_grad_allreduce: Optional[bool] = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, ++ use_transformer_engine_fl: Optional[bool] = False, + ) -> torch.Tensor: + """Linear layer execution with asynchronous communication and + gradient accumulation fusion in backprop. +@@ -711,6 +721,7 @@ def linear_with_grad_accumulation_and_async_allreduce( + grad_output_buffer, + wgrad_deferral_limit, + tp_group, ++ use_transformer_engine_fl, + ] + + if not linear_with_grad_accumulation_and_async_allreduce.warned: +@@ -807,6 +818,7 @@ class ColumnParallelLinear(torch.nn.Module): + tp_group: Optional[torch.distributed.ProcessGroup] = None, + ): + super(ColumnParallelLinear, self).__init__() ++ print(f"[ColumnParallelLinear], {config.transformer_impl=}") + + # Keep input parameters + self.input_size = input_size +@@ -938,6 +950,8 @@ class ColumnParallelLinear(torch.nn.Module): + if not weight.requires_grad: + return linear_with_frozen_weight(input, weight, *args, **kwargs) + else: ++ if self.config.enable_transformer_engine_fl: ++ kwargs['use_transformer_engine_fl'] = True + return linear_with_grad_accumulation_and_async_allreduce(input, weight, *args, **kwargs) + + def forward( +@@ -1298,3 +1312,4 @@ class RowParallelLinear(torch.nn.Module): + f"{type(self).__name__}(in_features={self.input_size}, " + f"out_features={self.output_size}, bias={use_bias}, TP={tp})" + ) ++ diff --git a/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch b/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch index 2e1b69b884..83bc10506c 100644 --- a/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index 5ff62f74c..4bea0c328 100644 +index 5ff62f74c..02b469933 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -317,6 +317,15 @@ class TransformerConfig(ModelParallelConfig): @@ -41,7 +41,7 @@ index 5ff62f74c..4bea0c328 100644 flash_decode: bool = False """ Use the optimized flash decoding kernel during inference. """ -@@ -705,6 +723,26 @@ class TransformerConfig(ModelParallelConfig): +@@ -705,6 +723,31 @@ class TransformerConfig(ModelParallelConfig): """Transformer implementation to use. Options are 'transformer_engine' for Transformer Engine and 'local' for MCore.""" @@ -64,11 +64,16 @@ index 5ff62f74c..4bea0c328 100644 + """Lora a init method""" + lora_out_init_method: Optional[str] = None + """Lora b init method""" ++ ++ #################### ++ # TransformerEngine-FL ++ #################### ++ enable_transformer_engine_fl: Optional[bool] = False + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more -@@ -1481,6 +1519,9 @@ class TransformerConfig(ModelParallelConfig): +@@ -1481,6 +1524,9 @@ class TransformerConfig(ModelParallelConfig): f"the number of layers ({self.num_layers})" ) @@ -78,3 +83,8 @@ index 5ff62f74c..4bea0c328 100644 @dataclass class MLATransformerConfig(TransformerConfig): +@@ -1569,3 +1615,4 @@ class MLATransformerConfig(TransformerConfig): + assert ( + self.apply_rope_fusion is False + ), "Rope Fusion is not compatible with caching latents" ++ diff --git a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch index 6e946d062a..c0405d0445 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 1120c7529..190fac52b 100644 +index 1120c7529..99f7d7f44 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -67,6 +67,7 @@ def add_megatron_arguments(parser: argparse.ArgumentParser): @@ -10,7 +10,7 @@ index 1120c7529..190fac52b 100644 parser = _add_moe_args(parser) parser = _add_mla_args(parser) parser = _add_heterogeneous_args(parser) -@@ -94,6 +95,10 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): +@@ -94,6 +95,11 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): allow_abbrev=False) parser = add_megatron_arguments(parser) @@ -18,10 +18,11 @@ index 1120c7529..190fac52b 100644 + parser = _add_auto_tuner_args(parser) + parser = _add_auto_skip_spiky_loss(parser) + parser = _add_peft_args(parser) ++ parser = _add_transformer_engine_fl_args(parser) # Custom arguments. if extra_args_provider is not None: -@@ -368,63 +373,68 @@ def validate_args(args, defaults={}): +@@ -368,63 +374,68 @@ def validate_args(args, defaults={}): "legacy model format only supports the 'torch' checkpoint format." update_use_dist_ckpt(args) @@ -141,7 +142,7 @@ index 1120c7529..190fac52b 100644 if args.hierarchical_context_parallel_sizes: from numpy import prod -@@ -433,8 +443,8 @@ def validate_args(args, defaults={}): +@@ -433,8 +444,8 @@ def validate_args(args, defaults={}): assert args.hierarchical_context_parallel_sizes is not None, \ "--hierarchical-context-parallel-sizes must be set when a2a+p2p is used in cp comm" @@ -152,7 +153,7 @@ index 1120c7529..190fac52b 100644 # Deprecated arguments. assert args.batch_size is None, '--batch-size argument is no longer ' \ -@@ -530,6 +540,7 @@ def validate_args(args, defaults={}): +@@ -530,6 +541,7 @@ def validate_args(args, defaults={}): if args.virtual_pipeline_model_parallel_size == 1: args.virtual_pipeline_model_parallel_size = None elif args.num_layers_per_virtual_pipeline_stage is not None or args.num_virtual_stages_per_pipeline_rank is not None: @@ -160,7 +161,7 @@ index 1120c7529..190fac52b 100644 if args.num_virtual_stages_per_pipeline_rank is None: assert args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None, \ 'please use --num-virtual-stages-per-pipeline-rank to specify virtual pipeline parallel degree when enable uneven pipeline parallelism' -@@ -571,8 +582,9 @@ def validate_args(args, defaults={}): +@@ -571,8 +583,9 @@ def validate_args(args, defaults={}): if args.account_for_loss_in_pipeline_split: num_layers += 1 @@ -172,7 +173,7 @@ index 1120c7529..190fac52b 100644 if args.virtual_pipeline_model_parallel_size is not None: if args.overlap_p2p_comm: -@@ -796,12 +808,22 @@ def validate_args(args, defaults={}): +@@ -796,12 +809,22 @@ def validate_args(args, defaults={}): # Checks. if args.ffn_hidden_size is None: if args.swiglu: @@ -201,7 +202,7 @@ index 1120c7529..190fac52b 100644 else: args.ffn_hidden_size = 4 * args.hidden_size -@@ -1175,6 +1197,141 @@ def validate_args(args, defaults={}): +@@ -1175,6 +1198,147 @@ def validate_args(args, defaults={}): args.recompute_granularity != 'full' ), 'recompute_granularity must not be full when CUDA Graphs are enabled.' @@ -339,11 +340,17 @@ index 1120c7529..190fac52b 100644 + assert args.num_experts is None, "PEFT is not tested with MoE currently" + assert args.recompute_method is None and args.recompute_granularity is None and args.recompute_num_layers is None, "PEFT will raise comfilcts with recompute currently" + assert args.ckpt_format == 'torch', "PEFT is only tested with torch format checkpoint" ++ ++ if args.enable_transformer_engine_fl: ++ assert args.context_parallel_size == 1, "DotProductAttention in FlagEngine, do not support context parallel now" ++ assert not args.tp_comm_overlap, "Do not support gemm/sp comm in FlagEngine now" ++ args.distributed_backend = 'flagcx' ++ args.use_flag_gems = True + # Print arguments. _print_args("arguments", args) -@@ -1585,6 +1742,8 @@ def _add_network_size_args(parser): +@@ -1585,6 +1749,8 @@ def _add_network_size_args(parser): help='Which normalization technique to use.') group.add_argument('--norm-epsilon', type=float, default=1e-5, help='Epsilon for layer norm and RMS norm.') @@ -352,7 +359,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--apply-layernorm-1p', action='store_true', help='Adjust LayerNorm weights such that they are centered ' 'around zero. This improves numerical stability.') -@@ -1608,6 +1767,10 @@ def _add_network_size_args(parser): +@@ -1608,6 +1774,10 @@ def _add_network_size_args(parser): group.add_argument('--glu-linear-offset', type=float, default=0.0, help='Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). ' 'Only used when gated_linear_unit is True') @@ -363,7 +370,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--onnx-safe', type=bool, required=False, help='Use workarounds for known problems with ' 'Torch ONNX exporter') -@@ -1820,6 +1983,14 @@ def _add_logging_args(parser): +@@ -1820,6 +1990,14 @@ def _add_logging_args(parser): help='The wandb experiment name.') group.add_argument('--wandb-save-dir', type=str, default='', help='Path to save the wandb results locally.') @@ -378,7 +385,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--logging-level', type=int, default=None, help='Set default logging level') return parser -@@ -2001,6 +2172,25 @@ def _add_training_args(parser): +@@ -2001,6 +2179,25 @@ def _add_training_args(parser): '"shared_experts": recompute the shared experts in the MoE layer.' '"moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing, ' '"core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing.') @@ -404,7 +411,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false', help='If not set, clone the output of the scatter in embedding layer to GC original tensor.', dest='clone_scatter_output_in_embedding') -@@ -2087,6 +2277,10 @@ def _add_training_args(parser): +@@ -2087,6 +2284,10 @@ def _add_training_args(parser): help='Total number of samples to train over all ' 'training runs. Note that either train-iters or ' 'train-samples should be provided.') @@ -415,7 +422,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, -@@ -2210,6 +2404,10 @@ def _add_training_args(parser): +@@ -2210,6 +2411,10 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--use-te-activation-func', action='store_true', help='Use activation function kernel from Transformer Engine in MLP module.') @@ -426,7 +433,7 @@ index 1120c7529..190fac52b 100644 return parser -@@ -2268,11 +2466,26 @@ def _add_learning_rate_args(parser): +@@ -2268,11 +2473,26 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learning rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', @@ -454,7 +461,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' ' If None defaults to `--train-iters`') -@@ -2331,6 +2544,8 @@ def _add_checkpointing_args(parser): +@@ -2331,6 +2551,8 @@ def _add_checkpointing_args(parser): group.add_argument('--save-retain-interval', type=int, default=None, help='Number of iterations between retained checkpoints (other' 'checkpoints _except the last checkpoint_ are automatically deleted).') @@ -463,7 +470,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--no-save-optim', action='store_true', default=None, help='Do not save current optimizer.') group.add_argument('--no-save-rng', action='store_true', default=None, -@@ -2380,6 +2595,8 @@ def _add_checkpointing_args(parser): +@@ -2380,6 +2602,8 @@ def _add_checkpointing_args(parser): group.add_argument('--no-use-tokenizer-model-from-checkpoint-args', action='store_false', dest='use_tokenizer_model_from_checkpoint_args', help='If set, do not use tokenizer model path from checkpoint') @@ -472,7 +479,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--exit-on-missing-checkpoint', action='store_true', help="If '--load' is set, but checkpoint is not found " "(e.g., path typo), then exit instead of random " -@@ -2541,7 +2758,7 @@ def _add_distributed_args(parser): +@@ -2541,7 +2765,7 @@ def _add_distributed_args(parser): default=False, help='if set, overlap pipeline parallel communication in warmup and flush', dest='overlap_p2p_comm_warmup_flush') group.add_argument('--distributed-backend', default='nccl', @@ -481,7 +488,7 @@ index 1120c7529..190fac52b 100644 help='Which backend to use for distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Timeout minutes for torch.distributed.') -@@ -2592,6 +2809,11 @@ def _add_distributed_args(parser): +@@ -2592,6 +2816,11 @@ def _add_distributed_args(parser): 'complete it instead. Also turns on ' '--use-cpu-initialization flag. This is for ' 'external DDP manager.' ) @@ -493,7 +500,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--account-for-embedding-in-pipeline-split', action='store_true', default=False, help='If set, *input* embedding layer will be treated as a standard transformer' 'layer in the context of partition and placement for pipeline parallelism.') -@@ -2636,6 +2858,10 @@ def _add_distributed_args(parser): +@@ -2636,6 +2865,10 @@ def _add_distributed_args(parser): help='If set, keep the fp8 transpose cache when using Megatron FSDP.') group.add_argument('--enable-full-sharding-in-hsdp', action='store_true', help='If set, enable full sharding in megatron-fsdp Hybrid Sharded Data Parallel (HSDP) mode.') @@ -504,7 +511,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--num-distributed-optimizer-instances', type=int, default=1, help='Number of Distributed Optimizer copies across Data Parallel domain.') group.add_argument('--use-torch-fsdp2', action='store_true', -@@ -2690,6 +2916,9 @@ def _add_validation_args(parser): +@@ -2690,6 +2923,9 @@ def _add_validation_args(parser): group.add_argument('--eval-interval', type=int, default=1000, help='Interval between running evaluation on ' 'validation set.') @@ -514,7 +521,7 @@ index 1120c7529..190fac52b 100644 group.add_argument("--test-mode", action="store_true", help='Run all real-time test alongside the experiment.') group.add_argument('--skip-train', action='store_true', default=False, help='If set, bypass the training loop, ' -@@ -2708,6 +2937,8 @@ def _add_tokenizer_args(parser): +@@ -2708,6 +2944,8 @@ def _add_tokenizer_args(parser): 'automatically calculated from vocab-size.') group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file.') @@ -523,7 +530,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file.') group.add_argument('--vocab-extra-ids', type=int, default=0, -@@ -2726,8 +2957,17 @@ def _add_tokenizer_args(parser): +@@ -2726,8 +2964,17 @@ def _add_tokenizer_args(parser): 'MultimodalTokenizer', 'NullTokenizer', 'NullMultimodalTokenizer', @@ -542,7 +549,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--tokenizer-model', type=str, default=None, help='Sentencepiece tokenizer model.') group.add_argument('--tokenizer-metadata', type=str, default=None, -@@ -2768,6 +3008,11 @@ def _add_data_args(parser): +@@ -2768,6 +3015,11 @@ def _add_data_args(parser): group.add_argument('--valid-data-path', nargs='*', default=None, help='The weight and prefix list for an independent validation dataset. ' 'Follows the same pattern rules as --data-path.') @@ -554,7 +561,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--test-data-path', nargs='*', default=None, help='The weight and prefix list for an independent test dataset. ' 'Follows the same pattern rules as --data-path.') -@@ -2816,11 +3061,18 @@ def _add_data_args(parser): +@@ -2816,11 +3068,18 @@ def _add_data_args(parser): 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') @@ -573,7 +580,7 @@ index 1120c7529..190fac52b 100644 group.add_argument('--object-storage-cache-path', type=str, default=None, help='Path to cache index files when using s3 or msc dataloader') group.add_argument('--mid-level-dataset-surplus', type=float, default=0.005, -@@ -2897,6 +3149,19 @@ def _add_biencoder_args(parser): +@@ -2897,6 +3156,19 @@ def _add_biencoder_args(parser): return parser @@ -593,7 +600,7 @@ index 1120c7529..190fac52b 100644 def _add_vision_args(parser): group = parser.add_argument_group(title="vision") -@@ -2967,6 +3232,8 @@ def _add_vision_args(parser): +@@ -2967,6 +3239,8 @@ def _add_vision_args(parser): help='Whether to layer normalize the q and k attention embeddings.') group.add_argument('--qk-l2-norm', action='store_true', help='Use llama 4 qk l2 norm') @@ -602,7 +609,7 @@ index 1120c7529..190fac52b 100644 return parser -@@ -3275,3 +3542,75 @@ def _add_sft_args(parser): +@@ -3275,3 +3549,94 @@ def _add_sft_args(parser): group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') return parser @@ -677,4 +684,23 @@ index 1120c7529..190fac52b 100644 + choices=['normal', 'kaiming', 'xavier', 'zero'], + help='Init method of lora b') + return parser ++ ++def _add_transformer_engine_fl_args(parser): ++ group = parser.add_argument_group(title="flag engine") ++ group.add_argument('--enable-transformer-engine-fl', action='store_true', ++ help='Enable transformer engine fl for training') ++ group.add_argument('--use-flag-gems', action='store_true', ++ help='Enable flag gems for training') ++ group.add_argument('--flag-gems-log-path', type=str, default=None, ++ help='Path of flag gems logging') ++ group.add_argument('-use-flag-engine-optimizer', action='store_true', ++ help='Use FlagEngine.FusedAdam for training') ++ group.add_argument( ++ '--flag-gems-unused', ++ nargs='*', ++ default=None, ++ help='Flag Gems unused ops list' ++ ) ++ return parser +########## FlagScale End ########## ++ diff --git a/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch b/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch index a541bbdc84..8e1c68997d 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch @@ -1,8 +1,8 @@ diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py -index 104fa6882..1c501cc1b 100644 +index 104fa6882..722859bf6 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py -@@ -286,12 +286,14 @@ def read_metadata(tracker_filename): +@@ -286,12 +286,15 @@ def read_metadata(tracker_filename): print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( tracker_filename)) sys.exit() @@ -15,11 +15,12 @@ index 104fa6882..1c501cc1b 100644 # Get the max iteration retrieved across the ranks. if torch.distributed.is_initialized(): - iters_cuda = torch.tensor([iteration], dtype=torch.long, device='cuda') -+ iters_cuda = torch.tensor([iteration], dtype=torch.long, device='cuda' if 'nccl' in torch.distributed.get_backend() else 'cpu') ++ # iters_cuda = torch.tensor([iteration], dtype=torch.long, device='cuda' if 'nccl' in torch.distributed.get_backend() else 'cpu') ++ iters_cuda = torch.tensor([iteration], dtype=torch.long, device='cuda' if torch.distributed.get_backend() != 'gloo' else 'cpu') torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) max_iter = iters_cuda[0].item() -@@ -692,6 +694,28 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati +@@ -692,6 +695,28 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati if not torch.distributed.is_initialized() \ or is_last_rank(): def wandb_finalize_fn(): @@ -48,7 +49,7 @@ index 104fa6882..1c501cc1b 100644 wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration) if args.async_save: assert async_save_request is not None -@@ -774,9 +798,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path) +@@ -774,9 +799,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path) torch.distributed.barrier(group=mpu.get_data_parallel_group()) @@ -59,7 +60,7 @@ index 104fa6882..1c501cc1b 100644 torch.distributed.barrier(group=mpu.get_data_parallel_group()) dataloader_save_dict = {} -@@ -1239,6 +1261,10 @@ def load_args_from_checkpoint( +@@ -1239,6 +1262,10 @@ def load_args_from_checkpoint( checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear') ) @@ -70,7 +71,7 @@ index 104fa6882..1c501cc1b 100644 def _set_arg(arg_name, old_arg_name=None, force=False): if not force and getattr(args, arg_name, None) is not None: return -@@ -1274,6 +1300,8 @@ def load_args_from_checkpoint( +@@ -1274,6 +1301,8 @@ def load_args_from_checkpoint( _set_arg('add_qkv_bias', force=True) _set_arg('squared_relu', force=True) _set_arg('swiglu', force=True) @@ -79,7 +80,7 @@ index 104fa6882..1c501cc1b 100644 _set_arg('untie_embeddings_and_output_weights', force=True) _set_arg('apply_layernorm_1p', force=True) _set_arg('normalization', force=True) -@@ -1432,6 +1460,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', +@@ -1432,6 +1461,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)".format( run_tp_pp, ckpt_tp_pp ) @@ -94,7 +95,7 @@ index 104fa6882..1c501cc1b 100644 # Determine if RNG state will be loaded if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng -@@ -1468,6 +1504,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', +@@ -1468,6 +1505,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', ckpt_tp_pp != run_tp_pp and sharded_sd_metadata['distrib_optim_sharding_type'] not in DistributedOptimizer.checkpoint_fully_reshardable_formats @@ -102,7 +103,7 @@ index 104fa6882..1c501cc1b 100644 ): raise RuntimeError(f"{mismatch_msg}: not supported for DistributedOptimizer with sharding type" f" {sharded_sd_metadata['distrib_optim_sharding_type']}." -@@ -1481,7 +1518,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', +@@ -1481,7 +1519,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', gen_sd_optim = None gen_sd_opt_param_scheduler = None @@ -111,3 +112,8 @@ index 104fa6882..1c501cc1b 100644 model_sd_kwargs = dict(metadata=sharded_sd_metadata) # Determine if rerun state will be loaded +@@ -1829,3 +1867,4 @@ def load_biencoder_checkpoint(model, only_query_model=False, + print(' successfully loaded {}'.format(checkpoint_name)) + + return model ++ diff --git a/flagscale/runner/runner_train.py b/flagscale/runner/runner_train.py index cbd2c0d45d..61f1def4e8 100644 --- a/flagscale/runner/runner_train.py +++ b/flagscale/runner/runner_train.py @@ -269,7 +269,7 @@ def _generate_run_script_train( f.write(f"\n") f.write(f"cd {root_dir}\n") f.write(f"\n") - f.write(f"export PYTHONPATH={root_dir}:{megatron_dir}:${{PYTHONPATH}}\n") + f.write(f"export PYTHONPATH={megatron_dir}:{root_dir}:${{PYTHONPATH}}\n") f.write(f"\n") f.write(f'cmd="{cmd}"\n') f.write(f"\n") @@ -432,6 +432,16 @@ def _run_each( runner_cmd = _get_runner_cmd_train( host, master_addr, master_port, nnodes, node_rank, nproc_per_node, self.config ) + + nsys_cmd = "/share/project/lixianduo/envs/nsys/nsight-system/2025.5.1/bin/nsys profile -s none -t nvtx,cuda,osrt -o /share/project/lixianduo/scale_gems_cx/nsys_reps/$HOSTNAME.nsys-rep --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop".split( + " " + ) + if "USE_NSYS_PROFILE" in self.user_envs.keys(): + runner_cmd = nsys_cmd + runner_cmd + print(f"use nsys profile") + else: + print(f"not use nsys profile") + # update hetero-current-device-type according to the device_type in hostfile if device_type is not None: if "--hetero-current-device-type" in self.user_args: diff --git a/flagscale/train/device_wrapper.py b/flagscale/train/device_wrapper.py new file mode 100644 index 0000000000..2266a2fbce --- /dev/null +++ b/flagscale/train/device_wrapper.py @@ -0,0 +1,79 @@ +import functools +import logging + +from typing import Any, Callable + +import torch + +logging.basicConfig(level=logging.INFO, format='[CUDA-LOG] %(message)s') +logger = logging.getLogger(__name__) + + +CUDA_TO_NEW_MAPPING = { + # Memory + 'memory_allocated': 'memory_allocated', + 'max_memory_allocated': 'max_memory_allocated', + 'empty_cache': 'empty_cache', + 'reset_peak_memory_stats': 'reset_peak_memory_stats', + 'reset_max_memory_allocated': 'reset_max_memory_allocated', + # Device + 'device_count': 'device_count', + 'current_device': 'current_device', + 'get_device_name': 'get_device_name', + 'get_device_properties': 'get_device_properties', + # Stream + 'Stream': 'Stream', + 'current_stream': 'current_stream', + 'default_stream': 'default_stream', + 'stream': 'stream', + 'synchronize': 'synchronize', + # Event + 'Event': 'Event', + # AMP + 'amp': 'amp', + # Random + 'manual_seed': 'manual_seed', + 'manual_seed_all': 'manual_seed_all', +} + + +def create_cuda_to_new_function(new_device, name: str) -> Any: + if name in CUDA_TO_NEW_MAPPING: + new_attr_name = CUDA_TO_NEW_MAPPING[name] + else: + new_attr_name = name + + if hasattr(new_device, new_attr_name): + return getattr(new_device, new_attr_name) + + return None + + +class NewDevice: + def __init__(self, _original_cuda, _new_device): + self._original_cuda = _original_cuda + self._new_device = _new_device + self._patched_attrs = {} + + def __getattr__(self, name: str) -> Any: + if name in self._patched_attrs: + return self._patched_attrs[name] + new_attr = create_cuda_to_new_function(self._new_device, name) + if new_attr is not None: + print(f"[AUTO] Redirecting torch.cuda.{name} -> torch.new.{name}") + self._patched_attrs[name] = new_attr + return new_attr + else: + print(f"[FALLBACK] torch.new.{name} not in ,using original torch.cuda.{name}") + orig_attr = getattr(self._original_cuda, name) + self._patched_attrs[name] = orig_attr + return orig_attr + + +def patch_cuda_to_new_device(): + if not hasattr(torch, 'cuda'): + raise RuntimeError("torch.cuda not found") + _original_cuda = torch.cuda + _new_device = torch.cuda + torch.cuda = NewDevice(_original_cuda, _new_device) + return True diff --git a/flagscale/train/train.py b/flagscale/train/train.py index e02e2890e5..5f7ba3b0e2 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -144,6 +144,15 @@ from flagscale.train.peft.peft import PEFT from flagscale.train.peft.lora import LoRA +try: + import flag_gems + HAVE_GEMS = True +except ImportError: + HAVE_GEMS = False + +# wrapper for torch.cuda.xxx +from device_wrapper import patch_cuda_to_new_device + def destroy_global_state(): destroy_global_vars() destroy_num_microbatches_calculator() @@ -798,6 +807,14 @@ def pretrain( args = get_args() timers = get_timers() + ###### FlagScale Begin ###### + args = get_args() + if args.enable_transformer_engine_fl: + os.environ['USE_TRANSFORMER_ENGINE_FL'] = "True" + print(f"[TransformerEngineFL], apply device patching") + patch_cuda_to_new_device() + ###### FlagScale End ###### + if args.log_progress: append_to_progress_log("Starting job") @@ -2546,6 +2563,13 @@ def get_e2e_base_metrics(): ) cuda_graph_helper.create_cudagraphs() + # enable flag_gems for transformer_engine_fl + if args.use_flag_gems: + if not HAVE_GEMS: + raise ValueError(f"Can not import flag gems") + else: + flag_gems.enable(record=True, once=True, unused=args.flag_gems_unused, path=args.flag_gems_log_path) + # Run training iterations till done. buffered_rollouts = None while iteration < args.train_iters: From 2264cfc5fcb6bed6520d34a21288d0ea80323900 Mon Sep 17 00:00:00 2001 From: lxdcumt <1141051934@qq.com> Date: Fri, 28 Nov 2025 15:58:09 +0800 Subject: [PATCH 2/2] polish --- examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml | 2 +- .../megatron/core/tensor_parallel/layers.py.patch | 4 ++-- .../megatron/core/transformer/transformer_config.py.patch | 4 ++-- .../Megatron-LM/megatron/training/arguments.py.patch | 6 +++--- flagscale/train/train.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml b/examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml index ac78c0595c..4e82d20833 100644 --- a/examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml +++ b/examples/qwen3/conf/train/32b_te_cx_gems_nsys.yaml @@ -36,7 +36,7 @@ system: model: transformer_impl: transformer_engine - enable_transformer_engine_fl: true + use_transformer_engine_fl: true flag_gems_log_path: /share/project/lixianduo/scale_gems_cx/gems_te_cx_gems_log flag_gems_unused: ['index_put', 'index_put_'] num_layers: 4 diff --git a/flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch b/flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch index bb1d2abdbf..8ac388dcbf 100644 --- a/flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/core/tensor_parallel/layers.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py -index 5c6f34b70..669a7e394 100644 +index 5c6f34b70..ae07dc556 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -448,6 +448,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): @@ -94,7 +94,7 @@ index 5c6f34b70..669a7e394 100644 if not weight.requires_grad: return linear_with_frozen_weight(input, weight, *args, **kwargs) else: -+ if self.config.enable_transformer_engine_fl: ++ if self.config.use_transformer_engine_fl: + kwargs['use_transformer_engine_fl'] = True return linear_with_grad_accumulation_and_async_allreduce(input, weight, *args, **kwargs) diff --git a/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch b/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch index 83bc10506c..f947218e3c 100644 --- a/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index 5ff62f74c..02b469933 100644 +index 5ff62f74c..3c0571b2d 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -317,6 +317,15 @@ class TransformerConfig(ModelParallelConfig): @@ -68,7 +68,7 @@ index 5ff62f74c..02b469933 100644 + #################### + # TransformerEngine-FL + #################### -+ enable_transformer_engine_fl: Optional[bool] = False ++ use_transformer_engine_fl: Optional[bool] = False + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. diff --git a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch index c0405d0445..7df0590d6e 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 1120c7529..99f7d7f44 100644 +index 1120c7529..aad514a3c 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -67,6 +67,7 @@ def add_megatron_arguments(parser: argparse.ArgumentParser): @@ -341,7 +341,7 @@ index 1120c7529..99f7d7f44 100644 + assert args.recompute_method is None and args.recompute_granularity is None and args.recompute_num_layers is None, "PEFT will raise comfilcts with recompute currently" + assert args.ckpt_format == 'torch', "PEFT is only tested with torch format checkpoint" + -+ if args.enable_transformer_engine_fl: ++ if args.use_transformer_engine_fl: + assert args.context_parallel_size == 1, "DotProductAttention in FlagEngine, do not support context parallel now" + assert not args.tp_comm_overlap, "Do not support gemm/sp comm in FlagEngine now" + args.distributed_backend = 'flagcx' @@ -687,7 +687,7 @@ index 1120c7529..99f7d7f44 100644 + +def _add_transformer_engine_fl_args(parser): + group = parser.add_argument_group(title="flag engine") -+ group.add_argument('--enable-transformer-engine-fl', action='store_true', ++ group.add_argument('--use-transformer-engine-fl', action='store_true', + help='Enable transformer engine fl for training') + group.add_argument('--use-flag-gems', action='store_true', + help='Enable flag gems for training') diff --git a/flagscale/train/train.py b/flagscale/train/train.py index 5f7ba3b0e2..bdf5a379dc 100644 --- a/flagscale/train/train.py +++ b/flagscale/train/train.py @@ -809,7 +809,7 @@ def pretrain( ###### FlagScale Begin ###### args = get_args() - if args.enable_transformer_engine_fl: + if args.use_transformer_engine_fl: os.environ['USE_TRANSFORMER_ENGINE_FL'] = "True" print(f"[TransformerEngineFL], apply device patching") patch_cuda_to_new_device()