Skip to content

Conversation

@yuan-luo
Copy link
Collaborator

@yuan-luo yuan-luo commented Sep 29, 2025

Motivation

The deterministic feature introduced float64 data type in MLA test. The current moe_sum_reduce cuda kernel does not cover this data type. So as when using this new cuda kernel like in #10654 , the following srt test failure occurred:

python test/srt/hicache/test_hicache_mla.py
...
[2025-09-29 02:36:08] INFO:     127.0.0.1:35010 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35022 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35044 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-29 02:36:08] INFO:     127.0.0.1:35068 - "POST /v1/chat/completions HTTP/1.1" 200 OK
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:30<00:00,  2.12it/s]
Writing report to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.html
{'other': np.float64(0.0), 'other:std': np.float64(0.0), 'score:std': np.float64(0.0), 'stem': np.float64(0.0), 'stem:std': np.float64(0.0), 'humanities': np.float64(0.0), 'humanities:std': np.float64(0.0), 'social_sciences': np.float64(0.0), 'social_sciences:std': np.float64(0.0), 'score': np.float64(0.0)}
Writing results to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.json
Total latency: 30.266 s
Score: 0.000
E
======================================================================
ERROR: test_mgsm_en (__main__.TestHierarchicalMLA)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/utils.py", line 2259, in retry
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/sglang/test/test_utils.py", line 1437, in <lambda>
    lambda: super(CustomTestCase, self)._callTestMethod(method),
AssertionError: np.float64(0.0) not greater than 0.8

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/sglang/test/test_utils.py", line 1436, in _callTestMethod
    retry(
  File "/usr/local/lib/python3.10/dist-packages/sglang/srt/utils.py", line 2262, in retry
    raise Exception(f"retry() exceed maximum number of retries.")
Exception: retry() exceed maximum number of retries.

This PR is to fix this problem, supporting bfloat16, float32 and float64. It also fixes the unsupported runtime topk_num in the previous version, which enters a dead-end when topk_num is not 2/4/8/9 and returns unexpected result.
With the fix the test case passed.

$python test/srt/hicache/test_hicache_mla.py
Auto-configed device: cuda
command=python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote-code --hicache-ratio 2 --device cuda --host 127.0.0.1 --port 8000
INFO 09-30 11:16:21 [__init__.py:216] Automatically detected platform cuda.
All deep_gemm operations loaded successfully!
W0930 11:16:21.852000 351367 site-packages/torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0930 11:16:21.852000 351367 site-packages/torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
config.json: 1.52kB [00:00, 7.95MB/s]
configuration_deepseek.py: 10.3kB [00:00, 44.8MB/s]
......

 58%|██████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                       | 37/64 [00:08<00:03,  6.83it/s][2025-09-30 11:18:13] Decode batch. #running-req: 9, #token: 3558, token usage: 0.00, cuda graph: True, gen throughput (token/s): 980.25, #queue-req: 0, 
[2025-09-30 11:18:13] INFO:     127.0.0.1:35528 - "POST /v1/chat/completions HTTP/1.1" 200 OK
 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                               | 52/64 [00:09<00:01, 10.24it/s][2025-09-30 11:18:13] Decode batch. #running-req: 8, #token: 3367, token usage: 0.00, cuda graph: True, gen throughput (token/s): 980.47, #queue-req: 0, 
[2025-09-30 11:18:13] INFO:     127.0.0.1:35588 - "POST /v1/chat/completions HTTP/1.1" 200 OK
 84%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                          | 54/64 [00:09<00:01,  9.94it/s][2025-09-30 11:18:13] INFO:     127.0.0.1:35484 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-30 11:18:13] INFO:     127.0.0.1:35446 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-30 11:18:13] Decode batch. #running-req: 5, #token: 2306, token usage: 0.00, cuda graph: True, gen throughput (token/s): 778.97, #queue-req: 0, 
[2025-09-30 11:18:13] INFO:     127.0.0.1:35368 - "POST /v1/chat/completions HTTP/1.1" 200 OK
 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋             | 59/64 [00:09<00:00, 10.59it/s][2025-09-30 11:18:14] Decode batch. #running-req: 4, #token: 1995, token usage: 0.00, cuda graph: True, gen throughput (token/s): 621.07, #queue-req: 0, 
[2025-09-30 11:18:14] INFO:     127.0.0.1:35518 - "POST /v1/chat/completions HTTP/1.1" 200 OK
 95%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████        | 61/64 [00:10<00:00,  9.69it/s][2025-09-30 11:18:14] INFO:     127.0.0.1:35386 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-30 11:18:14] Decode batch. #running-req: 2, #token: 952, token usage: 0.00, cuda graph: True, gen throughput (token/s): 574.43, #queue-req: 0, 
[2025-09-30 11:18:14] Decode batch. #running-req: 2, #token: 544, token usage: 0.00, cuda graph: True, gen throughput (token/s): 404.95, #queue-req: 0, 
[2025-09-30 11:18:14] INFO:     127.0.0.1:35584 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-30 11:18:14] INFO:     127.0.0.1:35450 - "POST /v1/chat/completions HTTP/1.1" 200 OK
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:10<00:00,  6.08it/s]
Writing report to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.html
{'other': np.float64(0.75), 'other:std': np.float64(0.4330127018922193), 'score:std': np.float64(0.49607837082461076), 'stem': np.float64(0.5454545454545454), 'stem:std': np.float64(0.49792959773196915), 'humanities': np.float64(0.4782608695652174), 'humanities:std': np.float64(0.4995271866554807), 'social_sciences': np.float64(0.5), 'social_sciences:std': np.float64(0.5), 'score': np.float64(0.5625)}
Writing results to /tmp/mmlu_deepseek-ai_DeepSeek-Coder-V2-Lite-Instruct.json
Total latency: 10.529 s
Score: 0.562
.
----------------------------------------------------------------------
Ran 2 tests in 118.597s

