diff --git a/examples/pre-training/ernie/pretrain.py b/examples/pre-training/ernie/pretrain.py index ab2cfe9da..87683c85c 100644 --- a/examples/pre-training/ernie/pretrain.py +++ b/examples/pre-training/ernie/pretrain.py @@ -35,6 +35,10 @@ PdArgumentParser, get_last_checkpoint, ) +from paddleformers.trainer.unified_checkpoint import unified_checkpoint +from paddleformers.transformers.model_utils import unwrap_model + +from safetensors import safe_open try: from paddleformers.utils.downloader import get_static_model_on_pdc @@ -202,6 +206,184 @@ def _collate_data(data, stack_fn=Stack()): return train_dataset, valid_dataset, test_dataset, _collate_data +def load_huggingface_checkpoint(model, args): + fused_rms_norm_replace = [ + ("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"), + ("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"), + ] + shared_layers_prefix = "shared_layers.embed_weight_share." + unnamed_layers = ["ernie.norm.weight", "lm_head.weight"] + + logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}") + with open( + os.path.join(args.model_name_or_path, "model.safetensors.index.json") + ) as f: + weight_map = json.load(f)["weight_map"] + + ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size() + ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank() + expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank + use_torch_format = False + + def param_to_weight(name): + # for PP=1, we only need to substitute the fused_rms_norm and expert_id + for src, dst in fused_rms_norm_replace: + name = name.replace(src, dst) + if m := re.search(r"mlp\.experts\.(\d+)", name): + expert_id = expert_offset + int(m.group(1)) + s, e = m.span() + name = name[:s] + f"mlp.experts.{expert_id}" + name[e:] + if isinstance(model, ErnieMoEForCausalLM): + return name + + # for PP>1, we also need to handle special layers and adjust layer_idx + if name.startswith(shared_layers_prefix): + return "ernie." + name[len(shared_layers_prefix) :] + layer_idx, stem = name.split(".", maxsplit=1) + if stem == "weight": + return unnamed_layers.pop(0) + if stem.startswith("mtp"): + return f"ernie.{stem}" + return f"ernie.layers.{int(layer_idx) - 1}.{stem}" + + def try_torch_format(weight_key): + if weight_key.startswith("ernie."): + weight_key = "model." + weight_key[6:] + + key_decompose = [weight_key] + if ".up_gate_proj." in weight_key: + key_decompose = [ + weight_key.replace(".up_gate_proj.", ".gate_proj."), + weight_key.replace(".up_gate_proj.", ".up_proj."), + ] + elif ".qkv_proj." in weight_key: + key_decompose = [ + weight_key.replace(".qkv_proj.", ".q_proj."), + weight_key.replace(".qkv_proj.", ".k_proj."), + weight_key.replace(".qkv_proj.", ".v_proj."), + ] + + tensor_decompose = [] + for key in key_decompose: + if not (weight_file := weight_map.get(key)): + return None + with safe_open( + os.path.join(args.model_name_or_path, weight_file), + framework="numpy", + ) as f: + tensor = paddle.to_tensor(f.get_tensor(key)) + if "_proj." in key or ".gate." in key: + tensor = tensor.T.contiguous() + tensor_decompose.append(tensor) + + if len(tensor_decompose) == 1: + return tensor_decompose[0] + else: + return paddle.concat(tensor_decompose, axis=-1) + + def auto_fix_shape(param, weight): + assert len(param.shape) == len(weight.shape), "rank not match" + assert all( + p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape) + ), "weight too small" + indices = tuple(slice(0, dim) for dim in param.shape) + return weight[indices].contiguous() + + for name, param in model.named_parameters(): + weight_key = param_to_weight(name) + if weight_file := weight_map.get(weight_key): + with safe_open( + os.path.join(args.model_name_or_path, weight_file), + framework="numpy", + ) as f: + weight = paddle.to_tensor(f.get_tensor(weight_key)) + elif (weight := try_torch_format(weight_key)) is not None: + use_torch_format = True + else: + logger.warning( + f"param `{name}`'s weight `{weight_key}` not found. " + "Skip initializing." + ) + continue + if use_torch_format and "lm_head" in weight_key: + weight = weight.T.contiguous() + if param.shape != weight.shape: + logger.warning( + f"param `{name}`'s shape doesn't match weight `{weight_key}`: " + f"{param.shape} and {weight.shape}. Auto fixing." + ) + weight = auto_fix_shape(param, weight) + param.copy_(weight) + + +def get_expected_state_dict(model, **kwargs): + fused_rms_norm_replace = [ + ("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"), + ("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"), + ] + shared_layers_prefix = "embed_share." + + model = unwrap_model(model) + hcg = fleet.get_hybrid_communicate_group() + ep_degree = hcg.get_expert_parallel_world_size() + ep_rank = hcg.get_expert_parallel_rank() + expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank + + if model.config.head_dim is None: + head_dim = model.config.hidden_size // model.config.num_attention_heads + else: + head_dim = model.config.head_dim + q_dim = head_dim * model.config.num_attention_heads + kv_dim = head_dim * model.config.num_key_value_heads + + def copy_attr(out, param): + if hasattr(param, "is_distributed"): + out.is_distributed = param.is_distributed + if hasattr(param, "no_sync"): + out.no_sync = param.no_sync + return out + + def param_to_weight(name): + # for PP=1, we only need to substitute the fused_rms_norm and expert_id + for src, dst in fused_rms_norm_replace: + name = name.replace(src, dst) + if m := re.search(r"\.experts\.(\d+)\.", name): + expert_id = expert_offset + int(m.group(1)) + s, e = m.span() + name = name[:s] + f".experts.{expert_id}." + name[e:] + if isinstance(model, ErnieMoEForCausalLM): + return name + + # for PP>1, we also need to handle shared layers + if name.startswith(shared_layers_prefix): + return "ernie." + name[len(shared_layers_prefix) :] + return name + + state_dict = {} + for name, param in model.state_dict().items(): + name = param_to_weight(name) + if name.startswith("ernie."): + name = "model." + name[6:] + + if "_proj." in name or ".gate." in name or "lm_head" in name: + param = copy_attr(param.T, param) + + if ".up_gate_proj." in name: + gate, up = param.split(2) + gate, up = copy_attr(gate, param), copy_attr(up, param) + state_dict[name.replace(".up_gate_proj.", ".gate_proj.")] = gate + state_dict[name.replace(".up_gate_proj.", ".up_proj.")] = up + elif ".qkv_proj." in name: + assert q_dim + kv_dim * 2 == param.shape[0] + state_dict[name.replace(".qkv_proj.", ".q_proj.")] = param[:q_dim] + state_dict[name.replace(".qkv_proj.", ".k_proj.")] = param[q_dim:-kv_dim] + state_dict[name.replace(".qkv_proj.", ".v_proj.")] = param[-kv_dim:] + else: + state_dict[name] = param + + return state_dict + + def main(): if set_affinity is not None: set_affinity_code = set_affinity() @@ -520,21 +702,12 @@ def sname_to_tname(pp_model): cfg.enable_delay_scale_loss = args.enable_delay_scale_loss register_pp_reshard_information(cfg.num_hidden_layers) - if args.from_scratch: - model = ErnieMoEForCausalLMPipe(cfg) - else: - model = ErnieMoEForCausalLMPipe.from_pretrained( - args.model_name_or_path, - config=cfg, - ) + model = ErnieMoEForCausalLMPipe(cfg) else: - if args.from_scratch: - model = ErnieMoEForCausalLM(cfg) - else: - model = ErnieMoEForCausalLM.from_pretrained( - args.model_name_or_path, - config=cfg, - ) + model = ErnieMoEForCausalLM(cfg) + + if not args.from_scratch: + load_huggingface_checkpoint(model, args) cfg = model.config logger.info(f"using model type:{type(model)}") @@ -581,6 +754,7 @@ def sname_to_tname(pp_model): if args.do_train: train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics + unified_checkpoint.get_expected_state_dict = get_expected_state_dict trainer.save_model(args.output_dir) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics)