|
27 | 27 | import numpy as np
|
28 | 28 | from tqdm.auto import tqdm
|
29 | 29 |
|
| 30 | +import paddle |
| 31 | +import paddle.distributed as dist |
| 32 | +from paddle.distributed.fleet import fleet |
| 33 | + |
| 34 | +from paddlenlp.transformers.moe_gate import PretrainedMoEGate |
30 | 35 | from paddlenlp.transformers.moe_utils import offload, reload
|
31 | 36 | from paddlenlp.utils.log import logger
|
32 | 37 |
|
|
44 | 49 | "EarlyStoppingCallback",
|
45 | 50 | "StepFlexToken",
|
46 | 51 | "FP8QuantWeightCallback",
|
| 52 | + "MoECorrectionBiasAdjustCallback", |
47 | 53 | ]
|
48 | 54 |
|
49 | 55 |
|
@@ -671,3 +677,55 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
|
671 | 677 | if (not g_shard_bypass_dygraph_optimizer) and hasattr(model, "fp8_quant_weight"):
|
672 | 678 | for name in self.moe_weights_name:
|
673 | 679 | reload(optimizer._master_weights[name])
|
| 680 | + |
| 681 | + |
| 682 | +class MoECorrectionBiasAdjustCallback(TrainerCallback): |
| 683 | + """used for moe aux loss free balance""" |
| 684 | + |
| 685 | + def __init__(self, lr=0.001, use_mp=False): |
| 686 | + super().__init__() |
| 687 | + self.update_lr = lr |
| 688 | + self.use_mp = use_mp |
| 689 | + |
| 690 | + def on_optimizer_end(self, args, state, control, **kwargs): |
| 691 | + model = kwargs["model"] |
| 692 | + |
| 693 | + biases = [] |
| 694 | + usages = [] |
| 695 | + |
| 696 | + def get_stat(layer): |
| 697 | + if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc": |
| 698 | + biases.append(layer.e_score_correction_bias) |
| 699 | + usages.append(layer.expert_usage) |
| 700 | + |
| 701 | + model.apply(get_stat) |
| 702 | + |
| 703 | + usages_tensor = paddle.stack(usages, 0) # [num_layers, num_experts] |
| 704 | + if not hasattr(fleet, "_hcg"): |
| 705 | + dist.all_reduce(usages_tensor) |
| 706 | + return |
| 707 | + |
| 708 | + hcg = fleet.get_hybrid_communicate_group() |
| 709 | + mp_group = hcg.get_model_parallel_group() |
| 710 | + dp_group = hcg.get_data_parallel_group() |
| 711 | + sd_group = hcg.get_sharding_parallel_group() |
| 712 | + |
| 713 | + if self.use_mp and mp_group.nranks > 1: |
| 714 | + dist.all_reduce(usages_tensor, group=mp_group) |
| 715 | + if dp_group.nranks > 1: |
| 716 | + dist.all_reduce(usages_tensor, group=dp_group) |
| 717 | + if sd_group.nranks > 1: |
| 718 | + dist.all_reduce(usages_tensor, group=sd_group) |
| 719 | + |
| 720 | + usages_mean = usages_tensor.mean(-1, keepdim=True) |
| 721 | + update = paddle.sign(usages_mean - usages_tensor) * self.update_lr |
| 722 | + update_list = list(update) |
| 723 | + |
| 724 | + def update_bias(layer): |
| 725 | + if isinstance(layer, PretrainedMoEGate) and layer.topk_method == "noaux_tc": |
| 726 | + with paddle.no_grad(): |
| 727 | + if not layer.weight.stop_gradient: |
| 728 | + biases.pop(0).add_(update_list.pop(0)) |
| 729 | + usages.pop(0).zero_() |
| 730 | + |
| 731 | + model.apply(update_bias) |
0 commit comments