-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Implement Loss Masking to Exclude Predefined Token Spans from LM Loss #109
Comments
I agree with the preferred approach as well. As @tscholak mentioned, the input for the
|
Thanks @sohamparikh!
Let's be more clear about this. Two options:
I have a preference for 2, because if the list is empty, the default is clearly to include everything. |
Re: exact alignment between character span boundaries and token boundaries. After some thinking, i believe this cannot and should not be generally assumed. Instead, what we should do is tokenize every segment that results from applying the spans to the text individually. This will be much more useful and correct because it will align how the model will have to consume and generate tokens at runtime: model prompts will end with the last character in a masked segment, and generation will start with the first character of an unmasked segment. |
Just saw your comment, I came to the same conclusion and implemented this (changes are still WIP and untested). |
🧐 Problem Description
In supervised fine-tuning (SFT) for decoder-only models, training data sequences typically include a mix of system prompts, user inputs, and model responses. However, the loss during fine-tuning should only be computed on the model responses. This practice, known as loss masking, has been shown to significantly improve model performance and generalization.
Currently, Fast-LLM does not support propagating token spans for loss masking from the dataset to the loss function, which limits its effectiveness for SFT. Implementing this feature requires:
💡 Proposed Solution
To support loss masking, we propose the following approaches, ranked by priority:
1. Enhanced Memory-Mapped Indexed Format (Preferred Approach)
prepare
command to:text
field contains fully formatted prompts (system prompt, user prompt(s), and model response(s)), optionally annotated with special chat tokens or tags.Advantages:
Trade-offs:
2. JSONL Dataset Format
"text"
: Fully formatted prompt containing system prompts, user inputs, and model responses."masked_spans"
: A list of character span tuples to exclude from the loss computation.Example JSONL Entry:
Advantages:
Trade-offs:
3. Hugging Face Dataset Integration
Advantages:
Trade-offs:
Pipeline Integration
For all approaches, token spans for excluded portions of the input should be passed as part of the
kwargs
in the model’s forward function:This masking logic ensures that the loss computation is limited to the relevant portions (e.g., model responses).
🔄 Alternatives Considered
📈 Potential Benefits
📝 Additional Context
Deliverables
The text was updated successfully, but these errors were encountered: