Skip to content

Commit d163cf0

Browse files
committed
Implement auxiliary-loss-free load balancing
1 parent 1176601 commit d163cf0

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

llm/model_config/DeepSeek-V3/config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
1111
},
1212
"aux_loss_alpha": 0.001,
13+
"aux_loss_free_gamma": 0.001,
1314
"bos_token_id": 0,
1415
"eos_token_id": 1,
1516
"ep_size": 1,

llm/run_pretrain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,8 @@ def main():
573573
callbacks = [StepFlexToken(), FP8QuantWeightCallback()]
574574

575575
if getattr(config, "topk_method", None) == "noaux_tc":
576-
callbacks += [MoECorrectionBiasAdjustCallback()]
576+
aux_loss_free_gamma = getattr(config, "aux_loss_free_gamma", 0.0)
577+
callbacks += [MoECorrectionBiasAdjustCallback(aux_loss_free_gamma)]
577578

578579
trainer = PretrainingTrainer(
579580
model=model,

paddlenlp/trainer/trainer_callback.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def get_stat(layer):
700700

701701
model.apply(get_stat)
702702

703-
usages_tensor = paddle.stack(usages, 0) # [num_layers, num_local_experts]
703+
usages_tensor = paddle.stack(usages, 0) # [num_layers, num_experts]
704704
if not hasattr(fleet, "_hcg"):
705705
dist.all_reduce(usages_tensor)
706706
return
@@ -721,10 +721,6 @@ def get_stat(layer):
721721
update = paddle.sign(usages_mean - usages_tensor) * self.update_lr
722722
update_list = list(update)
723723

724-
print('on_optimizer_end bias:', [bias.tolist() for bias in biases])
725-
print('on_optimizer_end usage:', usages_tensor.tolist())
726-
print('on_optimizer_end update:', update.tolist())
727-
728724
def update_bias(layer):
729725
if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc":
730726
with paddle.no_grad():

0 commit comments

Comments
 (0)