From 997e8f99b46d2f3c87fb4df3ff160f3b2b7718b6 Mon Sep 17 00:00:00 2001 From: Liang Shuhao Date: Mon, 8 Sep 2025 05:07:05 +0000 Subject: [PATCH] Implement auxiliary-loss-free load balancing --- llm/model_config/DeepSeek-V3/config.json | 3 +- llm/run_pretrain.py | 5 ++ paddlenlp/trainer/trainer_callback.py | 62 +++++++++++++++++++ .../transformers/deepseek_v2/modeling.py | 7 +++ paddlenlp/transformers/moe_gate.py | 2 +- 5 files changed, 77 insertions(+), 2 deletions(-) diff --git a/llm/model_config/DeepSeek-V3/config.json b/llm/model_config/DeepSeek-V3/config.json index f17844eda4c9..631e42f781f4 100644 --- a/llm/model_config/DeepSeek-V3/config.json +++ b/llm/model_config/DeepSeek-V3/config.json @@ -9,7 +9,8 @@ "AutoModel": "modeling_deepseek.DeepseekV3Model", "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM" }, - "aux_loss_alpha": 0.001, + "aux_loss_alpha": 0.0001, + "aux_loss_free_gamma": 0.001, "bos_token_id": 0, "eos_token_id": 1, "ep_size": 1, diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 7b45b3165a35..1a622a4c1539 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -28,6 +28,7 @@ ) from paddlenlp.trainer import ( FP8QuantWeightCallback, + MoECorrectionBiasAdjustCallback, PdArgumentParser, StepFlexToken, Trainer, @@ -571,6 +572,10 @@ def main(): callbacks = [StepFlexToken(), FP8QuantWeightCallback()] + if getattr(config, "topk_method", None) == "noaux_tc": + aux_loss_free_gamma = getattr(config, "aux_loss_free_gamma", 0.001) + callbacks += [MoECorrectionBiasAdjustCallback(aux_loss_free_gamma)] + trainer = PretrainingTrainer( model=model, args=training_args, diff --git a/paddlenlp/trainer/trainer_callback.py b/paddlenlp/trainer/trainer_callback.py index 3c3fc6ed7550..0dd6b7fc02db 100644 --- a/paddlenlp/trainer/trainer_callback.py +++ b/paddlenlp/trainer/trainer_callback.py @@ -27,6 +27,11 @@ import numpy as np from tqdm.auto import tqdm +import paddle +import paddle.distributed as dist +from paddle.distributed.fleet import fleet + +from paddlenlp.transformers.moe_gate import PretrainedMoEGate from paddlenlp.transformers.moe_utils import offload, reload from paddlenlp.utils.log import logger @@ -44,6 +49,7 @@ "EarlyStoppingCallback", "StepFlexToken", "FP8QuantWeightCallback", + "MoECorrectionBiasAdjustCallback", ] @@ -671,3 +677,59 @@ def on_optimizer_begin(self, args, state, control, **kwargs): if (not g_shard_bypass_dygraph_optimizer) and hasattr(model, "fp8_quant_weight"): for name in self.moe_weights_name: reload(optimizer._master_weights[name]) + + +class MoECorrectionBiasAdjustCallback(TrainerCallback): + """used for moe aux loss free balance""" + + def __init__(self, lr=0.001, use_mp=False): + super().__init__() + self.update_lr = lr + self.use_mp = use_mp + + def on_optimizer_end(self, args, state, control, **kwargs): + model = kwargs["model"] + + biases = [] + usages = [] + + def get_stat(layer): + if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc": + biases.append(layer.e_score_correction_bias) + usages.append(layer.expert_usage) + + model.apply(get_stat) + + usages_tensor = paddle.stack(usages, 0) # [num_layers, num_experts] + if not hasattr(fleet, "_hcg"): + dist.all_reduce(usages_tensor) + return + + hcg = fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() + sd_group = hcg.get_sharding_parallel_group() + + if self.use_mp and mp_group.nranks > 1: + dist.all_reduce(usages_tensor, group=mp_group) + if dp_group.nranks > 1: + dist.all_reduce(usages_tensor, group=dp_group) + if sd_group.nranks > 1: + dist.all_reduce(usages_tensor, group=sd_group) + + usages_mean = usages_tensor.mean(-1, keepdim=True) + update = paddle.sign(usages_mean - usages_tensor) * self.update_lr + update_list = list(update) + + print('on_optimizer_end bias:', [bias.cast('float64').round(6).tolist() for bias in biases]) + print('on_optimizer_end usage:', usages_tensor.tolist()) + print('on_optimizer_end update:', update.cast('float64').round(6).tolist()) + + def update_bias(layer): + if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc": + with paddle.no_grad(): + if not layer.weight.stop_gradient: + biases.pop(0).add_(update_list.pop(0)) + usages.pop(0).zero_() + + model.apply(update_bias) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 33fcd4520411..e737093ec8a3 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -925,6 +925,11 @@ def __init__( default_initializer=nn.initializer.Constant(0.0), ) self.e_score_correction_bias.is_distributed = True + self.expert_usage = paddle.zeros( + shape=[num_experts], + dtype=paddle.int64, + ) + self.expert_usage.stop_gradient = True if self.using_post_norm_recompute: assert norm_weight is not None and norm_eps is not None @@ -970,6 +975,8 @@ def forward(self, hidden_states): scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop( scores ) # (scores, routing_map, exp_counts, l_aux, l_zloss) + with paddle.no_grad(): + self.expert_usage += exp_counts ret = (scores, routing_map, l_aux, l_zloss) else: ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss) diff --git a/paddlenlp/transformers/moe_gate.py b/paddlenlp/transformers/moe_gate.py index d1eb14abfa98..5507702fdcd8 100644 --- a/paddlenlp/transformers/moe_gate.py +++ b/paddlenlp/transformers/moe_gate.py @@ -301,7 +301,7 @@ def _topk_noaux_tc( assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" - scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) + scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.detach().unsqueeze(0) reshape_tmp_rst = scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]) top_k = min(reshape_tmp_rst.shape[2], 2) group_scores = reshape_tmp_rst.topk(top_k, axis=-1)[0].sum(axis=-1) # fmt:skip [n, n_group]