Description
🧐 Problem Description
In supervised fine-tuning (SFT) for decoder-only models, training data sequences typically include a mix of system prompts, user inputs, and model responses. However, the loss during fine-tuning should only be computed on the model responses. This practice, known as loss masking, has been shown to significantly improve model performance and generalization.
Currently, Fast-LLM does not support propagating token spans for loss masking from the dataset to the loss function, which limits its effectiveness for SFT. Implementing this feature requires:
- A mechanism to identify and propagate spans corresponding to excluded portions (e.g., system prompts and user inputs).
- Modifications to the training pipeline to exclude these spans during the loss computation.
💡 Proposed Solution
To support loss masking, we propose the following approaches, ranked by priority:
1. Enhanced Memory-Mapped Indexed Format (Preferred Approach)
- Extend the existing memory-mapped dataset format to include optional fields for token spans that specify which portions of the input should be excluded from the loss computation.
- Enhance the
prepare
command to:- Convert Hugging Face datasets into the memory-mapped format.
- Extract and preprocess character spans into token spans during dataset preparation.
- Assumptions:
- The
text
field contains fully formatted prompts (system prompt, user prompt(s), and model response(s)), optionally annotated with special chat tokens or tags. - Character spans align cleanly with token spans.
- The
Advantages:
- Retains the scalability and efficiency of the memory-mapped dataset format.
- Integrates seamlessly into Fast-LLM’s current workflow.
- Supports both pretraining and SFT use cases.
Trade-offs:
- Adds complexity to the dataset preparation process.
- Assumes accurate character-to-token alignment, which may cause issues in edge cases.
2. JSONL Dataset Format
- Use a simpler dataset format with JSONL files, where each entry includes:
"text"
: Fully formatted prompt containing system prompts, user inputs, and model responses."masked_spans"
: A list of character span tuples to exclude from the loss computation.
- Create a dataset wrapper to read and process JSONL files during training.
Example JSONL Entry:
{
"text": "<system>\nYou are an assistant that provides concise and accurate answers.\n</system>\n<user>\nWhat is the capital of France?\n</user>\n<assistant>\nThe capital of France is Paris.\n</assistant>\n<user>\nCan you tell me more about Paris?\n</user>\n<assistant>\nParis is the capital city of France, known for its art, culture, and history.\n</assistant>",
"masked_spans": [
[0, 61], // System message
[62, 91], // User prompt: "What is the capital of France?"
[116, 146], // User prompt: "Can you tell me more about Paris?"
]
}
Advantages:
- Simplifies the pipeline, especially for smaller datasets.
- Avoids complex preprocessing and format conversions.
Trade-offs:
- Lacks scalability optimizations for larger datasets.
- Requires tight control over dataset formatting to avoid inconsistencies.
3. Hugging Face Dataset Integration
- Use Hugging Face datasets directly instead of the memory-mapped format.
- Perform tokenization on-the-fly and dynamically extract spans for loss masking.
- Require users to specify a unified chat format to delineate system prompts, user inputs, and model responses.
Advantages:
- User-friendly and works well with smaller datasets already in Hugging Face format.
- Reduces the need for format conversions.
Trade-offs:
- On-the-fly tokenization introduces runtime overhead.
- Not optimized for large-scale pretraining datasets.
Pipeline Integration
For all approaches, token spans for excluded portions of the input should be passed as part of the kwargs
in the model’s forward function:
def forward(self, input_: torch.Tensor, kwargs: dict):
masked_spans = kwargs.pop("masked_spans", [])
...
This masking logic ensures that the loss computation is limited to the relevant portions (e.g., model responses).
🔄 Alternatives Considered
- Ignore Loss Masking: Simplifies implementation but results in suboptimal performance.
- Invert the Mask: Specify spans to include in the loss computation instead of excluding spans. However, this approach risks producing NaN losses if no spans are provided.
📈 Potential Benefits
- Improved Model Performance: Aligns loss computation with best practices in SFT, enhancing both performance and generalization.
- Flexibility: Supports multiple dataset formats for diverse user needs.
- Ease of Use: Automates span extraction, reducing user overhead.
📝 Additional Context
- This feature is critical for SFT.
- Reference PR for RLHF GRPO: #20.
Deliverables
- Code Implementation:
- Extend dataset handling and pipeline logic for one or more proposed formats.
- Modify the loss computation to respect token spans for loss masking.
- Add tests.
- Documentation:
- Add clear guidelines for preparing datasets for loss masking.
- Provide examples for each supported format.