diff --git a/examples/pre-training/ernie/pretrain_auto.py b/examples/pre-training/ernie/pretrain_auto.py new file mode 100644 index 000000000..d871cd27b --- /dev/null +++ b/examples/pre-training/ernie/pretrain_auto.py @@ -0,0 +1,359 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import json +import numpy as np +import random +import paddle +import paddle.distributed.fleet as fleet +from src.utils import logger +from paddleformers.trainer import ( + PdArgumentParser, + get_last_checkpoint, +) +from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer +from omegaconf.listconfig import ListConfig +from omegaconf.dictconfig import DictConfig +from src.callbacks import ( + ProgreesiveBatchingCallback, + GlobalRNGCallback, +) +from models.ernie import ( + ErnieForCausalLMAuto, +) +from models.ernie_moe.configuration import ( + ErnieConfig, + ErnieMoEConfig, +) +from src.trainers import AutoPretrainingTrainer, AutoPreTrainingArguments +from src.utils import ( + setup_logger_output_file, +) +from src.utils.misc import global_training_logs +from pretrain import create_pretrained_dataset + + +from config import get_config + +try: + from paddleformers.trainer.trainer_utils import log_trainer_start +except ImportError: + + def log_trainer_start(): + """Print main process messgae""" + if "MAIN_PROCESS_STARTED" not in os.environ: + start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + logger.info( + f"The Training Main Process Started Successfully. time: {start_time}, pid: {os.getpid()}" + ) + os.environ["MAIN_PROCESS_STARTED"] = "1" + + +log_trainer_start() + + +try: + from paddle.distributed.fleet import monitor_perf as collective_perf +except ImportError: + from paddle.distributed.fleet import collective_perf + + +assert paddle.version.mkl() == "OFF", ( + "MKL is not supported" + " in this version. Please set -DWITH_MKL=OFF when compiling PaddlePaddle." +) + + +def update_model_config_from_args(config: ErnieConfig, model_args: dict): + for k, v in model_args.items(): + if hasattr(config, k): + logger.info(f"update model config: {k} = {v}") + setattr(config, k, v) + else: + logger.warning(f"model config key: {k} does not exist") + return config + + +def init_parameter(model): + + for param in model.parameters(): + param.initialize() + + +def main(): + """Main function""" + config = get_config(verbose=True) + os.makedirs(config.model_args.output_dir, exist_ok=True) + parser = PdArgumentParser(AutoPreTrainingArguments) + if not hasattr(config.trainer_args, "pipeline_parallel_config"): + config.trainer_args.pipeline_parallel_config = "" + + if "enable_dp_comm_overlap" in config.trainer_args.pipeline_parallel_config: + logger.warning( + "Pipeline dp_comm_overlap and FusedLinearWithGradAdd can not be used at " + "the same time." + ) + + if "enable_timer" in config.trainer_args.pipeline_parallel_config: + from paddle.distributed.fleet.meta_parallel.pipeline_parallel import ( + PipelineParallel, + ) + + PipelineParallel.timer_printer = lambda _: None + + def formatv(v): + if isinstance(v, ListConfig): + return list(v) + elif isinstance(v, DictConfig): + return dict(v) + return v + + model_args = {k: formatv(v) for k, v in dict(config.model_args).items()} + trainer_args = {k: formatv(v) for k, v in dict(config.trainer_args).items()} + (args,) = parser.parse_dict(dict(**model_args, **trainer_args)) + + if args.strategy.pipeline.enable and args.virtual_pp_degree > 1: + pipeline = args.strategy.pipeline + pipeline.vpp_degree = args.virtual_pp_degree + pipeline.vpp_seg_method = args.virtual_pipeline_seg_method + + if args.modality_ratio is not None: + args.modality_interleave = ( + sum(args.modality_ratio) + if args.modality_interleave == "acc" + else sum(args.modality_ratio) * args.gradient_accumulation_steps + ) + args.modality_ratio = [ + i / sum(args.modality_ratio) for i in args.modality_ratio + ] + + args.eval_iters = 10 + args.test_iters = args.eval_iters * 10 + + args.use_moe = dict(**dict(config.model_args), **dict(config.trainer_args)).get( + "use_moe", False + ) + model_config = dict(getattr(config.model_args, "model_config", {})) + model_config = {k: formatv(v) for k, v in model_config.items()} + logger.info(f"model_config_from_yaml: {json.dumps(model_config, indent=4)}") + setup_logger_output_file(config.model_args.output_dir, args.local_rank) + paddle.set_device(args.device) + + np.random.seed(args.seed) + random.seed(args.seed) + paddle.seed(args.seed) + # set_seed(args.seed) + + prop = paddle.device.cuda.get_device_properties() + if prop.total_memory < args.pre_alloc_memory * 1024 * 1024 * 1024: + logger.warning( + "Invalid value for `pre_alloc_memory`, so pre-allocating just failed." + ) + elif args.pre_alloc_memory > 0: + logger.warning( + f"pre-allocating a tensor whose memory capacity is {args.pre_alloc_memory} GB " + "and then release it." + ) + memory_size = int(args.pre_alloc_memory * 1024 * 1024 * 1024) + x = paddle.empty([memory_size], dtype=paddle.uint8) + del x + + # add fleet test + try: + collective_perf( + "allgather", + round=50, + size_and_time={67108864: 0.00625, 234881024: 0.02, 637534208: 0.057}, + ) + logger.info("======monitor allgather done!=======\n") + collective_perf( + "allreduce", + round=50, + size_and_time={67108864: 0.02, 134217728: 0.038, 268435456: 0.075}, + ) + logger.info("======monitor allreduce done!=======\n") + except Exception as e: + logger.warning(f"fleet test unexcepted error! skip exception[{e}]...") + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(args.output_dir) + and args.do_train + and not args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(args.output_dir) + if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0: + raise ValueError( + f"Output directory ({args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Define the metrics of tasks. + def compute_metrics(p): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + + output = paddle.to_tensor(preds) + labels = paddle.to_tensor(p.label_ids) + output = [t.astype("float32").cuda() for t in output] + labels = [t[t != tokenizer.ignored_index] for t in labels] + labels = [t.cuda() for t in labels] + all_numel = ( + (paddle.concat(labels, 0) != tokenizer.ignored_index).astype("int64").sum() + ) + ignored = (paddle.concat(labels, 0) == -100).astype("int64").sum() + labels = all_numel - ignored + output = sum(output) + logger.info(f"output : {output.item()}, labels : {labels.item()}") + nll_loss = output / (labels + 1.0e-6) # nll_loss is global loss + ppl = paddle.exp(nll_loss) + + return { + "nll_loss": nll_loss.item(), + "ppl": ppl.item(), + "num_token": labels.item(), + } + + # model + dtype = "float32" + if args.fp16 and args.fp16_opt_level == "O2": + paddle.set_default_dtype("float16") + dtype = "float16" + elif args.bf16: + paddle.set_default_dtype("bfloat16") + dtype = "bfloat16" + + if args.use_moe: + global ErnieConfig, ErnieForCausalLMAuto + ErnieConfig = ErnieMoEConfig + + if args.moe_group.lower() in {"mp", "tp", "model", "dummy"}: + logger.info(f"disable moe flag when using moe-group={args.moe_group}") + args.use_moe = False + + args.multi_token_pred_depth = model_config.get("multi_token_pred_depth", 0) + + cfg = ErnieConfig.from_pretrained(args.model_name_or_path) + cfg.seqlen = args.max_seq_length + cfg.token_balance_seqlen = args.max_seq_length * args.per_device_train_batch_size + cfg.fp16_opt_level = args.fp16_opt_level + cfg.moe_group = args.moe_group + cfg.dtype = dtype + cfg.pipeline_parallel_degree = args.pipeline_parallel_degree + cfg.virtual_pp_degree = args.virtual_pp_degree + if args.tensor_parallel_degree > 1: + cfg.sequence_parallel = args.sequence_parallel + cfg.tensor_parallel_degree = max( + fleet.get_hybrid_communicate_group().get_model_parallel_world_size(), 1 + ) + cfg.tensor_parallel_rank = max( + fleet.get_hybrid_communicate_group().get_model_parallel_rank(), 0 + ) + else: + cfg.sequence_parallel = False + cfg.tensor_parallel_degree = 1 + cfg.tensor_parallel_rank = 0 + + cfg.micro_batch_size = args.per_device_train_batch_size + tokenizer = ErnieBotTokenizer.from_pretrained(args.tokenizer_name) + tokenizer.ignored_index = cfg.ignored_index + logger.info( + f"using tokenizer={type(tokenizer)}, bos:{tokenizer.bos_token_id} " + f"eos:{tokenizer.eos_token_id} pad:{tokenizer.pad_token_id} " + ) + + cfg = update_model_config_from_args(cfg, model_config) + + if args.from_scratch: + with paddle.LazyGuard(): + model = ErnieForCausalLMAuto(cfg) + else: + with paddle.LazyGuard(): + model = ErnieForCausalLMAuto.from_pretrained( + args.model_name_or_path, + config=cfg, + ) + + cfg = model.config + logger.info(f"using model type:{type(model)}") + paddle.set_default_dtype("float32") + + logger.info(f"using model={type(model)}, cfg={cfg}") + + freeze_config = set(args.freeze_config.split(" ")) + if "freeze_vision" in freeze_config and hasattr(model, "freeze_vision"): + logger.info("Freeze model vision module") + model.freeze_vision() + + # data + logger.info("loading data...") + train_dataset, eval_dataset, test_dataset, data_collator = ( + create_pretrained_dataset(args) + ) + + callbacks = [] + callbacks += [GlobalRNGCallback()] + + if args.batch_size_warmup_steps: + progreesive_batcing_callback = ProgreesiveBatchingCallback( + args.gradient_accumulation_steps, + args.max_gradient_accumulation_steps, + args.batch_size_warmup_steps, + args.batch_size_warmup_increment, + ) + callbacks.append(progreesive_batcing_callback) + + init_parameter(model) + model.apply(model.init_weights) + trainer = AutoPretrainingTrainer( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + callbacks=callbacks, + ) + global_training_logs.accumulate = args.gradient_accumulation_steps + checkpoint = None + if args.resume_from_checkpoint is not None: + checkpoint = args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model(args.output_dir) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluate and tests model + if args.do_eval: + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + + +if __name__ == "__main__": + main() diff --git a/examples/pre-training/ernie/src/callbacks/__init__.py b/examples/pre-training/ernie/src/callbacks/__init__.py index 15b0cb9f5..659ae6f19 100644 --- a/examples/pre-training/ernie/src/callbacks/__init__.py +++ b/examples/pre-training/ernie/src/callbacks/__init__.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from .tensorboard_callback import TensorBoardCallback + from .gc_callback import GCCallback +from .progressive_batching_callback import ProgreesiveBatchingCallback from .logging_callback import LoggingCallback +from .stopper_callback import StopperCallback +from .adaptivegradclip_callback import ClipGradByAdaptiveNormCallback + from .moe_correction_bias_adjust_callback import MoECorrectionBiasAdjustCallback from .moe_logging_callback import GlobalRNGCallback, MoeLoggingCallback from .sp_grad_sync_callback import SPGradSyncCallback -from .tensorboard_callback import TensorBoardCallback from .fp8_quant_weight_callback import FP8QuantWeightCallback from .ortho_loss_callback import OrthogonalCallback @@ -31,4 +38,7 @@ "MoECorrectionBiasAdjustCallback", "FP8QuantWeightCallback", "OrthogonalCallback", + "ClipGradByAdaptiveNormCallback", + "StopperCallback", + "ProgreesiveBatchingCallback", ] diff --git a/examples/pre-training/ernie/src/callbacks/adaptivegradclip_callback.py b/examples/pre-training/ernie/src/callbacks/adaptivegradclip_callback.py new file mode 100644 index 000000000..f05e45000 --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks/adaptivegradclip_callback.py @@ -0,0 +1,122 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" ClipGradByAdaptiveNormCallback """ + +import os +import paddle +from paddleformers.trainer.trainer_callback import TrainerCallback +from paddleformers.trainer.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + get_last_checkpoint, +) +from src.utils import logger + + +class ClipGradByAdaptiveNormCallback(TrainerCallback): + """ + Load and save adaptive norm state hook, hack version + """ + + def on_train_begin(self, args, state, control, **kwargs): + """ + load adaptive norm state at the beginning of training. + """ + optimizer = kwargs.get("optimizer", None) + assert optimizer is not None + if optimizer._grad_clip is None: + logger.info("grad_clip is None.") + return + elif not hasattr(optimizer._grad_clip, "state_dict"): + logger.info("grad_clip {optimizer._grad_clip} has not state_dict method.") + return + + if args.adaptive_norm_force_clear_state: + logger.info("force clear ClipGradByAdaptiveNorm state dict.") + return + + resume_from_checkpoint = ( + None if not args.resume_from_checkpoint else args.resume_from_checkpoint + ) + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError( + f"No valid checkpoint found in output directory ({args.output_dir})" + ) + + if resume_from_checkpoint is None: + return + + # if use distributed training + if args.world_size > 1: + process_index = args.process_index + path = os.path.join( + resume_from_checkpoint, f"adaptivenorm_clip_state_{process_index}.pth" + ) + if not os.path.isfile(path): + logger.info( + f"Didn't find an adaptivenorm clip state file for process {process_index}, if you are resuming " + "a training that wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + path = os.path.join(resume_from_checkpoint, "adaptivenorm_clip_state.pth") + if not os.path.isfile(path): + logger.info( + "Didn't find an adaptivenorm clip state file, if you are resuming a training that was " + "launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + + logger.info(f"Loading adaptivenorm clip state state to {path}") + state_dict = paddle.load(path) + + optimizer._grad_clip.set_state_dict(state_dict) + logger.info("load ClipGradByAdaptiveNorm state dict success.") + + def on_save(self, args, state, control, **kwargs): + """ + Event called after a checkpoint save. + """ + optimizer = kwargs.get("optimizer", None) + assert optimizer is not None + + if optimizer._grad_clip is None or not hasattr( + optimizer._grad_clip, "state_dict" + ): + return + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" + + run_dir = args.output_dir + + output_dir = os.path.join(run_dir, checkpoint_folder) + + os.makedirs(output_dir, exist_ok=True) + + if args.world_size > 1: + # use global process_index to save + process_index = args.process_index + path = os.path.join( + output_dir, f"adaptivenorm_clip_state_{process_index}.pth" + ) + else: + path = os.path.join(output_dir, "adaptivenorm_clip_state.pth") + logger.info(f"Saving randompos rng state to {path}") + paddle.save(optimizer._grad_clip.state_dict(), path) diff --git a/examples/pre-training/ernie/src/callbacks/progressive_batching_callback.py b/examples/pre-training/ernie/src/callbacks/progressive_batching_callback.py new file mode 100644 index 000000000..79de8beba --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks/progressive_batching_callback.py @@ -0,0 +1,70 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +from paddleformers.trainer.trainer_callback import TrainerCallback + +logger = logging.getLogger(__name__) + + +def progressive_accumulate_steps( + acc_step_begin, acc_step_end, warmup_global_steps, increment, step +): + + assert step >= 0, step + if step >= warmup_global_steps: + return acc_step_end + slope = (acc_step_end - acc_step_begin) / warmup_global_steps + acc_steps = int(slope * step + acc_step_begin) + acc_steps = int(np.ceil(acc_steps / increment) * increment) + return acc_steps + + +class ProgreesiveBatchingCallback(TrainerCallback): + def __init__(self, acc_step_bigin, acc_step_end, warmup_global_steps, increment): + self.acc_step_bigin = acc_step_bigin + self.acc_step_end = acc_step_end + self.warmup_global_steps = warmup_global_steps + self.increment = increment + + def on_train_begin(self, args, state, control, **kwargs): + new_acc_step = progressive_accumulate_steps( + self.acc_step_bigin, + self.acc_step_end, + self.warmup_global_steps, + self.increment, + state.global_step, + ) + if new_acc_step != args.gradient_accumulation_steps: + logger.info( + f"updating acc_step{args.gradient_accumulation_steps}->{new_acc_step}, global_step={state.global_step}" + ) + args.gradient_accumulation_steps = new_acc_step + + def on_step_end(self, args, state, control, **kwargs): + new_acc_step = progressive_accumulate_steps( + self.acc_step_bigin, + self.acc_step_end, + self.warmup_global_steps, + self.increment, + state.global_step, + ) + if new_acc_step != args.gradient_accumulation_steps: + logger.info( + f"updating acc_step{args.gradient_accumulation_steps}->{new_acc_step}, global_step={state.global_step}" + ) + args.gradient_accumulation_steps = new_acc_step diff --git a/examples/pre-training/ernie/src/callbacks/stopper_callback.py b/examples/pre-training/ernie/src/callbacks/stopper_callback.py new file mode 100644 index 000000000..2b7763095 --- /dev/null +++ b/examples/pre-training/ernie/src/callbacks/stopper_callback.py @@ -0,0 +1,29 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import logging +from paddleformers.trainer.trainer_callback import TrainerCallback + +logger = logging.getLogger(__name__) + + +class StopperCallback(TrainerCallback): + + def on_substep_end(self, args, state, control, **kwargs): + if os.path.exists("/root/stop"): + control.should_training_stop = True diff --git a/examples/pre-training/ernie/src/clip/__init__.py b/examples/pre-training/ernie/src/clip/__init__.py index 6484ef448..215b51562 100644 --- a/examples/pre-training/ernie/src/clip/__init__.py +++ b/examples/pre-training/ernie/src/clip/__init__.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .clip import ClipGradByAdaptiveNorm from .moe_clip import ClipGradForMOEByGlobalNorm -__all__ = ['ClipGradForMOEByGlobalNorm'] +__all__ = [ + "ClipGradForMOEByGlobalNorm", + "ClipGradByAdaptiveNorm", +] diff --git a/examples/pre-training/ernie/src/clip/clip.py b/examples/pre-training/ernie/src/clip/clip.py new file mode 100644 index 000000000..d795061f9 --- /dev/null +++ b/examples/pre-training/ernie/src/clip/clip.py @@ -0,0 +1,316 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from collections import defaultdict +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet + +try: + from paddle.base import framework +except ImportError: + from paddle.fluid import framework +from paddle.nn.clip import ClipGradBase, _squared_l2_norm +from src.utils import logger + + +class ClipGradByAdaptiveNorm(ClipGradBase): + + def __init__( + self, + clip_ratio=1.03, + start_clip_steps=100, + beta=0.98, + epsilon=1e-8, + shard_clip=False, + enable_record=False, + enable_record_clip_history=False, + verbose=False, + ): + super().__init__() + self.clip_ratio = clip_ratio + self.beta = beta + self.epsilon = epsilon + self.state = defaultdict(dict) + self.start_clip_steps = start_clip_steps + self.shard_clip = shard_clip + self.enable_record = enable_record + self.steps = 0 + self.enable_record_clip_history = enable_record_clip_history + self.verbose = verbose + self.keys = [ + "clip_ratio", + "beta", + "epsilon", + "start_clip_steps", + "shard_clip", + "enable_record", + "steps", + "enable_record_clip_history", + ] + + if start_clip_steps < 0: + raise ValueError( + "start_clip_steps {}, please start_clip_steps >= 0.".format( + start_clip_steps + ) + ) + + def __str__(self): + return "ClipGradByAdaptiveNorm, clip_ratio={}, beta={}, start_clip_steps={}, \ + shard_clip={}, enable_record={}".format( + self.clip_ratio, + self.beta, + self.start_clip_steps, + self.shard_clip, + self.enable_record, + ) + + def clip_by_norm(self, param, grad, norm_value, global_norm): + + state = self.state[param.name] + + if "norm_value" not in state: + state["norm_value"] = norm_value + + if "clip_times" not in state: + state["clip_times"] = 0 + + if self.enable_record_clip_history: + if "clip_history" not in state: + state["clip_history"] = {} + + avg_norm_value = state["norm_value"] + + if self.enable_record: + if "norm_history" not in state: + state["norm_history"] = {} + state["norm_history"][self.steps] = [ + float(norm_value), + float(avg_norm_value), + ] + + if self.steps <= self.start_clip_steps: + clip_coeff = 1.0 / (global_norm + self.epsilon) + if clip_coeff < 1.0: + grad.multiply_(clip_coeff) + param._reset_grad_inplace_version(True) + + if norm_value < state["norm_value"]: + state["norm_value"] = norm_value + else: + if norm_value > self.clip_ratio * avg_norm_value: + # clip grad + coef = (self.clip_ratio * avg_norm_value) / (norm_value + self.epsilon) + grad.multiply_(coef) + param._reset_grad_inplace_version(True) + norm_value_old = norm_value + norm_value = self.clip_ratio * avg_norm_value + state["clip_times"] = state["clip_times"] + 1 + if self.enable_record_clip_history: + state["clip_history"][self.steps] = [ + float(norm_value_old), + float(norm_value), + ] + if self.verbose: + logger.info( + "{} gradclip {} times, clip from {} to {}".format( + param.name, + state["clip_times"], + float(norm_value_old), + float(norm_value), + ) + ) + + logger.info( + "{} steps {}, gradclip {} times, clip_ratio {}, clip from {} to {}".format( + param.name, + self.steps, + state["clip_times"], + self.clip_ratio, + float(norm_value_old), + float(norm_value), + ) + ) + state["norm_value"] = avg_norm_value * self.beta + norm_value * ( + 1.0 - self.beta + ) + + return grad + + @paddle.no_grad() + def _dygraph_clip(self, params_grads): + global_norm_tensor = None + if self.steps <= self.start_clip_steps: + hcg = fleet.get_hybrid_communicate_group() + mp_size = hcg.get_model_parallel_world_size() + mp_group = hcg.get_model_parallel_group() + pp_size = hcg.get_pipe_parallel_world_size() + pp_group = hcg.get_pipe_parallel_group() + sharding_size = hcg.get_sharding_parallel_world_size() + sharding_group = hcg.get_sharding_parallel_group() + + norm_squared_values = [] + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + norm_squared_value = _squared_l2_norm(g) + if not p.is_distributed and mp_size > 1: + norm_squared_value = norm_squared_value / mp_size + norm_squared_values.append(norm_squared_value) + + global_norm_squared_tensor = paddle.stack(norm_squared_values).sum() + + if mp_size > 1: + dist.all_reduce(global_norm_squared_tensor, group=mp_group) + if pp_size > 1: + dist.all_reduce(global_norm_squared_tensor, group=pp_group) + if sharding_size > 1: + dist.all_reduce(global_norm_squared_tensor, group=sharding_group) + global_norm_tensor = paddle.sqrt(global_norm_squared_tensor) + + if self.verbose and global_norm_tensor is not None: + logger.info( + "step: {}, global norm: {}".format( + self.steps, float(global_norm_tensor) + ) + ) + + if hasattr(self, "sharding_stage1_v2") and self.sharding_stage1_v2: + need_sync = False + if not self.shard_clip: + hcg = fleet.get_hybrid_communicate_group() + mp_size = hcg.get_model_parallel_world_size() + mp_group = hcg.get_model_parallel_group() + sharding_size = hcg.get_sharding_parallel_world_size() + sharding_group = hcg.get_sharding_parallel_group() + if mp_size > 1 or sharding_size > 1: + need_sync = True + + norm_squared_values = [ + paddle.zeros([1], dtype=params_grads[0][1].dtype) + for _ in range(self.num_params) + ] + + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + norm_squared_value = _squared_l2_norm(g) + if need_sync and not p.is_distributed: + norm_squared_values[self.pname_to_paramindex[p.name]] = ( + 1 / mp_size + ) * norm_squared_value + else: + norm_squared_values[self.pname_to_paramindex[p.name]] = ( + norm_squared_value + ) + + num_has_grad = len(norm_squared_values) + norm_squared_tensor = paddle.concat(norm_squared_values, axis=0) + if need_sync: + if mp_size > 1: + dist.all_reduce(norm_squared_tensor, group=mp_group) + if sharding_size > 1: + dist.all_reduce(norm_squared_tensor, group=sharding_group) + + norm_tensor = paddle.sqrt(norm_squared_tensor) + norm_values = paddle.split(norm_tensor, num_has_grad, axis=0) + + params_and_grads = [] + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + params_and_grads.append((p, g)) + continue + new_grad = self.clip_by_norm( + p, + g, + norm_values[self.pname_to_paramindex[p.name]], + global_norm_tensor, + ) + params_and_grads.append((p, new_grad)) + else: + need_sync = False + if not self.shard_clip: + hcg = fleet.get_hybrid_communicate_group() + mp_size = hcg.get_model_parallel_world_size() + mp_group = hcg.get_model_parallel_group() + if mp_size > 1: + need_sync = True + + norm_squared_values = [] + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + norm_squared_value = _squared_l2_norm(g) + if need_sync and not p.is_distributed: + norm_squared_values.append((1 / mp_size) * norm_squared_value) + else: + norm_squared_values.append(norm_squared_value) + + num_has_grad = len(norm_squared_values) + norm_squared_tensor = paddle.concat(norm_squared_values, axis=0) + if need_sync: + dist.all_reduce(norm_squared_tensor, group=mp_group) + + norm_tensor = paddle.sqrt(norm_squared_tensor) + norm_values = paddle.split(norm_tensor, num_has_grad, axis=0) + + params_and_grads = [] + idx = 0 + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + params_and_grads.append((p, g)) + continue + new_grad = self.clip_by_norm(p, g, norm_values[idx], global_norm_tensor) + params_and_grads.append((p, new_grad)) + idx += 1 + + self.steps += 1 + return params_and_grads + + @framework.dygraph_only + def state_dict(self): + + state_dict = {k: v for k, v in self.state.items()} + for key in self.keys: + state_dict[key] = self.__dict__[key] + return state_dict + + @framework.dygraph_only + def set_state_dict(self, state_dict): + + if len(state_dict) == 0 or state_dict is None: + logger.info("state_dict is empty, please check if it is right.") + + for key in self.keys: + if key in state_dict: + self.__dict__[key] = state_dict[key] + else: + logger.info("Can't find [ {} ] in state_dict".format(key)) + + for k in state_dict: + if k in self.keys: + continue + self.state[k] = copy.deepcopy(state_dict[k]) diff --git a/examples/pre-training/ernie/src/datasets/__init__.py b/examples/pre-training/ernie/src/datasets/__init__.py new file mode 100644 index 000000000..b9c4df26b --- /dev/null +++ b/examples/pre-training/ernie/src/datasets/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""pretraining task +""" + +from .dist_data_loader import DistDataLoader, DistDataLoaderAuto diff --git a/examples/pre-training/ernie/src/datasets/dist_data_loader.py b/examples/pre-training/ernie/src/datasets/dist_data_loader.py new file mode 100644 index 000000000..c9ca138cf --- /dev/null +++ b/examples/pre-training/ernie/src/datasets/dist_data_loader.py @@ -0,0 +1,300 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from itertools import groupby +from functools import reduce +from dataclasses import dataclass + +import paddle +import paddle.distributed as dist +from paddle.utils.layers_utils import flatten, map_structure, pack_sequence_as +from paddleformers.data import DistDataLoader + +from src.utils.misc import global_training_logs + + +class DistDataLoaderAuto(DistDataLoader): + def __init__( + self, + dataset, + batch_sampler=None, + collate_fn=None, + num_workers=0, + prefetch_factor=2, + ): + super().__init__( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) + self._pp_data_group = self._hcg.get_pipe_parallel_group() + + def __next__(self): + if self._need_data: + data = next(self._dataloader_iter) + if "data_not_valid" in data: + global_training_logs.update( + data_not_valid=data["data_not_valid"].astype("float32").mean() + ) + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = ( + data["input_ids"], + data["labels"], + data.get("data_type", None), + data.get("images", None), + data.get("token_type_ids", None), + data.get("image_type_ids", None), + data.get("audio_input_ids", None), + data.get("audio_labels", None), + data.get("grid_thw", None), + data.get("inbatch_pack_offset", None), + data.get("position_ids", None), + data.get("log_prob", None), + ) + assert {input_ids.dtype, labels.dtype} == {paddle.int64}, ( + f"Distloader requires dtype == `int64`, " + f"got:{[input_ids.dtype, labels.dtype]}" + ) + else: + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + # broadcast data + pp_broadcast = (self._pp_data_group is None) or self.pp_rank == 0 + # print(f'pp_broadcast:{pp_broadcast}') + if self.mp_group is not None and self.mp_group.nranks > 1 and pp_broadcast: + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = broadcast_data_obj( + [ + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ], + self.mp_src_rank, + self.mp_group, + ) + + if self._pp_data_group is not None and self._pp_data_group.nranks > 1: + ( + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ) = broadcast_data_obj( + [ + input_ids, + labels, + data_type, + images, + token_type_ids, + image_type_ids, + audio_input_ids, + audio_labels, + grid_thw, + inbatch_pack_offset, + position_ids, + log_prob, + ], + self._pp_data_group.ranks[0], + self._pp_data_group, + ) + to_return = OrderedDict( + [ + ("input_ids", input_ids), + ("labels", labels), + ("data_type", data_type), + ("images", images), + ("token_type_ids", token_type_ids), + ("image_type_ids", image_type_ids), + ("audio_input_ids", audio_input_ids), + ("audio_labels", audio_labels), + ("grid_thw", grid_thw), + ("inbatch_pack_offset", inbatch_pack_offset), + ("position_ids", position_ids), + ] + ) + optional_keys = [ + "data_type", + "images", + "token_type_ids", + "image_type_ids", + "audio_input_ids", + "audio_labels", + "grid_thw", + "inbatch_pack_offset", + "position_ids", + "log_prob", + ] + none_keys = [ + k for k, v in to_return.items() if v is None and k in optional_keys + ] + for k in none_keys: + to_return.pop(k) + + data_dict = to_return + + return OrderedDict( + [("input_ids", data_dict["input_ids"]), ("labels", data_dict["labels"])] + ) + + +@dataclass +class _DtypeSndShape: + dtype: paddle.dtype + shape: list + + def size(self): + return reduce(lambda x, y: x * y, self.shape) + + +def split_group(grouped, split_size): + ret = [] + while grouped: + if sum([r[1].size() for r in ret]) > split_size: + yield ret + ret = [] + ret.append(grouped.pop()) + if ret: + yield ret + + +def broadcast_data_obj(data, src_rank, group): + # print("lzx debug broadcast_data_obj") + this_rank = dist.get_rank() + if this_rank == src_rank: + template = [ + map_structure( + lambda x: ( + _DtypeSndShape(dtype=x.dtype, shape=x.shape) + if x is not None + else _DtypeSndShape(dtype="", shape=[0]) + ), + data, + ) + ] + else: + template = [None] + dist.broadcast_object_list(template, src_rank, group) + template = template[0] + + temp_flat = flatten(template) + data_flat = flatten(data) + + def keyfn(i): + return str(i[1].dtype) + + ret_flat = [-1 for _ in range(len(temp_flat))] + for dtype, grouped in groupby(sorted(enumerate(temp_flat), key=keyfn), keyfn): + grouped = list(grouped) + for grouped_chunk in split_group( + grouped, 2**18 + ): # 对 > 2**31 的 tensor 进行 spilt 会出 paddle 问题。 + idxs = [g[0] for g in grouped_chunk] + if not dtype: + for id in idxs: + ret_flat[id] = None + continue + + data_buf_shapes = [ + reduce(lambda x, y: x * y, g[1].shape) for g in grouped_chunk + ] + if this_rank == src_rank: + data_buf = paddle.concat([data_flat[i].reshape([-1]) for i in idxs], 0) + else: + data_buf = paddle.empty( + [sum(data_buf_shapes)], dtype=grouped_chunk[0][1].dtype + ) + dist.broadcast(data_buf, src_rank, group) + + if this_rank != src_rank: + if len(data_buf_shapes) == 1: + data_buf = [data_buf] + else: + data_buf = data_buf.split(data_buf_shapes, axis=0) + for g, data_chunk in zip(grouped_chunk, data_buf): + ret_flat[g[0]] = data_chunk.reshape(g[1].shape) + + if this_rank != src_rank: + assert not [r for r in ret_flat if r is -1], ret_flat + data = pack_sequence_as(template, ret_flat) + return data diff --git a/examples/pre-training/ernie/src/lr_schedulers/__init__.py b/examples/pre-training/ernie/src/lr_schedulers/__init__.py index 77159c8e1..71081f4b2 100644 --- a/examples/pre-training/ernie/src/lr_schedulers/__init__.py +++ b/examples/pre-training/ernie/src/lr_schedulers/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from .cosine_lr import get_cosine_schedule_with_warmup from .wsd_lr import get_wsd_schedule_with_warmup -__all__ = ['get_wsd_schedule_with_warmup'] +__all__ = ["get_wsd_schedule_with_warmup", "get_cosine_schedule_with_warmup"] diff --git a/examples/pre-training/ernie/src/lr_schedulers/cosine_lr.py b/examples/pre-training/ernie/src/lr_schedulers/cosine_lr.py new file mode 100644 index 000000000..6059c60a5 --- /dev/null +++ b/examples/pre-training/ernie/src/lr_schedulers/cosine_lr.py @@ -0,0 +1,62 @@ +# !/usr/bin/env python3 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Custom lr schedule +""" + +import math +from paddle.optimizer.lr import LambdaDecay + + +def get_cosine_schedule_with_warmup( + learning_rate: float, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + min_lr: float = 0.0, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + learning_rate (float) + The initial learning rate. It is a python float number. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + Return: + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + ratio = max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + ) + return ratio * (1 - min_lr / learning_rate) + min_lr / learning_rate + + return LambdaDecay(learning_rate, lr_lambda, last_epoch) diff --git a/examples/pre-training/ernie/src/trainers/__init__.py b/examples/pre-training/ernie/src/trainers/__init__.py index 254a42c39..477eeef40 100644 --- a/examples/pre-training/ernie/src/trainers/__init__.py +++ b/examples/pre-training/ernie/src/trainers/__init__.py @@ -17,9 +17,12 @@ PretrainingTrainer, WeightedDistributedSampler, ) +from .pretraining_trainer_auto import AutoPretrainingTrainer, AutoPreTrainingArguments __all__ = [ - 'PretrainingTrainer', - 'PreTrainingArguments', - 'WeightedDistributedSampler', + "PretrainingTrainer", + "PreTrainingArguments", + "WeightedDistributedSampler", + "AutoPretrainingTrainer", + "AutoPreTrainingArguments", ] diff --git a/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py new file mode 100644 index 000000000..e2cfe3330 --- /dev/null +++ b/examples/pre-training/ernie/src/trainers/pretraining_trainer_auto.py @@ -0,0 +1,1261 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AutoPretrainingTrainer""" + +__all__ = [ + "AutoPretrainingTrainer", +] + + +import sys +import re +import os +import json +import contextlib +from typing import Optional +from collections import OrderedDict +from dataclasses import dataclass, field +import time +import math +import logging + + +import paddle +import paddle.nn as nn +from paddle.io import DataLoader +import paddle.amp.auto_cast as autocast +from paddle.distributed.communication.group import _get_global_group + +from paddleformers.trainer import ( + speed_metrics, +) + +from paddleformers.trainer.auto_trainer import AutoTrainer + +try: + from paddleformers.utils.env import ( + PADDLE_OPTIMIZER_NAME, + ) +except ImportError: + from paddleformers.trainer.trainer import ( + OPTIMIZER_NAME, + ) + + PADDLE_OPTIMIZER_NAME = OPTIMIZER_NAME +from paddleformers.utils.batch_sampler import ( + DistributedBatchSampler as PaddleNLPDistributedBatchSampler, +) + +try: + from paddleformers.trainer.trainer import ( + PADDLE_WEIGHT_FILE_NAME as PADDLE_WEIGHTS_NAME, + ) +except ImportError: + from paddleformers.utils.env import PADDLE_WEIGHTS_NAME +from paddleformers.trainer.utils import add_start_docstrings +from paddleformers.trainer.trainer_callback import PrinterCallback +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.pipelining.schedules import get_pp_schedule +import paddle.distributed as dist +from typing import Any, Dict, Union + +from paddleformers.transformers.model_utils import _add_variant + +from src.lr_schedulers import get_cosine_schedule_with_warmup +from src.utils.training_utils import ( + reset_per_device_batch_size, +) +from src.callbacks import ( + TensorBoardCallback, + LoggingCallback, + StopperCallback, + ClipGradByAdaptiveNormCallback, +) +from src.datasets import ( + DistDataLoaderAuto, +) +from paddle.distributed import in_auto_parallel_align_mode +from src.clip import ClipGradByAdaptiveNorm, ClipGradForMOEByGlobalNorm + +try: + from paddleformers.trainer.trainer import ( + is_dp_group_support_in_group_sharded_parallel, + ) +except Exception: + + def is_dp_group_support_in_group_sharded_parallel(): + """ + hack for paddlenlp develop branch. + """ + return True + + +logger = logging.getLogger(__name__) + +try: + from paddleformers.trainer import AutoTrainingArguments +except ImportError: + from paddleformers.trainer import TrainingArguments as AutoTrainingArguments + + logger.warning("paddlenlp.trainer.AutoTrainingArguments CANNOT import!") + logger.warning("Use TrainingArguments as an alternative but will lose some args!") + + +DATATYPE_2_ID = {"mm": 0, "lm": 1, "audio": 2} + + +@dataclass +@add_start_docstrings(AutoTrainingArguments.__doc__) +class AutoPreTrainingArguments(AutoTrainingArguments): + + vocab_path: str = field( + default=None, metadata={"help": "eb35 streaming data vocab"} + ) + task_need_convert: str = field(default=None, metadata={"help": "glm task id"}) + multimodal: bool = field( + default=False, metadata={"help": "whether training with multimodal"} + ) + model_name_or_path: str = field( + default=None, + metadata={ + "help": "Path to pretrained model or model identifier from " + "https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + vision_model_name_or_path: str = field( + default=None, + metadata={ + "help": "Path to pretrained model or model identifier from " + "https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + inception_model_name_or_path: str = field( + default=None, + metadata={ + "help": "Path to pretrained model or model identifier from " + "https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + prefetch_factor: int = field( + default=2, + metadata={"help": "global random seed factor."}, + ) + eval_iters: int = field( + default=-1, + metadata={"help": "eval iteration for every evaluation."}, + ) + num_consecutive: int = field( + default=1, + metadata={ + "help": "H5文件连续采样。为了保证AFS性能,在读取AFS H5文件的时候需要尽量读取一片ID" + ",这个参数指定了一次连续读取的`样本`大小" + }, + ) + train_emb_only: int = field( + default=0, + metadata={"help": "是否只训练embedding,通常用于热启换词表"}, + ) + use_train_part_sharding: Optional[int] = field( + default=1, + metadata={"help": "根据file进行数据切片,只在预训练时候使用。否则会很慢"}, + ) + min_lr: float = field( + default=0.0, + metadata={"help": "minus learning rate"}, + ) + use_map_style_data: int = field( + default=0, + metadata={ + "help": "以为HF dataset为中心的 MapStyle SFT数据流(支持ShareGPT/DistillGPT)等数据", + }, + ) + use_streaming_data: int = field( + default=0, + metadata={ + "help": "标准线上明文数据流", + }, + ) + dataset: str = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + data_load_process_num: int = field( + default=10, + metadata={ + "help": "是否使用多进程加速原始数据读取,与DataLoader的num_workers意义不同" + }, + ) + + input_dir: str = field(default=None, metadata={"help": "data path"}) + split: str = field( + default="949,50,1", metadata={"help": "Train/valid/test data split ratio"} + ) + + data_dir: str = field(default=None, metadata={"help": "数据路径(指向一个目录)"}) + + data_filelist: str = field( + default=None, metadata={"help": "数据文件列表,与`args.data_dir`互斥"} + ) + data_weights: str = field(default=None, metadata={"help": "数据配比权重"}) + + dev_data: str = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + + max_seq_length: int = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + global_batch_size: int = field( + default=-1, + metadata={ + "help": "if `global_batch_size` and `per_device_train_batch_size` is provied, " + "`gradient_accumulation_steps` will be ignored" + }, + ) + init_global_batch_size: int = field( + default=-1, + metadata={ + "help": "开启动态Batching。必须提供`global_batch_size`, " + "global_batch_size 会在 `batch_size_warumup_steps` 步内从 " + "`init_global_batch_size` 提升到 `global_batch_size`, " + "每次 `batchsize` 的提升量为`batch_size_warmup_increment`" + }, + ) + batch_size_warmup_steps: int = field( + default=-1, + metadata={ + "help": "开启动态Batching。必须提供`global_batch_size`, " + "global_batch_size 会在 `batch_size_warumup_steps` 步内从 " + "`init_global_batch_size` 提升到 `global_batch_size`, " + "每次 `batchsize` 的提升量为`batch_size_warmup_increment`" + }, + ) + batch_size_warmup_increment: int = field( + default=1, + metadata={ + "help": "开启动态Batching。必须提供`global_batch_size`, " + "global_batch_size 会在 `batch_size_warumup_steps` 步内从 " + "`init_global_batch_size` 提升到 `global_batch_size`, " + "每次 `batchsize` 的提升量为`batch_size_warmup_increment`" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + init_ckpt: Optional[str] = field( + default=None, + metadata={}, + ) + sequence_parallel: Optional[int] = field( + default=0, + metadata={}, + ) + + config_file: Optional[str] = field( + default=None, + metadata={"help": "config file (YAML) to update hyper-parameters"}, + ) + virtual_pp_degree: Optional[int] = field( + default=1, + metadata={ + "help": "vpp", + }, + ) + from_scratch: Optional[int] = field(default=1, metadata={"help": "是否重头训练"}) + no_shuffle: Optional[int] = field(default=0, metadata={"help": "不要shuffle数据"}) + no_part_shuffle: Optional[int] = field( + default=0, metadata={"help": "不进行part内数据shuffle"} + ) + record_optimizer_stat: Optional[bool] = field( + default=False, metadata={"help": "是否记录优化器momentum信息"} + ) + skip_optimizer_badcases: Optional[bool] = field( + default=False, metadata={"help": "是否跳过optimizer badcase很多的step"} + ) + same_data: Optional[bool] = field( + default=False, + metadata={"help": "热启时,数据、配比、DP数是否完全一致, 支持续线"}, + ) + base_seq_length: Optional[int] = field( + default=4096, metadata={"help": "reeao最小seq_length"} + ) + shuffle_consecutive: Optional[bool] = field( + default=False, + metadata={ + "help": "是否对num_consecutive片段进行shuffle, same_data=True热启时,该值需与上一次保持一致" + }, + ) + global_shuffle_num_examples: Optional[int] = field( + default=0, + metadata={ + "help": "part间shuffle的num_example总数限制,默认不做限制, " + "这个值与最小配比的积 必须大于1, 改变该值时,需要设置same_data=False" + }, + ) + adaptive_norm_clip: Optional[bool] = field( + default=False, metadata={"help": "是否启用 AdaptiveNormClip 梯度裁剪策略"} + ) + adaptive_norm_clip_ratio: Optional[float] = field( + default=1.03, + metadata={"help": "AdaptiveNormClip 裁剪阈值, 大于设定的阈值才会启动裁剪"}, + ) + adaptive_norm_force_clear_state: Optional[bool] = field( + default=False, metadata={"help": "AdaptiveNormClip 强制清空 state dict"} + ) + adaptive_norm_shard_clip: Optional[bool] = field( + default=False, metadata={"help": "AdaptiveNormClip 在切分参数上是否在局部clip"} + ) + adaptive_norm_enable_record: Optional[bool] = field( + default=False, metadata={"help": "AdaptiveNormClip 是否启用统计历史norm值"} + ) + adaptive_norm_start_clip_steps: Optional[int] = field( + default=100, metadata={"help": "AdaptiveNormClip 开始裁剪的step"} + ) + adaptive_norm_enable_record_clip_history: Optional[bool] = field( + default=False, metadata={"help": "AdaptiveNormClip 是否启用统计历史裁剪的记录"} + ) + adaptive_norm_verbose: Optional[bool] = field( + default=False, metadata={"help": "AdaptiveNormClip 是否开启裁剪日志打印"} + ) + use_async_save: Optional[bool] = field( + default=False, metadata={"help": "是否开启异步保存功能"} + ) + pre_alloc_memory: float = field( + default=0.0, + metadata={ + "help": "Pre-allocate one specific-capacity empty tensor " + "and release it for avoiding memory fragmentation" + }, + ) + enable_global_training_logs: bool = field( + default=False, metadata={"help": "是否启用global_training_logs"} + ) + use_dummy_dataset: Optional[bool] = field( + default=False, metadata={"help": "是否使用DummyDataSet, 仅用于Debug"} + ) + reshard_save_then_exit: Optional[bool] = field( + default=False, metadata={"help": "是否在reshard后直接退出程序"} + ) + moe_group: Optional[str] = field( + default="dp", metadata={"help": "moe 的通信组,目前支持“dp|sharding|mp|dummy”"} + ) + use_moe: Optional[bool] = field( + default=False, metadata={"help": "expert parallel 临时替代"} + ) + moe_use_all2all: Optional[bool] = field( + default=False, metadata={"help": "是否使用all2all通信方式"} + ) + log_global_grad_norm: Optional[bool] = field( + default=False, + metadata={ + "help": "打印全局grad-norm, 只有在开启`enable_global_training_logs`时生效" + }, + ) + + multi_token_pred_depth: Optional[int] = field( + default=0, + metadata={}, + ) + + lr_scheduler: str = field( + default="cosine", + metadata={ + "help": "The scheduler type to use. suppor linear, cosine, constant, constant_with_warmup" + }, + ) + image_token_len: int = field( + default=64, + metadata={"help": "number of images tokens from resampler per image"}, + ) + freeze_config: str = field( + default="", + metadata={ + "help": ( + "Some additional config for freeze params, we provide some option to config it." + "following config is support: freeze_vision,freeze_lm" + ) + }, + ) + moe_gate_lr_ratio: float = field( + default=None, + metadata={"help": ("启用 moe 时,对 gate/router 的 LR 做特殊处理")}, + ) + vit_lr_ratio: float = field( + default=None, + metadata={"help": ("启用vit训练时,对 vit 的 LR 做特殊处理")}, + ) + modality_interleave: str = field(default="acc", metadata={"help": "acc"}) + modality_ratio: tuple = field( + default=None, + metadata={"help": "ratio of modality tokens to be masked out"}, + ) + bos_retry_max_time: int = field( + default=0, metadata={"help": "when bos download failed, #retry times"} + ) + bos_retry_interval: float = field( + default=1, metadata={"help": "when bos download failed, interval between retry"} + ) + + pipeline_schedule_mode: str = field( + default="1F1B", + metadata={"help": "The pipeline schedule mode, support 1F1B and VPP"}, + ) + virtual_pipeline_seg_method: str = field( + default="ErnieDecoderLayerAuto", + metadata={"help": "The seg method of spliting pp layer for virtual pipeline."}, + ) + pp_need_data_degree: int = field( + default=0, + metadata={ + "help": "pipline 并行中的机器也需要 fetch 数据,提升吞吐,搭配 `ErniemmMoEForCausalPipe` 使用" + }, + ) + pp_need_data: bool = field(default=False, metadata={"help": "向前兼容"}) + custom_data_status: str = field( + default=None, + metadata={"help": "load data status from custom trainer_state.json"}, + ) + model_type: Optional[str] = field( + default="ernie", + metadata={"help": "Only support for ernie pre-training for now."}, + ) + n_microbatches: int = field( + default=1, + metadata={"help": "Control the num of microbatches in one pp step."}, + ) + + @property + def need_data(self): + + if self.pp_need_data_degree: + assert self.pipeline_parallel_degree > 1 + assert ( + self.pp_need_data_degree >= 2 + and self.pp_need_data_degree <= self.pipeline_parallel_degree + ), ( + self.pp_need_data_degree, + self.pipeline_parallel_degree, + ) + no_need_data_range = list( + range(self.pp_need_data_degree - 1, self.pipeline_parallel_degree - 1) + ) + return self.tensor_parallel_rank == 0 and ( + self.pipeline_parallel_rank not in no_need_data_range + ) + return self.pipeline_parallel_rank == 0 and self.tensor_parallel_rank == 0 + + @property + def combine_batch(self): + return self.max_seq_length // self.base_seq_length + + @property + def reeao_dataset_rank(self): + if not self.pp_need_data_degree: + return super().dataset_rank + no_need_data_range = list( + range(self.pp_need_data_degree - 1, self.pipeline_parallel_degree - 1) + ) + ranks = [ + i + for i in range(self.pipeline_parallel_degree) + if i not in no_need_data_range + ] + if self.pipeline_parallel_rank not in ranks: + return None + reeao_pp_rank = ranks.index(self.pipeline_parallel_rank) + + assert not (self.sharding_parallel_degree > 1 and self.data_parallel_rank > 1) + return ( + max(self.pp_need_data_degree, 1) * self.sharding_parallel_rank + + reeao_pp_rank + ) + + @property + def reeao_dataset_world_size(self): + if not self.pp_need_data: + return super().dataset_world_size + return ( + max(self.sharding_parallel_degree, 1) + * max(self.data_parallel_degree, 1) + * max(self.pipeline_parallel_degree, 1) + ) + + def __post_init__(self): + super().__post_init__() + if in_auto_parallel_align_mode(): + self.adaptive_norm_clip = False + self.adaptive_norm_clip_ratio = 0.0 + self.no_shuffle = 1 + self.no_part_shuffle = 1 + + assert ( + self.global_batch_size + == self.per_device_train_batch_size + * self.gradient_accumulation_steps + * max(self.sharding_parallel_degree, 1) + * max(self.data_parallel_degree, 1) + ), ( + f"`gbs` should be equal to `lbs * acc * (dp_degree or sd_degree)`, " + f"but got gbs={self.global_batch_size}, " + f"lbs={self.per_device_train_batch_size}, " + f"acc={self.gradient_accumulation_steps}, " + f"dp_degree={max(self.data_parallel_degree, 1)}, " + f"sd_degree={max(self.sharding_parallel_degree, 1)}" + ) + + if self.global_batch_size > 0: + micro_bsz, acc_steps = reset_per_device_batch_size( + self.global_batch_size, + self.per_device_train_batch_size, + self.dataset_world_size, + ) + logger.info( + f"global_batch={self.global_batch_size} micro-bsz:{micro_bsz}, accumulate_steps:{acc_steps}" + ) + if ( + acc_steps != 1 + and self.gradient_accumulation_steps != 1 + and acc_steps != self.gradient_accumulation_steps + ): + raise ValueError( + f"global_accumulation_steps={self.gradient_accumulation_steps}" + f"& global_batch={self.global_batch_size} are both set" + ) + self.per_device_train_batch_size, self.gradient_accumulation_steps = ( + micro_bsz, + acc_steps, + ) + + if self.batch_size_warmup_steps > 0: + assert self.global_batch_size > 0, self.global_batch_size + assert self.init_global_batch_size > 0, self.init_global_batch_size + self.max_gradient_accumulation_steps = self.gradient_accumulation_steps + ( + self.per_device_train_batch_size, + self.gradient_accumulation_steps, + ) = reset_per_device_batch_size( + self.init_global_batch_size, + self.per_device_train_batch_size, + self.dataset_world_size, + ) + logger.info( + f"using progressive batching, accumulate step will increese from {self.gradient_accumulation_steps}" + f"to {self.max_gradient_accumulation_steps} in {self.batch_size_warmup_steps} steps" + ) + else: + self.max_gradient_accumulation_steps = ( + self.gradient_accumulation_steps + ) # hack add new + + if self.pipeline_parallel_degree > 1: + self.per_device_eval_batch_size = ( + self.per_device_train_batch_size * self.gradient_accumulation_steps + ) # hack Eval for PP! + logger.warn( + f"eval_batch_size set to {self.per_device_eval_batch_size} in Pipeline Parallel!" + ) + user_defined_strategy = fleet.fleet._user_defined_strategy + user_defined_strategy.strategy.pipeline_configs.accumulate_steps = ( + self.gradient_accumulation_steps + ) + if self.pp_need_data and not self.pp_need_data_degree: + self.pp_need_data_degree = self.pipeline_parallel_degree + if self.pp_need_data_degree: + assert ( + self.gradient_accumulation_steps % self.pp_need_data_degree == 0 + ), ( + f"gradient_accumulation_steps[{self.gradient_accumulation_steps}] should be divisible by " + f"pp_need_data_degree[{self.pp_need_data_degree}]" + ) + self.gradient_accumulation_steps = ( + self.gradient_accumulation_steps // self.pp_need_data_degree + ) + logger.info( + f"pp-need-data hack args.gradient_accumulation_steps to - {self.gradient_accumulation_steps}" + ) + self.max_gradient_accumulation_steps = ( + self.gradient_accumulation_steps + ) # hack add new + logger.info(f"fixing pp configs: {user_defined_strategy.pipeline_configs}") + else: + self.per_device_eval_batch_size = self.per_device_train_batch_size + logger.warn(f"eval_batch_size set to {self.per_device_eval_batch_size}") + + if self.sharding_parallel_degree > 1: + sharding_parallel_config = ( + set(self.sharding_parallel_config.split(" ")) + if self.sharding_parallel_config + else set() + ) + sharding_comm_overlap_non_pp = ( + True + if "shardingv1_comm_overlap" in sharding_parallel_config + or "sharding_comm_overlap" in sharding_parallel_config + else False + ) + if sharding_comm_overlap_non_pp: + assert hasattr(fleet.fleet, "_user_defined_strategy") + user_defined_strategy = fleet.fleet._user_defined_strategy + user_defined_strategy.hybrid_configs[ + "sharding_configs" + ].accumulate_steps = self.gradient_accumulation_steps + + if hasattr(fleet.fleet, "_user_defined_strategy"): + user_defined_strategy = fleet.fleet._user_defined_strategy + if ( + hasattr(user_defined_strategy, "hybrid_configs") + and "sharding_configs" in user_defined_strategy.hybrid_configs + ): + sd_configs = user_defined_strategy.hybrid_configs["sharding_configs"] + if sd_configs.comm_overlap: + assert self.global_batch_size % self.dataset_world_size == 0, ( + f"global_batch_size[{self.global_batch_size}] should be divisible by " + f"dataset_world_size[{self.dataset_world_size}]" + ) + lbs = self.global_batch_size // self.dataset_world_size + assert lbs % self.per_device_train_batch_size == 0, ( + f"local_batch_size[{lbs}] should be divisible by " + f"per_device_train_batch_size[{self.per_device_train_batch_size}]" + ) + assert ( + lbs // self.per_device_train_batch_size + == sd_configs.accumulate_steps + ), ( + f"local_batch_size[{lbs}] should be equal to " + f"accumulate_steps[{sd_configs.accumulate_steps}] * " + f"per_device_train_batch_size[{self.per_device_train_batch_size}]" + ) + if self.vision_model_name_or_path is not None: + self.multimodal = True + + +class AutoPretrainingTrainer(AutoTrainer): + + def __init__(self, _shit=None, args=None, model=None, callbacks=[], **kwargs): + assert _shit is None, "use key-ward argument" + callbacks = [ + LoggingCallback(), + StopperCallback(), + TensorBoardCallback( + args, model=model, log_tokens_per_step=True, log_flops_per_step=False + ), + ] + callbacks + + if args.adaptive_norm_clip: + callbacks.append( + ClipGradByAdaptiveNormCallback(), + ) + args.use_async_save = ( + args.use_async_save and args.save_sharded_model and args.load_sharded_model + ) + super().__init__(args=args, model=model, callbacks=callbacks, **kwargs) + + def get_numel_item(p): + item = p.numel().item() + return item if item else 0 + + model_numel = sum( + get_numel_item(p) + for n, p in model.named_parameters() + if not p.stop_gradient and "embeddings" not in n and "embed_tokens" not in n + ) + numel_tensor = paddle.to_tensor(model_numel) + dist.all_reduce(numel_tensor) + self.model_numel = numel_tensor.item() // self.args.dataset_world_size + + self.pop_callback(PrinterCallback) + self.pp_data_buffer = [] # pp + self._tokens_per_sec_per_card_buffer = [] + self._start_save_time = time.time() + self._end_save_time = time.time() + self._first_end_save_time = time.time() + self.resume_global_step = -1 + self.first_skip_step = ( + 5 if self.args.save_steps > 5 else self.args.save_steps / 2 + ) + if args.same_data: + logger.warning( + "You have set same_data=True. \ + Carefully check whether the data, population proportion, " + "and DP count are completely consistent with those before." + ) + else: + logger.warning( + "You have set same_data=False. \ + which will regenerate the global shuffle domain." + ) + # self.return_value = paddle.zeros([]) #fake return value + if self.args.pipeline_parallel_degree > 1: + if self.criterion is None: + self.criterion = self.model.criterion + self.pp_schedule = get_pp_schedule( + model, + self.args.gradient_accumulation_steps, + self.criterion, + self.args.pipeline_schedule_mode, + self.args.pipeline_parallel_degree, + self.comm_group_in_pp, + ) + self.args.per_device_train_batch_size = ( + self.args.per_device_train_batch_size + * self.args.gradient_accumulation_steps + ) + self.args.gradient_accumulation_steps = 1 + + def compute_pipeline_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + Subclass and override for custom behavior. + """ + if self.criterion is not None: + if "labels" in inputs: + labels = inputs.pop("labels") + + elif "start_positions" in inputs and "end_positions" in inputs: + labels = (inputs.pop("start_positions"), inputs.pop("end_positions")) + elif self.args.label_names is not None: + labels = [] + for label in self.label_names: + labels.append(inputs.pop(label)) + labels = tuple(labels) + elif "generator_labels" in inputs: + labels = inputs["generator_labels"] + else: + labels = None + + pp_rank = self.comm_group_in_pp.rank + losses = [] + if pp_rank == 0: + self.pp_schedule.step(**inputs) + elif pp_rank == self.args.pipeline_parallel_degree - 1: + self.pp_schedule.step(target=labels, losses=losses) + else: + self.pp_schedule.step() + + final_loss = None + if len(losses) != 0: + losses = [loss[0] for loss in losses] + final_loss = paddle.stack(losses).mean() + + return final_loss + + def dynamic_auto_parallel_pipeline_training( + self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]] + ) -> paddle.Tensor: + assert ( + self.args.pipeline_parallel_degree > 1 + ), "pipeline_parallel_degree must be greater than 1." + with self.autocast_smart_context_manager(): + loss = self.compute_pipeline_loss(model, inputs) + + return loss + + def dynamic_training( + self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]] + ) -> paddle.Tensor: + if self.args.pipeline_parallel_degree > 1: + return self.dynamic_auto_parallel_pipeline_training(model, inputs) + else: + return super().dynamic_training(model, inputs) + + def autocast_smart_context_manager(self): + + if self.enable_autocast_context_manager: + black = [ + "reduce_sum", + "c_softmax_with_cross_entropy", + "elementwise_div", + "sin", + "cos", + ] + white = [ + "lookup_table", + "lookup_table_v2", + "flash_attn", + "flash_attn_v1", + "matmul", + "matmul_v2", + "fused_gemm_epilogue", + ] + if self.args.bf16 and self.args.fp16_opt_level == "O2": + black.append("c_embedding") + + ctx_manager = autocast( + True, + custom_black_list=black, + custom_white_list=white, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + ) + else: + ctx_manager = ( + contextlib.nullcontext() + if sys.version_info >= (3, 7) + else contextlib.suppress() + ) + + return ctx_manager + + def _load_optimizer_state(self, checkpoint): + # def _load_moe_optimizer_state(checkpoint): + # opt_moe_suffix = re.sub(r"moe\d\d", "moe00", self.args.optimizer_name_suffix) + # return self._load_optimizer_state_of_one_shard(checkpoint, opt_moe_suffix) + + def _broadcast_moe_optimizer_state(state_dict): + # boardcast_keys + base_state_dict = {"master_weights": {}} + buf = [ + { + i: j.shape + for i, j in state_dict.items() + if i not in ["master_weights", "LR_Scheduler"] + }, + {i: j.shape for i, j in state_dict["master_weights"].items()}, + {"LR_Scheduler": state_dict.get("LR_Scheduler", {})}, + ] + + if self.args.use_hybrid_parallel: + hcg = fleet.get_hybrid_communicate_group() + src_rank = hcg.get_data_parallel_group_src_rank() + group = hcg.get_data_parallel_group() + else: + src_rank = 0 + group = None + + dist.broadcast_object_list(buf, src=src_rank, group=group) + for k, s in buf[0].items(): + v = state_dict.get(k, paddle.zeros(s, "float32")).cuda() + v.name = k + dist.broadcast(v, src=src_rank, group=group) + logger.info(f"broadcast moe optimizer {k} from {src_rank}") + base_state_dict[k] = v.cpu() + for k, s in buf[1].items(): + v = ( + state_dict["master_weights"] + .get(k, paddle.zeros(s, "float32")) + .cuda() + ) + v.name = k + dist.broadcast(v, src=src_rank, group=group) + logger.info( + f"broadcast moe optimizer-master_weights {k} from {src_rank}" + ) + base_state_dict["master_weights"][k] = v.cpu() + base_state_dict.update(buf[2]) + return base_state_dict + + state_dict = super()._load_optimizer_state(checkpoint) + + if self.args.use_moe: + base_state_dict = _broadcast_moe_optimizer_state(state_dict) + if self.args.data_parallel_rank > 0: + master_weight = state_dict.pop("master_weights", {}) + base_state_dict.update(state_dict) + if master_weight: + if "master_weights" in base_state_dict: + base_state_dict["master_weights"].update(master_weight) + else: + base_state_dict["master_weights"] = master_weight + state_dict = base_state_dict + del base_state_dict + return state_dict + + def _save_moe_weights(self, output_dir): + + optimizer_name = _add_variant( + PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix + ) + saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") + + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model.state_dict() + optimzier_state_dict = self.optimizer.state_dict() + + filtered_state_dict = OrderedDict() + filter_optimzier_state_dict = OrderedDict() + + param_names_in_master_weights = ( + list(optimzier_state_dict["master_weights"].keys()) + if self.args.bf16 + else [] + ) + filter_optimzier_state_dict["master_weights"] = OrderedDict() + + for k, v in state_dict.items(): + if getattr(v, "no_sync", False): + + if v.name in param_names_in_master_weights: + filter_optimzier_state_dict["master_weights"][v.name] = ( + optimzier_state_dict["master_weights"][v.name] + ) + if not ( + getattr(self.args, "should_save_sharding_stage1_model", False) + or getattr(self.args, "save_sharding_stage1_model", False) + ): + filtered_state_dict[k] = v + for op_k, op_v in optimzier_state_dict.items(): + if op_k.startswith(v.name): + filter_optimzier_state_dict[op_k] = op_v + + if getattr(self.args, "should_save_sharding_stage1_model", False) or getattr( + self.args, "save_sharding_stage1_model", False + ): + self._save(output_dir=output_dir) + else: + if self.args.sharding_parallel_rank == 0: + paddle.save( + filtered_state_dict, + os.path.join( + output_dir, + _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix), + ), + ) + paddle.save( + filter_optimzier_state_dict, os.path.join(output_dir, optimizer_name) + ) + with open(saved_signal_path, mode="w+") as f: + f.write("1") + + def evaluate( + self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval" + ): + + self.model_wrapped.accumulate_steps = self.args.gradient_accumulation_steps + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + start_time = time.time() + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate( + self.args, self.state, self.control, output.metrics + ) + return output.metrics + + def prediction_pipeline_step( + self, model, inputs, prediction_loss_only, ignore_keys + ): + + loss, _, labels = super().prediction_pipeline_step( + model, inputs, prediction_loss_only, ignore_keys + ) + num_tokens = (labels != self.tokenizer.ignored_index).sum().item() + loss_avg = loss * self.model_wrapped.accumulate_steps / num_tokens + return loss_avg, loss, labels + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return PaddleNLPDistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def get_train_dataloader(self): + + if self.args.need_data and self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + _DataLoader = DistDataLoaderAuto + + train_dataset = self.train_dataset + if self._is_iterable_dataset(train_dataset): + return DataLoader( + train_dataset, + batch_size=None, # we do data collation in Stream + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + use_shared_memory=True, + prefetch_factor=self.args.prefetch_factor, + ) + if self.args.need_data: + train_sampler = self._get_train_sampler() + else: + train_sampler = None + return _DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + prefetch_factor=self.args.prefetch_factor, + ) + + def _broadcast_final_loss(self, tr_loss): + tr_loss = tr_loss._local_value() if tr_loss.is_dist() else tr_loss + + if self.args.pipeline_parallel_degree > 1: + hcg = fleet.get_hybrid_communicate_group() + num_stages = hcg.get_pipe_parallel_world_size() + + paddle.distributed.broadcast( + tr_loss, + src=hcg.get_rank_from_stage(num_stages - 1), + sync_op=True, + group=hcg.get_pipe_parallel_group(), + ) + return tr_loss + + def _maybe_log_save_evaluate( + self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs + ): + super()._maybe_log_save_evaluate( + tr_loss, model, epoch, ignore_keys_for_eval, **kwargs + ) + return + + def create_scheduler(self, num_training_steps): + + if self.args.warmup_steps > 0: + warmup = self.args.warmup_steps + else: + warmup = int(self.args.warmup_ratio * num_training_steps) + self.lr_scheduler = get_cosine_schedule_with_warmup( + self.args.learning_rate, + warmup, + self.args.max_steps, + min_lr=self.args.min_lr if self.args.min_lr else 0.0, + ) + + return self.lr_scheduler + + def create_optimizer(self, lr_scheduler=None): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + optimizer_params = ( + [p for n, p in self.model.named_parameters() if "embeddings" in n] + if self.args.train_emb_only + else self.model.parameters() + ) + if self.args.train_emb_only: + logger.info( + f"using `train-emb-only`, #embedding params={len(optimizer_params)}" + ) + if self.optimizer is None: + + def need_decay(name): + if ( + name == "ernie.norm.weight" + and self.args.pipeline_parallel_degree > 1 + ): + return True + return not any(nd in name for nd in ["bias", "norm"]) + + decay_parameters = [ + p.name for n, p in self.model.named_parameters() if need_decay(n) + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + optimizer_cls, optimizer_kwargs = AutoTrainer.get_optimizer_cls_and_kwargs( + self.args + ) + + if self.args.adaptive_norm_clip: + if "split_param" in self.args.sharding_parallel_config: + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizerV2, + ) + + v2_assign_slice_grad = DygraphShardingOptimizerV2._assign_slice_grad + + def _assign_slice_grad(self): + v2_assign_slice_grad(self) + assert isinstance( + self._grad_clip, ClipGradByAdaptiveNorm + ), "self._grad_clip must be ClipGradByAdaptiveNorm" + if not hasattr(self._grad_clip, "pname_to_paramindex"): + pname_to_paramindex = {} + assert not isinstance(self._parameter_list[0], dict) + for idx, param in enumerate(self._parameter_list): + param = self._slice_params[param.name] + if param._is_initialized(): + pname_to_paramindex[param.name] = idx + self._grad_clip.pname_to_paramindex = pname_to_paramindex + self._grad_clip.num_params = len(self._parameter_list) + self._grad_clip.sharding_stage1_v2 = True + + DygraphShardingOptimizerV2._assign_slice_grad = _assign_slice_grad + logger.info( + "Hack DygraphShardingOptimizerV2._assign_slice_grad for ClipGradByAdaptiveNorm" + ) + + grad_clip = ClipGradByAdaptiveNorm( + clip_ratio=self.args.adaptive_norm_clip_ratio, + start_clip_steps=self.args.adaptive_norm_start_clip_steps, + shard_clip=self.args.adaptive_norm_shard_clip, + enable_record=self.args.adaptive_norm_enable_record, + enable_record_clip_history=self.args.adaptive_norm_enable_record_clip_history, + verbose=self.args.adaptive_norm_verbose, + ) + logger.info("using ClipGradByAdaptiveNorm") + elif ( + self.args.use_moe + and not self.args.use_hybrid_parallel + and not self.args.enable_auto_parallel + ): + logger.info("using moe Global clip") + + def expert_fn(p): + return getattr(p, "no_sync", False) + + grad_clip = ClipGradForMOEByGlobalNorm( + self.args.max_grad_norm, + is_expert_param_func=expert_fn, + moe_group=_get_global_group(), + local_clip=False, + ) + else: + grad_clip = ( + nn.ClipGradByGlobalNorm(self.args.max_grad_norm) + if self.args.max_grad_norm > 0 + else None + ) + + self.static_name_to_dyg_name = { + p.name: n for n, p in self.model.state_dict().items() + } + gate_pattern = re.compile(r"ernie\.layers\.0\.mlp\.gate\.weight") + vit_pattern = re.compile( + r"vision_model\.(cls_token|pos_embed|patch_embed|blocks)" + ) + vit_blocks_pattern = re.compile(r"vision_model\.blocks\.(\d+)\.") + + def lr_ratio_fn(param): + if param.name in self.static_name_to_dyg_name.keys(): + name = self.static_name_to_dyg_name[param.name] + if self.args.moe_gate_lr_ratio is not None and gate_pattern.match( + name + ): + logger.info( + f"apply moe_gate_lr_ratio to {name}, ratio={self.args.moe_gate_lr_ratio}" + ) + return float(self.args.moe_gate_lr_ratio) + elif self.args.vit_lr_ratio is not None and vit_pattern.match(name): + n_layers = self.model.config.vision_config.layers + if vit_blocks_pattern.match(name): + layer_id = int(vit_blocks_pattern.match(name).group(1)) + else: + layer_id = 0 + lr_ratio = self.args.vit_lr_ratio ** (n_layers - 1 - layer_id) + logger.info(f"apply vit lr_ratio to {name}, ratio={lr_ratio}") + return float(lr_ratio) + return 1.0 + + self.optimizer = optimizer_cls( + learning_rate=( + self.lr_scheduler if lr_scheduler is None else lr_scheduler + ), + apply_decay_param_fun=apply_decay_param_fun, + parameters=optimizer_params, + weight_decay=self.args.weight_decay, + grad_clip=grad_clip, + multi_precision=True, + lr_ratio=( + lr_ratio_fn + if ( + self.args.moe_gate_lr_ratio is not None + or self.args.vit_lr_ratio is not None + ) + else None + ), + **optimizer_kwargs, + ) + + self.static_name_to_dyg_name = { + p.name: n for n, p in self.model.named_parameters() + } + + return self.optimizer + + def save_model(self, output_dir=None): + + super().save_model(output_dir) + if self.args.should_save: + with open( + os.path.join(output_dir, "static_name_to_dyg_name.json"), "w" + ) as of: + of.write(json.dumps(self.static_name_to_dyg_name)) + + def _load_rng_state(self, checkpoint): + pass + + def _get_meshes_for_loader(self): + def _get_mesh(pp_idx=0): + return self.global_mesh.get_mesh_with_dim("pp")[pp_idx] + + meshes = [] + if self.args.pipeline_parallel_degree > 1: + # input_ids + meshes.append(_get_mesh(0)) + # labels + meshes.append(_get_mesh(self.args.pipeline_parallel_degree - 1)) + else: + meshes.append(_get_mesh(0)) + return meshes + + def _wrap_for_dist_loader(self, train_dataloader, dense_tensor_idx=None): + self.dense_tensor_idx = dense_tensor_idx + dist_loader = dist.shard_dataloader( + dataloader=train_dataloader, + meshes=self._get_meshes_for_loader(), + shard_dims="dp", + dense_tensor_idx=dense_tensor_idx, + is_dataset_splitted=True, + ) + return dist_loader diff --git a/examples/pre-training/ernie/src/utils/__init__.py b/examples/pre-training/ernie/src/utils/__init__.py index edcdc5290..30361d4cd 100644 --- a/examples/pre-training/ernie/src/utils/__init__.py +++ b/examples/pre-training/ernie/src/utils/__init__.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .logging import logger, setup_logger_output_file +# from .logging import logger, setup_logger_output_file -__all__ = ['logger', 'setup_logger_output_file'] +# __all__ = ['logger', 'setup_logger_output_file'] + +from .logging import * # noqa +from .seed_utils import * # noqa +from .training_utils import * # noqa diff --git a/examples/pre-training/model_configs_auto/model_config.json b/examples/pre-training/model_configs_auto/model_config.json new file mode 100644 index 000000000..a9b844cea --- /dev/null +++ b/examples/pre-training/model_configs_auto/model_config.json @@ -0,0 +1,66 @@ +{ + "architectures": [ + "ErnieForCausalLM" + ], + "bos_token_id": 0, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 8192, + "intermediate_size": 28672, + "initializer_range": 0.00482174, + "max_sequence_length": 4096, + "max_position_embeddings": 4096, + "model_type": "ernie", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 4, + "pad_token_id": -1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.27.0.dev0", + "use_cache": true, + "vocab_size": 100352, + "rope_theta": 10000, + "use_recompute": false, + "use_recompute_attn": false, + "use_recompute_moe": false, + "use_recompute_loss_fn": false, + "use_rmsnorm": true, + "fuse_rms_norm": true, + "use_bias": false, + "use_fast_ln": true, + "fuse_attn_ffn": true, + "fuse_linear": true, + "rope_reorder": false, + "fuse_rope": true, + "fuse_swiglu": true, + "fuse_gate_detach_matmul": true, + "remove_tail_layer": 0, + "refined_recompute": { + "mlp_row_ln": -1, + "flash_attn": -1, + "attention_row_ln": -1, + "attention_column_ln": 2, + "mlp_column_ln": 0 + }, + "moe_num_experts": 16, + "moe_num_shared_experts": 0, + "moe_layer_start_index": 2, + "moe_group_experts": false, + "moe_intermediate_size": 3584, + "moe_capacity": [8,8,8], + "moe_gate": "top2_fused", + "moe_gate_scale": false, + "moe_gate_detach": 1.0, + "moe_k": 8, + "moe_aux_loss_lambda": 1e-5, + "moe_group_orthogonal_loss": true, + "moe_orthogonal_loss_lambda": 0.0, + "moe_z_loss_lambda": 0.0, + "moe_layer_interval": 1, + "z_loss_lambda": 0, + "using_precision_check": false, + "use_ep_comm_overlap": true, + "moe_use_all2all": true, + "tie_word_embeddings": true +} diff --git a/examples/pre-training/models/aadiff_decorator.py b/examples/pre-training/models/aadiff_decorator.py new file mode 100644 index 000000000..64b7aa632 --- /dev/null +++ b/examples/pre-training/models/aadiff_decorator.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +The AADiff decorator. +""" +import os +import paddle +import decorator + + +def get_md5(tensors): + """ + Get MD5 of tensor, list of tensors or the combination of them. + """ + if tensors is None: + return None + elif isinstance(tensors, paddle.Tensor): + return tensors._md5sum() + elif isinstance(tensors, (list, tuple)): + return [get_md5(t) for t in tensors] + else: + raise ValueError(tensors) + + +def check_aadiff(ntimes=None): + """ + The AADiff decorator. + """ + if ntimes is None: + ntimes = int(os.getenv("AADIFF_TIMES", "0")) + + @decorator.decorator + def __impl__(_func, *args, **kwargs): + if ntimes > 0: + with paddle.no_grad(): + old_md5 = None + for idx in range(ntimes): + ret = _func(*args, **kwargs) + print("AADiff Pass {}/{} ...".format(idx, ntimes)) + cur_md5 = get_md5(ret) + del ret + if old_md5 is None: + old_md5 = cur_md5 + else: + assert old_md5 == cur_md5, "Rank {} has aadiff".format( + paddle.distributed.get_rank() + ) + + return _func(*args, **kwargs) + + return __impl__ diff --git a/examples/pre-training/models/ernie/__init__.py b/examples/pre-training/models/ernie/__init__.py index b00b05790..382402d6b 100644 --- a/examples/pre-training/models/ernie/__init__.py +++ b/examples/pre-training/models/ernie/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configuration import ErnieMoEConfig -from .modeling_pp import ErnieMoEForCausalLMPipe -__all__ = ['ErnieMoEConfig', 'ErnieMoEForCausalLMPipe'] +from .configuration import * # noqa +from .modeling import * # noqa +from .modeling_auto import * # noqa diff --git a/examples/pre-training/models/ernie/modeling_auto.py b/examples/pre-training/models/ernie/modeling_auto.py new file mode 100644 index 000000000..a2ef0a440 --- /dev/null +++ b/examples/pre-training/models/ernie/modeling_auto.py @@ -0,0 +1,2978 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Ernie model""" +import math +from functools import partial +import logging +from typing import Optional, Tuple +import contextlib +import inspect +from paddle.distributed.auto_parallel.pipelining.schedules import ( + parse_args, + return_args, + get_pp_stage_id, +) + +try: + from fast_ln import fast_ln +except ImportError: + fast_ln = None + +from copy import deepcopy +from dataclasses import dataclass +import numpy as np +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import recompute +from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker +from paddle.incubate.nn.memory_efficient_attention import ( + memory_efficient_attention, + BlockDiagonalCausalMask, +) +from paddle.distributed import in_auto_parallel_align_mode + +from models.comm_utils import subbatch + +from models.moe.top2_gate_auto_auto import Top2Gate +from models.moe.top2_gate_auto import TopKGateFusedAuto + + +# from src/ops which is install in build_envs + +from paddleformers.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) + +from paddleformers.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions as _BaseModelOutput, +) +from paddleformers.transformers.model_outputs import CausalLMOutputWithCrossAttentions + +from paddleformers.transformers.model_utils import PretrainedModel, register_base_model + +from models.ernie.modeling import FusedDropoutImpl +from models.sequence_parallel_utils_auto import ( + sequence_parallel_sparse_mask_labels, +) +from models.moe.moe_layer_auto import ( + MOELayerAuto, +) +from .configuration import ErnieMoEConfig +from models.moe.moe_utils_auto import get_mesh + +# Because param_name is generated based on the class name, +# when changes in distributed strategies result in class modifications, +# there may be mismatches during parameter loading. +# You can achieve class name changes by importing the following environment variables. +# Example: `export rowcol_parallel_linear_class_name_convert_map="tpsp->smp"` + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(_BaseModelOutput): + + router_loss: Optional[paddle.Tensor] = None + gate_logits: Optional[Tuple[paddle.Tensor]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentionsAuto(CausalLMOutputWithCrossAttentions): + + router_loss: Optional[paddle.Tensor] = None + + +logger = logging.getLogger(__name__) + +try: + from paddle.nn.functional.flash_attention import flash_attention + + logger.warning( + "Use flash attention in scaled-dot-product. Attention mask is deprecated" + ) +except (ImportError, ModuleNotFoundError): + flash_attention = None + +try: + from paddle.nn.functional.flash_attention import flash_attention_with_mask +except (ImportError, ModuleNotFoundError): + try: + from paddle.nn.functional.flash_attention import ( + scaled_dot_product_attention as flash_attention_with_mask, + ) + except (ImportError, ModuleNotFoundError): + logger.warning( + "flash_attention_with_mask not found. Use FleetY8.2 SFT instead." + ) + flash_attention_with_mask = None + +try: + from paddle.nn.functional.flash_attention import flash_attention_with_sparse_mask +except (ImportError, ModuleNotFoundError): + logger.warning("flash_attention_with_sparse_mask not found. Use FleetY8.9 instead.") + flash_attention_with_sparse_mask = None + +try: + from to_block_diag_causal_mask import to_block_diag_causal_mask +except (ImportError, ModuleNotFoundError): + logger.warning("to_block_diag_causal_mask not found. Use FleetY8.2 SFT instead.") + to_block_diag_causal_mask = None + +try: + import fused_ln as fused +except ImportError: + logger.warning( + "fused-ln not found, run `python src/ops/fused_ln_setup.py install` to build fused ln" + ) + fused = None + +try: + from paddle.incubate.nn.functional import ( + fused_rotary_position_embedding as fused_rope, + ) +except (ImportError, ModuleNotFoundError): + logger.warning("fused_rotary_position_embedding not found") + fused_rope = None + +try: + from paddle.incubate.nn.functional import swiglu as fused_swiglu +except (ImportError, ModuleNotFoundError): + fused_swiglu = None + + +ERNIE_PRETRAINED_MODEL_ARCHIVE_LIST = [] + +__all__ = [ + "ErnieModelAuto", + "ErniePretrainedModelAuto", + "ErnieForCausalLMAuto", +] + + +gate_class = dict( + top2=Top2Gate, + top2_fused=TopKGateFusedAuto, +) + + +def is_pp_enable(): + + mesh = fleet.auto.get_mesh() + return "pp" in mesh.dim_names + + +def global_mesh_starts_with_pp(): + + mesh = fleet.auto.get_mesh() + if is_pp_enable(): + return mesh.get_mesh_with_dim("pp") + else: + return mesh + + +def is_fleety_func(): + """ + Check whether it is PaddlePaddle FleetY version. + """ + if flash_attention_with_sparse_mask is None: + return True + + args = inspect.getfullargspec(flash_attention_with_sparse_mask).args + return "causal" in args + + +IS_FLEETY = is_fleety_func() + + +def get_triangle_upper_mask(x, mask=None): + + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, -np.inf, dtype=x.dtype) + mask.stop_gradient = True + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def naive_fuse_split_tp( + weight, + tensor_parallel_degree, + tensor_parallel_rank=None, + is_column=True, + fuse_tensor_parts=2, +): + + logging.info(f"spliting fused-ffn: {weight.shape}") + axis = -1 if is_column else 0 + splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis) + return np.concatenate( + splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis + ) + + +def parallel_matmul( + x, + y, + bias=None, + transpose_y=False, + tensor_parallel_degree=1, + tensor_parallel_output=True, +): + + if transpose_y: + logits = paddle.matmul(x, y, transpose_y=True) + if bias is not None: + logits += bias + else: + logits = F.linear(x, y, bias) + + if tensor_parallel_degree > 1 and not tensor_parallel_output: + logits = dist.reshard(logits, get_mesh(-1), [dist.Shard(0), dist.Replicate()]) + + return logits + + +def calc_lm_head_logits( + config, + hidden_states, + weight, + bias, + sparse_label_idx=None, + tensor_parallel_output=None, +): + """the core function to calc lm head""" + if config.sequence_parallel: + + assert ( + not config.use_sparse_head_and_loss_fn + ), "use_sparse_head_and_loss_fn is not supported now." + + # do all gather + hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + dp_rank = hcg.get_data_parallel_rank() + sharding_rank = hcg.get_sharding_parallel_rank() + if dp_rank <= 1 and sharding_rank <= 1: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Replicate(), dist.Replicate()], + ) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Shard(1), dist.Replicate()], + ) + # [S, B, H] to [B, S, H] + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + if not config.using_dynamic_sequence_length: + hidden_states = hidden_states.reshape( + [-1, config.seqlen, hidden_states.shape[-1]] + ) + else: + assert ( + config.micro_batch_size + ), "micro_batch_size should be set when using dygramic sequence length." + hidden_states = hidden_states.reshape( + [config.micro_batch_size, -1, hidden_states.shape[-1]] + ) + if tensor_parallel_output is None: + tensor_parallel_output = config.tensor_parallel_output + logits = parallel_matmul( + hidden_states, + weight, + bias=bias, + transpose_y=config.tie_word_embeddings, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_output=tensor_parallel_output, + ) + + return logits + + +def finfo(dtype: paddle.dtype = None): + + if dtype is None: + dtype = paddle.get_default_dtype() + + if dtype == paddle.bfloat16: + + class BFloatFInfo: + """ + Numpy do not support `np.finfo(np.uint16)`, so try to construct a finfo object to fetch min value + """ + + min = -3.3895313892515355e38 + + return BFloatFInfo + if dtype == paddle.float32: + return np.finfo(np.float32) + if dtype == paddle.float16: + return np.finfo(np.float16) + if dtype == paddle.float64: + return np.finfo(np.float64) + + +def masked_fill(x, mask, value): + + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +def mem_eff_attn( + query, key, value, pack_offset, drop_prob=0.0, dtype=paddle.bfloat16, training=True +): + + pack_offset = pack_offset.numpy() + shape = pack_offset.shape + assert len(shape) == 2, len(shape) + assert shape[0] == 1, shape[0] + n = pack_offset.size + pack_offset = pack_offset.flatten() + seqlens = [] + assert pack_offset[0] == 0, pack_offset[0] + for i in range(1, n): + if pack_offset[i] < 0: + break + cur = pack_offset[i] - pack_offset[i - 1] + assert cur > 0 + seqlens.append(cur) + + assert drop_prob == 0.0, drop_prob + assert dtype == paddle.bfloat16, dtype + + def cast(x): + return x.astype(dtype) if x.dtype != dtype else x + + if len(seqlens) == 1: + out, _ = flash_attention( + query, key, value, drop_prob, causal=True, training=training + ) + else: + mask = BlockDiagonalCausalMask.from_seqlens(seqlens) + out = memory_efficient_attention( + cast(query), + cast(key), + cast(value), + attn_bias=mask, + p=drop_prob, + training=training, + ) + return out + + +def inbatch_pack_offset_to_attn_mask_start_row_indices(inbatch_pack_offset): + """convert inbatch_pack_offset to attn_mask_start_row_indices""" + inbatch_pack_offset = inbatch_pack_offset.numpy() + attn_mask_row_start_indices = [] + min_start_row = np.inf + for bidx in range(inbatch_pack_offset.shape[0]): + item = inbatch_pack_offset[bidx] + cumsum_item = item[item != -1] + record_lens = cumsum_item[1:] - cumsum_item[0:-1] + min_start_row = min(cumsum_item[1], min_start_row) + row_start_indices = np.repeat(cumsum_item[1:], record_lens) + attn_mask_row_start_indices.append(row_start_indices[None, None, ...]) + attn_mask_row_start_indices = np.concatenate(attn_mask_row_start_indices, axis=0) + return paddle.to_tensor(attn_mask_row_start_indices, dtype=paddle.int32), int( + min_start_row + ) + + +def scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + output_attentions, + config, + is_causal=True, + rr_flash_attn=None, + inbatch_pack_offset=None, + training=True, +): + + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = value_states.shape + + can_use_fa = config.use_flash_attn and flash_attention is not None + can_use_fa_sparse_mask = ( + config.use_mem_eff_attn + and inbatch_pack_offset is not None + and flash_attention_with_sparse_mask is not None + ) + + if not can_use_fa and not can_use_fa_sparse_mask: + if query_states.shape[-2] != key_states.shape[-2]: + key_states = key_states.repeat_interleave( + num_heads // num_key_value_heads, axis=-2 + ) + if query_states.shape[-2] != value_states.shape[-2]: + value_states = value_states.repeat_interleave( + num_heads // num_key_value_heads, axis=-2 + ) + + if can_use_fa: + if rr_flash_attn is not None: + attn_output, attn_weights = rr_flash_attn( + query_states, + key_states, + value_states, + dropout=config.attention_probs_dropout_prob, + causal=is_causal and query_states.shape[1] != 1, + return_softmax=output_attentions, + ) + else: + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + dropout=config.attention_probs_dropout_prob, + causal=is_causal and query_states.shape[1] != 1, + return_softmax=output_attentions, + ) + + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return attn_output, attn_weights + elif config.use_mem_eff_attn and inbatch_pack_offset is not None: + assert ( + not output_attentions + ), "output_attentions should be False when use_mem_eff_attn=True" + if config.use_flash_attn_with_mask: + if flash_attention_with_sparse_mask is not None: + causal_mask_indices, attn_mask_min_start_row = ( + inbatch_pack_offset_to_attn_mask_start_row_indices( + inbatch_pack_offset + ) + ) + if IS_FLEETY: + kwargs = { + "causal": True, + "dropout": config.attention_probs_dropout_prob, + } + else: + kwargs = { + "is_causal": True, + "dropout_p": config.attention_probs_dropout_prob, + } + attn_output = flash_attention_with_sparse_mask( + query_states.astype(value_states.dtype), + key_states.astype(value_states.dtype), + value_states.astype(value_states.dtype), + attn_mask_start_row_indices=causal_mask_indices, + attn_mask_start_row=attn_mask_min_start_row, + **kwargs, + ) + else: + attn_mask = to_block_diag_causal_mask( + inbatch_pack_offset, q_len, float("-inf"), "bfloat16" + ) + attn_output = flash_attention_with_mask( + query_states, + key_states, + value_states, + attn_mask, + config.attention_probs_dropout_prob, + ) + else: + attn_output = mem_eff_attn( + query_states, + key_states, + value_states, + inbatch_pack_offset, + drop_prob=config.attention_probs_dropout_prob, + ) + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return attn_output, None + else: + + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) / math.sqrt( + head_dim + ) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + # Pipeline 的Attention mask不能从外面传。 + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + if training: + attn_weights = attention_mask + attn_weights + attn_weights = paddle.maximum( + attn_weights, + paddle.to_tensor( + float(finfo(query_states.dtype).min), dtype=query_states.dtype + ), + ) + + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + attn_weights = F.softmax( + attn_weights, axis=-1, dtype="float32" + ).astype(query_states.dtype) + else: + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype( + query_states.dtype + ) + else: # use inplace operation to save memory + attn_weights = attn_weights.cast(paddle.float32) + attention_mask = attention_mask.cast(paddle.float32) + attn_weights = attn_weights.add_(attention_mask) + attn_weights = F.softmax_(attn_weights, axis=-1).astype(query_states.dtype) + + if config.attention_probs_dropout_prob > 0.0: + if config.tensor_parallel_degree > 1: + with get_rng_state_tracker().rng_state("local_seed"): + attn_weights = F.dropout( + attn_weights, + config.attention_probs_dropout_prob, + training=training, + mode="upscale_in_train", + ) + else: + attn_weights = F.dropout( + attn_weights, + config.attention_probs_dropout_prob, + training=training, + mode="upscale_in_train", + ) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + if output_attentions: + return attn_output, attn_weights + return attn_output, None + + +def _make_causal_mask(input_ids_shape, past_key_values_length, dtype): + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + + mask = paddle.full((target_length, target_length), float(finfo(dtype).min)) + + mask_cond = paddle.arange(mask.shape[-1]) + mask = masked_fill( + mask, mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0 + ) + + if past_key_values_length > 0: + mask = paddle.concat( + [paddle.zeros([target_length, past_key_values_length]), mask], axis=-1 + ) + + return mask[None, None, :, :].expand( + [batch_size, 1, target_length, target_length + past_key_values_length] + ) + + +def _expand_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + if mask.ndim == 4: + expanded_mask = mask + elif mask.ndim == 3: + expanded_mask = mask[:, None, :, :] + else: + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = mask[:, None, None, :].expand( + [batch_size, 1, tgt_length, src_length] + ) + + inverted_mask = 1.0 - expanded_mask + return masked_fill( + inverted_mask, inverted_mask.cast("bool"), float(finfo(dtype).min) + ) + + +def slice_experts(experts, moe_world_size): + moe_num_experts_per_device = len(experts) // moe_world_size + experts_per_device = [[] for _ in range(moe_world_size)] + + for i, expert in enumerate(experts): + ep_group_id = i // moe_num_experts_per_device + experts_per_device[ep_group_id].append(expert) + + lm_experts = nn.LayerList([]) + for experts_list in experts_per_device: + lm_experts.extend(experts_list[: moe_num_experts_per_device // 2]) + return lm_experts + + +def get_gate( + config: ErnieMoEConfig, + expert: Tuple[Tuple[int, nn.Layer]], + layer_idx: int, + ipp: int = 0, +) -> Tuple[nn.Layer, nn.LayerList]: + + moe_num_experts = config.moe_num_experts + assert ( + moe_num_experts >= config.moe_world_size + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={config.moe_world_size}" + assert ( + moe_num_experts % config.moe_world_size == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={config.moe_world_size} == 0" + moe_num_experts_per_device = moe_num_experts // config.moe_world_size + experts = nn.LayerList([]) + for expert_id, (experts_num, fc) in enumerate(expert): + assert experts_num % config.moe_world_size == 0 + experts_to_append = [] + if not hasattr(fc, "__len__"): + experts_to_append.append(fc) + if expert_id == 1: + with paddle.utils.unique_name.guard("_mm_deepcopy"): + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + for _ in range(experts_num - 1): + experts_to_append.append(deepcopy(fc)) + else: + experts_to_append = fc + for ex in experts_to_append: + for p in ex.parameters(): + p.expert_type = f"expert_type_{expert_id}" + experts.extend(experts_to_append) + + logger.info( + f"using moe-world-size: {config.moe_world_size} " + f"expert-per-device: {moe_num_experts_per_device} " + ) + if config.moe_use_hard_gate and moe_num_experts <= 2: + gate = None + logger.info("MOE-GATE:-hard-gate") + else: + logger.info(f"MOE-GATE:-{config.moe_gate}") + gate = gate_class[config.moe_gate.lower()]( + config, layer_idx=layer_idx, group=config.moe_group, ipp=ipp + ) + + lm_gate, lm_experts = None, None + logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") + + index = 0 if config.moe_group == "dp" else 1 + ep_sub_meshes = dist.auto_parallel.api.split_mesh(get_mesh(ipp), index) + + for i, expert in enumerate(experts): + ep_group_id = i // moe_num_experts_per_device + if isinstance(expert, (ErnieMoeMLPFused, ErnieMoeMLP)): + experts[i].redistribute_expert( + ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()] + ) + experts[i].ep_group_id = ep_group_id + + return gate, experts, lm_gate, lm_experts + + +def _parse_moe_group(moe_group: str): + moe_group = moe_group.lower() + assert moe_group in { + "dp", + "mp", + "none", + }, f"moe-group not supported, got: {moe_group}" + logger.info(f"using moe-group: {moe_group}") + + return moe_group + + +class RMSNorm(nn.Layer): + """ + RMSNorm is a variant of layer normalization. + """ + + def __init__(self, config, ipp=0): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + def forward(self, hidden_states): + + if self.config.fuse_rms_norm: + return fused.fused_rms_norm( + hidden_states, self.weight, self.variance_epsilon + )[0] + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = ( + paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + ) + else: + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = ( + paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + ) + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + +class LayerNorm(nn.LayerNorm): + """ + layer normalization. + """ + + def __init__(self, config, ipp=0): + super().__init__(config.hidden_size, epsilon=config.rms_norm_eps) + + self.use_fast_ln = config.use_fast_ln + if self.use_fast_ln: + assert fast_ln is not None + self.ipp = ipp + if config.pipeline_parallel_degree > 1: + self.weight = dist.shard_tensor( + self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + ) + self.bias = dist.shard_tensor( + self.bias, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + ) + + def forward(self, hidden_states): + """ + The layer normalization operator. + """ + if self.use_fast_ln: + return fast_ln(hidden_states, self.weight, self.bias, self._epsilon)[0] + else: + return super().forward(hidden_states) + + +class FusedLayerNorm(nn.Layer): + """ + FusedLayerNorm is a variant of layer normalization. + """ + + def __init__(self, config, ipp=0): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.bias = paddle.create_parameter( + shape=[self.hidden_size], dtype=paddle.get_default_dtype(), is_bias=True + ) + self.variance_epsilon = config.rms_norm_eps + self.ipp = ipp + if config.pipeline_parallel_degree > 1: + self.weight = dist.shard_tensor( + self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + ) + self.bias = dist.shard_tensor( + self.bias, get_mesh(self.ipp), [dist.Replicate(), dist.Replicate()] + ) + + def forward(self, hidden_states): + + return fused.fused_ln( + hidden_states, self.weight, self.bias, self.variance_epsilon + )[0] + + +class RotaryEmbedding(nn.Layer): + r""" + RotaryEmbedding Layer + """ + + def __init__(self, dim, max_position_embeddings=4096, base=10000): + + super().__init__() + # dtype = paddle.get_default_dtype() + self.base = base + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / ( + base ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / dim) + ) + + # self.register_buffer("inv_freq", inv_freq.cast(dtype)) + + # higher acc using float32 + t = paddle.arange(max_position_embeddings, dtype="float32") + freqs = paddle.einsum("i,j->ij", t, inv_freq.cast("float32")) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = paddle.concat([freqs, freqs], axis=-1) + + # [bs, seqlen, nhead, head_dim] + self.cos_cached = emb.cos() # [None, :, None, :] # .astype(dtype) + self.sin_cached = emb.sin() # [None, :, None, :] # .astype(dtype) + + self._cast_to_low_precision = False # 兼容develop分支paddle + self._cast_to_low_precison = False + + def forward(self, x, seq_len=None): + + return ( + self.cos_cached[:seq_len, :], + self.sin_cached[:seq_len, :], + ) + + @classmethod + def rotate_half(cls, x): + """Rotates half the hidden dims of the input.""" + + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) + + @classmethod + def apply_rotary_pos_emb(cls, q, k, cos, sin, offset: int = 0, position_ids=None): + """doc""" + if position_ids is not None: + # logger.info(f'applying pos:{position_ids}') + assert offset == 0, offset + cos = F.embedding(position_ids, cos) + sin = F.embedding(position_ids, sin) + else: + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + cos = cos[:, offset : q.shape[1] + offset, None, :] + sin = sin[:, offset : q.shape[1] + offset, None, :] + + q_embed = paddle.add( + paddle.multiply(q, cos), paddle.multiply(cls.rotate_half(q), sin) + ) + k_embed = paddle.add( + paddle.multiply(k, cos), paddle.multiply(cls.rotate_half(k), sin) + ) + q_embed = q_embed.astype(q.dtype) # fp32->bf16 + k_embed = k_embed.astype(k.dtype) + return q_embed, k_embed + + +class RopeEmbeddingLegacy(nn.Layer): + + def __init__(self, head_dim, compression_ratio=1.0, base=10000): + super().__init__() + self.head_dim = head_dim + self.compression_ratio = compression_ratio + self.base = base + + def forward(self, seq_length, position_ids=None): + + indices = paddle.arange(0, self.head_dim, 2, dtype="float32") + indices = 1 / self.base ** (indices / self.head_dim) + if position_ids is None: + position_ids = paddle.arange(0, seq_length, 1, dtype="float32").unsqueeze(1) + position_ids = position_ids / self.compression_ratio + sinusoid_inp = position_ids * indices.unsqueeze(0) + else: + position_ids = position_ids / self.compression_ratio + seq_length = position_ids.shape[-1] + sinusoid_inp = position_ids.unsqueeze(-1).astype( + "float32" + ) * indices.unsqueeze( + 0 + ) # [b, s, 1] * [1, d/2] -> [b, s, d/2] + pos_emb = paddle.concat( + [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1 + ) + pos_emb = paddle.reshape(pos_emb, (-1, 1, seq_length, self.head_dim)) + pos_emb.stop_gradient = True + return pos_emb + + def apply_rotary(self, rp, q, k): + + # sin [sequence_length, embed_size_per_head//2] + # cos [sequence_length, embed_size_per_head//2] + sin, cos = paddle.chunk(rp, 2, axis=-1) + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), rp.shape) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), rp.shape) + # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + query = paddle.add( + paddle.multiply(q.astype("float32"), cos_pos), + paddle.multiply(rotate_half_q.astype("float32"), sin_pos), + ) + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + key = paddle.add( + paddle.multiply(k.astype("float32"), cos_pos), + paddle.multiply(rotate_half_k.astype("float32"), sin_pos), + ) + return query, key + + def forward_single(self, position_ids): + + batch_size, seq_length = position_ids.shape[:2] + rope_emb = paddle.zeros( + (2, batch_size, seq_length, 1, self.head_dim), dtype="float32" + ) + inv_freq = self.base ** ( + -paddle.arange(0, self.head_dim, 2, dtype="float32") / self.head_dim + ) + position_ids = position_ids.cast("float32") + position_ids = position_ids / self.compression_ratio + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, D] + emb = paddle.stack([freqs, freqs], axis=-1).reshape( + (batch_size, seq_length, self.head_dim) + ) + # shape: [B, S, 1, D] + emb = paddle.unsqueeze(emb, 2) + + rope_emb[0] = paddle.cos(emb) + rope_emb[1] = paddle.sin(emb) + return rope_emb + + @staticmethod + def apply_rotary_single(x, rope_emb): + + rotate_half_x = paddle.reshape( + paddle.stack([-x[:, :, :, 1::2], x[:, :, :, 0::2]], axis=-1), + paddle.shape(x), + ) + return x * rope_emb[0] + rotate_half_x * rope_emb[1] + + +class ErnieLinear(nn.Layer): + + def __init__( + self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None, + ipp=0, + ): + super(ErnieLinear, self).__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + self.bias = self.create_parameter( + shape=[out_features], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True, + ) + self.name = name + self.ipp = ipp + + def forward(self, input): + + out = F.linear(x=input, weight=self.weight, bias=None, name=self.name) + out = dist.reshard( + out, + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + if self.bias: + out += self.bias + return out + + +class ErnieMLP(nn.Layer): + + def __init__(self, config, ipp=None, do_shard_tensor=True): + super().__init__() + self.config = config + self.ipp = ipp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + LinearFN = nn.Linear + self.gate_proj = LinearFN( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) + self.up_proj = LinearFN( + self.hidden_size, self.intermediate_size, bias_attr=config.use_bias + ) + + if config.sequence_parallel: + self.down_proj = ErnieLinear( + self.intermediate_size, + self.hidden_size, + bias_attr=config.use_bias, + ipp=self.ipp, + ) + else: + self.down_proj = LinearFN( + self.intermediate_size, self.hidden_size, bias_attr=config.use_bias + ) + + if do_shard_tensor and ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.gate_proj.weight = dist.shard_tensor( + self.gate_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + self.up_proj.weight = dist.shard_tensor( + self.up_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + if config.use_bias: + self.gate_proj.bias = dist.shard_tensor( + self.gate_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.up_proj.bias = dist.shard_tensor( + self.up_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.down_proj.weight = dist.shard_tensor( + self.down_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + if config.use_bias: + self.down_proj.bias = dist.shard_tensor( + self.down_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def forward(self, x): + + if self.fuse_swiglu: + x = fused_swiglu(self.gate_proj(x), self.up_proj(x)) + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + return self.down_proj(x) + + +class ErnieAttentionAuto(nn.Layer): + + def __init__(self, config, ipp: Optional[int] = None): + super().__init__() + self.ipp = ipp + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.use_recompute_attn = config.use_recompute_attn # aka recompute core-attn + self.is_gqa = ( + config.num_key_value_heads is not None + and config.num_key_value_heads != self.num_heads + ) + if config.fuse_rope: + assert fused_rope is not None, "fused_rope is not supported" + self.fuse_rope = config.fuse_rope + + if self.is_gqa: + logger.info( + f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}" + ) + assert ( + self.num_heads % self.num_key_value_heads == 0 + ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" + kv_hidden_size = ( + self.hidden_size // self.num_heads * self.num_key_value_heads + ) + + LinearFN = nn.Linear + self.q_proj = LinearFN( + self.hidden_size, + self.hidden_size, + bias_attr=config.use_bias, + ) + self.k_proj = LinearFN( + self.hidden_size, + self.hidden_size if not self.is_gqa else kv_hidden_size, + bias_attr=config.use_bias, + ) + self.v_proj = LinearFN( + self.hidden_size, + self.hidden_size if not self.is_gqa else kv_hidden_size, + bias_attr=config.use_bias, + ) + + if config.sequence_parallel: + self.o_proj = ErnieLinear( + self.hidden_size, + self.hidden_size, + bias_attr=config.use_bias, + ipp=self.ipp, + ) + else: + self.o_proj = LinearFN( + self.hidden_size, + self.hidden_size, + bias_attr=config.use_bias, + ) + + self.config = config + + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.q_proj.weight = dist.shard_tensor( + self.q_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + self.k_proj.weight = dist.shard_tensor( + self.k_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + self.v_proj.weight = dist.shard_tensor( + self.v_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(1)], + ) + if config.use_bias: + self.q_proj.bias = dist.shard_tensor( + self.q_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.k_proj.bias = dist.shard_tensor( + self.k_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.v_proj.bias = dist.shard_tensor( + self.v_proj.bias, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + self.o_proj.weight = dist.shard_tensor( + self.o_proj.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Shard(0)], + ) + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + inbatch_pack_offset: Optional[Tuple[paddle.Tensor]] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if self.config.sequence_parallel: + # do all-gather + hidden_states = dist.reshard( + hidden_states, get_mesh(self.ipp), [dist.Shard(1), dist.Replicate()] + ) + + query_states = ( + self.q_proj(hidden_states).reshape( + shape=[0, 0, self.num_heads, self.head_dim] + ) + # .transpose([0, 2, 1, 3]) + ) + key_states = ( + self.k_proj(hidden_states).reshape( + shape=[ + 0, + 0, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + # .transpose([0, 2, 1, 3]) + ) + value_states = ( + self.v_proj(hidden_states).reshape( + shape=[ + 0, + 0, + self.num_key_value_heads if self.is_gqa else self.num_heads, + self.head_dim, + ] + ) + # .transpose([0, 2, 1, 3]) + ) + + if self.config.sequence_parallel: + query_states = paddle.transpose(query_states, [1, 0, 2, 3]) + key_states = paddle.transpose(key_states, [1, 0, 2, 3]) + value_states = paddle.transpose(value_states, [1, 0, 2, 3]) + + if self.use_recompute_attn: + assert past_key_value is None, "do not use kv cache in recompute" + assert not use_cache + attn_output, attn_weights, past_key_value = recompute( + self.rope_attn, + None, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + use_reentrant=False, + ) + else: + attn_output, attn_weights, past_key_value = self.rope_attn( + mix_layer=None, + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + ) + + if self.config.sequence_parallel: + attn_output = paddle.transpose(attn_output, [1, 0, 2]) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def rope_attn( + self, + mix_layer, + query_states, + key_states, + value_states, + attention_mask, + position_ids, + output_attentions=False, + past_key_value=None, + use_cache=False, + inbatch_pack_offset=None, + ): + if mix_layer is not None: + query_states, key_states, value_states = paddle.split(mix_layer, 3, axis=-1) + query_states_dtype = query_states.dtype + + kv_seq_len = key_states.shape[-3] + offset = 0 + if past_key_value is not None: + offset = past_key_value[0].shape[-3] + kv_seq_len += offset + + if self.config.rope_reorder: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = self.rotary_emb.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids=position_ids, + offset=offset if position_ids is None else 0, + ) + else: + if offset > 0 or position_ids is not None or not self.fuse_rope: + cos_sin = self.rotary_emb(kv_seq_len, position_ids).transpose( + [0, 2, 1, 3] + ) # [b,h,s,d]->[b,s,h,d] + if offset > 0 and position_ids is None: + # position_ids has been sliced in prepare_inputs_for_generation + cos_sin = cos_sin[:, offset:] + query_states, key_states = self.rotary_emb.apply_rotary( + cos_sin, query_states, key_states + ) + else: + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape + if num_heads != num_key_value_heads: + query_states, _, _ = fused_rope(query_states, None, None) + key_states, _, _ = fused_rope(key_states, None, None) + else: + query_states, key_states, _ = fused_rope( + query_states, key_states, None + ) + + if use_cache: + query_states = query_states.astype(query_states_dtype) + key_states = key_states.astype(query_states_dtype) + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = [key_states, value_states] if use_cache else None + + attn_output, attn_weights = scaled_dot_product_attention( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + config=self.config, + inbatch_pack_offset=inbatch_pack_offset, + training=self.training, + ) + return attn_output, attn_weights, past_key_value + + +class ErnieMoeMLP(ErnieMLP): + """_summary_ + + Args: + ErnieMoeMLP (_type_): _description_ + """ + + def __init__(self, config, ipp=0): + """ + doc + """ + disable_ffn_model_parallel = getattr( + config, "disable_ffn_model_parallel", False + ) + if disable_ffn_model_parallel: + # assert config.moe_group == "mp", f"when using mp_moe, expect moe-group == mp, but get {config.moe_group}" + config = deepcopy(config) + config.tensor_parallel_degree = 1 + config.sequence_parallel = False + + super().__init__(config, ipp, do_shard_tensor=not disable_ffn_model_parallel) + self.moe_dropout_prob = config.moe_dropout_prob + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def redistribute_expert(self, mesh, placements): + """ + Place the experts on different devices. + """ + self.gate_proj.weight = dist.shard_tensor( + self.gate_proj.weight, mesh, placements + ) + # self.gate_proj.bias = dist.shard_tensor(self.gate_proj.bias, mesh, placements) + self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements) + # self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) + self.down_proj.weight = dist.shard_tensor( + self.down_proj.weight, mesh, placements + ) + if self.config.use_bias: + self.gate_proj.bias = dist.shard_tensor( + self.gate_proj.bias, mesh, placements + ) + self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) + self.down_proj.bias = dist.shard_tensor( + self.down_proj.bias, mesh, placements + ) + + def forward(self, x): + + if self.fuse_swiglu: + x = fused_swiglu(self.gate_proj(x), self.up_proj(x)) + else: + x = F.silu(self.gate_proj(x)) * self.up_proj(x) + if self.moe_dropout_prob > 0: + with get_rng_state_tracker().rng_state("local_seed"): + x = F.dropout(x=x, p=self.moe_dropout_prob) + ret = self.down_proj(x) + return ret + + +class BMMLinear(nn.Layer): + + def __init__(self, experts, d_in, d_out, use_bias=False): + super().__init__() + self.weight = self.create_parameter( + [experts, d_in, d_out], dtype=paddle.get_default_dtype() + ) + if use_bias: + self.bias = self.create_parameter( + [experts, d_out], dtype=paddle.get_default_dtype(), is_bias=True + ) + else: + self.bias = None + + def forward(self, x): + """x: [num_experts, Seq, dim]""" + if self.bias is not None: + return paddle.bmm(x, self.weight) + self.bias + return paddle.bmm(x, self.weight) + + +class ErnieMoeMLPFused(nn.Layer): + """Fused Implement of ErnieMoeMLP""" + + def __init__(self, config): + + assert ( + hasattr(config, "disable_ffn_model_parallel") + or config.tensor_parallel_degree == 1 + ), f"fused mlp only suport mp-moe, mp={config.tensor_parallel_degree}" + assert config.fuse_attn_ffn, "fused mlp only support fuse_attn_ffn" + super().__init__() + self.moe_dropout_prob = config.moe_dropout_prob + self.num_local_experts = config.moe_num_experts // config.moe_world_size + logger.info( + f"fused-expert-weight-shape: {[self.num_local_experts, config.hidden_size, config.intermediate_size]}" + ) + + self.up_gate_proj = BMMLinear( + self.num_local_experts, config.hidden_size, config.intermediate_size * 2 + ) + self.down_proj = BMMLinear( + self.num_local_experts, config.intermediate_size, config.hidden_size + ) + self.fuse_swiglu = config.fuse_swiglu + if self.fuse_swiglu: + assert fused_swiglu is not None, "fused_swiglu operator is not found." + + def __len__(self): + return self.num_local_experts + + def __iter__(self): + return (self for _ in range(1)) + + def forward(self, x): + """x""" + if self.fuse_swiglu: + x = fused_swiglu(self.up_gate_proj(x)) + else: + gate, x = self.up_gate_proj(x).chunk(2, axis=-1) + x = F.silu(gate) * x + x = self.down_proj(x) + return x + + +class ErnieDecoderLayerAuto(nn.Layer): + """ + ErnieDecoderLayerAuto is a decoder layer in Ernie model. + It is composed of self-attention, cross-attention and feedforward layers. + """ + + def __init__(self, config, layer_idx=0, ipp=0): + """ + Initializes the ErnieBlock module. + + Args: + config (ErnieConfig): The model configuration. + layer_idx (int, optional): The index of this block in the model. Defaults to 0. + ipp (int, optional): The index of this block in the pipeline parallelism. Defaults to 0. + """ + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.ipp = ipp + self.hidden_size = config.hidden_size + self.self_attn = ErnieAttentionAuto(config, ipp) + self.use_moe = config.use_moe if hasattr(config, "use_moe") else False + if self.use_moe: + moe_layer_start_index = ( + min(config.moe_layer_start_index) + if isinstance(config.moe_layer_start_index, (tuple, list)) + else config.moe_layer_start_index + ) + moe_layer_end_index = ( + max(config.moe_layer_end_index) + if isinstance(config.moe_layer_end_index, (tuple, list)) + else config.moe_layer_end_index + ) + + if ( + self.use_moe + and ((layer_idx + 1) % config.moe_layer_interval == 0) + and layer_idx >= moe_layer_start_index + and layer_idx <= moe_layer_end_index + ): + self.create_moe_mlp_layer(layer_idx, ipp) + else: + self.mlp = ErnieMLP(config, ipp) + Norm = RMSNorm if config.use_rmsnorm else LayerNorm + if not config.use_rmsnorm and config.fuse_ln: + Norm = FusedLayerNorm + self.input_layernorm = Norm(config, ipp) + self.post_attention_layernorm = Norm(config, ipp) + self.residual_add1 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) + self.residual_add2 = FusedDropoutImpl( + config.hidden_dropout_prob, mode="upscale_in_train" + ) + + def create_moe_mlp_layer(self, layer_idx, ipp): + _ex_cfg = deepcopy(self.config) + fc_cls = ErnieMoeMLPFused if _ex_cfg.moe_fuse_experts else ErnieMoeMLP + if _ex_cfg.moe_intermediate_size: + if isinstance(_ex_cfg.moe_intermediate_size, (tuple, list)): + assert isinstance(_ex_cfg.moe_num_experts, (tuple, list)) and len( + _ex_cfg.moe_num_experts + ) == len(_ex_cfg.moe_intermediate_size) + fc = [] + for _i, (num_experts, intermediate_size) in enumerate( + zip(_ex_cfg.moe_num_experts, _ex_cfg.moe_intermediate_size) + ): + _ex_cfg_real = deepcopy(_ex_cfg) + _ex_cfg_real.intermediate_size = intermediate_size + cur_modality_start_layer_idx = ( + self.config.moe_layer_start_index[_i] + if isinstance(self.config.moe_layer_start_index, (tuple, list)) + else self.config.moe_layer_start_index + ) + cur_modality_end_layer_idx = ( + self.config.moe_layer_end_index[_i] + if isinstance(self.config.moe_layer_end_index, (tuple, list)) + else self.config.moe_layer_end_index + ) + if ( + layer_idx >= cur_modality_start_layer_idx + and layer_idx <= cur_modality_end_layer_idx + ): + if _i == 1: + with paddle.utils.unique_name.guard( + f"mm_expert_{layer_idx}_" + ): + fc.append((num_experts, fc_cls(_ex_cfg_real))) + else: + fc.append((num_experts, fc_cls(_ex_cfg_real))) + else: + logger.info( + f"moe multimodal experts use Identity layer_idx: {layer_idx}" + ) + fc.append((num_experts, nn.Identity())) + else: + _ex_cfg.intermediate_size = _ex_cfg.moe_intermediate_size + fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] + else: + fc = [(_ex_cfg.moe_num_experts, fc_cls(_ex_cfg))] + gate, experts, lm_gate, lm_experts = get_gate( + self.config, fc, layer_idx, self.ipp + ) + _sh_cfg = deepcopy(self.config) + + if _sh_cfg.moe_num_shared_experts > 0: + if _sh_cfg.moe_intermediate_size: + _sh_inter_size = ( + _sh_cfg.moe_intermediate_size[0] + if isinstance(_sh_cfg.moe_intermediate_size, (tuple, list)) + else _sh_cfg.moe_intermediate_size + ) + _sh_cfg.intermediate_size = ( + _sh_inter_size * _sh_cfg.moe_num_shared_experts + ) + else: + _sh_cfg.intermediate_size = ( + _sh_cfg.intermediate_size * _sh_cfg.moe_num_shared_experts + ) + _sh_cfg.disable_ffn_model_parallel = False # split shared epxert + shared_experts = ErnieMoeMLP(_sh_cfg, ipp) + else: + shared_experts = None + + is_moe_infer = self.config.get("is_moe_infer", False) + if is_moe_infer: + raise NotImplementedError + elif self.config.moe_use_size_all2all: + raise NotImplementedError + else: + logger.info(f"moe-logging:{self.config.moe_logging}") + moe_cls = MOELayerAuto + self.mlp = moe_cls( + gate, + experts, + layer_idx=layer_idx, + shared_experts=shared_experts, + group=self.config.moe_group, + recompute=self.config.use_recompute_moe, + enable_logging=self.config.moe_logging, + k=self.config.moe_k, + enable_pbr=self.config.moe_use_bpr, + all_to_all_dropout=self.config.moe_all_to_all_dropout, + group_experts=self.config.moe_group_experts, + config=self.config, + ipp=self.ipp, + ) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + inbatch_pack_offset: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + output_gate_logits=True, # PP model should not output gate logits, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `cache` key value states are returned and can be used to speed up decoding + (see `cache`). + cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + (hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = ( + self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + use_cache=use_cache, + inbatch_pack_offset=inbatch_pack_offset, + ) + ) + + if ( + self.config.tensor_parallel_degree > 1 + and self.config.hidden_dropout_prob > 0.0 + ): + current_seed = ( + "local_seed" if self.config.sequence_parallel else "global_seed" + ) + with get_rng_state_tracker().rng_state(current_seed): + hidden_states = self.residual_add1(hidden_states, residual) + else: + hidden_states = self.residual_add1(hidden_states, residual) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if isinstance( + self.mlp, + (MOELayerAuto), + ): + + hidden_states, _, router_loss, gate_logits = self.mlp( + hidden_states, token_type_ids + ) + else: + if self.config.sequence_parallel: + # do all-gather + hidden_states = dist.reshard( + hidden_states, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + hidden_states = self.mlp(hidden_states) + gate_logits = None + + if ( + self.config.tensor_parallel_degree > 1 + and self.config.hidden_dropout_prob > 0.0 + ): + current_seed = ( + "local_seed" if self.config.sequence_parallel else "global_seed" + ) + with get_rng_state_tracker().rng_state(current_seed): + hidden_states = self.residual_add2(hidden_states, residual) + else: + hidden_states = self.residual_add2(hidden_states, residual) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if hasattr(self.config, "use_moe") and self.config.use_moe: + if router_loss_attn: + router_loss_attn = router_loss_attn[0] + router_loss = router_loss + router_loss_attn + + if isinstance(self.mlp, (MOELayerAuto)): + outputs += (router_loss,) + else: + outputs += (paddle.zeros([1], dtype=paddle.float32),) + + if output_gate_logits: + outputs += (gate_logits,) + + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + return outputs + + +class ErniePretrainedModelAuto(PretrainedModel): + """ + ErniePretrainedModelAuto is a pretrained model class for Ernie model. + It is composed of a encoder and a decoder. + """ + + config_class = ErnieMoEConfig + base_model_prefix = "ernie" + + @classmethod + def _get_name_mappings(cls, config: ErnieMoEConfig) -> StateDictNameMapping: + + mappings: StateDictNameMapping = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range( + config.num_hidden_layers + if not config.remove_tail_layer + else config.num_hidden_layers - 1 + ): + layer_mappings = [ + [ + f"layers.{layer_index}.self_attn.q_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.k_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.v_proj.weight", + None, + "transpose", + ], + [ + f"layers.{layer_index}.self_attn.o_proj.weight", + None, + "transpose", + ], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + if "ErnieModelAuto" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "ernie." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [ + StateDictNameMapping(*mapping, index=index) + for index, mapping in enumerate(model_mappings) + ] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + base_actions = { + # Column Linear + "layers.0.self_attn.q_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), + "lm_head.weight": partial(fn, is_column=not config.tie_word_embeddings), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + if config.use_bias: + base_actions.update( + { + # Column Linear + "layers.0.self_attn.q_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.k_proj.bias": partial(fn, is_column=True), + "layers.0.self_attn.v_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.up_proj.bias": partial(fn, is_column=True), + "lm_head.bias": partial(fn, is_column=True), + } + ) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings( + config.num_hidden_layers + if not config.remove_tail_layer + else config.num_hidden_layers - 1 + ) + + return mappings + + def init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + else: + rng_tracker = contextlib.nullcontext + + if isinstance( + layer, + ( + ErnieLMHead, + nn.Embedding, + nn.Linear, + paddle.incubate.nn.FusedLinear, + ), + ): + + with rng_tracker(): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + if layer.weight._is_initialized(): + if layer.weight.is_dist(): + layer.weight._local_value().set_value( + paddle.randn( + layer.weight._local_shape, dtype=layer.weight.dtype + ).scale(self.config.initializer_range) + ) + else: + layer.weight.set_value( + paddle.randn( + layer.weight.shape, dtype=layer.weight.dtype + ).scale(self.config.initializer_range) + ) + paddle.set_default_dtype(dtype) + logger.info( + f"dist-init-fc: shape={layer.weight.shape}, " + f" range={self.config.initializer_range}," + f' type={type(layer)},norm={layer.weight.astype("float32").norm()}' + ) + + elif isinstance(layer, RotaryEmbedding): + head_dim = self.config.hidden_size // self.config.num_attention_heads + inv_freq = 1.0 / ( + layer.base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) + ) + # self.register_buffer("inv_freq", inv_freq.cast(dtype)) + + # higher acc using float32 + t = np.arange(layer.max_position_embeddings, dtype="float32") + freqs = np.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = np.concatenate([freqs, freqs], axis=-1) + # [bs, seqlen, nhead, head_dim] + cos_cached = np.cos(emb)[:, :] + sin_cached = np.sin(emb)[:, :] + layer.cos_cached.set_value(cos_cached) + layer.sin_cached.set_value(sin_cached) + elif isinstance(layer, Top2Gate): + if not hasattr(layer, "weight"): + return + with rng_tracker("model_parallel_rng"): + dtype = paddle.get_default_dtype() + paddle.set_default_dtype("float32") + if self.config.moe_group_experts: + if layer.weight._is_initialized(): + layer.weight.set_value( + paddle.randn( + layer.weight.shape, dtype=layer.weight.dtype + ).scale(self.config.initializer_range) + ) + else: + if layer.weight._is_initialized(): + granularity = ( + 1 + if self.config.moe_intermediate_size == 0 + else self.config.intermediate_size + // self.config.moe_intermediate_size + ) + layer.weight.set_value( + paddle.randn( + [ + self.config.hidden_size, + self.config.moe_num_experts // granularity, + ], + dtype="float32", + ) + .scale(self.config.initializer_range) + .repeat_interleave(granularity, axis=-1) + ) + logger.info( + f"dist-init-moe_gate: shape={layer.weight.shape}, dtype={layer.weight.dtype} " + f"range={self.config.initializer_range},type={type(layer)}, " + f'norm={layer.weight.astype("float32").norm()}' + ) + + +@register_base_model +class ErnieModelAuto(ErniePretrainedModelAuto): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ErnieDecoderLayerAuto`] + Args: + config: ErnieMoEConfig + """ + + def __init__(self, config: ErnieMoEConfig): + if hasattr(config, "use_moe") and config.use_moe: + if config.moe_group in {"mp", "model", "tp", "mpdp"}: + assert config.sequence_parallel + logger.info( + f"disable FFN tensor model parallel, moe-group={config.moe_group}" + ) + config.disable_ffn_model_parallel = True + + config.moe_group = _parse_moe_group(config.moe_group) + if config.moe_group in fleet.auto.get_mesh().dim_names: + config.moe_world_size = fleet.auto.get_mesh().get_dim_size( + config.moe_group + ) + if config.moe_world_size < 0: + config.moe_world_size = 1 + else: + config.moe_world_size = 1 + + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.config = config + + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + if not in_auto_parallel_align_mode(): + self.embed_tokens.weight = dist.shard_tensor( + self.embed_tokens.weight, + get_mesh(), + [dist.Replicate(), dist.Shard(1)], + ) + + layers_list = [] + + def get_layer_pp_info(ipp): + mesh = fleet.auto.get_mesh() + if is_pp_enable() is False: + return None, False + else: + pp_degree = mesh.get_dim_size("pp") + layer_num = ( + config.num_hidden_layers - 1 + if config.remove_tail_layer + else config.num_hidden_layers + ) + layer_per_stage = math.ceil(layer_num / pp_degree) + input_need_reshard = ipp % layer_per_stage == 0 + return ipp // layer_per_stage, input_need_reshard + + self.next_pp_stage_indexes = [] + for layer_idx in range( + config.num_hidden_layers - 1 + if config.remove_tail_layer + else config.num_hidden_layers + ): + pp_stage_id, input_need_reshard = get_layer_pp_info(layer_idx) + layers_list.append(ErnieDecoderLayerAuto(config, layer_idx, pp_stage_id)) + if input_need_reshard: + self.next_pp_stage_indexes.append(layer_idx) + self.layers = nn.LayerList(layers_list) + Norm = RMSNorm if config.use_rmsnorm else LayerNorm + if not config.use_rmsnorm and config.fuse_ln: + Norm = FusedLayerNorm + self.norm = Norm(config, -1) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + + self.placements = ( + [dist.Shard(1), dist.Shard(0)] + if self.config.sequence_parallel + else [dist.Shard(0), dist.Replicate()] + ) + self.all_gate_logits = () if hasattr(self.config, "use_moe") else None + self.lm_head = ErnieLMHead(config) + self.inbatch_pack_offset = None + self.token_type_ids = None + self.past_key_values = None + self.inbatch_pack_offset = None + self.inputs_embeds = None + self.all_hidden_states = None + self.all_self_attns = None + self.next_decoder_cache = None + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @classmethod + def _prepare_decoder_attention_mask( + cls, attention_mask, input_shape, past_key_values_length, dtype + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length, dtype=dtype + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, dtype, tgt_length=input_shape[-1] + ) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + combined_attention_mask = paddle.maximum( + combined_attention_mask.astype(dtype), + paddle.to_tensor(float(finfo(dtype).min), dtype=dtype), + ) + return combined_attention_mask + + def recompute_training( + self, + layer_module, + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids, + ): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_gate_logits=False) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + position_ids, + output_attentions, + past_key_value, + use_cache, + inbatch_pack_offset, + token_type_ids, + use_reentrant=False, + ) + return hidden_states + + def embed_inputs(self, input_ids, attention_mask, position_ids): + inputs_embeds = self.inputs_embeds + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if self.past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + seq_length_with_past = seq_length + cache_length = 0 + + if past_key_values[0] is not None: + cache_length = paddle.shape(past_key_values[0][0])[1] + seq_length_with_past += cache_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids).astype( + self.embed_tokens.weight.dtype + ) + + global_mesh = global_mesh_starts_with_pp() + if self.config.sequence_parallel: + # [B, S, H] -> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + + if position_ids is not None: + position_ids = dist.shard_tensor( + position_ids, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + can_use_fa = self.config.use_flash_attn and flash_attention is not None + can_mem_eff_attn = ( + self.config.use_mem_eff_attn and self.inbatch_pack_offset is not None + ) + if can_use_fa or can_mem_eff_attn: + if attention_mask is not None: + attention_mask = None + + elif attention_mask is None: + attention_mask = paddle.ones( + (batch_size, seq_length_with_past), dtype=paddle.bool + ) + + if attention_mask is not None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + cache_length, + inputs_embeds.dtype, + ) + attention_mask = dist.shard_tensor( + attention_mask, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + + hidden_states = inputs_embeds + if self.config.tensor_parallel_degree > 1: + hidden_states = dist.reshard(hidden_states, get_mesh(0), self.placements) + + return hidden_states, attention_mask, position_ids + + def decode_layer( + self, + decoder_layer, + hidden_states, + attention_mask, + position_ids, + all_router_loss=None, + ): + if self.config.output_hidden_states: + self.all_hidden_states += (hidden_states,) + has_gradient = not hidden_states.stop_gradient + ipp = decoder_layer.ipp + if not is_pp_enable(): + position_ids_input = position_ids + attention_mask_input = attention_mask + token_type_ids_input = self.token_type_ids + else: + if position_ids is not None: + position_ids_input = dist.reshard( + position_ids, + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + position_ids_input = position_ids + attention_mask_input = ( + dist.reshard( + attention_mask, + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], + ) + if attention_mask is not None + else None + ) + token_type_ids_input = ( + dist.reshard( + self.token_type_ids, + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], + ) + if self.token_type_ids is not None + else None + ) + if self.config.use_recompute and has_gradient: + layer_outputs = self.recompute_training( + decoder_layer, + hidden_states, + attention_mask_input, + position_ids_input, + self.config.output_attentions, + self.past_key_values, + self.config.use_cache, + self.inbatch_pack_offset, + token_type_ids_input, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask_input, + position_ids_input, + self.config.output_attentions, + self.past_key_values, + self.config.use_cache, + self.inbatch_pack_offset, + token_type_ids_input, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if self.config.use_cache: + self.next_decoder_cache += ( + layer_outputs[2 if self.config.output_attentions else 1], + ) + + if self.config.output_attentions: + self.all_self_attns += (layer_outputs[1],) + if hasattr(self.config, "use_moe") and self.config.use_moe: + if not (self.config.use_recompute and has_gradient): + layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1] + self.all_gate_logits = self.all_gate_logits + (gate_logits,) + router_loss = layer_outputs[-1] + if all_router_loss is not None: + all_router_loss += router_loss + return hidden_states, all_router_loss + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + inbatch_pack_offset=None, + token_type_ids=None, + **kwargs, + ): + self.inputs_embeds = inputs_embeds + self.past_key_values = past_key_values + self.inbatch_pack_offset = inbatch_pack_offset + self.token_type_ids = token_type_ids + self.inbatch_pack_offset = inbatch_pack_offset + if use_cache is not None: + self.config.use_cache = use_cache + if return_dict is not None: + self.config.return_dict = return_dict + if output_hidden_states is not None: + self.config.output_hidden_states = output_hidden_states + if output_attentions is not None: + self.config.output_attentions = output_attentions + + hidden_states, attention_mask, position_ids = self.embed_inputs( + input_ids, attention_mask, position_ids + ) + + self.all_hidden_states = () if output_hidden_states else None + self.all_self_attns = () if output_attentions else None + self.next_decoder_cache = () if use_cache else None + + all_router_loss = None + if hasattr(self.config, "use_moe") and self.config.use_moe: + all_router_loss = paddle.to_tensor(0.0) + + for idx, (decoder_layer) in enumerate(self.layers): + hidden_states, all_router_loss = self.decode_layer( + decoder_layer, + hidden_states, + attention_mask, + position_ids, + all_router_loss, + ) + + if use_cache and not (hasattr(self.config, "use_moe") and self.config.use_moe): + hidden_states = paddle.unsqueeze(hidden_states[:, -1, :], 1) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + self.all_hidden_states += (hidden_states,) + + next_cache = self.next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + self.all_hidden_states, + self.all_self_attns, + all_router_loss, + self.all_gate_logits, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=self.all_hidden_states, + attentions=self.all_self_attns, + cross_attentions=None, + router_loss=all_router_loss, + gate_logits=self.all_gate_logits, + ) + + +class ErniePretrainingCriterionBase(paddle.nn.Layer): + """ + Criterion for Ernie. + It calculates the final loss. + """ + + def __init__(self, config, return_tuple=True): + super(ErniePretrainingCriterionBase, self).__init__() + self.ignored_index = getattr(config, "ignored_index", -100) + self.config = config + self.return_tuple = return_tuple + self.enable_parallel_cross_entropy = ( + config.tensor_parallel_degree > 1 and config.tensor_parallel_output + ) + + self.loss_func = paddle.nn.CrossEntropyLoss( + reduction="none", + ) + + def forward(self, prediction_scores, masked_lm_labels): + if self.config.use_sparse_head_and_loss_fn: + hidden_states, outlinear_weight, outlinear_bias = prediction_scores + + if self.config.sequence_parallel: + masked_lm_labels, sparse_label_idx = ( + sequence_parallel_sparse_mask_labels( + masked_lm_labels, self.ignored_index + ) + ) + else: + masked_lm_labels = masked_lm_labels.flatten() + sparse_label_idx = paddle.nonzero( + masked_lm_labels != self.ignored_index + ).flatten() + masked_lm_labels = paddle.take_along_axis( + masked_lm_labels, sparse_label_idx, axis=0 + ) + + hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]]) + hidden_states = paddle.take_along_axis( + hidden_states, sparse_label_idx.reshape([-1, 1]), axis=0 + ) + + if self.config.use_recompute_loss_fn: + res = recompute( + self.forward_impl_with_calc_logits, + masked_lm_labels, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx, + ) + else: + logits = calc_lm_head_logits( + self.config, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx, + ) + res = self.forward_impl(logits, masked_lm_labels) + elif self.config.use_recompute_loss_fn: + assert isinstance(prediction_scores, tuple) and len(prediction_scores) in [ + 3, + 4, + ] + res = recompute( + self.forward_impl_with_calc_logits, masked_lm_labels, *prediction_scores + ) + else: + res = self.forward_impl(prediction_scores, masked_lm_labels) + + return res + + def forward_impl_with_calc_logits( + self, + masked_lm_labels, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx=None, + tensor_parallel_output=None, + ): + + logits = calc_lm_head_logits( + self.config, + hidden_states, + outlinear_weight, + outlinear_bias, + sparse_label_idx, + tensor_parallel_output, + ) + + return self.forward_impl(logits, masked_lm_labels) + + def loss_impl(self, prediction_scores, masked_lm_labels): + """extract loss impl for subbatch""" + masked_lm_loss = self.loss_func( + prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(-1) + ) + return masked_lm_loss + + def forward_impl(self, prediction_scores, masked_lm_labels): + + with paddle.amp.auto_cast(False): + if self.config.use_sparse_head_and_loss_fn and prediction_scores.shape[ + 0 + ] > self.config.get("loss_subbatch_seqlen", 32768): + sb_loss_func = subbatch( + self.loss_impl, + [0, 1], + [0, 0], + self.config.get("loss_subbatch_seqlen", 32768), + 0, + ) + masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels) + else: + masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels) + lossmask = masked_lm_labels != self.ignored_index + + if (~lossmask).all(): # empty span + logger.warning( + f"encounter empty span when calculate loss, ignored_index={self.ignored_index}" + ) + loss = paddle.mean(masked_lm_loss) * 0.0 + loss_sum = masked_lm_loss.sum().detach() + else: + lossmask_ = lossmask.reshape([-1]).cast(paddle.float32) + # 逐位对齐, 全精度聚合 + masked_lm_loss_ = paddle.sum( + masked_lm_loss.cast(paddle.float32).reshape([-1]) * lossmask_ + ) + loss = masked_lm_loss_ / lossmask_.sum() + loss_sum = masked_lm_loss_.sum().detach() + + if not self.return_tuple: # only used in pp + if self.training: + return loss + return loss_sum + return loss, loss_sum + + +class ErniePretrainingCriterion(ErniePretrainingCriterionBase): + """ + Criterion for Ernie. + It calculates the final loss. + """ + + def __init__(self, config, return_tuple=True): + super(ErniePretrainingCriterion, self).__init__( + config, return_tuple=return_tuple + ) + + def forward(self, prediction_scores, masked_lm_labels, router_loss=None): + """ + calculates the final loss + """ + res = super().forward( + prediction_scores, + masked_lm_labels, + ) + if self.return_tuple: + loss, loss_sum = res + else: + loss, loss_sum = res, None + # global_training_logs.update(lm_loss=loss.clone().detach()) + if router_loss is not None and not in_auto_parallel_align_mode(): + loss = loss + router_loss - router_loss.detach() + # if isinstance(router_loss, paddle.Tensor): + # global_training_logs.update(router_loss=router_loss.detach()) + return loss, loss_sum + + +class ErnieLMHead(nn.Layer): + """ + ErnieLMHead is the linear layer used to project hidden state of decoder into word embeddings. + """ + + def __init__(self, config): + super(ErnieLMHead, self).__init__() + self.config = config + vocab_size = config.vocab_size + self.weight = self.create_parameter( + shape=( + [vocab_size, config.hidden_size] + if config.tie_word_embeddings + else [config.hidden_size, vocab_size] + ), + dtype=paddle.get_default_dtype(), + ) + + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.weight = dist.shard_tensor( + self.weight, + get_mesh(-1), + [dist.Replicate(), dist.Shard(1)], + ) + + logger.info( + f"output-weight:{self.weight.shape} config.tie_word_embeddings={config.tie_word_embeddings}" + ) + if config.weight_share_add_bias and config.use_bias: + self.bias = self.create_parameter( + shape=[vocab_size], + dtype=paddle.get_default_dtype(), + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.constant.Constant(0.0) + ), + ) + if ( + self.config.tensor_parallel_degree > 1 + or self.config.pipeline_parallel_degree > 1 + ): + self.bias = dist.shard_tensor( + self.bias, + get_mesh(-1), + [dist.Replicate(), dist.Shard(0)], + ) + else: + self.bias = None + + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) + if config.weight_share_add_bias and config.use_bias: + self.bias.is_distributed = ( + True if (vocab_size != config.vocab_size) else False + ) + + if self.weight.is_distributed: + self.weight.split_axis = 1 + if ( + config.weight_share_add_bias + and config.use_bias + and self.bias.is_distributed + ): + self.bias.split_axis = 0 + + if self.config.use_recompute_loss_fn: + logger.info( + "Using recompute_loss_fn, the calculation of logits will be moved into " + "loss_fn for memory optimization" + ) + + def forward(self, hidden_states, tensor_parallel_output=None): + + if self.config.use_recompute_loss_fn or self.config.use_sparse_head_and_loss_fn: + out_tensors = ( + (hidden_states, self.weight, self.bias) + if tensor_parallel_output is None + else (hidden_states, self.weight, self.bias, tensor_parallel_output) + ) + + return out_tensors + + return calc_lm_head_logits( + self.config, + hidden_states, + self.weight, + self.bias, + None, + tensor_parallel_output, + ) + + +class ErnieModelAutoPP(ErnieModelAuto): + def __init__(self, config, layer_idx=0, ipp=0): + super().__init__(config) + self.layer = ErnieDecoderLayerAuto(config, layer_idx, ipp) + self.config.use_cache = False + + def forward(self, args): + hidden_states, attention_mask, position_ids = parse_args(args) + if self.layer.layer_idx == 0: + hidden_states, attention_mask, position_ids = self.embed_inputs( + hidden_states, attention_mask, position_ids + ) + + hidden_states, all_router_loss = self.decode_layer( + self.layer, hidden_states, attention_mask, position_ids + ) + + if self.layer.layer_idx == self.config.num_hidden_layers - 1: + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + ret_args = return_args(logits=logits) + else: + ret_args = return_args( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + return ret_args + + +class ErnieForCausalLMAuto(ErniePretrainedModelAuto): + """ + ErnieForCausalLMAuto is the model class for causal language modeling. + """ + + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.sequence_parallel: + logger.info(f"using sequence_parallel, input seqlen={config.seqlen}") + if config.using_dynamic_sequence_length: + assert ( + not config.micro_batch_size + ), "sequence-parallel needs micro_batch_size setting when using dygramic_sequnence_length" + else: + assert config.seqlen is not None + + assert ( + config.tensor_parallel_degree > 1 + ), f"sequence-parallel needs mp>1, got mp={config.tensor_parallel_degree}" + + # initialize-trick for big model, see + # https://github.com/bigscience-workshop/bigscience/blob/master/train/tr11-176B-ml/README.md#std-init + new_initializer_range = math.sqrt(0.3333 / config.hidden_size) + logger.info( + f"change initializer-range from {config.initializer_range} to {new_initializer_range}" + ) + config.initializer_range = new_initializer_range + self.config = config + self.ernie = ErnieModelAuto(config) + self.lm_head = ErnieLMHead(config) + self.criterion = ErniePretrainingCriterion(config) + + self.tie_weights() # maybe weight share + + if self.config.use_rmsnorm: + if self.config.fuse_rms_norm: + logger.info("Use fusedRMSNorm") + else: + logger.info("Use normal RMSNorm") + else: + if self.config.fuse_ln: + logger.info("Use fusedLN") + else: + logger.info("Use normal LayerNorm") + decoder_layers = [] + for i in range(config.num_hidden_layers): + pp_stage_id = get_pp_stage_id( + i, config.num_hidden_layers, config.virtual_pp_degree + ) + decoder_layers.append(ErnieModelAutoPP(config, i, pp_stage_id)) + self.layers = nn.LayerList(decoder_layers) + + def _post_init(self, original_init, *args, **kwargs): + """ + Initialize weights and apply final processing + """ + super()._post_init(self, original_init, *args, **kwargs) + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + logger.info(f"using post init div: factor:{factor}") + + def scale_by_factor_if_valid(w): + if w.is_dist() and w._is_initialized(): + w.scale_(factor) + + if hasattr(self.config, "use_moe") and self.config.use_moe: + with paddle.no_grad(): + for left in self.ernie.layers: + if isinstance( + left.self_attn.o_proj, + (MOELayerAuto), + ): + for e in left.self_attn.o_proj.experts: + if isinstance(e, ErnieMoeMLP): + scale_by_factor_if_valid(e.weight) + else: + scale_by_factor_if_valid(left.self_attn.o_proj.weight) + + if isinstance( + left.mlp, + (MOELayerAuto), + ): + for e in left.mlp.experts: + if isinstance(e, ErnieMoeMLP): + scale_by_factor_if_valid(e.down_proj.weight) + else: + scale_by_factor_if_valid(left.mlp.down_proj.weight) + else: + with paddle.no_grad(): + for left in self.ernie.layers: + scale_by_factor_if_valid(left.self_attn.o_proj.weight) + scale_by_factor_if_valid(left.mlp.down_proj.weight) + + def get_input_embeddings(self): + + return self.ernie.embed_tokens + + def set_input_embeddings(self, value): + + self.ernie.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.ernie = decoder + + def get_decoder(self): + return self.ernie + + @staticmethod + def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id): + is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any( + input_ids == pad_token_id + ).numpy().item() + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + attention_mask = (input_ids != pad_token_id).astype("int64") + else: + attention_mask = paddle.ones_like(input_ids, dtype="int64") + return attention_mask + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=False, + past_key_values=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + attention_mask = kwargs.get("attention_mask", None) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": True, # use_cache, + "attention_mask": attention_mask, + "return_dict": True, + } + ) + return model_inputs + + @staticmethod + def update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False + ): + # update cache + if ( + isinstance(outputs, tuple) + and len(outputs) > 1 + and not isinstance(outputs[1], paddle.Tensor) + ): + model_kwargs["past_key_values"] = outputs[1] + + if ( + isinstance(outputs, CausalLMOutputWithCrossAttentions) + and "past_key_values" in outputs + ): + model_kwargs["past_key_values"] = outputs.past_key_values + + # update token_type_ids with last value + if ( + "token_type_ids" in model_kwargs + and model_kwargs["token_type_ids"] is not None + ): + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = paddle.concat( + [token_type_ids, token_type_ids[:, -1:]], axis=-1 + ) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [ + attention_mask, + paddle.ones([attention_mask.shape[0], 1], dtype="int64"), + ], + axis=-1, + ) + # update role_ids + if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None: + role_ids = model_kwargs["role_ids"] + model_kwargs["role_ids"] = paddle.concat( + [role_ids, role_ids[:, -1:]], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids, + labels=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=False, + ignored_index=0, + inbatch_pack_offset=None, + token_type_ids=None, + ): + if isinstance(input_ids, list): + input_ids, labels = input_ids[:2] + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.ernie( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + inbatch_pack_offset=inbatch_pack_offset, + token_type_ids=token_type_ids, + ) + + hidden_states = outputs.last_hidden_state + + logits = self.lm_head( + hidden_states, + ) # tensor_parallel_output=tensor_parallel_output) + + if return_dict: # aka Generate Decoding + if labels is not None: + loss, _ = self.criterion(logits, labels) + else: + loss = None + return CausalLMOutputWithCrossAttentionsAuto( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_loss=outputs.router_loss if self.config.use_moe else None, + ) + + assert labels is not None + router_loss = ( + outputs.router_loss + if hasattr(self.config, "use_moe") and self.config.use_moe + else None + ) + return self.criterion(logits, labels, router_loss) diff --git a/examples/pre-training/models/ernie_moe/configuration.py b/examples/pre-training/models/ernie_moe/configuration.py new file mode 100644 index 000000000..fdebf8c60 --- /dev/null +++ b/examples/pre-training/models/ernie_moe/configuration.py @@ -0,0 +1,740 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Ernie model configuration""" +import logging +import json +from typing import Union +import paddle.distributed.communication.group + +from paddleformers.transformers.configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +__all__ = [ + "ERNIE_PRETRAINED_INIT_CONFIGURATION", + "ErnieMoEConfig", +] + +ERNIE_PRETRAINED_INIT_CONFIGURATION = { + "ernie/tiny-random-ernie": { + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "ernie", + "num_attention_heads": 2, + "num_hidden_layers": 2, + "rms_norm_eps": 1e-06, + "vocab_size": 32000, + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + "use_cache": False, + "use_recompute": False, + "use_flash_attn": True, + "use_pure_fp16": False, + }, +} + + +class ErnieConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Ernie model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ErnieModel`] or [`~TFErnieModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from paddleformers.transformer import ErnieModel, ErnieConfig + + >>> # Initializing a Ernie ernie-7b style configuration + >>> configuration = ErnieConfig() + + >>> # Initializing a model from the ernie-7b style configuration + >>> model = ErnieModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + max_position_embeddings=32768, + num_hidden_layers=2, + num_attention_heads=2, + head_dim=None, + initializer_range=0.02, # no use + rms_norm_eps=1e-6, + use_cache=False, + use_flash_attn=True, + use_mem_eff_attn=False, + use_flash_attn_with_mask=False, + use_recompute=False, + use_recompute_attn=False, + recompute_use_reentrant=False, + use_rmsnorm=True, + z_loss_lambda=None, + fuse_rms_norm=False, + fuse_ln=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + fuse_attn_ffn=False, + fuse_swiglu=False, + use_bias=False, + expert_mlp_use_bias=None, + rope_reorder=True, + rope_theta=10000, + fuse_rope=False, + use_fast_ln=False, + weight_share_add_bias=True, + fuse_linear=False, + seqlen=False, + ignored_index=-100, + remove_tail_layer=False, + use_recompute_lm_head=False, + use_recompute_loss_fn=False, + use_recompute_mtp=False, + use_recompute_dnd=False, + selective_no_recompute_num=0, + use_mp_gathered_weight=False, + refined_recompute=dict(), + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + compression_ratio: float = 1.0, + quant_bits=-1, + num_key_value_heads=None, + submatrix_parallel=False, + submatrix_parallel_low_memory=True, + use_sparse_head_and_loss_fn=False, + using_dynamic_sequence_length=False, + micro_batch_size=-1, + using_precision_check=False, + use_qk_norm=False, + use_tpsp_comm_overlap=False, + offload_pp_data_chunk_size=0, + use_fused_head_loss_fn=False, + use_recompute_resampler=False, + resampler_fuse_rms_norm=False, + token_loss_equal_weight=False, + token_balance_loss=False, + token_balance_seqlen=False, + use_fp8=False, + fp8_configs=dict(), + use_fp8_mlp=False, + fp8_mem_configs=dict(), + fp8_fused_ops_configs=dict(), + drop_before_deepep=False, + deepep_drop_padding=False, + disable_pipeline_warmup=False, + skip_align_position_id=False, + rope_3d=False, + freq_allocation=0, + moe_layer_feed_fake_token=False, + decoderlayer_act_offload_settings={"type": "", "value": ""}, + loss_subbatch_seqlen=32768, + gate_force_zero_padding_grad=False, + recompute_num_layers=None, + use_combine_before_a2a=False, + use_quant_before_a2a=False, + rope_yarn_config={}, + **kwargs, + ): + if "tie_word_embeddings" not in kwargs: + kwargs["tie_word_embeddings"] = False + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_recompute_attn = use_recompute_attn + if use_recompute_attn: + logger.warning("set `use_recompute_attn`=True, disabling `use_recompute`") + use_recompute = False + self.use_recompute = use_recompute + self.recompute_num_layers = ( + recompute_num_layers + if recompute_num_layers is not None + else num_hidden_layers + ) + self.use_flash_attn = use_flash_attn + self.recompute_use_reentrant = recompute_use_reentrant + self.use_mem_eff_attn = use_mem_eff_attn + self.use_flash_attn_with_mask = use_flash_attn_with_mask + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.fuse_attn_ffn = fuse_attn_ffn + self.fuse_swiglu = fuse_swiglu + self.fuse_rms_norm = fuse_rms_norm + self.fuse_ln = fuse_ln + self.use_rmsnorm = use_rmsnorm + self.z_loss_lambda = z_loss_lambda + self.using_dynamic_sequence_length = using_dynamic_sequence_length + if using_dynamic_sequence_length: + assert ( + micro_batch_size > 0 + ), "micro_batch_size should be set when using_dynamic_sequence_length" + self.micro_batch_size = micro_batch_size + self.using_precision_check = using_precision_check + self.use_qk_norm = use_qk_norm + + self.seqlen = seqlen + self.use_bias = use_bias + self.weight_share_add_bias = weight_share_add_bias + self.rope_reorder = rope_reorder + self.rope_yarn_config = rope_yarn_config + self.rope_theta = rope_theta + self.fuse_rope = fuse_rope + self.use_fast_ln = use_fast_ln + + self.fuse_linear = fuse_linear + self.ignored_index = ignored_index + self.remove_tail_layer = remove_tail_layer + self.use_recompute_lm_head = use_recompute_lm_head + self.use_recompute_loss_fn = use_recompute_loss_fn + self.use_recompute_mtp = use_recompute_mtp + self.use_recompute_dnd = use_recompute_dnd + + self.use_mp_gathered_weight = use_mp_gathered_weight + self.selective_no_recompute_num = selective_no_recompute_num # only PP + + self.refined_recompute = refined_recompute + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.compression_ratio = compression_ratio + self.skip_recompute_ops = dict() + self.quant_bits = quant_bits + self.num_key_value_heads = num_key_value_heads + self.submatrix_parallel = submatrix_parallel + self.submatrix_parallel_low_memory = submatrix_parallel_low_memory + self.use_sparse_head_and_loss_fn = use_sparse_head_and_loss_fn + self.use_tpsp_comm_overlap = use_tpsp_comm_overlap + self.offload_pp_data_chunk_size = offload_pp_data_chunk_size + self.use_fused_head_loss_fn = use_fused_head_loss_fn + self.use_recompute_resampler = use_recompute_resampler + self.resampler_fuse_rms_norm = resampler_fuse_rms_norm + self.token_balance_loss = token_balance_loss + self.token_balance_seqlen = token_balance_seqlen + self.rope_3d = rope_3d + self.freq_allocation = freq_allocation + self.decoderlayer_act_offload_settings = decoderlayer_act_offload_settings + self.loss_subbatch_seqlen = loss_subbatch_seqlen + self.gate_force_zero_padding_grad = gate_force_zero_padding_grad + + # 默认的 fp8 设置 + default_fp8_configs = { + "quant_scheme": "DelayedScaling", + "recipe": { + "format": "hybrid", + "calibrating": True, + "amax_history_len": 1024, + "amax_compute_algo": "max", + "fuse_wgrad_accumulation": False, + "quant_weight_at_first_microbatch": False, + }, + "layers": { + "attn_fc1_linear": True, + "attn_fc2_linear": True, + "mlp_fc1_linear": True, + "mlp_fc2_linear": True, + "attn_tp_fc1_linear": True, + "attn_tp_fc2_linear": True, + "mlp_tp_fc1_linear": True, + "mlp_tp_fc2_linear": True, + }, + "smooth_swiglu": False, + } + + def update_nested_dict(default_dict, update_dict): + for key, value in update_dict.items(): + if ( + isinstance(value, dict) + and key in default_dict + and isinstance(default_dict[key], dict) + ): + update_nested_dict(default_dict[key], value) + else: + default_dict[key] = value + + # 更新默认设置 + update_nested_dict(default_fp8_configs, fp8_configs) + self.fp8_configs = default_fp8_configs + self.use_fp8 = use_fp8 + self.expert_mlp_use_bias = expert_mlp_use_bias + self.use_fp8_mlp = use_fp8_mlp + default_fp8_mem_configs = { + "shared_expert": False, + "recompute_fwd_gate_up": False, + "dequant_input": False, + } + update_nested_dict(default_fp8_mem_configs, fp8_mem_configs) + self.fp8_mem_configs = default_fp8_mem_configs + default_fp8_fused_ops_configs = { + "stack_quant": False, + "swiglu_probs_bwd": False, + "split_group_gemm": True, + } + update_nested_dict(default_fp8_fused_ops_configs, fp8_fused_ops_configs) + self.fp8_fused_ops_configs = default_fp8_fused_ops_configs + self.drop_before_deepep = drop_before_deepep + self.deepep_drop_padding = deepep_drop_padding + self.disable_pipeline_warmup = disable_pipeline_warmup + self.skip_align_position_id = skip_align_position_id + self.moe_layer_feed_fake_token = moe_layer_feed_fake_token + + if self.sequence_parallel: + assert ( + self.using_dynamic_sequence_length or self.seqlen + ), "seqlen not provided in sequence-parallel when not using dygramic sequence length" + + assert ( + self.tensor_parallel_degree > 1 + ), f"senquence-parallel only works in mp, got mp={self.tensor_parallel_degree}" + + self.register_nonsaveable_keys("use_recompute") + self.register_nonsaveable_keys("recompute_use_reentrant") + self.register_nonsaveable_keys("refined_recompute") + self.register_nonsaveable_keys("use_recompute_attn") + self.register_nonsaveable_keys("use_recompute_lm_head") + self.register_nonsaveable_keys("use_recompute_mtp") + self.register_nonsaveable_keys("use_recompute_dnd") + self.register_nonsaveable_keys("use_recompute_loss_fn") + self.register_nonsaveable_keys("using_precision_check") + self.register_nonsaveable_keys("skip_recompute_ops") + + def __setattr__(self, name: str, value): + super().__setattr__(name, value) + if getattr(self, "use_recompute", False): + assert not getattr( + self, "use_recompute_attn", False + ), "cannot set `use_recompute_attn=True` when `use_recompute=True`" + + def register_nonsaveable_keys(self, keys): + + if hasattr(super(), "register_nonsaveable_keys"): + return super().register_nonsaveable_keys(keys) + elif hasattr(super(), "register_unsavable_keys"): + return super().register_unsavable_keys(keys) + else: + raise AttributeError( + "register_nonsaveable_keys not found in PretrainedConfig" + ) + + +class ErnieMoEConfig(ErnieConfig): + r""" + This is the configuration class to store the configuration of a [`~ErnieModel`]. It is used to instantiate an Ernie + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Ernie-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Ernie model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ErnieModel`] or [`~TFErnieModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from paddleformers.transformer import ErnieModel, ErnieConfig + + >>> # Initializing a Ernie ernie-7b style configuration + >>> configuration = ErnieConfig() + + >>> # Initializing a model from the ernie-7b style configuration + >>> model = ErnieModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie" + attribute_map = { + "n_positions": "max_position_embeddings", + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "n_inner": "intermediate_size", + "activation_function": "hidden_act", + } + pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION + + def __init__( + self, + moe_num_experts: Union[int, list] = 0, + use_fake_gate=False, + use_recompute_moe=False, + moe_capacity=(), + moe_layer_interval=2, + moe_layer_start_index: Union[int, list] = 0, + moe_layer_end_index: Union[int, list] = -1, + moe_aux_loss_lambda=1e-2, + moe_z_loss_lambda=1e-4, + moe_orthogonal_loss_lambda=1e-2, + moe_use_size_all2all=False, + sinkhorn_2gate=True, + sinkhorn_temp=3e-2, + global_aux_loss=False, + moe_dropout_prob=0.0, + moe_group="world", + moe_gate="top2", + moe_num_attn_experts=False, + moe_logging=False, + num_experts_per_tok: int = 8, + moe_intermediate_size: Union[int, list] = 0, + moe_num_shared_experts: int = 0, + moe_num_dense_experts: int = 0, + moe_dense_experts_token_type_id: int = 3, + moe_multimodal_dispatch_use_allgather: str = "", + moe_multimodal_paired_experts: bool = False, + moe_reverse_token_drop: bool = False, + moe_gate_act: str = "softmax", + moe_norm_gate_logits=True, + moe_use_hard_gate: bool = False, + moe_use_bpr: bool = False, + moe_fuse_experts: bool = False, + moe_all_to_all_dropout: float = 0.0, + moe_use_token_type_bias: bool = False, + moe_k=2, + moe_use_aux_free: bool = False, + moe_group_experts: bool = False, + moe_group_orthogonal_loss: bool = False, + moe_with_send_router_loss: bool = True, + enable_delay_scale_loss: bool = True, + num_acc_steps: int = None, + insert_empty_layer: list = None, + pp_no_recompute_layer: list = None, + multi_token_pred_depth: int = 0, + multi_token_pred_lambda: float = 0.3, + fuse_gate_detach_matmul: bool = False, + enable_mtp_magic_send: bool = False, + use_elastic_topk: bool = False, + use_deepep: bool = False, + use_elastic_expert_num: bool = False, + elastic_min_expert_num: int = 0, + all_expert_ratio: float = 1.0, + use_elastic_topk_for_mbs: bool = False, + elastic_min_topk: int = 1, + elastic_max_topk: int = None, + n_group: int = 0, + topk_group: int = 0, + scaling_factor: float = None, + aux_loss_type: str = "", + deepep_fine_grained: bool = False, + deepep_use_fused: bool = False, + deepep_tokens_per_subbatch: int = 0, + use_linear_residual_norm_recompute: bool = False, + use_rms_qkv_recompute: bool = False, + build_skip_comm_buffer: bool = False, + use_norm_gate_recompute: bool = False, + moe_state_dict_use_global_expert_id: bool = False, + enable_entropy_logging: bool = False, + use_fp8_fuse_node: bool = False, + use_combine_before_a2a: bool = False, + use_fp8_dispatch_a2a: bool = False, + use_ep_comm_overlap: bool = False, + **kwargs, + ): + """ + config + """ + if use_recompute_moe: + logger.warning("set `use_recompute_moe`=True, disabling `use_recompute`") + kwargs["use_recompute"] = False + super().__init__(**kwargs) + # moe + self.use_fake_gate = use_fake_gate + self.use_recompute_moe = use_recompute_moe + self.moe_num_experts = moe_num_experts + self.moe_capacity = moe_capacity + self.moe_aux_loss_lambda = moe_aux_loss_lambda + self.moe_z_loss_lambda = moe_z_loss_lambda + self.moe_orthogonal_loss_lambda = moe_orthogonal_loss_lambda + self.global_aux_loss = global_aux_loss + self.sinkhorn_2gate = sinkhorn_2gate + self.sinkhorn_temp = sinkhorn_temp + self.moe_layer_interval = moe_layer_interval + self.moe_dropout_prob = moe_dropout_prob + self.moe_group = moe_group + self.moe_gate = moe_gate + self.moe_num_attn_experts = moe_num_attn_experts + # implemtent size-all2all as https://arxiv.org/pdf/2303.06182.pdf + self.moe_use_size_all2all = moe_use_size_all2all + self.moe_logging = moe_logging + self.num_experts_per_tok = num_experts_per_tok + self.moe_num_shared_experts = moe_num_shared_experts + self.moe_num_dense_experts = moe_num_dense_experts + self.moe_dense_experts_token_type_id = moe_dense_experts_token_type_id + self.moe_intermediate_size = moe_intermediate_size + self.moe_reverse_token_drop = moe_reverse_token_drop + self.moe_use_hard_gate = moe_use_hard_gate + self.moe_fuse_experts = moe_fuse_experts + self.moe_k = moe_k + self.moe_all_to_all_dropout = moe_all_to_all_dropout + self.moe_use_token_type_bias = moe_use_token_type_bias + self.moe_use_bpr = moe_use_bpr + self.moe_group_experts = moe_group_experts + self.moe_group_orthogonal_loss = moe_group_orthogonal_loss + # optimize send without router loss + self.moe_with_send_router_loss = moe_with_send_router_loss + self.enable_delay_scale_loss = enable_delay_scale_loss + self.num_acc_steps = num_acc_steps + self.moe_layer_start_index = moe_layer_start_index + self.moe_layer_end_index = ( + self.num_hidden_layers - 1 + if moe_layer_end_index == -1 + else moe_layer_end_index + ) + self.moe_multimodal_dispatch_use_allgather = ( + moe_multimodal_dispatch_use_allgather + ) + self.moe_multimodal_paired_experts = moe_multimodal_paired_experts + self.moe_gate_act = moe_gate_act + self.moe_norm_gate_logits = moe_norm_gate_logits + self.moe_use_aux_free = moe_use_aux_free + self.fuse_gate_detach_matmul = fuse_gate_detach_matmul + if insert_empty_layer is not None: + assert isinstance( + insert_empty_layer, list + ), "insert_empty_layer should be a list" + else: + insert_empty_layer = [] + + # Overlap A2A communication with shared expert and auxiliary loss. + self.use_ep_comm_overlap = use_ep_comm_overlap + # Move the combine operation before A2A communication. + self.use_combine_before_a2a = use_combine_before_a2a + # Use FP8 for dispatch communication. + self.use_fp8_dispatch_a2a = use_fp8_dispatch_a2a + + # Multi-Token Prediction (MTP) + self.multi_token_pred_depth = multi_token_pred_depth + self.multi_token_pred_lambda = multi_token_pred_lambda + self.enable_mtp_magic_send = enable_mtp_magic_send + + # The insert_empty_layer is a list of integer which will be used under pipeline parallel. + # After each layer indicated in the insert_empty_layer, an empty layer will be inserted. + # For example, a model with 4 layers, insert_empty_layer = [1, 3], the model actually passed to + # pp is: transformer, transformer, EMPTY, transformer, transformer, EMPTY + self.insert_empty_layer = insert_empty_layer + + # elastic + self.use_elastic_topk = use_elastic_topk + self.use_elastic_expert_num = use_elastic_expert_num + self.elastic_min_expert_num = elastic_min_expert_num + self.all_expert_ratio = all_expert_ratio + self.use_elastic_topk_for_mbs = use_elastic_topk_for_mbs + self.elastic_min_topk = elastic_min_topk + if elastic_max_topk is None: + self.elastic_max_topk = self.moe_k * 2 - 1 + + # Using fusion expert node in moe layer. + self.use_fp8_fuse_node = use_fp8_fuse_node + + # Perform MoE computation at expert granularity. + self.deepep_fine_grained = deepep_fine_grained + # Requires deepep_fine_grained to be enabled; further disperses token + # granularity within experts to compute subbatches. + self.deepep_tokens_per_subbatch = deepep_tokens_per_subbatch + # Fuse combine and scatter operations when using BF16 for expert computation. + self.deepep_use_fused = deepep_use_fused + + assert not ( + self.use_combine_before_a2a and self.use_deepep + ), "combine_before_a2a is not supported for deepep now." + + assert not ( + self.use_fp8_dispatch_a2a and not self.use_fp8_fuse_node + ), "fp8_dispatch_a2a must be used with use_fp8_fuse_node." + + assert not ( + self.use_fp8_dispatch_a2a and self.use_ep_comm_overlap + ), "fp8_dispatch_a2a connot be used with use_ep_comm_overlap." + + if self.deepep_tokens_per_subbatch: + assert ( + self.deepep_fine_grained + ), "deepep_fine_grained must be enabled when deepep_tokens_per_subbatch is set." + + # node limit routing + self.n_group = n_group + self.topk_group = topk_group + + # router scaling_factor + self.scaling_factor = scaling_factor + + self.build_skip_comm_buffer = build_skip_comm_buffer + + # router loss type + assert aux_loss_type in ["", "default", "seq_aux_loss", "switch_aux_loss"] + self.aux_loss_type = aux_loss_type + + self.use_deepep = use_deepep + if self.moe_multimodal_paired_experts and isinstance( + self.moe_num_experts, (tuple, list) + ): + logger.warning( + "moe_num_experts must be one element when using paired experts" + ) + self.moe_num_experts = self.moe_num_experts[0] + + if pp_no_recompute_layer is not None: + assert isinstance( + insert_empty_layer, list + ), "pp_no_recompute_layer should be a list" + + # Indicating layers not do recompute under pipeline parallel. + # Note that, when insert_empty_layer is not None, the pp_no_recompute_layer should be indicating + # layers number in origin model structure, AKA model before insert empty layers. + self.pp_no_recompute_layer = pp_no_recompute_layer + self.register_nonsaveable_keys("moe_group") + self.register_nonsaveable_keys("pp_no_recompute_layer") + + if ( + self.moe_group in ["dp", "data"] + and self.moe_multimodal_dispatch_use_allgather + ): + assert ( + self.moe_num_shared_experts == 0 + ), "shared experts are not supported when using dp moe and moe_allgather_layer" + assert ( + self.moe_num_dense_experts == 0 + ), "dense experts are not supported when using dp moe and moe_allgather_layer" + + self.use_linear_residual_norm_recompute = use_linear_residual_norm_recompute + self.use_rms_qkv_recompute = use_rms_qkv_recompute + self.use_norm_gate_recompute = use_norm_gate_recompute + self.moe_state_dict_use_global_expert_id = moe_state_dict_use_global_expert_id + self.enable_entropy_logging = enable_entropy_logging + + @property + def multimodel_experts(self) -> bool: + + return ( + isinstance(self.moe_num_experts, (tuple, list)) + and len(self.moe_num_experts) > 1 + ) + + @property + def use_moe(self) -> bool: + """_summary_ + + Returns: + bool: _description_ + """ + return ( + sum(self.moe_num_experts) > 0 + if self.multimodel_experts + else self.moe_num_experts > 0 + ) + + def __setattr__(self, name: str, value): + super().__setattr__(name, value) + if getattr(self, "use_recompute", False): + assert not getattr( + self, "use_recompute_moe", False + ), "cannot set `use_recompute_moe=True` when `use_recompute=True`" + + def to_json_string(self, use_diff: bool = True) -> str: + + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + + def _serializer(obj): + if isinstance(obj, paddle.distributed.communication.group.Group): + return repr(obj) + raise TypeError(f"Type {type(obj)} is not serializable") + + return ( + json.dumps( + config_dict, + indent=2, + sort_keys=True, + ensure_ascii=False, + default=_serializer, + ) + + "\n" + ) diff --git a/examples/pre-training/models/moe/moe_layer_auto.py b/examples/pre-training/models/moe/moe_layer_auto.py new file mode 100644 index 000000000..0b7fc0cf7 --- /dev/null +++ b/examples/pre-training/models/moe/moe_layer_auto.py @@ -0,0 +1,851 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""_summary_ + +Returns: + _type_: _description_ +""" +from typing import Any, Tuple, List, Optional, Callable +import logging +from collections import namedtuple +from contextlib import contextmanager +from functools import partial + +import paddle +from paddle import framework +from paddle import nn +from paddle.distributed.communication import stream +import paddle.nn.functional as F +from paddle.distributed import in_auto_parallel_align_mode + +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import Group +from paddle.distributed import fleet + +import paddle.distributed as dist +from paddle import Tensor +from paddleformers.trainer.plugins.timer import get_timers + + +from models.moe.top2_gate_auto import TopKGateFusedAuto +from models.moe.moe_utils_auto import get_flatten_mesh, get_mesh, _reshard +from models.moe.moe_layer_auto_utils import MOELayer + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} + + +logger = logging.getLogger(__name__) + +try: + import moe_ops +except ImportError: + moe_ops = None + logger.warning( + "`moe-ops` not found, run " + "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) + +try: + import moe_ops_auto +except ImportError: + moe_ops_auto = None + logger.warning( + "`moe_ops_auto` not found, run " + "`python3 src/ernie_core/ops/moe/setup_auto.py install` to install" + ) + +try: + import moe_combine_auto +except ImportError: + moe_combine_auto = None + logger.warning( + "`moe_combine_auto` not found, run " + "`python3 src/ernie_core/ops/moe/setup_auto.py install` to install" + ) + + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + + +@contextmanager +def profile(name): + """doc""" + if get_timers() is not None: + get_timers()(name).start() + yield + if get_timers() is not None: + get_timers()(name).stop() + + +class GateCombineForStatic(PyLayer): + """GateCombine""" + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + """ + Input: + x: [seqlen * k, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + y: [seqlen, hidden_size] + """ + ctx.save_for_backward(x, combine_weights, scatter_index) + assert moe_combine_auto is not None + return moe_combine_auto.moe_combine_auto(x, combine_weights, scatter_index) + + @staticmethod + def backward(ctx, grad_y, *_): + """ + Input: + grad_y: [seqlen, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + grad_x: [seqlen * k, hidden_size] + grad_combine_weight: [seqlen, k] + + """ + x, combine_weights, scatter_index = ctx.saved_tensor() + assert moe_combine_auto is not None + grad_x, grad_combine_weight_helper = moe_combine_auto.moe_combine_bwd_auto( + x, combine_weights, scatter_index, grad_y + ) + # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] + # reduce the hidden shape + # TODO: implement reduce in cuda ops + grad_combine_weight = grad_combine_weight_helper.sum(-1) + # NOTE: PyLayer do not support some inputs with stop_gradient=True in static mode, + # this means that there must be a gradient for each input + scatter_index_grad = paddle.zeros_like(scatter_index) + return grad_x, grad_combine_weight, scatter_index_grad + + +class GateCombine(PyLayer): + """GateCombine""" + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + """ + Input: + x: [seqlen * k, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + y: [seqlen, hidden_size] + """ + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + assert moe_combine_auto is not None + return moe_combine_auto.moe_combine_auto(x, combine_weights, scatter_index) + + @staticmethod + def backward(ctx, grad_y, *_): + """ + Input: + grad_y: [seqlen, hidden_size] + combine_weights: [seqlen, k] + scatter_index: [seqlen, k] + Output: + grad_x: [seqlen * k, hidden_size] + grad_combine_weight: [seqlen, k] + + """ + + assert moe_combine_auto is not None + grad_x, grad_combine_weight_helper = moe_combine_auto.moe_combine_bwd_auto( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + # grad_combine_weight_helper is the same shape with grad x [seqlen * K, dim] + # reduce the hidden shape + # TODO: implement reduce in cuda ops + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + + +def combining_fused_auto(x, combine_weights, scatter_index, hard_gate=False): + """ + Args: + x: Tensor[seq, dim] + combine_weights: [s, k] + scatter_index: ** [k, s] ** + + Returns: + y: Tensor[s, dim] + """ + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + return x_gatherd.squeeze(-2) + ret = moe_combine_auto.moe_combine_auto(x, combine_weights, scatter_index) + + ret.stop_gradient = False + return ret + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + + output = None + # init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype='float32') + # output = init_output + 0. * x.sum() + orig_dtype = x.dtype + scatter_index = scatter_index.unbind(1) + dispatch_mask = dispatch_mask.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros( + [num_experts * capacity, x.shape[-1]], dtype="float32" + ) + updates = x * i_dispatch_mask.unsqueeze(-1).cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + + dim = x.shape[-1] + scatter_index = scatter_index.reshape([-1]) + num_k = combine_weights.shape[-1] + x = dist.reshard(x, get_mesh(0), [dist.Replicate(), dist.Shard(0)]) + combine_weights = combine_weights.unsqueeze(1) + # num_k = 2 + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) # [seq,2,dim] + return paddle.matmul(combine_weights, x).squeeze( + 1 + ) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] + + +class AlltoAll(PyLayer): + """ + AlltoAll w/ backward + """ + + @staticmethod + def forward(ctx, x, group): + """ + All-to-all communication in the group. + """ + ctx.group = group + if dist.get_world_size(group) <= 1: + return x + output = paddle.empty_like(x) + output.stop_gradient = False + with profile("moe-all2all"): + stream.alltoall_single(output, x, None, None, group, True, True) + return output + + @staticmethod + def backward(ctx, *dx): + """backward""" + return AlltoAll.apply(*dx, group=ctx.group) + + +class AlltoAllAsync(PyLayer): + """ + AlltoAll async w/ backward + """ + + @staticmethod + def forward(ctx, x, *fn_args, group=None, fn=None, is_first_fwd=False): + """ + All-to-all communication in the group. + Args: + x: Tensor + args: List[Any], argument(s) to `fn` + group: ProcessGroup + fn: callable, called while doing alltoall + is_first_fwd: if using recompute, don't record bacward when first forward + Returns: + x: Tensor + fn_out: List[Tensor] + """ + assert fn is not None, "use AlltoAll no async" + ctx.group = group + if dist.get_world_size(group) <= 1: + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + return (x,) + fn_out + x_out = paddle.empty_like(x) + x_out.stop_gradient = False + with profile("moe-all2all"): + task = stream.alltoall_single( + x_out, + x, + None, + None, + group, + sync_op=False, + ) + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + task.wait() + return (x_out,) + fn_out + + @staticmethod + def backward(ctx, dx_out, *fn_out_grads): + """backward""" + if dist.get_world_size(ctx.group) <= 1: + fn_args_grads = ctx.bwf(*fn_out_grads) + return (dx_out,) + fn_args_grads + + dx = paddle.empty_like(dx_out) + dx.stop_gradient = False + with profile("moe-all2all"): + task = stream.alltoall_single( + dx, + dx_out, + None, + None, + ctx.group, + sync_op=False, + ) + fn_args_grads = ctx.bwf(*fn_out_grads) + task.wait() + return (dx,) + fn_args_grads + + +def detach_and_requires_grad_(*args): + """detach_and_requires_grad_""" + ret = [a.detach() if a is not None else None for a in args] + for r, a in zip(ret, args): + if a is not None: + r.stop_gradient = a.stop_gradient + return ret + + +def manual_backward(f: Callable, is_first_fwd: bool, *args: List[Any]): + """ + Args: + f(callable) + args(*Any) + Returns + bw_f(callable): manual backward fn + out(List[Tensor]): output of f(*args) + """ + tracer = framework._dygraph_tracer() + orig = tracer._has_grad + if not is_first_fwd: + tracer._has_grad = True # turn on grad trace so we can manual backward + + detached_args = detach_and_requires_grad_(*args) + detached_args_clone = [a.clone() if a is not None else None for a in detached_args] + out = f(*detached_args_clone) + for a in detached_args: + if a is not None: + a._clear_dataptr() # free mem + if isinstance(out, list): + out = tuple(out) + elif not isinstance(out, tuple): + out = (out,) + + if is_first_fwd: + tracer._has_grad = orig + return None, out + + out_cached = [ + o.clone() for o in out if o is not None and not o.stop_gradient + ] # do not cache stop_gradient output + for o in out_cached: + o._clear_dataptr() # free mem + tracer._has_grad = orig + + def bwd_f(*grad): + nonlocal out_cached, detached_args, f + grad = list(grad) + grad = [g for g in grad if g is not None] + assert len(grad) == len(out_cached), (len(grad), len(out_cached), f) + # out, grad = zip(*[(o, g) for o, g in zip(out, grad) if g is not None]) + paddle.autograd.backward(out_cached, grad) + return tuple([t.grad if t is not None else None for t in detached_args]) + + return bwd_f, out + + +def bpr_preprocess(input, logits, capacity, buffer): + """impletment bpr sorting""" + assert input.ndim == 2, input.shape + idx = paddle.argsort(logits.max(-1), axis=0, descending=True) + input = input[idx] + logits = logits[idx] + buffer["idx"] = idx + return input, logits + + +def bpr_postprocess(output, buffer): + """bpr sorting""" + idx = buffer.pop("idx") + rev_idx = paddle.argsort(idx) + output = output[rev_idx] + return output + + +class MOELayerAuto(MOELayer): + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_pbr: bool = False, + all_to_all_dropout=0, + group_experts=False, + config=None, + ipp=0, + ): + nn.Layer.__init__(self) + self.config = config + self.gate = gate + self.layer_idx = layer_idx + self.ipp = ipp + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.enable_logging = enable_logging + is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") + and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + is_dummy_moe = config.moe_world_size == 1 + + for p in experts.parameters(): + p.expert = not (is_mp_moe or is_dummy_moe) # type: ignore + p.no_sync = not (is_mp_moe or is_dummy_moe) + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if is_mp_moe or is_mp_moe: + p.is_distributed = True + + self.world_size = config.moe_world_size + if self.group in fleet.auto.get_mesh().dim_names: + self.rank = fleet.auto.get_mesh().get_rank_by_dim_and_process_id( + self.group, dist.get_rank() + ) + if self.rank < 0: + self.rank = 0 + else: + self.rank = 0 + + self.num_experts_per_group = len(self.experts) + self.ep_group_num = config.moe_world_size + self.num_local_experts = self.num_experts_per_group // self.ep_group_num + + self.moe_mesh_dim = 0 if config.moe_group == "dp" else 1 + self.dispatch_by_task = ( + hasattr(self.gate, "dispatch_by_task") and self.gate.dispatch_by_task + ) + + if self.dispatch_by_task: + assert 0, "no supported, checkout earylier code" + assert self.num_local_experts == 1 + + if enable_pbr: + logger.info("using BPR") + prepost_process_buffer = {} + self.input_preprocess = partial( + bpr_preprocess, buffer=prepost_process_buffer + ) + self.output_postprocess = partial( + bpr_postprocess, buffer=prepost_process_buffer + ) + else: + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + + def _cal_multimodel_experts_prob( + self, gate_logits, token_type_ids, group_experts, moe_k + ): + + if not self.gate.experts_type_ids.is_dist(): + self.gate.experts_type_ids = dist.shard_tensor( + self.gate.experts_type_ids, + get_mesh(), + [dist.Replicate(), dist.Replicate()], + ) + return super()._cal_multimodel_experts_prob( + gate_logits, token_type_ids, group_experts, moe_k + ) + + def forward_experts(self, dispatched_input): + """ + call experts sequently + Args: + dispatched_input: Tensor[num_experts, capacity, dim] + Returns: + expert_output: Tensor[num_experts, capacity, dim] + """ + assert isinstance(self.experts, nn.LayerList) + if self.config.moe_group == "mp": + local_input_list = dist.auto_parallel.api.moe_sub_mesh_tensors( + dispatched_input, + get_mesh(self.ipp), + self.moe_mesh_dim, + [dist.Shard(2), dist.Shard(0)], + ) + + assert len(self.experts) % len(local_input_list) == 0, ( + "num of experts must be divided by num of ep_group, " + f"but got {len(self.experts)} and {len(local_input_list)}" + ) + expert_group_outputs = [] + for i_ep_group, local_input in enumerate(local_input_list): + chunks = local_input.unbind(1) + experts = self.experts[ + i_ep_group + * self.num_local_experts : (i_ep_group + 1) + * self.num_local_experts + ] + ep_output = [] + assert len(experts) == len( + chunks + ), f"num of experts must be equal to num of chunks, but got {len(experts)} and {len(chunks)}" + for chunk_id, (chunk, expert) in enumerate(zip(chunks, experts)): + ep_output += [expert(chunk)] + expert_group_outputs += [paddle.stack(ep_output, axis=1)] + return expert_group_outputs + else: + chunks = dispatched_input.unbind(1) + expert_outputs = [] + assert len(chunks) == len(self.experts), (len(chunks), len(self.experts)) + for chunk, expert in zip(chunks, self.experts): + expert_outputs += [expert(chunk)] + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + def gate_and_distpach(self, input, token_type_ids): + """ + calc gate and dispatch inputs (and do logging, optionaly) + Args: + input: Tensor[seq, dim], float + token_type_ids: Tensor[seq], int + Returns: + dispatched_input: Tensor[num_experts, capacity, dim] + combine_weights: [seq, k] + scatter_index: [seq, k] + router_loss: scalar + gate_logits: [seq, num_experts] + """ + with profile("moe-gate"): + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + use_fuse = isinstance(self.gate, (TopKGateFusedAuto)) + if use_fuse: + (gate_logits, capacity, router_loss, local_capacity) = self.gate( + input, *args + ) + else: + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + gate_logits, + ) = self.gate(input, *args) + prob = None + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + + with profile("moe-dispatch"): + if use_fuse: + # capacity no use + k = self.k + prob, max_prob = self.fused_gate_logits_process( + gate_logits, token_type_ids + ) + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_ops_auto.moe_gate_dispatch_auto( + input, prob, k, local_capacity, True + ) + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + # NOTE: PyLayer do not support some inputs with stop_gradient=True in static mode + # it's a bug that will be fixed in the future + # scatter_index.stop_gradient = True + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add( + p, paddle.nonzero(token_type_ids == 0), -1 + max_prob + ) + else: + p = max_prob + combine_weights_unnorm = ( + combine_weights_unnorm.unsqueeze(-1) * p + ).squeeze(-1) + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape( + [p.shape[0], -1] + ) + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + combine_weights = combine_weights.cast(dispatched_input.dtype) + else: + dispatched_input = dispatching( + input, + dispatch_mask, + scatter_index, + num_experts=self.config.moe_num_experts, + capacity=capacity, + ) + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = True + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + ) + + def combine_expert_output(self, expert_output, combine_weights, scatter_index): + """ + Combine Expert output + Args: + expert_output: Tensor[num_experts, caapcity, dim] + combine_weights: + Returns: + combined_output: Tensor[seqlen, dim] + """ + with profile("moe-combine"): + if self.config.moe_use_all2all and self.config.moe_group == "mp": + expert_output = dist.auto_parallel.moe_utils._dist_reshape( + expert_output, + [-1, expert_output.shape[-1]], + get_flatten_mesh(get_mesh(self.ipp)), + [dist.Shard(0)], + ) + else: + expert_output = expert_output.reshape( + [-1, expert_output.shape[-1]] + ) # [e*c,m] + + if not self.config.moe_use_all2all: + if self.config.moe_group == "mp": + expert_output = dist.reshard( + expert_output, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + expert_output = dist.reshard( + expert_output, get_mesh(), [dist.Shard(0), dist.Replicate()] + ) + use_fuse = isinstance(self.gate, (TopKGateFusedAuto)) + combine_fn = combining_fused_auto if use_fuse else combining + combined_output = combine_fn(expert_output, combine_weights, scatter_index) + + if self.output_postprocess is not None: + combined_output = self.output_postprocess(combined_output) + return combined_output + + def forward( + self, + input: Tensor, + token_type_ids=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + Args: + input (`Tensor`): The input data with shape ``(s, d)``. + Only one token is supported for now. + token_type_ids (`Tensor`) int64 tensor with shape (s), + if specified, rount tensor according to `token_type_ids`. + Returns: + output (`Tensor`): The final output tensor with shape ``(s, d)`` where ``m`` is the + size of model parameters. + combine_weights (`Tensor`, optional): A tensor with shape ``(s,)``, which represents weights + for each expert in MoE. + router_loss (`Tensor`, optional): A scalar tensor representing the loss of routing function. + """ + if self.shared_experts is not None: + shared_expert_input = dist.reshard( + input, + get_mesh(self.ipp), + [dist.Shard(1), dist.Replicate()], + ) + if input.ndim == 3: + orig_shape = input.shape + input = dist.reshard( + input, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)] + ) + if self.config.moe_use_all2all: + input = dist.auto_parallel.moe_utils._dist_reshape( + input, + [-1, input.shape[-1]], + get_flatten_mesh(get_mesh(self.ipp)), + [dist.Shard(0)], + ) + else: + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert ( + len(input.shape) == 2 + ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + seqlen, d_model = input.shape + + if token_type_ids is not None: + token_type_ids = token_type_ids.clone()[:, :-1] + if self.config.sequence_parallel: + token_type_ids = token_type_ids.reshape([-1]) + # token_type_ids = ScatterOp.apply(token_type_ids) + token_type_ids.stop_gradient = True + + assert self.gate is not None + if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + orig_shape_2 = input.shape + output = self.forward_experts(input) + output += self.gate.weight.sum() * 0.0 # hack for grad + output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] + return output, None, 0 + + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + ) = self.gate_and_distpach(input, token_type_ids) + if self.config.moe_use_all2all and self.config.moe_group == "mp": + dispatched_input = _reshard( + dispatched_input, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(1)] + ) + if self.config.moe_group == "mp": + # TODO(zhangyichen): 统一 moe_group 是 mp 和其他情况下的代码 + dispatched_input = dist.reshard( + dispatched_input, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(0)] + ) + + if self.shared_experts is not None: + shared_out = self.shared_experts(shared_expert_input) + dispatched_input = dispatched_input.reshape( + [self.config.moe_world_size, self.num_local_experts, -1, d_model] + ) + expert_out = self.forward_experts(dispatched_input) + if self.config.moe_group == "mp": + expert_out = dist.auto_parallel.api.moe_global_mesh_tensor( + expert_out, + get_mesh(self.ipp), + [dist.Shard(2), dist.Shard(0)], + self.moe_mesh_dim, + ) + expert_out = dist.auto_parallel.moe_utils._dist_reshape( + expert_out, + [self.config.moe_world_size * self.num_local_experts, -1, d_model], + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + expert_out = dist.reshard( + expert_out, get_mesh(self.ipp), [dist.Shard(1), dist.Shard(1)] + ) + if not in_auto_parallel_align_mode(): + router_loss2 = self.calc_router_loss_and_logging( + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + ) + else: + router_loss2 = router_loss + router_loss2 = dist.shard_tensor( + router_loss2, get_flatten_mesh(get_mesh(self.ipp)), [dist.Replicate()] + ) + combined_output = self.combine_expert_output( + expert_out, combine_weights, scatter_index + ) + + if self.shared_experts is not None: + shared_out = dist.auto_parallel.moe_utils._dist_reshape( + shared_out, + [-1, shared_out.shape[-1]], + get_flatten_mesh(get_mesh(self.ipp)), + [dist.Shard(0)], + ) + combined_output += shared_out + + if orig_shape: + if self.config.moe_use_all2all: + combined_output = dist.auto_parallel.moe_utils._dist_reshape( + combined_output, + orig_shape[:-1] + [combined_output.shape[-1]], + get_mesh(self.ipp), + [dist.Shard(1), dist.Shard(0)], + ) + router_loss2 = _reshard( + router_loss2, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + combined_output = combined_output.reshape( + orig_shape[:-1] + [combined_output.shape[-1]] + ) + return combined_output, combine_weights, router_loss2, gate_logits diff --git a/examples/pre-training/models/moe/moe_layer_auto_utils.py b/examples/pre-training/models/moe/moe_layer_auto_utils.py new file mode 100644 index 000000000..f9ad59958 --- /dev/null +++ b/examples/pre-training/models/moe/moe_layer_auto_utils.py @@ -0,0 +1,1454 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, List, Optional +import logging +from collections import namedtuple +import inspect + +import paddle +from paddle import framework +from paddle import nn +from paddle.distributed.communication import stream +import paddle.nn.functional as F + +from paddle.autograd import PyLayer +from paddle.distributed.communication.group import Group +from paddle.distributed.fleet.utils import recompute +from paddle.distributed import fleet +from paddle.distributed import in_auto_parallel_align_mode + +import paddle.distributed as dist +from paddle import Tensor + +from models.moe.top2_gate_auto_auto import ( + TopKGateFused, + cast_if_needed, +) +from models.sequence_parallel_utils_auto import ScatterOp +from models.utils import ( + global_training_logs_enabled, + manual_backward, +) + +from models.comm_utils import profile + + +from paddle.incubate.nn.functional import ( + moe_combine, +) + + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} +try: + import moe_router_loss_ops +except ImportError: + moe_router_loss_ops = None + + +logger = logging.getLogger(__name__) + + +try: + import moe_ops +except ImportError: + moe_ops = None + logger.warning( + "`moe-ops` not found, run " + "`python3 src/ernie_core/ops/moe/setup.py install` to install" + ) + +try: + import moe_ops_fp8 +except ImportError: + moe_ops_fp8 = None + logger.warning( + "`moe-ops` not found, run " + "`python3 src/ernie_core/ops/moe/setup_fp8.py install` to install" + ) + +try: + from moe_combine import moe_combine_no_weight +except ImportError: + moe_combine_no_weight = None + + +try: + import fused_ln as fused +except ImportError: + logger.warning( + "fused-ln not found, run `python src/ops/fused_ln_setup.py install` to build fused ln" + ) + fused = None + +try: + from custom_setup_ops import matmul_bwd +except ImportError: + matmul_bwd = None + + +GateOutput = namedtuple( + "GateOutput", + [ + "aux", + "z", + "logits", + ], +) + + +class GateCombine_ori(PyLayer): + + @staticmethod + def forward(ctx, x, combine_weights, scatter_index): + ctx.x = x + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + assert moe_combine is not None + ret = moe_combine.moe_combine(x, combine_weights, scatter_index) + return ret + + @staticmethod + def backward(ctx, grad_y, *_): + assert moe_combine is not None + grad_x, grad_combine_weight_helper = moe_combine.moe_combine_bwd( + ctx.x, ctx.combine_weights, ctx.scatter_index, grad_y + ) + + grad_combine_weight = grad_combine_weight_helper.sum(-1) + return grad_x, grad_combine_weight.reshape(ctx.combine_weights.shape), None + + +def combining_fused(x, combine_weights, scatter_index, hard_gate=False): + + if hard_gate: + x_gatherd = F.embedding(scatter_index, x) # [s,k,dim] + return x_gatherd.squeeze(-2) + ret = GateCombine_ori.apply(x, combine_weights, scatter_index) + ret.stop_gradient = False + return ret + + +def recompute_fwd_gate_up_func(config, layer_idx): + + if "recompute_fwd_gate_up" in config.fp8_mem_configs: + if isinstance(config.fp8_mem_configs["recompute_fwd_gate_up"], bool): + return config.fp8_mem_configs["recompute_fwd_gate_up"] + if isinstance(config.fp8_mem_configs["recompute_fwd_gate_up"], list): + return layer_idx in config.fp8_mem_configs["recompute_fwd_gate_up"] + + return False + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + + output = None + # init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype='float32') + # output = init_output + 0. * x.sum() + orig_dtype = x.dtype + scatter_index = scatter_index.unbind(1) + dispatch_mask = dispatch_mask.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros( + [num_experts * capacity, x.shape[-1]], dtype="float32" + ) + updates = x * i_dispatch_mask.unsqueeze(-1).cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + + dim = x.shape[-1] + scatter_index = scatter_index.reshape([-1]) + num_k = combine_weights.shape[-1] + combine_weights = combine_weights.unsqueeze(1) + # num_k = 2 + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) # [seq,2,dim] + return paddle.matmul(combine_weights, x).squeeze( + 1 + ) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] + + +def fuse_logging(gate_logits, combine_weights, token_type_ids): + with paddle.no_grad(): + gate_expert_per_token_type_0, gate_expert_per_token_type_1 = None, None + gate_experts_per_token = None + ce = moe_router_loss_ops.cal_cross_entropy_info(gate_logits).mean(0) + if token_type_ids is not None: + ( + gate_expert_per_token_type_0, + gate_expert_per_token_type_1, + gate_experts_per_token, + ) = moe_router_loss_ops.cal_gate_experts_per_token_info( + combine_weights, token_type_ids + ) + else: + gate_experts_per_token = paddle.count_nonzero(combine_weights) / ( + gate_logits.shape[0] + ) + + return ( + gate_expert_per_token_type_0, + gate_expert_per_token_type_1, + gate_experts_per_token, + ce, + ) + + +class Fp8MoeGateDispatchAndQuant(paddle.autograd.PyLayer): + + @staticmethod + def forward( + ctx, x, gate_logtis, corr_bias, k, capacity, use_pad, use_pow2_scale=True + ): + ( + out_fp8, + scale, + combine_weights, + scatter_index, + expert_offset, + expert_id, + ) = moe_ops_fp8.moe_gate_dispatch_and_quant( + x, + gate_logtis, + corr_bias=corr_bias, + k=k, + capacity=capacity, + use_pad=use_pad, + use_pow2_scale=use_pow2_scale, + ) + assert out_fp8.shape[0] == scale.shape[0] + + out_fp8.stop_gradient = False + combine_weights.stop_gradient = False + scatter_index.stop_gradient = True + expert_offset.stop_gradient = True + expert_id.stop_gradient = True + scale.stop_gradient = True + + ctx.k = k + ctx.capacity = capacity + ctx.use_pad = use_pad + ctx.combine_weights = combine_weights + ctx.scatter_index = scatter_index + ctx.expert_id = expert_id + ctx.has_corr_bias = corr_bias is not None + + return ( + out_fp8, + combine_weights, + scatter_index, + expert_offset, + expert_id, + { + "scale": scale, + }, + ) + + @staticmethod + def backward(ctx, *grads): + out_grad, combine_weights_grad = grads[0], grads[1] + x_grad, gate_logits_grad = moe_ops.moe_gate_dispatch_bwd( + ctx.combine_weights, + ctx.scatter_index, + ctx.expert_id, + out_grad, + combine_weights_grad, + k=ctx.k, + capacity=ctx.capacity, + use_pad=ctx.use_pad, + ) + if ctx.has_corr_bias: + return x_grad, gate_logits_grad, None + else: + return x_grad, gate_logits_grad + + +class AlltoAll(PyLayer): + + @staticmethod + def forward(ctx, x, group, sync_op=True): + + ctx.group = group + if dist.get_world_size(group) <= 1: + return x + output = paddle.empty_like(x) + output.stop_gradient = False + task = stream.alltoall_single( + output, x, None, None, group, sync_op=sync_op, use_calc_stream=sync_op + ) + if not sync_op: + return output, task + else: + return output + + @staticmethod + def backward(ctx, *dx): + return AlltoAll.apply(*dx, group=ctx.group) + + +class AlltoAllExpertOverlap(PyLayer): + + @staticmethod + def forward( + ctx, input, group, num_local_experts, forward_func_dict, is_first_fwd=False + ): + assert ( + dist.get_world_size(group) > 1 + ), "AlltoAllExpertOverlap is not supported for a world size less than or equal to 1." + + ctx.bw_funcs = {} + ctx.group = group + ctx.num_local_experts = num_local_experts + + assert isinstance(forward_func_dict, nn.LayerList) + all2all_tasks = [] + all2all_ins = paddle.unbind(input, axis=0) + for stage_id in range(1): + stage_input = all2all_ins[stage_id] + x_out, task = AlltoAll.apply(stage_input, group=group, sync_op=False) + all2all_tasks.append((task, x_out)) + + expert_outputs = [] + for stage_id in range(num_local_experts): + if stage_id + 1 != num_local_experts: + stage_input = all2all_ins[stage_id + 1] + x_out, task = AlltoAll.apply(stage_input, group=group, sync_op=False) + all2all_tasks.append((task, x_out)) + + task, dispatched_input = all2all_tasks[stage_id] + task.wait() + bwf, (expert_outputs_cur_stage,) = manual_backward( + forward_func_dict[stage_id], is_first_fwd, dispatched_input + ) + ctx.bw_funcs[stage_id] = bwf + expert_outputs.append(expert_outputs_cur_stage) + + expert_output = paddle.stack(expert_outputs, axis=1) + return expert_output + + @staticmethod + def backward(ctx, out_grad): + all2all_tasks = [] + expert_outputs = [] + + out_grad_list = paddle.split( + out_grad, num_or_sections=out_grad.shape[1], axis=1 + ) + for stage_id in range(ctx.num_local_experts): + (grad_cur_stage,) = ctx.bw_funcs[stage_id](out_grad_list[stage_id]) + + x_out, task = AlltoAll.apply(grad_cur_stage, group=ctx.group, sync_op=False) + all2all_tasks.append(task) + expert_outputs.append(x_out) + + for task in all2all_tasks: + task.wait() + + expert_output = paddle.stack(expert_outputs, axis=0) + return expert_output + + +class AlltoAllAsync(PyLayer): + + @staticmethod + def forward(ctx, x, *fn_args, group=None, fn=None, is_first_fwd=False): + + assert fn is not None, "use AlltoAll no async" + ctx.group = group + if dist.get_world_size(group) <= 1: + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + return (x,) + fn_out + x_out = paddle.empty_like(x) + x_out.stop_gradient = False + task = stream.alltoall_single( + x_out, + x, + None, + None, + group, + sync_op=False, + ) + ctx.bwf, fn_out = manual_backward(fn, is_first_fwd, *fn_args) + task.wait() + return (x_out,) + fn_out + + @staticmethod + def backward(ctx, dx_out, *fn_out_grads): + if dist.get_world_size(ctx.group) <= 1: + fn_args_grads = ctx.bwf(*fn_out_grads) + return (dx_out,) + fn_args_grads + + dx = paddle.empty_like(dx_out) + dx.stop_gradient = False + task = stream.alltoall_single( + dx, + dx_out, + None, + None, + ctx.group, + sync_op=False, + ) + fn_args_grads = ctx.bwf(*fn_out_grads) + task.wait() + return (dx,) + fn_args_grads + + +class FusedNormGateFunc(paddle.autograd.PyLayer): + + @staticmethod + def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): + ctx.dtype = paddle.float32 + norm_output, invar = fused.fused_rms_norm(x, rms_norm_weight, eps) + with paddle.amp.auto_cast(False): + gate_logits = F.linear( + cast_if_needed(norm_output, ctx.dtype), + cast_if_needed(moe_gate_weight, ctx.dtype), + ) + + ctx.save_for_backward(x, rms_norm_weight, moe_gate_weight, eps) + return gate_logits, norm_output + + @staticmethod + def backward(ctx, d_gate_logits, d_norm_output): + x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor() + norm_output, invar = fused.fused_rms_norm(x, rms_norm_weight, eps) + d_norm_output_linear, d_moe_gate_weight = matmul_bwd( + cast_if_needed(norm_output, ctx.dtype), + cast_if_needed(moe_gate_weight, ctx.dtype), + d_gate_logits, + False, + False, + ) + d_norm_output_linear, d_moe_gate_weight = cast_if_needed( + d_norm_output_linear, norm_output.dtype + ), cast_if_needed(d_moe_gate_weight, moe_gate_weight.dtype) + d_norm_output = d_norm_output + d_norm_output_linear + dx, d_rms_norm_weight = fused.fused_rms_norm_grad_func( + x, rms_norm_weight, invar, d_norm_output, eps + ) + + return dx, d_rms_norm_weight, d_moe_gate_weight + + +class MOELayer(nn.Layer): + + def __init__( + self, + gate: nn.Layer, + experts: List[nn.Layer], + layer_idx, + shared_experts: Optional[List[nn.Layer]] = None, + group: Group = None, + recompute=False, + enable_logging: bool = False, + k=2, + enable_bpr: bool = False, + all_to_all_dropout=0, + group_experts=False, + moe_statics=None, + ): + + super().__init__() + self.gate = gate + self.layer_idx = layer_idx + self.recompute = recompute + logger.info(f"using moe recompute={recompute}") + for p in self.gate.parameters(): + p.is_gate = True + if isinstance(experts, nn.LayerList): + self.experts = experts + else: + logger.info(f"using fused experts, type={type(experts)}") + self.experts = experts + self.shared_experts = shared_experts + + self.group = group + self.k = k + self.all_to_all_dropout = all_to_all_dropout + self.enable_logging = enable_logging + self.use_correction_bias = moe_statics is not None + self.moe_statics = moe_statics + if self.use_correction_bias: + logger.info( + f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" + ) + assert self.gate.config.moe_use_aux_free + + self.is_mp_moe = ( + hasattr(fleet.fleet, "_hcg") + and group is fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + self.is_ep_moe = ( + hasattr(fleet.fleet, "_hcg") + and hasattr( + fleet.get_hybrid_communicate_group(), + "get_moe_sharding_parallel_world_size", + ) + and fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_world_size() + > 0 + ) + is_dummy_moe = dist.get_world_size(group) == 1 + + for p in experts.parameters(): + p.expert = not (self.is_mp_moe or is_dummy_moe) # type: ignore + p.no_sync = not (self.is_mp_moe or is_dummy_moe) + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True + + expert_color = None + if self.is_ep_moe: + moe_grad_group = ( + fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() + ) + expert_color = {"color": "moe_expert", "group": moe_grad_group} + elif ( + self.config.offline_quant_expert_weight + and self.config.clear_origin_weight_when_offline_quant + ): + expert_color = {"color": "moe_expert"} + + if expert_color is not None: + for p in self.experts.parameters(): + setattr(p, "color", expert_color) + + self.world_size = dist.get_world_size(self.group) + # assert self.world_size > 1, f'moe-group not found, world_size {self.world_size}' + self.rank = dist.get_rank(self.group) + if self.world_size < 1: + self.world_size = 1 + if self.rank < 0: + self.rank = 0 + + self.num_local_experts = len(self.experts) + self.dispatch_by_task = ( + hasattr(self.gate, "dispatch_by_task") and self.gate.dispatch_by_task + ) + + if self.dispatch_by_task: + assert 0, "no supported, checkout earylier code" + assert self.num_local_experts == 1 + + self.input_preprocess = self.output_postprocess = None + self.group_experts = group_experts + self.config = self.gate.config + self.zero = paddle.to_tensor(0, dtype=paddle.float32) + + self._rr_moe_gate_dispatch = None + self._rr_moe_combine = None + self.use_norm_gate_recompute = None + + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_gate_dispatch", False + ): + self._rr_moe_gate_dispatch = None + if self.config.use_recompute and self.config.skip_recompute_ops.get( + "moe_combine", False + ): + self._rr_moe_combine = None + if hasattr(fleet.fleet, "_hcg"): + hcg = fleet.get_hybrid_communicate_group() + if ( + hasattr(hcg, "get_moe_sharding_parallel_world_size") + and hcg.get_moe_sharding_parallel_world_size() > 0 + ): + moe_grad_group = hcg.get_moe_sharding_parallel_group() + for p in self.experts.parameters(): + setattr( + p, "color", {"color": "moe_expert", "group": moe_grad_group} + ) + + def forward_experts(self, dispatched_input): + + with profile("fwd-expert"): + dispatched_input = dispatched_input.reshape( + [ + self.world_size, + self.num_local_experts, + -1, + dispatched_input.shape[-1], + ] + ) # [e,1,c,m] + expert_outputs = [] + if isinstance(self.experts, nn.LayerList): + + chunks = dispatched_input.transpose([1, 0, 2, 3]).contiguous().unbind(0) + assert len(chunks) == len(self.experts), ( + len(chunks), + len(self.experts), + ) + for chunk, expert in zip(chunks, self.experts): + expert_outputs += [expert(chunk)] + # logger.info( + # f"moe-fwd-expert: {chunk.shape}" + # f'-> {expert_outputs[-1].shape}: {chunk.astype("float32").norm(axis=-1)}' + # ) + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + + else: + dispatched_input = dispatched_input.transpose([1, 0, 2, 3]) + dispatched_input.contiguous() + orig_shape = dispatched_input.shape + chunks = dispatched_input.reshape([orig_shape[0], -1, orig_shape[-1]]) + chunks = self.experts(chunks) + chunks = chunks.reshape(orig_shape[:-1] + [chunks.shape[-1]]).unbind(0) + expert_outputs += chunks + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + def fused_gate_logits_process( + self, gate_logits, token_type_ids, offload_helper=None + ): + + k = self.k + experts_type_ids = self.gate.experts_type_ids + use_hard_gate = self.config.moe_use_hard_gate + max_prob = None + + if token_type_ids is not None and use_hard_gate: + if offload_helper is None: + offload_helper = dict() + lm_mask = token_type_ids == 0 + is_lm = lm_mask.any() + mm_mask = token_type_ids == 1 + is_mm = mm_mask.any() + seq_lm = lm_mask.sum() + seq_mm = mm_mask.sum() + lm_mask = lm_mask.unsqueeze(1) & (experts_type_ids == 0).unsqueeze(0) + mm_mask = mm_mask.unsqueeze(1) & (experts_type_ids == 1).unsqueeze(0) + offload_helper["lm_mask"] = [lm_mask, is_lm, seq_lm] + offload_helper["mm_mask"] = [mm_mask, is_mm, seq_mm] + + is_lm = offload_helper["lm_mask"][1] + prob = paddle.zeros_like(gate_logits) + # 处理 lm_prob + if is_lm: + lm_mask = offload_helper["lm_mask"][0] + seq_lm_cpu = offload_helper["lm_mask"][2] + lm_mask_nonzero = lm_mask.nonzero() + lm_partial_gate_logits = gate_logits.gather_nd(lm_mask_nonzero).reshape( + [seq_lm_cpu, -1] + ) + if self.group_experts: + lm_prob = self.gate.act( + lm_partial_gate_logits.reshape( + [lm_partial_gate_logits.shape[0], k, -1] + ) + ) + max_prob = lm_prob.max(-1, keepdim=True) # [s_l, k, 1] + lm_prob /= max_prob + else: + lm_prob = self.gate.act(lm_partial_gate_logits) + prob = paddle.scatter_nd_add(prob, lm_mask_nonzero, lm_prob.flatten()) + is_mm = offload_helper["mm_mask"][1] + if is_mm: + mm_mask = offload_helper["mm_mask"][0] + seq_mm_cpu = offload_helper["mm_mask"][2] + mm_mask_nonzero = paddle.nonzero(mm_mask) + mm_partial_gate_logits = gate_logits.gather_nd(mm_mask_nonzero).reshape( + [seq_mm_cpu, -1] + ) + mm_prob = self.gate.act(mm_partial_gate_logits) + prob = paddle.scatter_nd_add(prob, mm_mask_nonzero, mm_prob.flatten()) + else: + if self.group_experts: + prob = self.gate.act(gate_logits.reshape([gate_logits.shape[0], k, -1])) + max_prob = prob.max(-1, keepdim=True) + prob /= max_prob + prob = prob.reshape([prob.shape[0], -1]) + else: + prob = self.gate.act(gate_logits) + return prob, max_prob + + def gate_distpach_and_quant(self, input, token_type_ids): + + assert isinstance(self.gate, (TopKGateFused)), "Only fused gate is supported." + assert not self.config.use_ep_comm_overlap, "ep_comm_overlap is not supported" + assert ( + self._rr_moe_gate_dispatch is None + ), "rr_moe_gate_dispatch is not supported" + assert moe_ops_fp8 is not None + + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + ( + gate_logits, + capacity, + router_loss, + ) = self.gate(input, *args) + + if self.config.moe_multimodal_paired_experts: + assert token_type_ids is not None + input = paddle.concat( + [input, token_type_ids.unsqueeze(-1).astype(input.dtype)], axis=-1 + ) + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + + k = self.k + prob, max_prob = self.fused_gate_logits_process(gate_logits, token_type_ids) + + with profile("dispatch_op"): + corr_bias = ( + self.moe_statics.e_score_correction_bias[0].detach() + if self.use_correction_bias + else None + ) + + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + fp8_dispatched_handle, + ) = Fp8MoeGateDispatchAndQuant.apply( + input, prob, corr_bias, k=k, capacity=capacity, use_pad=True + ) + + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias: + if self.gate.config.multimodel_experts: + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += dispatch_mask[ + self.gate.experts_type_mask[i] + ].detach() + else: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add( + p, paddle.nonzero(token_type_ids == 0), -1 + max_prob + ) + else: + p = max_prob + combine_weights_unnorm = ( + combine_weights_unnorm.unsqueeze(-1) * p + ).squeeze(-1) + # gate_prob 进行还原 + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape([p.shape[0], -1]) + if self.gate.norm_gate_logits: + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + combine_weights = combine_weights_unnorm + combine_weights = combine_weights.cast("bfloat16") + + def reshape_for_a2a(tensor): + return tensor.reshape( + [ + self.world_size * self.num_local_experts, + capacity, + -1, + ] + ) + + dispatched_input = reshape_for_a2a(dispatched_input) + fp8_dispatched_handle["scale"] = reshape_for_a2a(fp8_dispatched_handle["scale"]) + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = True + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + fp8_dispatched_handle, + ) + + def gate_and_distpach(self, input, token_type_ids): + + seqlen, d_model = input.shape + args = () + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + args = (token_type_ids,) + + use_fuse = isinstance(self.gate, (TopKGateFused)) + if use_fuse: + if self.use_norm_gate_recompute: + ( + gate_logits, + capacity, + router_loss, + norm_res, + ) = self.fused_norm_gate(input) + input = norm_res + else: + ( + gate_logits, + capacity, + router_loss, + ) = self.gate(input, *args) + else: + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + gate_logits, + ) = self.gate( + input, + *args, + correction_bias=( + self.moe_statics.e_score_correction_bias[0] + if self.use_correction_bias + else None + ), + ) + prob = None + if self.config.moe_multimodal_paired_experts: + assert token_type_ids is not None + input = paddle.concat( + [input, token_type_ids.unsqueeze(-1).astype(input.dtype)], axis=-1 + ) + if self.input_preprocess is not None: + input, gate_logits = self.input_preprocess(input, gate_logits, capacity) + if use_fuse: + # capacity no use + k = self.k + prob, max_prob = self.fused_gate_logits_process(gate_logits, token_type_ids) + + assert moe_ops is not None + with profile("dispatch_op"): + if ( + "corr_bias" + in inspect.signature(moe_ops.moe_gate_dispatch).parameters + ): + if self.use_correction_bias: + compat_args = (self.moe_statics.e_score_correction_bias[0],) + else: + compat_args = (None,) + else: + assert ( + not self.use_correction_bias + ), "correction bias not supported, rebuild moe-ops" + compat_args = () + if not self.config.use_ep_comm_overlap: + if self._rr_moe_gate_dispatch is None: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_ops.moe_gate_dispatch( + input, + prob, + *compat_args, + k=k, + capacity=capacity, + use_pad=True, + ) + else: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = self._rr_moe_gate_dispatch( + input, + prob, + compat_args, + k=k, + capacity=capacity, + use_pad=True, + ) + else: + ( + dispatched_input, + combine_weights_unnorm, + scatter_index, + dispatch_mask, + _, + ) = moe_ops.moe_gate_dispatch_permute( + input, + prob, + *compat_args, + k=k, + capacity=capacity, + world_size=self.group.nranks, + ) + dispatch_mask = paddle.diff(F.pad(dispatch_mask, (1, 0))) + if self.use_correction_bias and framework._dygraph_tracer()._has_grad: + if self.gate.config.multimodel_experts: + for i in range(len(self.moe_statics.expert_usage)): + self.moe_statics.expert_usage[i] += dispatch_mask[ + self.gate.experts_type_mask[i] + ].detach() + else: + self.moe_statics.expert_usage[0] += dispatch_mask.detach() + dispatched_input.stop_gradient = False + combine_weights_unnorm.stop_gradient = False + scatter_index.stop_gradient = True + dispatch_mask.stop_gradient = True + + scatter_index = scatter_index.transpose([1, 0]) # [k,s] ->[s,k] + if self.group_experts: + if max_prob is not None: + if token_type_ids is not None: + p = paddle.ones_like(combine_weights_unnorm.unsqueeze(-1)) + p = paddle.scatter_nd_add( + p, paddle.nonzero(token_type_ids == 0), -1 + max_prob + ) + else: + p = max_prob + combine_weights_unnorm = ( + combine_weights_unnorm.unsqueeze(-1) * p + ).squeeze(-1) + # gate_prob 进行还原 + prob = (prob.reshape([p.shape[0], k, -1]) * p).reshape( + [p.shape[0], -1] + ) + if self.gate.norm_gate_logits: + combine_weights = combine_weights_unnorm / paddle.clip( + combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 + ) + else: + combine_weights = combine_weights_unnorm + combine_weights = combine_weights.cast(dispatched_input.dtype) + else: + dispatched_input = dispatching( + input, + dispatch_mask, + scatter_index, + num_experts=self.world_size * self.num_local_experts, + capacity=capacity, + ) + if self.use_correction_bias and framework._dygraph_tracer()._has_grad: + usage = paddle.bincount( + scatter_index.reshape([-1]) // capacity, + minlength=self.world_size * self.num_local_experts, + ) + assert ( + not self.config.multimodel_experts + ), "correction bias not supported, use top2-fused gate" + self.moe_statics.expert_usage[0] += usage.detach() + if not self.config.use_ep_comm_overlap: + dispatched_input = dispatched_input.reshape( + [ + self.world_size * self.num_local_experts, + capacity, + ( + d_model + if not self.config.moe_multimodal_paired_experts + else d_model + 1 + ), + ] + ) # .clone() + else: + assert ( + len(dispatched_input.shape) == 4 + and dispatched_input.shape[1] == self.world_size + and dispatched_input.shape[0] == self.num_local_experts + ), ( + f"When using ep_comm_overlap, moe_gate_dispatch_permute is needed. " + f"Expected dispatched_input to have shape[1] == {self.world_size} " + f"and shape[0] == {self.num_local_experts}, " + f"but got shape {dispatched_input.shape}" + ) + dispatched_input = dispatched_input # .clone() + dispatch_mask.stop_gradient = True + scatter_index.stop_gradient = True + return ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + prob, + ) + + def _calc_router_loss( + self, + dispatch_mask, + gate_logits, + gate_prob, + num_experts, + use_group, + layer_idx, + token_type=None, + tokens_type_mask=None, + dispatch_tokens_mask=None, + prefix="", + ): + log = {} + router_loss, l_aux, orthogonal_loss, zloss = 0.0, None, None, None + if self.gate.config.moe_aux_loss_lambda: + l_aux = self.gate._cal_aux_loss( + gate_prob, + dispatch_mask, + num_experts, + use_group, + tokens_type_mask, + dispatch_tokens_mask, + ) + router_loss += self.gate.moe_aux_loss_lambda[token_type or 0] * l_aux + else: + router_loss += ( + self.zero * gate_prob[0, 0] + ) # must use gate prob to avoid zero pointer + if self.gate.config.moe_orthogonal_loss_lambda: + orthogonal_loss = self.gate._cal_orthogonal_loss(token_type, use_group) + router_loss += ( + self.gate.moe_orthogonal_loss_lambda[token_type or 0] * orthogonal_loss + ) + if self.gate.config.moe_z_loss_lambda and not in_auto_parallel_align_mode(): + zloss = self.gate._cal_z_loss(gate_logits, tokens_type_mask) + router_loss += self.gate.moe_z_loss_lambda[token_type or 0] * zloss + + tracer = framework._dygraph_tracer() + if self.enable_logging and global_training_logs_enabled() and tracer._has_grad: + if l_aux is not None: + log[f"aux_loss_layer_{self.layer_idx}"] = l_aux + + if orthogonal_loss is not None: + log[f"orthogonal_loss_layer_{self.layer_idx}"] = orthogonal_loss + + if zloss is not None: + log[f"zloss_layer_{self.layer_idx}"] = zloss + + global_training_logs.update( + **log, + **{ + k.replace(f"_layer_{self.layer_idx}", ""): v for k, v in log.items() + }, + ) + global_training_logs.update( + **{ + prefix + "_" + k.replace(f"_layer_{self.layer_idx}", ""): v + for k, v in log.items() + } + ) + return router_loss + + def calc_router_loss_and_logging( + self, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + dispatch_token_type_ids=None, + offload_helper=None, + ): + + use_fuse = isinstance(self.gate, (TopKGateFused)) + if use_fuse: + assert gate_prob is not None + if token_type_ids is not None and self.gate.config.moe_use_hard_gate: + if not self.gate.weight.stop_gradient: + lm_tokens_mask = token_type_ids == 0 + if offload_helper is not None: + is_lm = offload_helper["lm_mask"][1] + else: + is_lm = lm_tokens_mask.any() + if is_lm: + dispatch_tokens_mask = ( + dispatch_token_type_ids == 0 + if dispatch_token_type_ids is not None + else None + ) + router_loss += self._calc_router_loss( + ( + dispatch_mask[self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else dispatch_mask + ), + ( + gate_logits[:, self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else gate_logits + ), + ( + gate_prob[:, self.gate.experts_type_mask[0]] + if hasattr(self.gate, "experts_type_mask") + else gate_prob + ), + ( + self.gate.num_experts_list[0] + if hasattr(self.gate, "num_experts_list") + else self.gate.num_experts_tensor + ), + self.group_experts, + self.layer_idx, + 0, + lm_tokens_mask, + dispatch_tokens_mask, + prefix="lm", + ) + mm_tokens_mask = token_type_ids == 1 + if offload_helper is not None: + is_mm = offload_helper["mm_mask"][1] + else: + is_mm = mm_tokens_mask.any() + if is_mm: + dispatch_tokens_mask = ( + dispatch_token_type_ids == 1 + if dispatch_token_type_ids is not None + else None + ) + router_loss += self._calc_router_loss( + dispatch_mask[self.gate.experts_type_mask[1]], + gate_logits[:, self.gate.experts_type_mask[1]], + gate_prob[:, self.gate.experts_type_mask[1]], + self.gate.num_experts_list[1], + False, + self.layer_idx, + 1, + mm_tokens_mask, + dispatch_tokens_mask, + prefix="mm", + ) + + else: + router_loss += self._calc_router_loss( + dispatch_mask, + gate_logits, + gate_prob, + self.gate.num_experts_tensor, + self.group_experts, + self.layer_idx, + ) + + if self.enable_logging and global_training_logs_enabled(): + seqlen = gate_logits.shape[0] + num_active = paddle.count_nonzero(combine_weights) + gate_experts_per_token = num_active.item() / seqlen + + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape([-1]) + combine_weights_type_0 = combine_weights[token_type_ids == 0] + if combine_weights_type_0.size: + gate_expert_per_token_type_0 = ( + paddle.count_nonzero(combine_weights_type_0).item() + / combine_weights_type_0.shape[0] + ) + global_training_logs.update( + experts_per_token_text=gate_expert_per_token_type_0, + ) + + combine_weights_type_1 = combine_weights[token_type_ids == 1] + if combine_weights_type_1.size: + gate_expert_per_token_type_1 = ( + paddle.count_nonzero(combine_weights_type_1).item() + / combine_weights_type_1.shape[0] + ) + global_training_logs.update( + experts_per_token_image=gate_expert_per_token_type_1, + ) + + ce = ( + (-F.softmax(gate_logits, -1) * F.log_softmax(gate_logits, -1)) + .sum(-1) + .mean(0) + ) + _log = { + f"gate_prob_ce_layer_{self.layer_idx}": ce.item(), + f"experts_per_token_layer_{self.layer_idx}": gate_experts_per_token, + } + global_training_logs.update( + **_log, + **{ + k.replace(f"_layer_{self.layer_idx}", ""): v + for k, v in _log.items() + }, + ) + else: + seqlen = dispatch_mask.shape[0] + dispatch_mask = dispatch_mask.unbind(-1) + top1_gate_experts_per_token = ( + paddle.cast(dispatch_mask[0], dtype="float32").sum() / seqlen + ) + if ( + self.enable_logging + and global_training_logs_enabled() + and len(dispatch_mask) == 2 + ): + top2_gate_experts_per_token = ( + paddle.cast(dispatch_mask[1], dtype="float32").sum() / seqlen + ) + leakage_experts_per_token = ( + paddle.cast( + (~dispatch_mask[0]) & (~dispatch_mask[1]), dtype="float32" + ).sum() + / seqlen + ) + experts_per_token = ( + top1_gate_experts_per_token + top2_gate_experts_per_token + ) + global_training_logs.update( + experts_per_token=experts_per_token.detach(), + top1_experts_per_token=top1_gate_experts_per_token.detach(), + top2_experts_per_token=top2_gate_experts_per_token.detach(), + leakage_experts_per_token=leakage_experts_per_token.detach(), + ) + elif ( + self.enable_logging + and global_training_logs_enabled() + and len(dispatch_mask) == 1 + ): + experts_per_token = top1_gate_experts_per_token + leakage_experts_per_token = ( + paddle.cast(~dispatch_mask[0], dtype="float32").sum() / seqlen + ) + global_training_logs.update( + experts_per_token=experts_per_token.detach(), + top1_experts_per_token=top1_gate_experts_per_token.detach(), + leakage_experts_per_token=leakage_experts_per_token.detach(), + ) + + return router_loss + + def combine_expert_output(self, expert_output, combine_weights, scatter_index): + + expert_output = expert_output.reshape([-1, expert_output.shape[-1]]) + use_fuse = isinstance(self.gate, (TopKGateFused)) + combine_fn = combining_fused if use_fuse else combining + combined_output = combine_fn(expert_output, combine_weights, scatter_index) + + if self.output_postprocess is not None: + combined_output = self.output_postprocess(combined_output) + return combined_output + + def forward_single_stage(self, dispatched_input, stage_id): + assert isinstance(self.experts, nn.LayerList) + return self.experts[stage_id](dispatched_input) + + def forward( + self, + input: Tensor, + token_type_ids=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + + if input.ndim == 3: + orig_shape = input.shape + input = input.reshape([-1, input.shape[-1]]) + else: + orig_shape = None + assert ( + len(input.shape) == 2 + ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" + hidden_size = input.shape[1] + if token_type_ids is not None: + token_type_ids = token_type_ids.clone()[:, :-1] + if self.config.sequence_parallel: + token_type_ids = token_type_ids.reshape([-1]) + token_type_ids = ScatterOp.apply(token_type_ids) + token_type_ids.stop_gradient = True + + assert self.gate is not None + if hasattr(self, "rng") and self.rng.random() < self.all_to_all_dropout: + orig_shape_2 = input.shape + if self.config.moe_multimodal_paired_experts: + assert token_type_ids is not None + input = paddle.concat( + [input, token_type_ids.unsqueeze(-1).astype(input.dtype)], axis=-1 + ) + output = self.forward_experts(input) + output += self.gate.weight.sum() * 0.0 # hack for grad + output = output.reshape(orig_shape or orig_shape_2) # [e*1,c,m] + return output, None, 0 + + is_first_fwd = not framework._dygraph_tracer()._has_grad + use_async = self.shared_experts is not None + if in_auto_parallel_align_mode(): + gate_input = paddle.assign(input) + else: + gate_input = input + + use_fp8_fuse_node = ( + self.config.use_combine_before_a2a and self.config.use_fp8_fuse_node + ) + use_fp8_dispatch_a2a = self.config.use_fp8_dispatch_a2a and use_fp8_fuse_node + + with profile("fused_gate_and_dispatch"): + fp8_dispatched_handle = None + if use_fp8_dispatch_a2a: + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + fp8_dispatched_handle, + ) = self.gate_distpach_and_quant(gate_input, token_type_ids) + else: + ( + dispatched_input, + combine_weights, + dispatch_mask, + scatter_index, + router_loss, + gate_logits, + gate_prob, + ) = self.gate_and_distpach(gate_input, token_type_ids) + + # TODO(shenliang03): to fuse one kernel to optimize + if self.config.use_combine_before_a2a: + assert ( + not self.config.use_ep_comm_overlap + ), "Dont support use_ep_comm_overlap" + assert ( + moe_combine_no_weight is not None + ), "use_combine_before_a2a can only use with moe_combine_no_weight op, please install it first." + cw_shape = combine_weights.shape + si_shape = scatter_index.shape + scatter_index = scatter_index.reshape([-1]) + + token_combine_weights = paddle.zeros( + [cw_shape[0] * cw_shape[1]], dtype=combine_weights.dtype + ) + token_combine_weights = paddle.scatter( + token_combine_weights, + scatter_index, + combine_weights.reshape([-1]), + overwrite=False, + ) + + token_combine_weights = token_combine_weights.reshape( + [cw_shape[0], cw_shape[1], 1] + ) + token_combine_weights = AlltoAll.apply(token_combine_weights, self.group) + + if not self.config.use_ep_comm_overlap: + if use_fp8_dispatch_a2a: + shared_out = ( + self.shared_experts(input) + if self.shared_experts is not None + else None + ) + else: + with profile("moe_comm_and_shared_expert"): + if use_async: + dispatched_input, shared_out = AlltoAllAsync.apply( + dispatched_input, + input, + group=self.group, + fn=self.shared_experts, + is_first_fwd=is_first_fwd, + ) + else: + dispatched_input = AlltoAll.apply(dispatched_input, self.group) + + expert_out = ( + recompute(self.forward_experts, dispatched_input) + if self.recompute and self.training + else self.forward_experts(dispatched_input) + ) + + if self.config.use_combine_before_a2a: + token_combine_weights = token_combine_weights.clone().reshape( + expert_out.shape[:-1] + [1] + ) + expert_out = expert_out * token_combine_weights + else: + assert ( + len(dispatched_input.shape) == 4 + and dispatched_input.shape[1] == self.world_size + and dispatched_input.shape[0] == self.num_local_experts + ), ( + f"When using ep_comm_overlap, moe_gate_dispatch_permute is needed. " + f"Expected dispatched_input to have shape[1] == {self.world_size} " + f"and shape[0] == {self.num_local_experts}, " + f"but got shape {dispatched_input.shape}" + ) + with profile("moe_comm_and_forward_expert"): + expert_out = AlltoAllExpertOverlap.apply( + dispatched_input, + self.group, + self.num_local_experts, + self.experts, + is_first_fwd=is_first_fwd, + ) + if self.shared_experts is not None: + shared_out = self.shared_experts(input) + + with profile("moe_comm_and_calc_routerloss"): + expert_out, router_loss2 = AlltoAllAsync.apply( + expert_out, + router_loss, + combine_weights, + dispatch_mask, + gate_logits, + gate_prob, + token_type_ids, + group=self.group, + fn=self.calc_router_loss_and_logging, + is_first_fwd=is_first_fwd, + ) + + with profile("combine"): + if self.config.use_combine_before_a2a: + expert_out = expert_out.reshape([-1, hidden_size]) + + scatter_index = scatter_index.reshape(si_shape) + combined_output = moe_combine_no_weight( + expert_out, combine_weights, scatter_index, epsilon=1e-15 + ) + else: + combined_output = self.combine_expert_output( + expert_out, combine_weights, scatter_index + ) + + if self.shared_experts is not None: + combined_output += shared_out + + if orig_shape: + combined_output = combined_output.clone().reshape( + orig_shape[:-1] + [combined_output.shape[-1]] + ) + return combined_output, combine_weights, router_loss2, gate_logits diff --git a/examples/pre-training/models/moe/moe_utils.py b/examples/pre-training/models/moe/moe_utils.py new file mode 100644 index 000000000..cd797ab45 --- /dev/null +++ b/examples/pre-training/models/moe/moe_utils.py @@ -0,0 +1,229 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" moe utils for allgather dispatcher """ +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +import paddle.nn.functional as F +from paddle import nn +from paddle.autograd import PyLayer + +from models.sequence_parallel_utils import ( + AllGatherOp, + ReduceScatterOp, +) + + +class MOEGather(PyLayer): + """ + MOE Gather + """ + + @staticmethod + def forward(ctx, input_, map_): + """ + MOE Gather forward + """ + ctx.input_shape = input_.shape + ctx.map = map_ + return paddle.take_along_axis(input_, map_, 0) + + @staticmethod + def backward(ctx, grad_output): + """ + MOE Gather backward + """ + input_shape = ctx.input_shape + map_ = ctx.map + + output = paddle.zeros(input_shape, dtype=grad_output.dtype) + return paddle.put_along_axis(output, map_, grad_output, 0), None + + +class MOEScatter(PyLayer): + """ + MOE Scatter + """ + + @staticmethod + def forward(ctx, input_, map_, output_size=None): + """ + MOE Scatter forward + """ + ctx.map = map_ + + if output_size is not None: + output = paddle.zeros(output_size, dtype=input_.dtype) + else: + output = paddle.zeros_like(input_) + + return paddle.put_along_axis(output, map_, input_, 0) + + @staticmethod + def backward(ctx, grad_output): + """ + MOE Scatter backward + """ + map_ = ctx.map + return paddle.take_along_axis(grad_output, map_, 0), None + + +class AllgatherDispatcherReturn(object): + """ + MOE allgather dispatcher return value + """ + + def __init__( + self, + global_hidden_states, + dispatched_input, + combine_weights, + scatter_index, + gather_scatter_mask, + dispatch_mask, + tokens_per_expert, + ): + self.global_hidden_states = global_hidden_states + self.dispatched_input = dispatched_input + self.combine_weights = combine_weights + self.scatter_index = scatter_index + self.gather_scatter_mask = gather_scatter_mask + self.dispatch_mask = dispatch_mask + self.tokens_per_expert = tokens_per_expert + + +class MOEAllGatherDispatcher(nn.Layer): + """ + MOE with allgather dispatcher. + Contains two static methos. + MOEAllGatherDispatcher.token_dispatcher + MOEAllGatherDispatcher.token_combine + """ + + @staticmethod + def token_dispatcher( + hidden_states, + local_gate_logits, + top_k, + local_expert_indices, + num_moe_experts, + num_local_experts, + ): + """ + MOE token dispatcher with allgather + """ + seq_len = local_gate_logits.shape[0] + num_experts = local_gate_logits.shape[-1] + prob = F.softmax(local_gate_logits.reshape([seq_len, top_k, -1]), axis=-1) + max_prob = prob.max(-1, keepdim=True) + prob /= max_prob + prob = prob.reshape([-1, num_experts]) + + probs, scatter_index = paddle.topk(prob, top_k, axis=-1) + dispatch_mask = paddle.cumsum( + paddle.histogram(scatter_index.flatten(), bins=num_experts) + ) + + # dispatch + with paddle.no_grad(): + global_indices = AllGatherOp.apply(scatter_index) + global_local_mask = (global_indices >= local_expert_indices[0]) & ( + global_indices <= local_expert_indices[-1] + ) + local_indices = global_indices.masked_select(global_local_mask) + + global_hidden_states = AllGatherOp.apply(hidden_states) + global_probs = AllGatherOp.apply(probs) + + # get local hidden states + combine_weights = global_probs.masked_select(global_local_mask).cast( + dtype=hidden_states.dtype + ) + gather_scatter_mask = global_local_mask.nonzero()[:, 0] + gather_scatter_mask = paddle.reshape(gather_scatter_mask, shape=[-1, 1]) + gather_scatter_mask = paddle.expand( + gather_scatter_mask, shape=[-1, hidden_states.shape[-1]] + ) + local_hidden_states = MOEGather.apply(global_hidden_states, gather_scatter_mask) + + with paddle.no_grad(): + # The indices of local_indices that give its sorted order along dim 0. + scatter_index = paddle.argsort(local_indices, axis=0) + tokens_per_expert = paddle.bincount( + paddle.reshape(local_indices, [-1]), minlength=num_moe_experts + ) + if num_local_experts < num_moe_experts: + start = local_expert_indices[0] + end = local_expert_indices[-1] + 1 + tokens_per_expert = tokens_per_expert[start:end] + + scatter_index = paddle.reshape(scatter_index, shape=[-1, 1]) + scatter_index = paddle.expand( + scatter_index, shape=[-1, hidden_states.shape[-1]] + ) + + dispatched_input = MOEGather.apply(local_hidden_states, scatter_index) + + return AllgatherDispatcherReturn( + global_hidden_states, + dispatched_input, + combine_weights, + scatter_index, + gather_scatter_mask, + dispatch_mask, + tokens_per_expert, + ) + + @staticmethod + def token_combine( + expert_out, + shared_out, + combine_weights, + scatter_index, + gather_scatter_mask, + global_shape, + ): + """ + MOE token combine with reduce scatter + """ + expert_out = MOEScatter.apply(expert_out, scatter_index) + expert_out = expert_out * paddle.reshape(combine_weights, shape=[-1, 1]) + expert_out = MOEScatter.apply(expert_out, gather_scatter_mask, global_shape) + combine_out = expert_out + shared_out + combine_out = ReduceScatterOp.apply(combine_out) + return combine_out + + +def get_flatten_mesh(mesh): + + return dist.ProcessMesh(mesh.process_ids) + + +def get_mesh(pp_idx=0): + + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + +def _reshard(tensor, mesh, placements): + + dst_tensor = dist.auto_parallel.moe_utils._dist_reshape( + tensor, tensor.shape, mesh, placements + ) + return dst_tensor diff --git a/examples/pre-training/models/moe/moe_utils_auto.py b/examples/pre-training/models/moe/moe_utils_auto.py new file mode 100644 index 000000000..fbaba34fd --- /dev/null +++ b/examples/pre-training/models/moe/moe_utils_auto.py @@ -0,0 +1,40 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle.distributed as dist +from paddle.distributed import fleet + + +def get_flatten_mesh(mesh): + + return dist.ProcessMesh(mesh.process_ids) + + +def get_mesh(pp_idx=0): + + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + +def _reshard(tensor, mesh, placements): + + dst_tensor = dist.auto_parallel.moe_utils._dist_reshape( + tensor, tensor.shape, mesh, placements + ) + return dst_tensor diff --git a/examples/pre-training/models/moe/top2_gate_auto.py b/examples/pre-training/models/moe/top2_gate_auto.py new file mode 100644 index 000000000..a8aee34d3 --- /dev/null +++ b/examples/pre-training/models/moe/top2_gate_auto.py @@ -0,0 +1,76 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +top2gate +""" + + +from typing import Tuple +import logging +from paddle import Tensor +import paddle.distributed as dist + + +logger = logging.getLogger(__name__) + +from models.moe.top2_gate_auto_auto import TopKGateFused +from models.moe.moe_utils_auto import get_mesh, get_flatten_mesh + + +class TopKGateFusedAuto(TopKGateFused): + """doc""" + + def __init__(self, config, layer_idx: int, group, gate_weight=None, ipp=0) -> None: + super().__init__(config, layer_idx, group, gate_weight) + self.ipp = ipp + self.weight = dist.shard_tensor( + self.weight, get_flatten_mesh(get_mesh(self.ipp)), [dist.Replicate()] + ) + + def forward( + self, + input: Tensor, + token_type_ids=None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # type: ignore + """ + Args: + input: paddle.Tensor, hidden-states of layer + Retruns: + paddle.Tensor [Seq, Expert, Capacity]: float32, combine weights + paddle.Tensor [Seq, Expert, Capacity]: bool, dispatch mask + Tuple[paddle.Tensor]: `GateOutput` + """ + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + if self.training: + cap = self.cap[0] + elif input.shape[0] < num_experts: # seqlen < num_expert + cap = self.cap[2] + else: + cap = self.cap[1] + num_tokens = input.shape[0] + # capacity = 2S/E + global_capacity = int(cap * num_tokens // num_experts) + local_num_tokens = input._local_shape[0] + local_capacity = int(cap * local_num_tokens // num_experts) + + logits, _, router_loss = super().forward(input, token_type_ids) + + return logits, global_capacity, router_loss, local_capacity diff --git a/examples/pre-training/models/moe/top2_gate_auto_auto.py b/examples/pre-training/models/moe/top2_gate_auto_auto.py new file mode 100644 index 000000000..6ce094d27 --- /dev/null +++ b/examples/pre-training/models/moe/top2_gate_auto_auto.py @@ -0,0 +1,1036 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple +from functools import partial +import logging +import numpy as np +import paddle +from paddle import Tensor +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.utils import unique_name +from paddle.nn.clip import _squared_l2_norm +from paddle.distributed import fleet +from paddleformers.utils.tools import get_env_device +from models.utils import global_training_logs_enabled + +try: + from src.utils.misc import global_training_logs +except ModuleNotFoundError: + global_training_logs = {} +try: + import moe_router_loss_ops +except ImportError: + moe_router_loss_ops = None + +try: + from custom_setup_ops import matmul_bwd +except ImportError: + matmul_bwd = None + +try: + from bincount_ops import int_bincount +except ImportError: + int_bincount = None + +logger = logging.getLogger(__name__) + + +class CalOrthogonalLossOptEachWeightFunctor(paddle.autograd.PyLayer): + + @staticmethod + def forward(ctx, gate_weight, moe_k, use_group, eps=1e-12): + if gate_weight.dtype != paddle.float32: + gate_weight = gate_weight.astype(paddle.float32) + ( + orthogonal_loss, + wnorm, + weight_scale, + normed_weight, + weight_matmul, + ) = moe_router_loss_ops.cal_orthogonal_loss_opt_each_weight( + gate_weight, moe_k, use_group, eps + ) + ctx.save_for_backward( + gate_weight, wnorm, weight_scale, normed_weight, weight_matmul + ) + ctx.moe_k = moe_k + ctx.use_group = use_group + ctx.eps = eps + return orthogonal_loss + + @staticmethod + def backward(ctx, out_grad): + gate_weight, wnorm, weight_scale, normed_weight, weight_matmul = ( + ctx.saved_tensor() + ) + if gate_weight.stop_gradient: + return None + moe_k = ctx.moe_k + use_group = ctx.use_group + eps = ctx.eps + return moe_router_loss_ops.cal_orthogonal_loss_opt_each_weight_grad( + out_grad, + wnorm, + weight_scale, + normed_weight, + weight_matmul, + moe_k, + use_group, + eps, + ) + + +class CalZLossFunctor(paddle.autograd.PyLayer): + + @staticmethod + def forward(ctx, logits, loss_mask=None, clip_min=1e-6): + if loss_mask is not None: + assert loss_mask.stop_gradient + loss, max_logits, safe_sumexp, logsumexp_per_token = ( + moe_router_loss_ops.cal_z_loss(logits, loss_mask, clip_min) + ) + ctx.save_for_backward( + logits, loss_mask, max_logits, safe_sumexp, logsumexp_per_token + ) + ctx.clip_min = clip_min + return loss + + @staticmethod + def backward(ctx, out_grad): + logits, loss_mask, max_logits, safe_sumexp, logsumexp_per_token = ( + ctx.saved_tensor() + ) + if logits.stop_gradient: + return None + clip_min = ctx.clip_min + return moe_router_loss_ops.cal_z_loss_grad( + out_grad, + logits, + loss_mask, + max_logits, + safe_sumexp, + logsumexp_per_token, + clip_min, + ) + + +class CalAuxLossFunctor(paddle.autograd.PyLayer): + + @staticmethod + def forward( + ctx, + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min=1e-6, + ): + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + loss, seqlen_float, ce = paddle.incubate.nn.functional.cal_aux_loss( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + clip_min, + ) + ctx.save_for_backward(gate_prob, seqlen_float, ce) + ctx.num_experts = num_experts + ctx.use_group = use_group + ctx.moe_k = moe_k + return loss + + @staticmethod + def backward(ctx, out_grad): + gate_prob, seqlen_float, ce = ctx.saved_tensor() + num_experts = ctx.num_experts + use_group = ctx.use_group + moe_k = ctx.moe_k + return paddle.incubate.nn.functional.cal_aux_loss_grad( + out_grad, gate_prob, seqlen_float, ce, num_experts, use_group, moe_k + ) + + +def cal_orthogonal_loss_opt_each_weight_func( + weight, moe_k, use_group, eps, xpu_matmul=None, training=True +): + weight = weight.transpose([1, 0]).contiguous() # transpose weight here + wnorm = weight.norm(axis=1) + weight = weight / paddle.maximum(wnorm, eps).unsqueeze(1) + + if use_group: + weight = weight.reshape([moe_k, -1, weight.shape[1]]) # [K, E/K, H] + eye_matrix = paddle.eye(weight.shape[1], dtype=weight.dtype).unsqueeze(0) + else: + eye_matrix = paddle.eye(weight.shape[0], dtype=weight.dtype) + + if get_env_device() == "xpu" and xpu_matmul is not None: + weight_matmul = xpu_matmul(weight, weight, transpose_y=True, training=training) + else: + weight_matmul = paddle.matmul(weight, weight, transpose_y=True) + + orthogonal_loss = weight_matmul - eye_matrix + orthogonal_loss = _squared_l2_norm(orthogonal_loss) / orthogonal_loss.size + return orthogonal_loss + + +def cal_z_loss_func(logits, loss_mask): + if loss_mask is not None: + loss_mask = loss_mask.astype(logits.dtype) + l_zloss = (logits.logsumexp(1).square() * loss_mask).sum() / paddle.clip( + loss_mask.sum(), min=1e-6 + ) + else: + l_zloss = logits.logsumexp(1).square().mean() + return l_zloss + + +def cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + moe_k, + global_aux_loss=False, + rank=None, + group=None, +): + if tokens_mask is not None and tokens_mask.dtype != gate_prob.dtype: + tokens_mask = tokens_mask.astype(gate_prob.dtype) + + scale = None + if dispatch_tokens_mask is not None: + seqlen_float = dispatch_tokens_mask.astype(gate_prob.dtype).sum() + if ( + tokens_mask is not None + and gate_prob.shape[0] != dispatch_tokens_mask.shape[0] + ): + scale = seqlen_float / paddle.clip(tokens_mask.sum(), min=1e-6) + elif tokens_mask is not None: + seqlen_float = tokens_mask.sum() + else: + seqlen_float = gate_prob.numel().astype(gate_prob.dtype) / num_experts + seqlen_float = paddle.clip(seqlen_float, min=1e-6) + + if len(dispatch_mask.shape) == 2: + dispatch_mask = dispatch_mask.sum(0) + ce = dispatch_mask.astype(gate_prob.dtype).detach() / seqlen_float + me = paddle.sum(gate_prob, axis=0) / seqlen_float + # me = paddle.mean(gate_prob, axis=0) + # ce = paddle.mean(dispatch_mask.cast("float32"), axis=0) + if global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=group) + dist.all_gather(ce_list, ce, group=group) + + me_list[rank] = me + ce_list[rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + + l_aux = paddle.sum(me * ce) * num_experts + if use_group: + l_aux = l_aux / moe_k + + if scale is not None: + l_aux = l_aux + (scale - 1) * l_aux.detach() + + return l_aux + + +def masked_fill(x, mask, value): + + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + +@paddle.no_grad() +def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): + + n, _ = M.shape + P = F.softmax(-M / lam) + u = paddle.zeros(n, "float32") + for _ in range(max_iters): + if (u - P.sum(1)).abs().max() < epsilon: + break + u = P.sum(1) + P *= (r / (u + 1e-8)).reshape((-1, 1)) + P *= (c / (P.sum(0) + 1e-8)).reshape((1, -1)) + P = paddle.where(~P.isnan(), P, paddle.zeros_like(P)) + return P, _ + + +def cast_if_needed(x, dtype): + + return x.cast(dtype) if x.dtype != dtype else x + + +class FusedGateDetachMatmul(paddle.autograd.PyLayer): + + @staticmethod + def forward(ctx, x, w): + + ctx.dtype = paddle.float32 + ctx.save_for_backward(x, w) + return F.linear(cast_if_needed(x, ctx.dtype), cast_if_needed(w, ctx.dtype)) + + @staticmethod + def backward(ctx, y_grad): + + x, w = ctx.saved_tensor() + assert ctx.dtype == y_grad.dtype, "dtype not match" + x_g, w_g = matmul_bwd( + cast_if_needed(x, ctx.dtype), + cast_if_needed(w, ctx.dtype), + y_grad, + False, + False, + ) + return cast_if_needed(x_g, x.dtype), cast_if_needed(w_g, w.dtype) + + +def gate_detach_matmul(x, weight, use_fuse, use_fake_gate=False): + + if use_fuse: + score = FusedGateDetachMatmul.apply(x, weight) + else: + x = cast_if_needed(x, paddle.float32) + score = F.linear(x, weight) + + if use_fake_gate: + score = paddle.randn(score.shape).astype(score.dtype) + score - score + return score + + +class Top2Gate(nn.Layer): + + def __init__(self, config, layer_idx: int, group, gate_weight=None) -> None: + + super().__init__() + if get_env_device() == "xpu": + try: + from paddle_xpu.layers.nn import xpu_matmul + + self.xpu_matmul = xpu_matmul() + except ImportError: + self.xpu_matmul = None + else: + self.xpu_matmul = None + self.config = config + self.fuse_gate_detach_matmul = config.fuse_gate_detach_matmul + if self.fuse_gate_detach_matmul: + assert matmul_bwd is not None, "matmul_bwd is not supported" + + self.use_fake_gate = config.use_fake_gate + if self.use_fake_gate: + logging.warning( + "You are use fake_gate, which is just for test, not for real training." + ) + + self.model_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.num_experts_tensor = ( + sum(config.moe_num_experts) + if config.multimodel_experts + else config.moe_num_experts + ) # paddle.to_tensor(config.moe_num_experts, dtype="float32").sum() + + self.cap = config.moe_capacity + self.group = group + + self.layer_idx = layer_idx + self.global_aux_loss = config.global_aux_loss + if self.global_aux_loss: + self.rank = dist.get_rank(self.group) + + self.sinkhorn_2gate = config.sinkhorn_2gate + self.sinkhorn_temp = config.sinkhorn_temp + self.use_token_type_bias = config.moe_use_token_type_bias + self.use_correction_bias = config.moe_use_aux_free + + if config.moe_gate_act == "softmax": + self.act = partial(F.softmax, axis=-1) # [S,E] + elif config.moe_gate_act == "sigmoid": + self.act = F.sigmoid + else: + raise ValueError(f"{config.moe_gate_act} is not supported.") + self.no_jitter = True + self.expert_drop = False + self.eye_matrix = None + self.eye_matrix_size = None + self.enable_logging = config.moe_logging + self.norm_gate_logits = config.moe_norm_gate_logits + self.one = paddle.ones([], dtype="float32") + + self.moe_aux_loss_lambda = paddle.to_tensor( + config.moe_aux_loss_lambda, dtype="float32" + ) + self.moe_z_loss_lambda = paddle.to_tensor( + config.moe_z_loss_lambda, dtype="float32" + ) + self.moe_orthogonal_loss_lambda = paddle.to_tensor( + config.moe_orthogonal_loss_lambda, dtype="float32" + ) + if self.moe_aux_loss_lambda.ndim == 0: + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) + if self.moe_z_loss_lambda.ndim == 0: + self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) + if self.moe_orthogonal_loss_lambda.ndim == 0: + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze( + 0 + ) + + self.experts_type_ids = None + if config.moe_orthogonal_loss_lambda: + if hasattr(fleet.fleet, "_user_defined_strategy"): + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs["sharding_configs"] + pp_config = strategy.hybrid_configs["pp_configs"] + assert ( + not sharding_configs.comm_overlap + and not pp_config.sharding_comm_overlap + ), "orthogonal loss will cause twice gradient accumulate, will break pp/sharding overlap" + + self.eps = paddle.to_tensor([1e-12], dtype="float32") + if config.multimodel_experts: + if config.moe_use_hard_gate: + self.num_experts_list = [] + self.experts_type_mask = [] + experts_ids = paddle.zeros( + [sum(self.num_experts)], dtype="int64" + ).reshape([config.moe_world_size, -1]) + offset = 0 + for i, expert_num in enumerate(self.num_experts): + experts_ids[ + :, offset : offset + expert_num // config.moe_world_size + ] = i + offset += expert_num // config.moe_world_size + self.experts_type_ids = experts_ids.reshape([-1]) + logger.info( + f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" + ) + for i, expert_num in enumerate(self.num_experts): + self.experts_type_mask.append( + self.experts_type_ids == i, + ) + self.num_experts_list.append(expert_num) + else: + assert ( + not config.moe_group_experts + ), "group_experts must use hard_gate when multimodel_experts is True" + else: + self.num_experts_list = [self.num_experts] + if gate_weight is not None: + self.weight = gate_weight + assert ( + not self.config.moe_use_token_type_bias + ), "gate_weights is from outside, token_type_bias can't be used" + logger.info("moe use gate_weight from outside") + self._cast_to_low_precision = False + self._cast_to_low_precison = False + else: + self._create_gate_parameter() + logger.info( + f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " + f"use_token_type_bias:{self.use_token_type_bias} gate_act:{config.moe_gate_act} " + f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" + ) + + def _create_gate_parameter(self): + + if self.config.multimodel_experts: + # support setting lambda for each expert group + self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( + len(self.num_experts) + ) + self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand( + len(self.num_experts) + ) + + for i, num_experts in enumerate(self.num_experts): + if i == 1: + with paddle.utils.unique_name.guard(f"mm_gate_{self.layer_idx}_"): + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate") + ), + ) + else: + p = self.create_parameter( + shape=[self.model_dim, num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + p.expert_type = f"expert_type_{i}" + self.add_parameter( + ("weight" if i == 0 else f"weight_{i}"), + p, + ) + else: + self.weight = self.create_parameter( + shape=[self.model_dim, self.num_experts], + dtype="float32", + attr=paddle.ParamAttr(name=unique_name.generate("moe_gate")), + ) + logger.info(f"moe-Gate, {self.weight}") + + if self.use_token_type_bias: + if self.config.multimodel_experts: + assert ( + not self.config.moe_use_hard_gate + ), "multimodel_experts with hard_gate is not support token_type_bias." + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + bias_type_num = ( + len(self.num_experts) if self.config.multimodel_experts else 1 + ) + self.bias = self.create_parameter( + shape=[bias_type_num, num_experts], + dtype="float32", + attr=paddle.ParamAttr( + name=unique_name.generate("moe_gate_bias"), + initializer=paddle.nn.initializer.Assign( + np.zeros([bias_type_num, num_experts]) + ), + ), + ) + logger.info(f"using token type bias, bias: {self.bias},") + self._cast_to_low_precision = False + self._cast_to_low_precison = False + + def get_gate_weight(self, transform_weight): + if not self.config.multimodel_experts: + return self.weight + if not transform_weight: + return paddle.concat( + [ + getattr(self, "weight" if i == 0 else f"weight_{i}") + for i in range(len(self.num_experts)) + ], + -1, + ) + weight = paddle.zeros( + [ + self.model_dim, + self.config.moe_world_size, + sum(self.num_experts) // self.config.moe_world_size, + ], + dtype="float32", + ) + offset = 0 + for i, num_experts in enumerate(self.num_experts): + weight[ + :, :, offset : offset + num_experts // self.config.moe_world_size + ] = getattr(self, "weight" if i == 0 else f"weight_{i}").reshape( + [self.model_dim, self.config.moe_world_size, -1] + ) + offset += num_experts // self.config.moe_world_size + weight = weight.reshape([self.model_dim, -1]) + + return weight + + def forward( + self, + input: Tensor, + token_type_ids: Tensor = None, + transform_weight: bool = True, # [seq] + correction_bias: Tensor = None, # [seq] + ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + + orig_dtype = input.dtype + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + if get_env_device() == "xpu" and self.xpu_matmul is not None: + assert not self.fuse_gate_detach_matmul, "not supported on XPU" + input_32 = input.cast("float32") + logits = self.xpu_matmul( + input_32, + weight, + training=self.training, + ) + else: + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul, self.use_fake_gate + ) + + if self.use_token_type_bias: + assert token_type_ids is not None + bias = self.bias[token_type_ids] # [seq] + # logger.info(f"adding bias: {bias}") + logits = logits + bias + ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + l_aux, + l_zloss, + ) = self.top2_gating(logits, correction_bias=correction_bias) + orthogonal_loss = self._cal_orthogonal_loss() + router_loss = ( + l_aux * self.moe_aux_loss_lambda + + l_zloss * self.moe_z_loss_lambda + + orthogonal_loss * self.moe_orthogonal_loss_lambda + ) + router_loss.stop_gradient = False + if self.enable_logging and global_training_logs_enabled(): + _log = { + f"aux_loss_layer_{self.layer_idx}": l_aux.item(), + f"orthogonal_loss_layer_{self.layer_idx}": orthogonal_loss.item(), + f"zloss_layer_{self.layer_idx}": l_zloss.item(), + } + global_training_logs.update( + **_log, + **{ + k.replace(f"_layer_{self.layer_idx}", ""): v + for k, v in _log.items() + }, + ) + if self.use_token_type_bias: + _bias_log = { + f"token_type_bias_layer_{self.layer_idx}_expert{i}_gap": v + for i, v in enumerate((self.bias[0] - self.bias[1]).numpy()) + } + global_training_logs.update(**_bias_log) + + combine_weights = combine_weights.cast(orig_dtype) + return ( + capacity, + dispatch_mask, + combine_weights, + scatter_index, + router_loss, + logits, + ) + + def get_capacity(self, num_tokens, cap_factor=None): + + num_experts = ( + sum(self.num_experts) + if self.config.multimodel_experts + else self.num_experts + ) + if cap_factor is not None: + cap = cap_factor + else: + if self.training: + cap = self.cap[0] + elif num_tokens < num_experts: # seqlen < num_expert + cap = self.cap[2] + else: + cap = self.cap[1] + # capacity = 2S/E + capacity = int(cap * num_tokens // num_experts) + assert ( + capacity > 0 + ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" + return capacity + + def top2_gating(self, logits, cap=None, correction_bias=None): + + # logger.info(f'gate-input: {logits}') + l_zloss = self._cal_z_loss(logits) + gates = self.act(logits) + + # gates has shape of SE + assert logits.ndim == 2, logits.shape + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + # capacity = 2S/E + capacity = self.get_capacity(logits.shape[0], cap) + + # Create a mask for 1st's expert per token + score_for_argmax = ( + gates + correction_bias.unsqueeze(0) + if correction_bias is not None + else gates + ) + indices1_s = paddle.argmax(score_for_argmax, axis=1) + mask1 = F.one_hot(indices1_s, num_classes=num_experts).cast( + paddle.int64 + ) # [0,1] + + l_aux = self._cal_aux_loss(gates, mask1.sum(axis=0), self.num_experts_tensor) + + if self.training and not self.no_jitter: + gumbels = ( + -paddle.empty_like( + logits, + ) + .exponential_() + .log() + ) # ~Gumbel(0,1) + logits_w_noise = logits + gumbels + else: + logits_w_noise = logits + + logits_except1 = masked_fill( + logits_w_noise, mask1.cast(paddle.bool), float("-inf") + ) + score_for_argmax = ( + self.act(logits_except1) + correction_bias.unsqueeze(0) + if correction_bias is not None + else logits_except1 + ) + indices2_s_original = paddle.argmax(score_for_argmax, axis=1) + + if self.training and self.sinkhorn_2gate: + r = paddle.ones(num_tokens, "float32") / num_tokens + + c = capacity - mask1.cast("float32").sum(0) + c = paddle.maximum(c, paddle.zeros_like(c)) + c /= c.sum() + + pi, _ = compute_optimal_transport( + -logits_except1.cast("float32").detach(), r, c, lam=self.sinkhorn_temp + ) + pi = masked_fill(pi, mask1.cast(paddle.bool), float("-inf")) + indices2_s = paddle.argmax(pi, axis=1) + else: + indices2_s = indices2_s_original + + if self.enable_logging and global_training_logs_enabled(): + global_training_logs.update( + **{ + "redispatch_acc": (indices2_s_original == indices2_s) + .cast(paddle.float32) + .mean() + .item(), + f"redispatch_acc_layer_{self.layer_idx}": ( + indices2_s_original == indices2_s + ) + .cast(paddle.float32) + .mean() + .item(), + } + ) + + mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).cast(paddle.int64) + + # Compute locations in capacity buffer + locations1 = ( + paddle.cumsum(mask1, axis=0) - 1 + ) # [0,1,1,0,1,0,0] -> [0,0,0,0,1,1,1,] + locations2 = paddle.cumsum(mask2, axis=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + # Remove locations outside capacity from mask + mask1 *= (locations1 < capacity).cast(paddle.int64) # [0,1,1,0,0,0,0] + mask2 *= (locations2 < capacity).cast(paddle.int64) + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = (gates * mask1_float).sum(axis=-1) + gates2_s = (gates * mask2_float).sum(axis=-1) + # logger.info(f'gates1_s:{gates1_s} gates2_s:{gates2_s} logits:{logits}') + + if self.norm_gate_logits: + denom_s = gates1_s + gates2_s # [0.2, 0.3] + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=1e-6) + gates1_s /= denom_s + gates2_s /= denom_s + if self.training and self.expert_drop: + # log.debug(gates2_s) + gates2_s = paddle.where( + 2 * gates2_s < paddle.rand_like(gates2_s), + paddle.zeros_like(gates2_s), + gates2_s, + ) + + # Calculate combine_weights and dispatch_mask + gates1 = gates1_s.unsqueeze(1) * mask1_float + gates2 = gates2_s.unsqueeze(1) * mask2_float + + expert1_index = paddle.argmax(gates1, -1) + combine1_weight = paddle.max(gates1, -1, keepdim=True) + scatter1_index = expert1_index * capacity + locations1_s + scatter1_index = scatter1_index.cast("int64") + dispatch1_mask = combine1_weight.cast(paddle.bool).detach() + + expert2_index = paddle.argmax(gates2, -1) + combine2_weight = paddle.max(gates2, -1, keepdim=True) + scatter2_index = expert2_index * capacity + locations2_s + scatter2_index = scatter2_index.cast("int64") + dispatch2_mask = combine2_weight.cast(paddle.bool).detach() + # logger.info(f'expert-id: {expert1_index} vs {expert2_index}, mask:{mask1_float} vs {mask2_float}') + if self.enable_logging and global_training_logs_enabled(): + global_training_logs.update( + **{ + "top1_gate": ( + combine1_weight.sum() + / (dispatch1_mask.cast("float32").sum() + 1e-9) + ).item(), + "top2_gate": ( + combine2_weight.sum() + / (dispatch2_mask.cast("float32").sum() + 1e-9) + ).item(), + f"top1_gate_layer_{self.layer_idx}": ( + combine1_weight.sum() + / (dispatch1_mask.cast("float32").sum() + 1e-9) + ).item(), + f"top2_gate_layer_{self.layer_idx}": ( + combine2_weight.sum() + / (dispatch2_mask.cast("float32").sum() + 1e-9) + ).item(), + } + ) + + seqlen = logits.shape[0] + top1_gate_experts_per_token = ( + paddle.cast(dispatch1_mask, dtype="float32").sum() / seqlen + ) + top2_gate_experts_per_token = ( + paddle.cast(dispatch2_mask, dtype="float32").sum() / seqlen + ) + leakage_experts_per_token = ( + paddle.cast( + (~dispatch1_mask) & (~dispatch2_mask), dtype="float32" + ).sum() + / seqlen + ) + + experts_per_token = ( + top1_gate_experts_per_token + top2_gate_experts_per_token + ) + _log = { + f"experts_per_token_layer_{self.layer_idx}": experts_per_token.item(), + f"top1_experts_per_token_layer_{self.layer_idx}": top1_gate_experts_per_token.item(), + f"top2_experts_per_token_layer_{self.layer_idx}": top2_gate_experts_per_token.item(), + f"leakage_experts_per_token_layer_{self.layer_idx}": leakage_experts_per_token.item(), + } + global_training_logs.update( + **_log, + **{ + k.replace(f"_layer_{self.layer_idx}", ""): v + for k, v in _log.items() + }, + ) + + return ( + capacity, + paddle.concat((dispatch1_mask, dispatch2_mask), 1), + paddle.concat((combine1_weight, combine2_weight), 1), + paddle.stack((scatter1_index, scatter2_index), 1), + l_aux, + l_zloss, + ) + + def _cal_aux_loss( + self, + gate_prob, + dispatch_mask, + num_experts=None, + use_group=None, + tokens_mask=None, + dispatch_tokens_mask=None, + ): + + if self.act is F.sigmoid: + gate_prob = gate_prob / gate_prob.sum(-1, keepdim=True) + + if self.use_correction_bias: + if tokens_mask is not None: + gate_prob_this_modality = gate_prob[tokens_mask.astype("bool")] + if gate_prob_this_modality.shape[0]: + _, top_idx = gate_prob_this_modality.topk( + k=self.config.moe_k, axis=-1 + ) + if int_bincount is not None: + dispatch_mask = int_bincount( + top_idx, 0, gate_prob.shape[-1], paddle.int64 + ) + else: + mask = paddle.zeros_like( + gate_prob_this_modality + ).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + else: + dispatch_mask = paddle.zeros(gate_prob.shape[-1], dtype="int64") + dist.stream.all_reduce( + dispatch_mask, + group=self.group, + use_calc_stream=True, + ) + else: + _, top_idx = gate_prob.topk(k=self.config.moe_k, axis=-1) + if int_bincount is not None: + dispatch_mask = int_bincount( + top_idx, 0, gate_prob.shape[-1], paddle.int64 + ) + else: + mask = paddle.zeros_like(gate_prob).put_along_axis( + top_idx, paddle.to_tensor(1.0), axis=1 + ) + dispatch_mask = paddle.sum(mask.cast(paddle.int64), axis=0) + + if num_experts is None: + num_experts = self.num_experts_tensor + if use_group is None: + use_group = self.config.moe_group_experts + + return cal_aux_loss_func( + gate_prob, + dispatch_mask, + tokens_mask, + dispatch_tokens_mask, + num_experts, + use_group, + self.config.moe_k, + self.global_aux_loss, + self.rank if self.global_aux_loss else None, + self.group if self.global_aux_loss else None, + ) + + def _cal_z_loss(self, logits, loss_mask=None): + + if ( + (moe_router_loss_ops is not None) + and (loss_mask is None or len(loss_mask.shape) == 1) + and (get_env_device() != "xpu") + and (logits.dtype == paddle.float32) + ): + return CalZLossFunctor.apply(logits, loss_mask) + else: + return cal_z_loss_func(logits, loss_mask) + + def _cal_orthogonal_loss_opt_each_weight(self, weight, use_group): + + if weight.dtype != paddle.float32: + weight = weight.astype(paddle.float32) + + if ( + (moe_router_loss_ops is not None) + and (get_env_device() != "xpu") + and (weight.dtype == paddle.float32) + ): + return CalOrthogonalLossOptEachWeightFunctor.apply( + weight, self.config.moe_k, use_group + ) + else: + return cal_orthogonal_loss_opt_each_weight_func( + weight, + self.config.moe_k, + use_group, + self.eps, + self.xpu_matmul, + self.training, + ) + + def _cal_orthogonal_loss(self, weight_id=None, use_group=None): + + if use_group is None: + use_group = ( + self.config.moe_group_experts and self.config.moe_group_orthogonal_loss + ) + + if weight_id is not None: + if weight_id == 0: + w_ = self.weight + else: + assert self.config.multimodel_experts + w_ = getattr(self, f"weight_{weight_id}") + return self._cal_orthogonal_loss_opt_each_weight(w_, use_group) + + orthogonal_loss = self._cal_orthogonal_loss_opt_each_weight( + self.weight, use_group + ) + if self.config.multimodel_experts: + for i in range(1, len(self.config.moe_num_experts)): + w_ = getattr(self, f"weight_{i}") + orthogonal_loss += self._cal_orthogonal_loss_opt_each_weight( + w_, use_group=False + ) + return orthogonal_loss + + +class TopKGateFused(Top2Gate): + + def forward( + self, + input: Tensor, + token_type_ids=None, + transform_weight=True, + ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + + capacity = self.get_capacity(input.shape[0]) + weight = self.get_gate_weight(transform_weight) + with paddle.amp.auto_cast(False): + if get_env_device() == "xpu" and self.xpu_matmul is not None: + assert not self.fuse_gate_detach_matmul, "not supported on XPU" + input_32 = input.cast("float32") + logits = self.xpu_matmul( + input_32, + weight, + training=self.training, + ) + else: + logits = gate_detach_matmul( + input, weight, self.fuse_gate_detach_matmul, self.use_fake_gate + ) + if self.use_token_type_bias: + assert token_type_ids is not None + assert ( + token_type_ids.max() < self.bias.shape[0] + ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" + bias = self.bias[token_type_ids] # [seq] + logits = logits + bias + orthogonal_loss = None + router_loss = paddle.zeros([1], dtype="float32") + router_loss.stop_gradient = False + if ( + self.enable_logging + and global_training_logs_enabled() + and orthogonal_loss is not None + ): + _log = { + f"orthogonal_loss_layer_{self.layer_idx}": orthogonal_loss.item(), + } + global_training_logs.update( + **_log, + **{ + k.replace(f"_layer_{self.layer_idx}", ""): v + for k, v in _log.items() + }, + ) + + return logits, capacity, router_loss diff --git a/examples/pre-training/models/sequence_parallel_utils_auto.py b/examples/pre-training/models/sequence_parallel_utils_auto.py new file mode 100644 index 000000000..408a7227f --- /dev/null +++ b/examples/pre-training/models/sequence_parallel_utils_auto.py @@ -0,0 +1,229 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# !/usr/bin/env python3 + +import hashlib +import numpy as np +import logging + +import paddle +from paddle import distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed import fleet + + +from models.comm_utils import ( + scatter, + all_gather, + reduce_scatter, +) + + +from paddle.distributed import in_auto_parallel_align_mode + + +try: + from paddle.nn.functional import gemm_reduce_scatter, all_gather_gemm +except ImportError: + gemm_reduce_scatter = None + all_gather_gemm = None + flux = None + +logger = logging.getLogger(__name__) + +if not hasattr(paddle.Tensor, "contiguous"): + + def contiguous(self): + + return self + + setattr(paddle.Tensor, "contiguous", contiguous) + + +if not hasattr(paddle.Tensor, "_md5sum"): + + def _md5sum(self): + numpy_array = np.array(self) + array_bytes = numpy_array.tobytes() + return hashlib.md5(array_bytes).hexdigest() + + setattr(paddle.Tensor, "_md5sum", _md5sum) + + +def get_hcg(): + return fleet.get_hybrid_communicate_group() + + +class ScatterOp(PyLayer): + + @staticmethod + def forward(ctx, input, axis=0, group=None): + ctx.axis = axis + ctx.group = group + return scatter(input, axis=axis, group=ctx.group) + + @staticmethod + def backward(ctx, grad): + return all_gather(grad, axis=ctx.axis, group=ctx.group) + + +class GatherOp(PyLayer): + + @staticmethod + def forward(ctx, input, axis=0, group=None): + ctx.axis = axis + ctx.group = group + return all_gather(input, axis=axis, group=group) + + @staticmethod + def backward(ctx, grad): + return scatter(grad, axis=ctx.axis, group=ctx.group) + + +class AllGatherOp(PyLayer): + + @staticmethod + def forward(ctx, input, group=None): + ctx.group = group + return all_gather(input, group=group) + + @staticmethod + def backward(ctx, grad): + if in_auto_parallel_align_mode(): + group = ctx.group + if group is None: + group = get_hcg().get_model_parallel_group() + pg = group.process_group + pg.allreduce(grad).wait() + return paddle.split(grad, group.nranks, axis=0)[group.rank] + else: + return reduce_scatter(grad, group=ctx.group) + + +class ReduceScatterOp(PyLayer): + @staticmethod + def forward(ctx, input, group=None): + ctx.group = group + return reduce_scatter(input, group=group) + + @staticmethod + def backward(ctx, grad): + return all_gather(grad, group=ctx.group) + + +class AllGatherVarlenOp(PyLayer): + + @staticmethod + def forward(ctx, input, group=None): + hcg = fleet.get_hybrid_communicate_group() + if group is None: + group = hcg.get_model_parallel_group() + + shape0 = paddle.to_tensor([input.shape[0]]) + shape0_all = paddle.empty(shape=[group.nranks], dtype=shape0.dtype) + dist.stream.all_gather(shape0_all, shape0, group=group, use_calc_stream=True) + shape0_all = shape0_all.numpy() + max_shape0 = shape0_all.max() + + indices = [] + for idx, s in enumerate(shape0_all): + offset = idx * max_shape0 + indices.append(list(range(offset, offset + s))) + indices = np.concatenate(indices, axis=0) + indices = indices.reshape([-1] + [1] * (len(input.shape) - 1)) + indices = paddle.to_tensor(indices, dtype=paddle.int32) + + padding = max_shape0 - input.shape[0] + + ctx.shape0 = input.shape[0] + ctx.max_shape0 = max_shape0 + ctx.shape0_all = shape0_all + ctx.padding = padding + ctx.indices = indices + ctx.group = group + + if padding > 0: + input_shape = input.shape + input_shape[0] = padding + padding_tensor = paddle.empty(shape=input_shape, dtype=input.dtype) + input = paddle.concat([input, padding_tensor], axis=0) + output = all_gather(input, group) + output = paddle.take_along_axis(output, indices, axis=0) + + return output + + @staticmethod + def backward(ctx, grad): + input_shape = grad.shape + input_shape[0] = ctx.max_shape0 * ctx.shape0_all.shape[0] + output = paddle.zeros(shape=input_shape, dtype=grad.dtype) + + grad = paddle.scatter(output, ctx.indices, grad) + + grad = scatter(grad, ctx.group) + + if ctx.padding > 0: + grad = grad[: ctx.shape0] + return grad + + +class GemmReduceScatterOp(PyLayer): + + @staticmethod + def forward(ctx, input, weight, group): + + ctx.save_for_backward(input, weight) + ctx.group = group + output = gemm_reduce_scatter(input, weight, group) + return output + + @staticmethod + def backward(ctx, grad): + input, weight = ctx.saved_tensor() + group = ctx.group + if input.stop_gradient and weight.stop_gradient: + return None, None + + if input.stop_gradient: + input_grad = None + grad_parallel = None + else: + input_grad, grad_parallel = all_gather_gemm( + grad, weight, group, deepcopy_input_parallel=False + ) + + if weight.stop_gradient: + weight_grad = None + else: + if grad_parallel is None: + grad_parallel = all_gather(grad) + weight_grad = paddle.matmul(input, grad_parallel, transpose_x=True) + return input_grad, weight_grad + + +def sequence_parallel_sparse_mask_labels(labels, ignore_label=-100): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + labels = labels.flatten() + labels_local = paddle.split(labels, group.nranks)[group.rank] + + tgt_index = paddle.nonzero(labels_local != ignore_label).squeeze() + if tgt_index.numel() == 0: + tgt_index = paddle.to_tensor([0]) + + tgt_index = tgt_index.reshape([-1]).astype(paddle.int32) + labels_local_gather = paddle.take_along_axis(labels_local, tgt_index, axis=0) + labels_all_gather = AllGatherVarlenOp.apply(labels_local_gather) + return labels_all_gather, tgt_index.reshape([-1, 1]) diff --git a/examples/pre-training/scripts/train_96_auto.sh b/examples/pre-training/scripts/train_96_auto.sh new file mode 100644 index 000000000..932b0d331 --- /dev/null +++ b/examples/pre-training/scripts/train_96_auto.sh @@ -0,0 +1,134 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export NNODES=1 +export PADDLE_TRAINERS_NUM=1 + +mpi_rank=${OMPI_COMM_WORLD_RANK:-0} +node_rank=$((mpi_rank+offset)) +mpi_node=${OMPI_COMM_WORLD_SIZE:-1} +echo "MPI status:${mpi_rank}/${mpi_node}" +nnode_train=${nnode_set:-${mpi_node}} +master_train=${master:-localhost} +# +echo "Distributed Training ${node_rank}/${nnode_train} master=${master_train}" +set -x + +export CUDA_MODULE_LOADING=LAZY +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_DEBUG=INFO +export PYTHONUNBUFFERED=1 +unset GLOG_vmodule GLOG_v +export PADDLE_DISABLE_CUDNN_FA=1 +export FLAGS_use_auto_growth_pinned_allocator=True +export FLAGS_pipeline_nccl_comm_init_option=1 +export FLAGS_sharding_v2_check_zero_padding=1 +export FLAGS_use_paddle_recall_error=0 +export FLAGS_tcp_max_syn_backlog=16384 +export FLAGS_call_stack_level=2 + + +# 屏蔽平台预设的环境变量,因为框架采用兼容升级,检测到这些配置会使用原方式启动 +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +nnodes=$PADDLE_TRAINERS_NUM +rank=$PADDLE_TRAINER_ID + + +export FLAGS_shard_use_reduce=1 +export FLAGS_shard_norm_align_dp=0 + +#加速pin memory save ckpt时间 +export FLAGS_use_auto_growth_pinned_allocator=True + +# export FLAGS_flash_attn_version=v1 +# 开启FA3 +SM=`nvidia-smi --query-gpu=compute_cap --format=csv | tail -n 1 | sed 's/\.//g'` +if [ $SM -eq 90 ] +then + export FLAGS_flash_attn_version=3 +else + export FLAGS_flash_attn_version=2 +fi + +# 保证集群稳定性的配置,跟性能无关 +export NCCL_IB_QPS_PER_CONNECTION=8 +export NCCL_IB_TIMEOUT=22 +export NCCL_IB_GID_INDEX=3 +# 开启AR功能 +export NCCL_IB_ADAPTIVE_ROUTING=1 + +# 集群hang检测 +export PADDLE_PG_TIMEOUT=150000 # 通信组超时时间,单位是ms,默认2分钟 +export FLAGS_enable_async_trace=False # True开启通信debug功能,False或不设置关闭,默认开启 +# export CUDA_MODULE_LOADING=LAZY + +export FLAGS_pipeline_nccl_comm_init_option=1 + +# 启动方式 +cuda_version=`nvidia-smi |grep "CUDA Version" |awk '{print $9}' |awk -F'.' '{print $1}'` +if [ ${cuda_version} != "12" ];then + export LD_LIBRARY_PATH=/usr/local/cuda/compat:$LD_LIBRARY_PATH +fi + +# nnodes=7 +# START_RANK=0 +# END_RANK=$nnodes + +# if [[ $rank -lt $START_RANK ]]; then +# exit 0 +# fi + +# if [[ $rank -ge $END_RANK ]]; then +# exit 0 +# fi +rank=$(($rank-$START_RANK)) +nnodes=$(($END_RANK-$START_RANK)) +master=`cat /root/paddlejob/workspace/hostfile | head -n $(($START_RANK+1)) | tail -n 1 | awk '{print $1}'` +port=36677 + + +#自动并行相关 +export FLAGS_enable_fused_ffn_qkv_pass=1 +export FLAGS_enable_pir_api=1 +#export FLAGS_enable_sharding_stage1_tensor_fusion=1 +export FLAGS_enable_moe_utils=true + +#调试相关 +export FLAGS_call_stack_level=2 +#export GLOG_v=6 +#export FLAGS_print_ir=1 +#export FLAGS_benchmark=1 +#export CUDA_VISIBLE_DEVICES=0,1 + +export PYTHONPATH=$PYTHONPATH:./ernie + +LOG_DIR=output/paddle_distributed_logs + +rm -rf output +rm -rf core.* + +python -m paddle.distributed.launch \ + --log_dir $LOG_DIR \ + --master $master:$port \ + --nnodes $nnodes \ + --rank $rank \ + --run_mode=collective \ + ${script:-ernie/pretrain_auto.py} \ + --config yamls/pretrain_96_auto.yaml diff --git a/examples/pre-training/yamls/pretrain_96_auto.yaml b/examples/pre-training/yamls/pretrain_96_auto.yaml new file mode 100644 index 000000000..63f0214f3 --- /dev/null +++ b/examples/pre-training/yamls/pretrain_96_auto.yaml @@ -0,0 +1,158 @@ +# -----------环境变量----------------------# +env: + HOME: null + +# ---------------------------model args-------------------------------------------------# +model_args: + model_name_or_path: model_configs_auto/ + tokenizer_name: ./ernie/src/tokenizers/tokenizer_model + output_dir: ./output/ + data_load_process_num: 40 + max_seq_length: 4096 + base_seq_length: 4096 + num_consecutive: 32 + sequence_parallel: 1 + + enable_global_training_logs: False + moe_use_aux_free_update_coef: 0.001 + global_logging_interval: 10 + model_config: + moe_logging: True + moe_use_aux_free: true + multi_token_pred_depth: 0 + + + +# ---------------------------trainer args-------------------------------------------------# +trainer_args: + input_dir: "0.4 ./demo_data/data-1-part0 0.6 ./demo_data/data-1-part0" + split: "998,1,1" + loss_spike_settings: + enable_loss_spike_watcher: 1 + longjob_id: long-78f0ae68688b4659 + supervised_filename: output/paddle_distributed_logs/metrics_rank0.json + delimiter: "Loading configuration file" + watch_loss_spike_interval: 20 + loss_spike_restart_interval: 300 + params: + - data_type: null + data_type_human_read: "纯文" + max_loss_thr: 2.0 + max_tolerance_steps: 1 + allow_loss_fallback: 0 + start_check_step: 219700 + + use_sp_callback: true + moe_gate_lr_ratio: 0.01 + do_train: True + dataloader_num_workers: 8 + prefetch_factor: 32 + overwrite_output_dir: 1 + disable_tqdm: 1 + logging_steps: 1 + eval_steps: 1000 + eval_iters: -1 + save_steps: 3000 + max_steps: 10 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-8 + learning_rate: 2.2e-4 + min_lr: 2.2e-5 + + global_batch_size: 4 # 16660 + gradient_accumulation_steps: 2 # 8008: 14; + per_device_train_batch_size: 2 + per_device_eval_batch_size: 1 + + lr_scheduler: wsd:231084 + decay_function: 1-sqrt + max_grad_norm: 1.0 + + adaptive_norm_clip: 0 # 4350 step后,关闭 adaptive-norm-clip + adaptive_norm_clip_ratio: 1.2 + adaptive_norm_force_clear_state: 0 # 在切换分布式策略时, 开启强制刷新统计状态 + adaptive_norm_enable_record: 1 # 开启更详细的裁剪日志 + + use_async_save: True # enable asynchronize save to gain efficiency + + weight_decay: 0.1 + warmup_steps: 200 + save_total_limit: 5 + bf16: True + fp16_opt_level: "O2" + use_fp8: False + scale_loss: 4096 + seed: 666 + use_train_part_sharding: 1 + # pre_alloc_memory: 60 + + # # N7 + # tensor_parallel_degree: 8 # N7:8, N4:8, N1:4 + # pipeline_parallel_degree: 7 # N7:7, N4:4, N1:2 + # virtual_pp_degree: 8 # N7:8, N4:8, N1:1 + + # # N4 + # tensor_parallel_degree: 8 # N7:8, N4:8, N1:4 + # pipeline_parallel_degree: 4 # N7:7, N4:4, N1:2 + # virtual_pp_degree: 8 # N7:8, N4:8, N1:1 + + # # N1 + # tensor_parallel_degree: 4 # N7:8, N4:8, N1:4 + # pipeline_parallel_degree: 2 # N7:7, N4:4, N1:2 + # virtual_pp_degree: 1 # N7:8, N4:8, N1:1 + + # N1 dynamic auto + tensor_parallel_degree: 4 # N7:8, N4:8, N1:4 + pipeline_parallel_degree: 2 # N7:7, N4:4, N1:2 + pipeline_schedule_mode: "VPP" + virtual_pp_degree: 2 # N7:8, N4:8, N1:1 + + data_parallel_degree: 1 + sharding: "stage1" + sharding_degree: 1 # 170 + # sharding_degree: 170 # + amp_master_grad: 1 + pipeline_parallel_config: enable_delay_scale_loss #enable_dp_comm_overlap + # pipeline_parallel_config: enable_delay_scale_loss enable_overlap_p2p_comm best_unbalanced_scheduler #enable_dp_comm_overlap + sharding_parallel_config: split_param enable_fuse_optimizer_states + sharding_comm_buffer_size_MB: 2048 + tensor_parallel_config: replace_with_parallel_cross_entropy + # tensor_parallel_config: sync_param sync_grad sync_moment + + + skip_profile_timer: False + + ignore_data_skip: 0 + shuffle_consecutive: True + + load_sharded_model: True + save_sharded_model: True + save_sharding_stage1_model_include_freeze_params: True + ignore_load_lr_and_optim: False + metrics_output_path: ./output/paddle_distributed_logs/ + + #TODO(@gexiao): move to longjob_args + pdc_download_ckpt: true + pdc_download_timeout: 300 + + # # Flash checkpoint settings + # enable_zero_cost_checkpoint: true + # save_tokenizer: false + # save_rng_states: false + # zcc_workers_num: 1 + # zcc_pipeline_hooks_capacity_usage: 0.8 + # flash_device_save_steps: 4 + # zcc_save_ema_coef: 0.9993 #exp((4/10000)*ln(1-0.9999)) + # zcc_ema_interval: 4 + + + use_moe: true + moe_with_send_router_loss: False + moe_group: mp + log_global_grad_norm: True + enable_optimizer_timer: False + gc_interval: 100000 + + enable_auto_parallel: 1 + to_static: 0