OK

Modifications

Accuracy Tests

Benchmarking and Profiling

$python ./benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Running correctness verification for bfloat16...
✅ All implementations match
Running correctness verification for float64...
✅ All implementations match (Triton skipped for float64)

Running performance benchmark for bfloat16...
sum_scaled_performance_bfloat16:
    num_tokens    Original  TorchCompile  TritonKernel  CudaKernel
0          1.0   12.288000     15.520000      9.152000    7.776000
1          2.0   14.368000     20.064000      9.248000    7.808000
2          4.0   14.304000     19.888001      9.216000    7.808000
3          8.0   14.272000     19.264000      9.584000    7.840000
4         16.0   14.208000     18.464001     10.208000    8.160000
5         32.0   14.304000     17.983999     10.240000    9.088000
6         64.0   14.720000     24.800001     10.240000    9.760000
7        128.0   16.352000     40.927999     10.336000   11.808000
8        256.0   20.000000     75.199999     12.288000   13.888000
9        512.0   26.528001    142.719999     16.384000   16.031999
10      1024.0   40.991999    277.999997     25.599999   25.632000
11      2048.0   67.423999    547.648013     40.544000   41.088000
12      4096.0  114.720002   1087.520003     67.520000   68.672001

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the moe_sum_reduce CUDA kernel by introducing comprehensive support for float64 data types, which was a critical missing feature leading to test failures. Beyond float64, the changes ensure more robust handling of bfloat16 and float32 and provide a solution for dynamic topk_num values, improving the kernel's flexibility and reliability. The accompanying benchmarking script has also been updated to reflect and verify these expanded capabilities.

Highlights

  • Float64 Support: The moe_sum_reduce CUDA kernel now supports float64 data types, resolving a previous limitation that caused MLA test failures.
  • Expanded Data Type Compatibility: The kernel has been updated to robustly handle bfloat16, float32, and float64 data types.
  • Runtime topk_num Fix: Addressed and fixed the issue of unsupported runtime topk_num values in the previous kernel version by introducing fallback kernels.
  • Improved Type Handling in CUDA Kernels: Refactored type conversion and accumulation logic in CUDA kernels using at::opmath_type for better precision and flexibility across different floating-point types.
  • Enhanced Benchmarking Script: The benchmarking script (benchmark_sum_scale.py) has been updated to include correctness verification and performance benchmarks for both bfloat16 and float64.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds float64 support to the moe_sum_reduce CUDA kernel and its corresponding Triton benchmark, addressing a failure in MLA tests. It also enhances flexibility by introducing a fallback mechanism for runtime topk_num values. The use of opmath_t in the CUDA kernel for handling various precisions is a solid implementation choice. My review highlights a critical precision loss issue in the Triton kernel for float64 inputs and provides medium-severity suggestions to refactor duplicated code in the CUDA kernel's dispatch logic, which will enhance maintainability.

@yuan-luo yuan-luo changed the title Support float64 moe_sum_reduce cuda kernel [sgl-kernel] Support float64 moe_sum_reduce cuda kernel Sep 29, 2025
@Alcanderian
Copy link
Collaborator

Using torch.tensor(xxx, ...) will raise a 2us memcpy h2d kernel, and it may not compatible with cuda graph?

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Sep 30, 2025

Using torch.tensor(xxx, ...) will raise a 2us memcpy h2d kernel, and it may not compatible with cuda graph?

This is a ref kernel to verify the float64 acc result. In product code, the triton kernel is unchanged.
The cuda kernel is the one concerned.

@yuan-luo yuan-luo force-pushed the float64_moe_sum_reduce branch 2 times, most recently from c7325f1 to 340a228 Compare October 1, 2025 08:12
@yuan-luo yuan-luo changed the title [sgl-kernel] Support float64 moe_sum_reduce cuda kernel [WIP][sgl-kernel] Support float64 moe_sum_reduce cuda kernel Oct 2, 2025
@yuan-luo yuan-luo force-pushed the float64_moe_sum_reduce branch from 9551d2f to c31c659 Compare October 2, 2025 01:16
@yuan-luo yuan-luo changed the title [WIP][sgl-kernel] Support float64 moe_sum_reduce cuda kernel [sgl-kernel] Support float64 moe_sum_reduce cuda kernel Oct 2, 2025
@yuan-luo yuan-luo force-pushed the float64_moe_sum_reduce branch from 2519942 to 02bd5e1 Compare October 2, 2025 04:41
@BBuf BBuf added the run-ci label Oct 3, 2025
@yuan-luo yuan-luo enabled auto-merge (squash) October 6, 2025 11:04
@yuan-luo yuan-luo added run-ci and removed run-ci labels Oct 7, 2025
@yuan-luo yuan-luo merged commit 4f42c8c into sgl-project:main Oct 7, 2025
171 of 183 checks passed
ch-tiger1 pushed a commit to ch-tiger1/sglang that referenced this pull request Oct 9, 2025
lpc0220 pushed a commit to lpc0220/sglang that referenced this pull request Oct 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants