Skip to content

[feat] Implement Loss Masking to Exclude Predefined Token Spans from LM Loss #109

Closed
@tscholak

Description

@tscholak

🧐 Problem Description

In supervised fine-tuning (SFT) for decoder-only models, training data sequences typically include a mix of system prompts, user inputs, and model responses. However, the loss during fine-tuning should only be computed on the model responses. This practice, known as loss masking, has been shown to significantly improve model performance and generalization.

Currently, Fast-LLM does not support propagating token spans for loss masking from the dataset to the loss function, which limits its effectiveness for SFT. Implementing this feature requires:

  1. A mechanism to identify and propagate spans corresponding to excluded portions (e.g., system prompts and user inputs).
  2. Modifications to the training pipeline to exclude these spans during the loss computation.

💡 Proposed Solution

To support loss masking, we propose the following approaches, ranked by priority:


1. Enhanced Memory-Mapped Indexed Format (Preferred Approach)

  • Extend the existing memory-mapped dataset format to include optional fields for token spans that specify which portions of the input should be excluded from the loss computation.
  • Enhance the prepare command to:
    • Convert Hugging Face datasets into the memory-mapped format.
    • Extract and preprocess character spans into token spans during dataset preparation.
  • Assumptions:
    • The text field contains fully formatted prompts (system prompt, user prompt(s), and model response(s)), optionally annotated with special chat tokens or tags.
    • Character spans align cleanly with token spans.

Advantages:

  • Retains the scalability and efficiency of the memory-mapped dataset format.
  • Integrates seamlessly into Fast-LLM’s current workflow.
  • Supports both pretraining and SFT use cases.

Trade-offs:

  • Adds complexity to the dataset preparation process.
  • Assumes accurate character-to-token alignment, which may cause issues in edge cases.

2. JSONL Dataset Format

  • Use a simpler dataset format with JSONL files, where each entry includes:
    • "text": Fully formatted prompt containing system prompts, user inputs, and model responses.
    • "masked_spans": A list of character span tuples to exclude from the loss computation.
  • Create a dataset wrapper to read and process JSONL files during training.

Example JSONL Entry:

{
    "text": "<system>\nYou are an assistant that provides concise and accurate answers.\n</system>\n<user>\nWhat is the capital of France?\n</user>\n<assistant>\nThe capital of France is Paris.\n</assistant>\n<user>\nCan you tell me more about Paris?\n</user>\n<assistant>\nParis is the capital city of France, known for its art, culture, and history.\n</assistant>",
    "masked_spans": [
        [0, 61],   // System message
        [62, 91],  // User prompt: "What is the capital of France?"
        [116, 146], // User prompt: "Can you tell me more about Paris?"
    ]
}

Advantages:

  • Simplifies the pipeline, especially for smaller datasets.
  • Avoids complex preprocessing and format conversions.

Trade-offs:

  • Lacks scalability optimizations for larger datasets.
  • Requires tight control over dataset formatting to avoid inconsistencies.

3. Hugging Face Dataset Integration

  • Use Hugging Face datasets directly instead of the memory-mapped format.
  • Perform tokenization on-the-fly and dynamically extract spans for loss masking.
  • Require users to specify a unified chat format to delineate system prompts, user inputs, and model responses.

Advantages:

  • User-friendly and works well with smaller datasets already in Hugging Face format.
  • Reduces the need for format conversions.

Trade-offs:

  • On-the-fly tokenization introduces runtime overhead.
  • Not optimized for large-scale pretraining datasets.

Pipeline Integration

For all approaches, token spans for excluded portions of the input should be passed as part of the kwargs in the model’s forward function:

def forward(self, input_: torch.Tensor, kwargs: dict):
    masked_spans = kwargs.pop("masked_spans", [])
    ...

This masking logic ensures that the loss computation is limited to the relevant portions (e.g., model responses).


🔄 Alternatives Considered

  • Ignore Loss Masking: Simplifies implementation but results in suboptimal performance.
  • Invert the Mask: Specify spans to include in the loss computation instead of excluding spans. However, this approach risks producing NaN losses if no spans are provided.

📈 Potential Benefits

  • Improved Model Performance: Aligns loss computation with best practices in SFT, enhancing both performance and generalization.
  • Flexibility: Supports multiple dataset formats for diverse user needs.
  • Ease of Use: Automates span extraction, reducing user overhead.

📝 Additional Context

  • This feature is critical for SFT.
  • Reference PR for RLHF GRPO: #20.

Deliverables

  1. Code Implementation:
    • Extend dataset handling and pipeline logic for one or more proposed formats.
    • Modify the loss computation to respect token spans for loss masking.
    • Add tests.
  2. Documentation:
    • Add clear guidelines for preparing datasets for loss masking.
    • Provide examples for each supported format.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions