Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 132 additions & 36 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
get_last_checkpoint,
get_scheduler,
has_length,
init_optimizer,
set_seed,
should_skip_data,
speed_metrics,
Expand Down Expand Up @@ -199,7 +200,6 @@
if is_datasets_available():
import datasets


try:
from paddle.distributed.fleet.utils import mix_precision_utils
except:
Expand Down Expand Up @@ -812,6 +812,10 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
if resume_from_checkpoint is not None:
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
else:
success, err_msg = True, None
if self.args.zcc_save_ema_coef is not None and self.sharding_io is not None:
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
else:
Expand All @@ -822,6 +826,11 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
self.zcc_manager.set_ema_state_dict(path)
else:
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
if success:
logger.info(f"ZCC EMA load from {path}")
self.zcc_manager.set_ema_state_dict(path)
else:
logger.info(f"ZCC EMA does not load {path} because {err_msg}")
else:
logger.info(f"ZCC EMA state dict not found, in: {path}")

Expand Down Expand Up @@ -929,13 +938,13 @@ def train(
self._memory_tracker.start()

if not self.args.enable_auto_parallel:
if not self.args.should_load_sharding_stage1_model:
if not self.args.should_load_sharding_stage1_model and not self.args.load_flex_checkpoint:
self._load_from_checkpoint(resume_from_checkpoint)

if self.args.should_load_sharding_stage1_model:
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)

elif self.args.should_save_sharding_stage1_model:
elif self.args.should_save_sharding_stage1_model and not self.args.load_flex_checkpoint:
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
model = self._wrap_model(self.model_wrapped)
Expand All @@ -949,13 +958,43 @@ def train(
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
elif self.args.load_flex_checkpoint:
model = self._wrap_model(self.model_wrapped)
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
if not self.args.ignore_load_lr_and_optim:
model_sharded_state_dict = self.model.sharded_state_dict()
accessible_files = os.listdir(resume_from_checkpoint)
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
assert len(metadata_files) == 1, "Only support one metadata file now."
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
state_dict_metadata = metadata.state_dict_metadata
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
self._load_scheduler(resume_from_checkpoint)
else:
model_sharded_state_dict = self.model.sharded_state_dict()
sharded_state_dict = model_sharded_state_dict
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
else:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
model = self.model_wrapped
Expand Down Expand Up @@ -1357,6 +1396,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
logger.warning(
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
)

elif isinstance(self.optimizer, HybridParallelOptimizer):
self.optimizer._step(parameters_list)
else:
Expand Down Expand Up @@ -1993,7 +2033,6 @@ def apply_decay_param_fun(x):
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
**optimizer_kwargs,
)

return self.optimizer

def _apply_to_optimizer(self, action):
Expand Down Expand Up @@ -2033,6 +2072,13 @@ def _load_rng_state(self, checkpoint):
return

rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth")
if not os.path.isfile(rng_file):
logger.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth")
if not os.path.isfile(rng_file):
logger.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
Expand Down Expand Up @@ -2238,7 +2284,6 @@ def _wrap_model(self, model, training=True):
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
assert self.optimizer is not None, "optimizer is empty!"
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)

# Pipeline mode
if in_pipeline_parallel_mode:
if self.args.amp_master_grad:
Expand Down Expand Up @@ -2288,15 +2333,13 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if (
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and self.args.unified_checkpoint
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model.register_sharding_comm_overlap_hook(self.optimizer)

# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
# Sharded DDP!
Expand All @@ -2310,7 +2353,6 @@ def get_expected_keys(inputs, keys):
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
model, hcg, strategy=fleet.fleet._user_defined_strategy
)

if ShardingOption.SHARD_OP in self.args.sharding:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
Expand Down Expand Up @@ -2352,6 +2394,7 @@ def get_expected_keys(inputs, keys):
offload=cpu_offload,
**extra_kwargs,
)

if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
assert hasattr(optimizer, "use_main_grad"), (
"Current installed paddle doesn't support sharding stage 2 with main grad, "
Expand All @@ -2377,7 +2420,6 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
Expand Down Expand Up @@ -2720,6 +2762,10 @@ def _save_checkpoint(self, model, metrics=None):
else:
self.save_model(output_dir)

if self.args.save_flex_checkpoint:
model_sharded_state_dict = self.model.sharded_state_dict()
os.makedirs(output_dir, exist_ok=True)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
Expand Down Expand Up @@ -2779,23 +2825,38 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir,
)
else:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
if self.args.save_flex_checkpoint:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)

if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(
os.getenv("FLAG_LLM_PDC", "False")
), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)

else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
Expand All @@ -2806,7 +2867,12 @@ def _save_checkpoint(self, model, metrics=None):
or "remove_master_weight" not in self.args.unified_checkpoint_config
):
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))
if self.args.should_save or self.args.use_expert_parallel:

if (
self.args.should_save
or self.args.use_expert_parallel
or (self.args.data_parallel_degree > 1 and self.args.save_flex_checkpoint)
):
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2816,6 +2882,17 @@ def _save_checkpoint(self, model, metrics=None):
output_dir,
signal_dir,
)
elif self.args.save_flex_checkpoint:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
Expand Down Expand Up @@ -2849,7 +2926,17 @@ def _save_checkpoint(self, model, metrics=None):

if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
self._offload_optimizer()

else:
if self.args.save_flex_checkpoint:
dist.save_state_dict(
model_sharded_state_dict,
output_dir,
)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
self.runtime_timer.stop()

# Maybe delete some older checkpoints.
Expand Down Expand Up @@ -3064,6 +3151,7 @@ def _save(
else:
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
config_to_save = None
self.sharding_io.set_optimizer(self.optimizer)
state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config(
self.model, merge_tensor_parallel=merge_tensor_parallel
)
Expand Down Expand Up @@ -3093,6 +3181,24 @@ def _save(
with open(path, "w") as f:
json.dump(model_meta, f)

def _load_scheduler(self, checkpoint):
if checkpoint is None:
self.runtime_timer.stop()
return

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
self.runtime_timer.start("checkpoint loading time")
Expand Down Expand Up @@ -3134,6 +3240,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model = self.model_wrapped

opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
model=model,
optimizer=self.optimizer,
Expand Down Expand Up @@ -3165,18 +3272,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)
self._load_scheduler(checkpoint)

if self.args.offload_optim:
logger.info("Offloading optimizer state...")
Expand Down
Loading
Loading