-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPO #2544
base: main
Are you sure you want to change the base?
MPO #2544
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Here is a rough colab notebook that I have created. In the notebook:
Here are my queries:
|
trl/trainer/dpo_trainer.py
Outdated
losses, chosen_rewards, rejected_rewards = self.dpo_loss( | ||
model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps | ||
) | ||
if "," in self.loss_type: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally, at this point, we would have
self.loss_type
a list of strings (eg,[“sigmoid”, “bco_pair”]
)self.loss_type_to_weights
a dict ofstr, float
which for each loss type associates a weight.
Parsing for loss type could be done directly in the config, as here:
trl/trl/trainer/nash_md_config.py
Lines 34 to 46 in fe4b5ef
mixture_coef: list[float] = field( | |
default_factory=lambda: [0.5], | |
metadata={ | |
"help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " | |
"then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " | |
"rest of the epochs." | |
}, | |
) | |
def __post_init__(self): | |
super().__post_init__() | |
if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: | |
self.mixture_coef = self.mixture_coef[0] |
@qgallouedec Here is a rough colab notebook with the current MPO training. Let me know what you think. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Some initial comments :)
trl/trainer/dpo_config.py
Outdated
loss_weights: Optional[Dict[str, float]] = field( | ||
default_factory=lambda: ["your_values"] | ||
) | ||
loss_type: List[str] | str = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to check if this work with the parser and the cli
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could let me know how you might check it, I could do it and report back to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
running this should work:
trl dpo --output_dir tmp_dir --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_preference --report_to none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried the above command, which resulted in an issue with the self.loss_type
being None
.
Upon changing the command to
trl dpo --output_dir tmp_dir --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_preference --report_to none --loss-type sigmoid
it started to train. I am not sure why this happens 🤔
Really nice!!! But don't do this (important): # Apply the chat template
prompt = processor.apply_chat_template(prompt, tokenize=False)
chosen = processor.apply_chat_template(chosen, tokenize=False)
rejected = processor.apply_chat_template(rejected, tokenize=False) The DPO Trainer handle applying the chat template. See #1930 for more info. This code snippet is present in so many examples online, it's a scourge. 😩 |
trl/trainer/dpo_config.py
Outdated
@@ -15,7 +15,7 @@ | |||
import warnings | |||
from dataclasses import dataclass, field | |||
from enum import Enum | |||
from typing import Any, Callable, Optional, Union | |||
from typing import Any, Callable, Optional, Union, List, Dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use list
and dict
instead
loss_weights (`dict[str, float]` or `None`, *optional*, defaults to `None`): | ||
Use to weight a combination of losses. The keys must be in `loss_type`. By default (if not specified in the dict), | ||
the weight for a loss in loss_type is 1.0. | ||
loss_type (`str` or `list`, *optional*, defaults to `"sigmoid"`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also document that, when a list is passed, the loss is the sum of these values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add "sft"
as well
curr_losses, curr_chosen_rewards, curr_rejected_rewards = self.dpo_loss( | ||
curr_loss_type, model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps | ||
) | ||
curr_loss_weight = getattr(self.loss_weights, curr_loss_type, 1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
@qgallouedec I have resolved the merge conflicts and have also worked on the review suggestions. Could you help me with another round of review? If this looks good, I can start a small training run on VLM with MPO. WDYT? |
@qgallouedec a gentle ping here! |
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.