Skip to content

Conversation

Alcanderian
Copy link

No description provided.

@LyricZhao
Copy link
Collaborator

Thanks for this! My only comment is to skip launching in LaunchRuntime::launch instead of making a macro and put it for every kernel.

Copy link
Contributor

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and this works well on sgl-kernel. Please fix @LyricZhao's comment. Thanks!

@rainj-me
Copy link
Contributor

rainj-me commented Sep 3, 2025

@zhyncs @Alcanderian After diving deep to the sglang code and DeepGEMM code, I believe we tend to spread the logic from sglang to DeepGEMM.

  1. From sglang source code, the ops need to WarmUp (compile) are
  • fp8_gemm_nt
  • m_grouped_fp8_gemm_nt_contiguous
  • m_grouped_fp8_gemm_nt_masked
  1. From DeepGEMM source code, these ops compiled dims are 'nk', which means the 'm' dim is set to 0 refer to code .
  2. Base on 1 and 2, there is no reason to WarmUP m dim as a full range from 1-32k in code

@Alcanderian
Copy link
Author

@zhyncs @Alcanderian After diving deep to the sglang code and DeepGEMM code, I believe we tend to spread the logic from sglang to DeepGEMM.

  1. From sglang source code, the ops need to WarmUp (compile) are
  • fp8_gemm_nt
  • m_grouped_fp8_gemm_nt_contiguous
  • m_grouped_fp8_gemm_nt_masked
  1. From DeepGEMM source code, these ops compiled dims are 'nk', which means the 'm' dim is set to 0 refer to code .
  2. Base on 1 and 2, there is no reason to WarmUP m dim as a full range from 1-32k in code

Actually, the reason we want to do warmup for the M=1-32K range is that the get_best_config function (

static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
) determines tiling and some other strategies based on the value of M. Different ranges of M trigger the compilation of different kernels. If this compilation were to occur during the service process, it would cause stuttering.

@rainj-me
Copy link
Contributor

rainj-me commented Sep 3, 2025

@zhyncs @Alcanderian After diving deep to the sglang code and DeepGEMM code, I believe we tend to spread the logic from sglang to DeepGEMM.

  1. From sglang source code, the ops need to WarmUp (compile) are
  • fp8_gemm_nt
  • m_grouped_fp8_gemm_nt_contiguous
  • m_grouped_fp8_gemm_nt_masked
  1. From DeepGEMM source code, these ops compiled dims are 'nk', which means the 'm' dim is set to 0 refer to code .
  2. Base on 1 and 2, there is no reason to WarmUP m dim as a full range from 1-32k in code

Actually, the reason we want to do warmup for the M=1-32K range is that the get_best_config function (

static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,

) determines tiling and some other strategies based on the value of M. Different ranges of M trigger the compilation of different kernels. If this compilation were to occur during the service process, it would cause stuttering.

I get it, but from my testing the warm up only generate 1 kernel for 1-32k range with N K binding. Maybe there is some issue with the get_best_config result ?

@rainj-me
Copy link
Contributor

rainj-me commented Sep 3, 2025

Just run a test by printing out m,k,n with the code's hash and the combination is

~ grep 'hash:' result.log | awk -F',' '{print $2","$3","$4}' |sort -n |uniq
 n:2112, k:7168, hash:ff553c76cf29a593aa64d3978324b464
 n:6144, k:1536, hash:7fcb2cba293de06aed40e104a84329ed
 n:7168, k:4096, hash:cd6abf55b6918d860e44b6eda2227393
 n:7168, k:4608, hash:2162a9ac1ef3ec57bcb61026f14a22c1
 n:8192, k:512, hash:178c4989723bf2cf6b216774fe798a1f
 n:9216, k:7168, hash:d7b3bb3a047b35716ae4dc8a3f73b678

For original file, check the attachment result.log.

And the code diff to print out the result.log is

--- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
+++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
@@ -8,6 +8,7 @@
 #include "../../utils/exception.hpp"
 #include "../../utils/format.hpp"
 #include "../../utils/math.hpp"
+#include "../../utils/hash.hpp"
 #include "../heuristics/sm100.hpp"
 #include "runtime_utils.hpp"
 
@@ -142,6 +143,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
         .tensor_map_d = tensor_map_d
     };
     const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
+    std::cout << fmt::format("m:{}, n:{}, k:{}, hash:{}", m, n, k, get_hex_digest(code)) << std::endl;
     const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
     MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
 }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants