-
Notifications
You must be signed in to change notification settings - Fork 8k
Support loss_mask in dataset to control loss calculation for specific turns #9630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @CjangCjengh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the data processing pipeline by introducing a Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a useful loss_mask feature to selectively control loss calculation for different turns in a conversation. The implementation is mostly solid, with changes to data converters to extract the mask and to the supervised processor to apply it. I've found a critical bug in the supervised.py processor where the loss_mask is not handled correctly when mask_history is enabled, leading to masks being applied to the wrong turns. I've provided a detailed comment with a suggested fix that also refactors the code for better clarity and correctness. Once this issue is addressed, the PR should be in good shape.
hiyouga
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Could you please resolve the failed tests? |
| images=examples["_images"][i] or [], | ||
| videos=examples["_videos"][i] or [], | ||
| audios=examples["_audios"][i] or [], | ||
| loss_mask=example_loss_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we simply use examples["_loss_mask"][i] or []
| batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], [] | ||
| lengths = [] | ||
| length2indexes = defaultdict(list) | ||
| loss_masks = examples.get("_loss_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the _loss_mask is always in examples
| for mask_value, message in zip(loss_mask, prompt + response): | ||
| if message.get("role") == Role.ASSISTANT.value: | ||
| assistant_loss_mask.append(1 if mask_value else 0) | ||
| if len(assistant_loss_mask) != len(encoded_pairs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check is redundant
|
|
||
| assistant_loss_mask: Optional[list[int]] = None | ||
| if loss_mask is not None: | ||
| if len(loss_mask) != len(prompt) + len(response): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check is already executed in converter
| images: list["ImageInput"], | ||
| videos: list["VideoInput"], | ||
| audios: list["AudioInput"], | ||
| loss_mask: Optional[list[int]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list[int]
| target_label = target_ids | ||
|
|
||
| if assistant_loss_mask is not None and turn_idx < len(assistant_loss_mask): | ||
| if assistant_loss_mask[turn_idx] == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we merge L97 and L98?
|
|
||
| if assistant_loss_mask is not None and turn_idx < len(assistant_loss_mask): | ||
| if assistant_loss_mask[turn_idx] == 0: | ||
| target_label = [IGNORE_INDEX] * target_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hiyouga @CjangCjengh
Hi! I just read through this PR and noticed a potential issue when mask_history=True and loss_mask is also used. In that case, mask_history sets IGNORE_INDEX before loss_mask is applied, so the loss mask may not take effect as intended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, when loss_mask is used, mask_history shouldn’t be responsible for setting IGNORE_INDEX. The main thing we still need from mask_history is reversing the IDs to avoid truncating the last turn. That said, I’m not sure whether this default behavior should also apply to loss_mask, since loss_mask is a fairly flexible option with different possible use cases.
Description
This PR introduces a new feature that allows users to control which assistant turns should participate in loss calculation by providing a
loss_maskfield in the dataset. This is useful for scenarios where certain responses in a multi-turn conversation should be excluded from training (e.g., low-quality responses or context-only turns).Key Changes
src/llamafactory/data/converter.py:_extract_loss_maskmethod toDatasetConverterto parse and validate theloss_maskfield.AlpacaDatasetConverter,SharegptDatasetConverter, andOpenAIDatasetConverterto extract and pass_loss_mask.src/llamafactory/data/processor/supervised.py:SupervisedDatasetProcessor(andPackedSupervisedDatasetProcessor) to acceptloss_mask.IGNORE_INDEX) for assistant turns where the correspondingloss_maskvalue is 0 (or False).Usage
To use this feature, add a
loss_maskfield to your dataset entry. Theloss_maskmust be a list with a length equal to the total number of messages (user prompts + assistant responses).Example (ShareGPT format)
{ "conversations": [ {"from": "human", "value": "Question 1"}, {"from": "gpt", "value": "Answer 1 (Ignore)"}, {"from": "human", "value": "Question 2"}, {"from": "gpt", "value": "Answer 2 (Train)"} ], "loss_mask": [0, 0, 0, 1] }In this example:
0, so its labels will be set toIGNORE_INDEXand it will not contribute to the loss.1, so it will be used for training.Before submitting