Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
6a79346
init ernie dynamic auto code(no moe op)
waliwali777 Jul 31, 2025
23a1b7c
rm ernie_mm_moe, dfnrope
waliwali777 Aug 12, 2025
1e4ba91
rm ernie_moe
waliwali777 Aug 12, 2025
61240a6
rm tokenizer_model_auto
waliwali777 Aug 12, 2025
a56f691
rm multimodel callback
waliwali777 Aug 12, 2025
86c338f
rm bos client
waliwali777 Aug 12, 2025
e335954
rm elastic_utils and fancy_print
waliwali777 Aug 12, 2025
665f122
rm openclip
waliwali777 Aug 12, 2025
d4236dc
rm longcontext_ops, global_random, align_mode_utils
waliwali777 Aug 12, 2025
4f7726c
rm sft_task_reader streaming_pretrain_reader
waliwali777 Aug 12, 2025
f1e913d
fix tp import
waliwali777 Aug 12, 2025
47999e5
rm other_gate ernie/moe/uesless_file
waliwali777 Aug 12, 2025
56f40c6
rm refine recompute
waliwali777 Aug 12, 2025
e4a6353
rm refine recompute 2
waliwali777 Aug 12, 2025
66608bf
rm useless callback
waliwali777 Aug 12, 2025
ba116f8
rm bns lock
waliwali777 Aug 12, 2025
bf87225
rm dataset uneless files
waliwali777 Aug 12, 2025
6751157
rmmodel io
waliwali777 Aug 12, 2025
475280f
rmmodel submatrix_parallel
waliwali777 Aug 12, 2025
25cc2bf
rm model callback
waliwali777 Aug 12, 2025
1ab338b
rm pretrain_iterable_dataset
waliwali777 Aug 12, 2025
7759273
rm token dispatcher
waliwali777 Aug 12, 2025
6b706de
restore moe_clip, logging, misc, seed_utils, trainging_utils
waliwali777 Aug 12, 2025
c3a9f33
move ErnieConfig
waliwali777 Aug 12, 2025
76c6b91
add moe_layer_auto_utils
waliwali777 Aug 12, 2025
b8e1b16
add moe_utils_auto
waliwali777 Aug 12, 2025
5d22913
add fp8_utils_auto
waliwali777 Aug 12, 2025
24ed7d5
add restore
waliwali777 Aug 12, 2025
4238284
add top2_gate_auto_auto, sequence_parallel_utils_auto, fp8_linear_auto
waliwali777 Aug 12, 2025
5a3c7c1
restore model_config
waliwali777 Aug 12, 2025
af43a79
pre-commit
waliwali777 Aug 13, 2025
49827a7
remove
waliwali777 Aug 13, 2025
dd1dab1
adapt data loader
waliwali777 Aug 13, 2025
41e242e
Merge branch 'pr' into eb45pp
Xing-lil Aug 14, 2025
3b79031
pp use
Xing-lil Aug 14, 2025
9686577
mv auto_trainer.py pp use to pretrain
Xing-lil Aug 19, 2025
f002934
update
Xing-lil Aug 19, 2025
20ad8b4
Merge branch 'eb45pp' of https://github.com/Xing-lil/ERNIE into eb45pp
Xing-lil Aug 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 359 additions & 0 deletions examples/pre-training/ernie/pretrain_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import json
import numpy as np
import random
import paddle
import paddle.distributed.fleet as fleet
from src.utils import logger
from paddleformers.trainer import (
PdArgumentParser,
get_last_checkpoint,
)
from src.tokenizers.tokenization_eb_v2 import ErnieBotTokenizer
from omegaconf.listconfig import ListConfig
from omegaconf.dictconfig import DictConfig
from src.callbacks import (
ProgreesiveBatchingCallback,
GlobalRNGCallback,
)
from models.ernie import (
ErnieForCausalLMAuto,
)
from models.ernie_moe.configuration import (
ErnieConfig,
ErnieMoEConfig,
)
from src.trainers import AutoPretrainingTrainer, AutoPreTrainingArguments
from src.utils import (
setup_logger_output_file,
)
from src.utils.misc import global_training_logs
from pretrain import create_pretrained_dataset


from config import get_config

try:
from paddleformers.trainer.trainer_utils import log_trainer_start
except ImportError:

def log_trainer_start():
"""Print main process messgae"""
if "MAIN_PROCESS_STARTED" not in os.environ:
start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
logger.info(
f"The Training Main Process Started Successfully. time: {start_time}, pid: {os.getpid()}"
)
os.environ["MAIN_PROCESS_STARTED"] = "1"


log_trainer_start()


try:
from paddle.distributed.fleet import monitor_perf as collective_perf
except ImportError:
from paddle.distributed.fleet import collective_perf


assert paddle.version.mkl() == "OFF", (
"MKL is not supported"
" in this version. Please set -DWITH_MKL=OFF when compiling PaddlePaddle."
)


def update_model_config_from_args(config: ErnieConfig, model_args: dict):
for k, v in model_args.items():
if hasattr(config, k):
logger.info(f"update model config: {k} = {v}")
setattr(config, k, v)
else:
logger.warning(f"model config key: {k} does not exist")
return config


def init_parameter(model):

for param in model.parameters():
param.initialize()


def main():
"""Main function"""
config = get_config(verbose=True)
os.makedirs(config.model_args.output_dir, exist_ok=True)
parser = PdArgumentParser(AutoPreTrainingArguments)
if not hasattr(config.trainer_args, "pipeline_parallel_config"):
config.trainer_args.pipeline_parallel_config = ""

if "enable_dp_comm_overlap" in config.trainer_args.pipeline_parallel_config:
logger.warning(
"Pipeline dp_comm_overlap and FusedLinearWithGradAdd can not be used at "
"the same time."
)

if "enable_timer" in config.trainer_args.pipeline_parallel_config:
from paddle.distributed.fleet.meta_parallel.pipeline_parallel import (
PipelineParallel,
)

PipelineParallel.timer_printer = lambda _: None

def formatv(v):
if isinstance(v, ListConfig):
return list(v)
elif isinstance(v, DictConfig):
return dict(v)
return v

model_args = {k: formatv(v) for k, v in dict(config.model_args).items()}
trainer_args = {k: formatv(v) for k, v in dict(config.trainer_args).items()}
(args,) = parser.parse_dict(dict(**model_args, **trainer_args))

if args.strategy.pipeline.enable and args.virtual_pp_degree > 1:
pipeline = args.strategy.pipeline
pipeline.vpp_degree = args.virtual_pp_degree
pipeline.vpp_seg_method = args.virtual_pipeline_seg_method

if args.modality_ratio is not None:
args.modality_interleave = (
sum(args.modality_ratio)
if args.modality_interleave == "acc"
else sum(args.modality_ratio) * args.gradient_accumulation_steps
)
args.modality_ratio = [
i / sum(args.modality_ratio) for i in args.modality_ratio
]

args.eval_iters = 10
args.test_iters = args.eval_iters * 10

args.use_moe = dict(**dict(config.model_args), **dict(config.trainer_args)).get(
"use_moe", False
)
model_config = dict(getattr(config.model_args, "model_config", {}))
model_config = {k: formatv(v) for k, v in model_config.items()}
logger.info(f"model_config_from_yaml: {json.dumps(model_config, indent=4)}")
setup_logger_output_file(config.model_args.output_dir, args.local_rank)
paddle.set_device(args.device)

np.random.seed(args.seed)
random.seed(args.seed)
paddle.seed(args.seed)
# set_seed(args.seed)

prop = paddle.device.cuda.get_device_properties()
if prop.total_memory < args.pre_alloc_memory * 1024 * 1024 * 1024:
logger.warning(
"Invalid value for `pre_alloc_memory`, so pre-allocating just failed."
)
elif args.pre_alloc_memory > 0:
logger.warning(
f"pre-allocating a tensor whose memory capacity is {args.pre_alloc_memory} GB "
"and then release it."
)
memory_size = int(args.pre_alloc_memory * 1024 * 1024 * 1024)
x = paddle.empty([memory_size], dtype=paddle.uint8)
del x

# add fleet test
try:
collective_perf(
"allgather",
round=50,
size_and_time={67108864: 0.00625, 234881024: 0.02, 637534208: 0.057},
)
logger.info("======monitor allgather done!=======\n")
collective_perf(
"allreduce",
round=50,
size_and_time={67108864: 0.02, 134217728: 0.038, 268435456: 0.075},
)
logger.info("======monitor allreduce done!=======\n")
except Exception as e:
logger.warning(f"fleet test unexcepted error! skip exception[{e}]...")

# Detecting last checkpoint.
last_checkpoint = None
if (
os.path.isdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(args.output_dir)
if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0:
raise ValueError(
f"Output directory ({args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

# Define the metrics of tasks.
def compute_metrics(p):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions

output = paddle.to_tensor(preds)
labels = paddle.to_tensor(p.label_ids)
output = [t.astype("float32").cuda() for t in output]
labels = [t[t != tokenizer.ignored_index] for t in labels]
labels = [t.cuda() for t in labels]
all_numel = (
(paddle.concat(labels, 0) != tokenizer.ignored_index).astype("int64").sum()
)
ignored = (paddle.concat(labels, 0) == -100).astype("int64").sum()
labels = all_numel - ignored
output = sum(output)
logger.info(f"output : {output.item()}, labels : {labels.item()}")
nll_loss = output / (labels + 1.0e-6) # nll_loss is global loss
ppl = paddle.exp(nll_loss)

return {
"nll_loss": nll_loss.item(),
"ppl": ppl.item(),
"num_token": labels.item(),
}

# model
dtype = "float32"
if args.fp16 and args.fp16_opt_level == "O2":
paddle.set_default_dtype("float16")
dtype = "float16"
elif args.bf16:
paddle.set_default_dtype("bfloat16")
dtype = "bfloat16"

if args.use_moe:
global ErnieConfig, ErnieForCausalLMAuto
ErnieConfig = ErnieMoEConfig

if args.moe_group.lower() in {"mp", "tp", "model", "dummy"}:
logger.info(f"disable moe flag when using moe-group={args.moe_group}")
args.use_moe = False

args.multi_token_pred_depth = model_config.get("multi_token_pred_depth", 0)

cfg = ErnieConfig.from_pretrained(args.model_name_or_path)
cfg.seqlen = args.max_seq_length
cfg.token_balance_seqlen = args.max_seq_length * args.per_device_train_batch_size
cfg.fp16_opt_level = args.fp16_opt_level
cfg.moe_group = args.moe_group
cfg.dtype = dtype
cfg.pipeline_parallel_degree = args.pipeline_parallel_degree
cfg.virtual_pp_degree = args.virtual_pp_degree
if args.tensor_parallel_degree > 1:
cfg.sequence_parallel = args.sequence_parallel
cfg.tensor_parallel_degree = max(
fleet.get_hybrid_communicate_group().get_model_parallel_world_size(), 1
)
cfg.tensor_parallel_rank = max(
fleet.get_hybrid_communicate_group().get_model_parallel_rank(), 0
)
else:
cfg.sequence_parallel = False
cfg.tensor_parallel_degree = 1
cfg.tensor_parallel_rank = 0

cfg.micro_batch_size = args.per_device_train_batch_size
tokenizer = ErnieBotTokenizer.from_pretrained(args.tokenizer_name)
tokenizer.ignored_index = cfg.ignored_index
logger.info(
f"using tokenizer={type(tokenizer)}, bos:{tokenizer.bos_token_id} "
f"eos:{tokenizer.eos_token_id} pad:{tokenizer.pad_token_id} "
)

cfg = update_model_config_from_args(cfg, model_config)

if args.from_scratch:
with paddle.LazyGuard():
model = ErnieForCausalLMAuto(cfg)
else:
with paddle.LazyGuard():
model = ErnieForCausalLMAuto.from_pretrained(
args.model_name_or_path,
config=cfg,
)

cfg = model.config
logger.info(f"using model type:{type(model)}")
paddle.set_default_dtype("float32")

logger.info(f"using model={type(model)}, cfg={cfg}")

freeze_config = set(args.freeze_config.split(" "))
if "freeze_vision" in freeze_config and hasattr(model, "freeze_vision"):
logger.info("Freeze model vision module")
model.freeze_vision()

# data
logger.info("loading data...")
train_dataset, eval_dataset, test_dataset, data_collator = (
create_pretrained_dataset(args)
)

callbacks = []
callbacks += [GlobalRNGCallback()]

if args.batch_size_warmup_steps:
progreesive_batcing_callback = ProgreesiveBatchingCallback(
args.gradient_accumulation_steps,
args.max_gradient_accumulation_steps,
args.batch_size_warmup_steps,
args.batch_size_warmup_increment,
)
callbacks.append(progreesive_batcing_callback)

init_parameter(model)
model.apply(model.init_weights)
trainer = AutoPretrainingTrainer(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=callbacks,
)
global_training_logs.accumulate = args.gradient_accumulation_steps
checkpoint = None
if args.resume_from_checkpoint is not None:
checkpoint = args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint

# Training
if args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model(args.output_dir)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

# Evaluate and tests model
if args.do_eval:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)


if __name__ == "__main__":
main()
12 changes: 11 additions & 1 deletion examples/pre-training/ernie/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from .tensorboard_callback import TensorBoardCallback

from .gc_callback import GCCallback
from .progressive_batching_callback import ProgreesiveBatchingCallback
from .logging_callback import LoggingCallback
from .stopper_callback import StopperCallback
from .adaptivegradclip_callback import ClipGradByAdaptiveNormCallback

from .moe_correction_bias_adjust_callback import MoECorrectionBiasAdjustCallback
from .moe_logging_callback import GlobalRNGCallback, MoeLoggingCallback
from .sp_grad_sync_callback import SPGradSyncCallback
from .tensorboard_callback import TensorBoardCallback
from .fp8_quant_weight_callback import FP8QuantWeightCallback
from .ortho_loss_callback import OrthogonalCallback

Expand All @@ -31,4 +38,7 @@
"MoECorrectionBiasAdjustCallback",
"FP8QuantWeightCallback",
"OrthogonalCallback",
"ClipGradByAdaptiveNormCallback",
"StopperCallback",
"ProgreesiveBatchingCallback",
]
Loading
Loading