Skip to content

Conversation

@trevor-m
Copy link
Collaborator

Motivation

#10758 introduced a change to only apply w13 -> w31 weight mapping with flashinfer_trtllm for NVFP4 quantization, however FP8 quantization is also supported and needs this mapping too, otherwise GSM8K accuracy went to 0.

Modifications

Apply w13 -> w31 for FP8 quantization also.

Accuracy Tests

Command

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-0528  --trust-remote-code --tp 8 --moe-runner-backend flashinfer_trtllm --quantization fp8
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port=30000

Before fix

Accuracy: 0.025
Invalid: 0.067
Latency: 73.347 s
Output throughput: 9207.244 token/s

After fix

Accuracy: 0.959
Invalid: 0.000
Latency: 25.951 s
Output throughput: 5488.445 token/s

Benchmarking and Profiling

N/A

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @trevor-m, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical accuracy regression in the flashinfer_trtllm MoE implementation when utilizing FP8 quantization. It rectifies an oversight where a necessary w13 -> w31 weight mapping was exclusively applied for NVFP4 quantization, leading to near-zero accuracy for FP8 models. The changes ensure this crucial mapping is correctly applied for FP8, thereby restoring high accuracy and improving the overall performance and reliability of FP8 quantized models.

Highlights

  • FP8 Quantization Support: Extended the w13 -> w31 weight mapping logic to include FP8 quantization for flashinfer_trtllm MoE, which was previously only applied to NVFP4. This ensures correct weight handling for FP8 models.
  • Accuracy Restoration: Resolved a critical accuracy issue for FP8 quantized DSR1 models using flashinfer_trtllm MoE, improving GSM8K accuracy from 0.025 to 0.959, as demonstrated by provided test results.
  • Configuration Assertion Update: Updated the server_args assertion message to correctly reflect that both modelopt_fp4 and fp8 quantizations are supported for flashinfer_trtllm MoE, preventing misleading error messages.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@trevor-m trevor-m force-pushed the fi-trtllm-moe-fp8-acc branch from fe5b41a to e79924e Compare September 29, 2025 21:53
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a critical accuracy issue for DSR1 models using MoE with FP8 quantization and the flashinfer_trtllm backend. The change extends the w13 -> w31 weight mapping to ModelOptFp8MoEMethod, which was previously only applied for ModelOptNvFp4FusedMoEMethod. The fix is well-motivated, and the provided accuracy metrics demonstrate a significant improvement. The code change is logical and improves the robustness of the type checking. I have one minor suggestion to further improve code readability.

@trevor-m
Copy link
Collaborator Author

cc @zhyncs

@LorrinWWW Can you check if this is ok for the problem you fixed in #10758

@trevor-m trevor-m force-pushed the fi-trtllm-moe-fp8-acc branch from e79924e to 507332f Compare September 29, 2025 21:56
@zhyncs zhyncs merged commit a6cc86d into sgl-project:main Sep 30, 2025
70 of 80 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants