You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In supervised fine-tuning (SFT) for decoder-only models, training data sequences typically include a mix of system prompts, user inputs, and model responses. However, the loss during fine-tuning should only be computed on the model responses. This practice, known as loss masking, has been shown to significantly improve model performance and generalization.
Currently, Fast-LLM does not support propagating token spans for loss masking from the dataset to the loss function, which limits its effectiveness for SFT. Implementing this feature requires:
A mechanism to identify and propagate spans corresponding to excluded portions (e.g., system prompts and user inputs).
Modifications to the training pipeline to exclude these spans during the loss computation.
💡 Proposed Solution
To support loss masking, we propose the following approaches, ranked by priority:
1. Enhanced Memory-Mapped Indexed Format (Preferred Approach)
Extend the existing memory-mapped dataset format to include optional fields for token spans that specify which portions of the input should be excluded from the loss computation.
Enhance the prepare command to:
Convert Hugging Face datasets into the memory-mapped format.
Extract and preprocess character spans into token spans during dataset preparation.
Assumptions:
The text field contains fully formatted prompts (system prompt, user prompt(s), and model response(s)), optionally annotated with special chat tokens or tags.
Character spans align cleanly with token spans.
Advantages:
Retains the scalability and efficiency of the memory-mapped dataset format.
Integrates seamlessly into Fast-LLM’s current workflow.
Supports both pretraining and SFT use cases.
Trade-offs:
Adds complexity to the dataset preparation process.
Assumes accurate character-to-token alignment, which may cause issues in edge cases.
2. JSONL Dataset Format
Use a simpler dataset format with JSONL files, where each entry includes:
"text": Fully formatted prompt containing system prompts, user inputs, and model responses.
"masked_spans": A list of character span tuples to exclude from the loss computation.
Create a dataset wrapper to read and process JSONL files during training.
Example JSONL Entry:
{
"text": "<system>\nYou are an assistant that provides concise and accurate answers.\n</system>\n<user>\nWhat is the capital of France?\n</user>\n<assistant>\nThe capital of France is Paris.\n</assistant>\n<user>\nCan you tell me more about Paris?\n</user>\n<assistant>\nParis is the capital city of France, known for its art, culture, and history.\n</assistant>",
"masked_spans": [
[0, 61], // System message
[62, 91], // User prompt: "What is the capital of France?"
[116, 146], // User prompt: "Can you tell me more about Paris?"
]
}
Advantages:
Simplifies the pipeline, especially for smaller datasets.
Avoids complex preprocessing and format conversions.
Trade-offs:
Lacks scalability optimizations for larger datasets.
Requires tight control over dataset formatting to avoid inconsistencies.
3. Hugging Face Dataset Integration
Use Hugging Face datasets directly instead of the memory-mapped format.
Perform tokenization on-the-fly and dynamically extract spans for loss masking.
Require users to specify a unified chat format to delineate system prompts, user inputs, and model responses.
Advantages:
User-friendly and works well with smaller datasets already in Hugging Face format.
This masking logic ensures that the loss computation is limited to the relevant portions (e.g., model responses).
🔄 Alternatives Considered
Ignore Loss Masking: Simplifies implementation but results in suboptimal performance.
Invert the Mask: Specify spans to include in the loss computation instead of excluding spans. However, this approach risks producing NaN losses if no spans are provided.
📈 Potential Benefits
Improved Model Performance: Aligns loss computation with best practices in SFT, enhancing both performance and generalization.
Flexibility: Supports multiple dataset formats for diverse user needs.
Ease of Use: Automates span extraction, reducing user overhead.
I agree with the preferred approach as well. As @tscholak mentioned, the input for the prepare command would be a dataset containing a text field and a list of character span tuples indicating the tokens for loss computation.
The prepare command would:
Tokenize the text
Convert the character spans to token spans
Add two extra fields per document in the .idx files:
number of spans
a list of spans, where each span contains the start and end position ids
In absence of any character spans, the default behaviour would be to compute loss on all tokens i.e., character span would be (0, len(text) - 1). This would ensure backwards compatibility
Re: exact alignment between character span boundaries and token boundaries. After some thinking, i believe this cannot and should not be generally assumed. Instead, what we should do is tokenize every segment that results from applying the spans to the text individually. This will be much more useful and correct because it will align how the model will have to consume and generate tokens at runtime: model prompts will end with the last character in a masked segment, and generation will start with the first character of an unmasked segment.
🧐 Problem Description
In supervised fine-tuning (SFT) for decoder-only models, training data sequences typically include a mix of system prompts, user inputs, and model responses. However, the loss during fine-tuning should only be computed on the model responses. This practice, known as loss masking, has been shown to significantly improve model performance and generalization.
Currently, Fast-LLM does not support propagating token spans for loss masking from the dataset to the loss function, which limits its effectiveness for SFT. Implementing this feature requires:
💡 Proposed Solution
To support loss masking, we propose the following approaches, ranked by priority:
1. Enhanced Memory-Mapped Indexed Format (Preferred Approach)
prepare
command to:text
field contains fully formatted prompts (system prompt, user prompt(s), and model response(s)), optionally annotated with special chat tokens or tags.Advantages:
Trade-offs:
2. JSONL Dataset Format
"text"
: Fully formatted prompt containing system prompts, user inputs, and model responses."masked_spans"
: A list of character span tuples to exclude from the loss computation.Example JSONL Entry:
Advantages:
Trade-offs:
3. Hugging Face Dataset Integration
Advantages:
Trade-offs:
Pipeline Integration
For all approaches, token spans for excluded portions of the input should be passed as part of the
kwargs
in the model’s forward function:This masking logic ensures that the loss computation is limited to the relevant portions (e.g., model responses).
🔄 Alternatives Considered
📈 Potential Benefits
📝 Additional Context
Deliverables
The text was updated successfully, but these errors were encountered: