Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Implement Loss Masking to Exclude Predefined Token Spans from LM Loss #109

Open
tscholak opened this issue Jan 8, 2025 · 4 comments · May be fixed by #113
Open

[feat] Implement Loss Masking to Exclude Predefined Token Spans from LM Loss #109

tscholak opened this issue Jan 8, 2025 · 4 comments · May be fixed by #113
Labels
enhancement New feature or request
Milestone

Comments

@tscholak
Copy link
Collaborator

tscholak commented Jan 8, 2025

🧐 Problem Description

In supervised fine-tuning (SFT) for decoder-only models, training data sequences typically include a mix of system prompts, user inputs, and model responses. However, the loss during fine-tuning should only be computed on the model responses. This practice, known as loss masking, has been shown to significantly improve model performance and generalization.

Currently, Fast-LLM does not support propagating token spans for loss masking from the dataset to the loss function, which limits its effectiveness for SFT. Implementing this feature requires:

  1. A mechanism to identify and propagate spans corresponding to excluded portions (e.g., system prompts and user inputs).
  2. Modifications to the training pipeline to exclude these spans during the loss computation.

💡 Proposed Solution

To support loss masking, we propose the following approaches, ranked by priority:


1. Enhanced Memory-Mapped Indexed Format (Preferred Approach)

  • Extend the existing memory-mapped dataset format to include optional fields for token spans that specify which portions of the input should be excluded from the loss computation.
  • Enhance the prepare command to:
    • Convert Hugging Face datasets into the memory-mapped format.
    • Extract and preprocess character spans into token spans during dataset preparation.
  • Assumptions:
    • The text field contains fully formatted prompts (system prompt, user prompt(s), and model response(s)), optionally annotated with special chat tokens or tags.
    • Character spans align cleanly with token spans.

Advantages:

  • Retains the scalability and efficiency of the memory-mapped dataset format.
  • Integrates seamlessly into Fast-LLM’s current workflow.
  • Supports both pretraining and SFT use cases.

Trade-offs:

  • Adds complexity to the dataset preparation process.
  • Assumes accurate character-to-token alignment, which may cause issues in edge cases.

2. JSONL Dataset Format

  • Use a simpler dataset format with JSONL files, where each entry includes:
    • "text": Fully formatted prompt containing system prompts, user inputs, and model responses.
    • "masked_spans": A list of character span tuples to exclude from the loss computation.
  • Create a dataset wrapper to read and process JSONL files during training.

Example JSONL Entry:

{
    "text": "<system>\nYou are an assistant that provides concise and accurate answers.\n</system>\n<user>\nWhat is the capital of France?\n</user>\n<assistant>\nThe capital of France is Paris.\n</assistant>\n<user>\nCan you tell me more about Paris?\n</user>\n<assistant>\nParis is the capital city of France, known for its art, culture, and history.\n</assistant>",
    "masked_spans": [
        [0, 61],   // System message
        [62, 91],  // User prompt: "What is the capital of France?"
        [116, 146], // User prompt: "Can you tell me more about Paris?"
    ]
}

Advantages:

  • Simplifies the pipeline, especially for smaller datasets.
  • Avoids complex preprocessing and format conversions.

Trade-offs:

  • Lacks scalability optimizations for larger datasets.
  • Requires tight control over dataset formatting to avoid inconsistencies.

3. Hugging Face Dataset Integration

  • Use Hugging Face datasets directly instead of the memory-mapped format.
  • Perform tokenization on-the-fly and dynamically extract spans for loss masking.
  • Require users to specify a unified chat format to delineate system prompts, user inputs, and model responses.

Advantages:

  • User-friendly and works well with smaller datasets already in Hugging Face format.
  • Reduces the need for format conversions.

Trade-offs:

  • On-the-fly tokenization introduces runtime overhead.
  • Not optimized for large-scale pretraining datasets.

Pipeline Integration

For all approaches, token spans for excluded portions of the input should be passed as part of the kwargs in the model’s forward function:

def forward(self, input_: torch.Tensor, kwargs: dict):
    masked_spans = kwargs.pop("masked_spans", [])
    ...

This masking logic ensures that the loss computation is limited to the relevant portions (e.g., model responses).


🔄 Alternatives Considered

  • Ignore Loss Masking: Simplifies implementation but results in suboptimal performance.
  • Invert the Mask: Specify spans to include in the loss computation instead of excluding spans. However, this approach risks producing NaN losses if no spans are provided.

📈 Potential Benefits

  • Improved Model Performance: Aligns loss computation with best practices in SFT, enhancing both performance and generalization.
  • Flexibility: Supports multiple dataset formats for diverse user needs.
  • Ease of Use: Automates span extraction, reducing user overhead.

📝 Additional Context

  • This feature is critical for SFT.
  • Reference PR for RLHF GRPO: #20.

Deliverables

  1. Code Implementation:
    • Extend dataset handling and pipeline logic for one or more proposed formats.
    • Modify the loss computation to respect token spans for loss masking.
    • Add tests.
  2. Documentation:
    • Add clear guidelines for preparing datasets for loss masking.
    • Provide examples for each supported format.
@tscholak tscholak added the enhancement New feature or request label Jan 8, 2025
@tscholak tscholak added this to the 0.3.0 milestone Jan 8, 2025
@sohamparikh
Copy link
Member

I agree with the preferred approach as well. As @tscholak mentioned, the input for the prepare command would be a dataset containing a text field and a list of character span tuples indicating the tokens for loss computation.
The prepare command would:

  1. Tokenize the text
  2. Convert the character spans to token spans
  3. Add two extra fields per document in the .idx files:
    • number of spans
    • a list of spans, where each span contains the start and end position ids
  4. In absence of any character spans, the default behaviour would be to compute loss on all tokens i.e., character span would be (0, len(text) - 1). This would ensure backwards compatibility

@tscholak
Copy link
Collaborator Author

tscholak commented Jan 9, 2025

Thanks @sohamparikh!

a list of character span tuples indicating the tokens for loss computation.

Let's be more clear about this. Two options:

  1. the spans define which tokens will be included in the loss (positive mask).
  2. the spans define which tokens will be excluded in the loss (negative mask).

I have a preference for 2, because if the list is empty, the default is clearly to include everything.

@tscholak
Copy link
Collaborator Author

Re: exact alignment between character span boundaries and token boundaries. After some thinking, i believe this cannot and should not be generally assumed. Instead, what we should do is tokenize every segment that results from applying the spans to the text individually. This will be much more useful and correct because it will align how the model will have to consume and generate tokens at runtime: model prompts will end with the last character in a masked segment, and generation will start with the first character of an unmasked segment.

@sohamparikh
Copy link
Member

Just saw your comment, I came to the same conclusion and implemented this (changes are still WIP and untested).
9367fcd

@sohamparikh sohamparikh linked a pull request Jan 15, 2025 that will close this issue
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants