-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Rearrange DPOTrainer #3501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rearrange DPOTrainer #3501
Conversation
Thanks! I like it! Please let us know when it's ready for review! |
Thanks for your attention! It's ready for review. cc @qgallouedec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's see if the CI passes! Ideally we want to have something very close to SFT, and this PR is a good move in this direction
trl/trainer/dpo_trainer.py
Outdated
if model is None: | ||
raise ValueError("No model provided. Please provide a model to train.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this, and replace
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ model:Union[PreTrainedModel, nn.Module, str],
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
trl/trainer/dpo_trainer.py
Outdated
if processing_class is None: | ||
raise ValueError("processing_class must be specified to tokenize a DPO dataset.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be nice to have instead
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model_id)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Thank you for your suggestions! I've updated the mentioned parts and more to sync the |
What does this PR do?
Rearranged the code structure in
DPOTrainer
. Details include:_create_model_from_path
,_prepare_peft_model
, and_prepare_gradient_checkpointing
, while keeping the original logic unchanged. This follows the implementation inSFTTrainer
, which also splits model preparation into various functional parts for clarity.disable_dropout_in_model(model)
andmodel.warnings_issued["estimate_tokens"]
ahead so that all operations on the model are grouped together. This is safe and would not influence the execution.model
andprocessing_class
. Now we can passNone
as the value formodel
andprocessing_class
inDPOTrainer
. This feature is adapted fromSFTTrainer
.SFTTrainer
.SFTTrainer
.Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.