-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Pr adapt flex checkpoint #11065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Pr adapt flex checkpoint #11065
Conversation
Thanks for your contribution! |
LGTM |
paddlenlp/trainer/trainer.py
Outdated
@@ -949,14 +949,36 @@ def train( | |||
if delay_optimizer_creation: | |||
self.create_optimizer_and_scheduler(num_training_steps=max_steps) | |||
self._load_optimizer_and_scheduler(resume_from_checkpoint) | |||
else: | |||
elif not self.args.using_flex_checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要使用elif not
这种反逻辑做分支判断,后续增加其它分支会变得复杂。可以:
elif self.args.using_flex_checkpoint:
load_from_flex_checkpoint
else:
load_from_default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddlenlp/trainer/trainer_utils.py
Outdated
@@ -53,6 +53,21 @@ | |||
from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool | |||
from .utils.helper import distributed_file | |||
|
|||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要try逻辑?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
仔细斟酌了一下,发现这里不需要try,已去掉
paddlenlp/trainer/training_args.py
Outdated
@@ -407,6 +407,10 @@ class TrainingArguments: | |||
Whether to release gradients during training. Default is `False`. | |||
ckpt_quant_stage (`str`, *optional*): | |||
Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0). | |||
using_flex_checkpoint(`bool`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
考虑与sharding_io的互转,开关是否应该区分save和load?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
/re-run all-failed |
40c375d
to
cfc3e7f
Compare
cfc3e7f
to
92d9b66
Compare
PR types
New features
PR changes
Others
Description
适配flex_checkpoint,修复dp hang住的bug