File tree Expand file tree Collapse file tree 3 files changed +4
-6
lines changed Expand file tree Collapse file tree 3 files changed +4
-6
lines changed Original file line number Diff line number Diff line change 10
10
"AutoModelForCausalLM" : " modeling_deepseek.DeepseekV3ForCausalLM"
11
11
},
12
12
"aux_loss_alpha" : 0.001 ,
13
+ "aux_loss_free_gamma" : 0.001 ,
13
14
"bos_token_id" : 0 ,
14
15
"eos_token_id" : 1 ,
15
16
"ep_size" : 1 ,
Original file line number Diff line number Diff line change @@ -573,7 +573,8 @@ def main():
573
573
callbacks = [StepFlexToken (), FP8QuantWeightCallback ()]
574
574
575
575
if getattr (config , "topk_method" , None ) == "noaux_tc" :
576
- callbacks += [MoECorrectionBiasAdjustCallback ()]
576
+ aux_loss_free_gamma = getattr (config , "aux_loss_free_gamma" , 0.0 )
577
+ callbacks += [MoECorrectionBiasAdjustCallback (aux_loss_free_gamma )]
577
578
578
579
trainer = PretrainingTrainer (
579
580
model = model ,
Original file line number Diff line number Diff line change @@ -700,7 +700,7 @@ def get_stat(layer):
700
700
701
701
model .apply (get_stat )
702
702
703
- usages_tensor = paddle .stack (usages , 0 ) # [num_layers, num_local_experts ]
703
+ usages_tensor = paddle .stack (usages , 0 ) # [num_layers, num_experts ]
704
704
if not hasattr (fleet , "_hcg" ):
705
705
dist .all_reduce (usages_tensor )
706
706
return
@@ -721,10 +721,6 @@ def get_stat(layer):
721
721
update = paddle .sign (usages_mean - usages_tensor ) * self .update_lr
722
722
update_list = list (update )
723
723
724
- print ('on_optimizer_end bias:' , [bias .tolist () for bias in biases ])
725
- print ('on_optimizer_end usage:' , usages_tensor .tolist ())
726
- print ('on_optimizer_end update:' , update .tolist ())
727
-
728
724
def update_bias (layer ):
729
725
if isinstance (layer , PretrainedMoEGate ) and layer .topk_method == "noaux_tc" :
730
726
with paddle .no_grad ():
You can’t perform that action at this time.
0 commit comments