Skip to content

Commit 626b586

Browse files
committed
Implement auxiliary-loss-free load balancing
1 parent 03cf701 commit 626b586

File tree

5 files changed

+73
-2
lines changed

5 files changed

+73
-2
lines changed

llm/model_config/DeepSeek-V3/config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
"AutoModel": "modeling_deepseek.DeepseekV3Model",
1010
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
1111
},
12-
"aux_loss_alpha": 0.001,
12+
"aux_loss_alpha": 0.0001,
13+
"aux_loss_free_gamma": 0.001,
1314
"bos_token_id": 0,
1415
"eos_token_id": 1,
1516
"ep_size": 1,

llm/run_pretrain.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from paddlenlp.trainer import (
3030
FP8QuantWeightCallback,
31+
MoECorrectionBiasAdjustCallback,
3132
PdArgumentParser,
3233
StepFlexToken,
3334
Trainer,
@@ -571,6 +572,10 @@ def main():
571572

572573
callbacks = [StepFlexToken(), FP8QuantWeightCallback()]
573574

575+
if getattr(config, "topk_method", None) == "noaux_tc":
576+
aux_loss_free_gamma = getattr(config, "aux_loss_free_gamma", 0.001)
577+
callbacks += [MoECorrectionBiasAdjustCallback(aux_loss_free_gamma)]
578+
574579
trainer = PretrainingTrainer(
575580
model=model,
576581
args=training_args,

paddlenlp/trainer/trainer_callback.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
import numpy as np
2828
from tqdm.auto import tqdm
2929

30+
import paddle
31+
import paddle.distributed as dist
32+
from paddle.distributed.fleet import fleet
33+
34+
from paddlenlp.transformers.moe_gate import PretrainedMoEGate
3035
from paddlenlp.transformers.moe_utils import offload, reload
3136
from paddlenlp.utils.log import logger
3237

@@ -44,6 +49,7 @@
4449
"EarlyStoppingCallback",
4550
"StepFlexToken",
4651
"FP8QuantWeightCallback",
52+
"MoECorrectionBiasAdjustCallback",
4753
]
4854

4955

@@ -671,3 +677,55 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
671677
if (not g_shard_bypass_dygraph_optimizer) and hasattr(model, "fp8_quant_weight"):
672678
for name in self.moe_weights_name:
673679
reload(optimizer._master_weights[name])
680+
681+
682+
class MoECorrectionBiasAdjustCallback(TrainerCallback):
683+
"""used for moe aux loss free balance"""
684+
685+
def __init__(self, lr=0.001, use_mp=False):
686+
super().__init__()
687+
self.update_lr = lr
688+
self.use_mp = use_mp
689+
690+
def on_optimizer_end(self, args, state, control, **kwargs):
691+
model = kwargs["model"]
692+
693+
biases = []
694+
usages = []
695+
696+
def get_stat(layer):
697+
if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc":
698+
biases.append(layer.e_score_correction_bias)
699+
usages.append(layer.expert_usage)
700+
701+
model.apply(get_stat)
702+
703+
usages_tensor = paddle.stack(usages, 0) # [num_layers, num_experts]
704+
if not hasattr(fleet, "_hcg"):
705+
dist.all_reduce(usages_tensor)
706+
return
707+
708+
hcg = fleet.get_hybrid_communicate_group()
709+
mp_group = hcg.get_model_parallel_group()
710+
dp_group = hcg.get_data_parallel_group()
711+
sd_group = hcg.get_sharding_parallel_group()
712+
713+
if self.use_mp and mp_group.nranks > 1:
714+
dist.all_reduce(usages_tensor, group=mp_group)
715+
if dp_group.nranks > 1:
716+
dist.all_reduce(usages_tensor, group=dp_group)
717+
if sd_group.nranks > 1:
718+
dist.all_reduce(usages_tensor, group=sd_group)
719+
720+
usages_mean = usages_tensor.mean(-1, keepdim=True)
721+
update = paddle.sign(usages_mean - usages_tensor) * self.update_lr
722+
update_list = list(update)
723+
724+
def update_bias(layer):
725+
if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc":
726+
with paddle.no_grad():
727+
if not layer.weight.stop_gradient:
728+
biases.pop(0).add_(update_list.pop(0))
729+
usages.pop(0).zero_()
730+
731+
model.apply(update_bias)

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,11 @@ def __init__(
925925
default_initializer=nn.initializer.Constant(0.0),
926926
)
927927
self.e_score_correction_bias.is_distributed = True
928+
self.expert_usage = paddle.zeros(
929+
shape=[num_experts],
930+
dtype=paddle.int64,
931+
)
932+
self.expert_usage.stop_gradient = True
928933

929934
if self.using_post_norm_recompute:
930935
assert norm_weight is not None and norm_eps is not None
@@ -970,6 +975,8 @@ def forward(self, hidden_states):
970975
scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(
971976
scores
972977
) # (scores, routing_map, exp_counts, l_aux, l_zloss)
978+
with paddle.no_grad():
979+
self.expert_usage += exp_counts
973980
ret = (scores, routing_map, l_aux, l_zloss)
974981
else:
975982
ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss)

paddlenlp/transformers/moe_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def _topk_noaux_tc(
301301
assert n_experts % n_group == 0, "n_experts must be divisible by n_groups"
302302

303303
assert self.e_score_correction_bias is not None, "e_score_correction_bias is None"
304-
scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0)
304+
scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.detach().unsqueeze(0)
305305
reshape_tmp_rst = scores_for_choice.reshape([bsz_seq_len, self.n_group, -1])
306306
top_k = min(reshape_tmp_rst.shape[2], 2)
307307
group_scores = reshape_tmp_rst.topk(top_k, axis=-1)[0].sum(axis=-1) # fmt:skip [n, n_group]

0 commit comments

Comments
 (0)