-
Notifications
You must be signed in to change notification settings - Fork 697
[feat] enable jit_kernels to skip launch #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[feat] enable jit_kernels to skip launch #182
Conversation
Thanks for this! My only comment is to skip launching in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and this works well on sgl-kernel. Please fix @LyricZhao's comment. Thanks!
@zhyncs @Alcanderian After diving deep to the sglang code and DeepGEMM code, I believe we tend to spread the logic from sglang to DeepGEMM.
|
Actually, the reason we want to do warmup for the M=1-32K range is that the get_best_config function (
|
I get it, but from my testing the warm up only generate 1 kernel for 1-32k range with N K binding. Maybe there is some issue with the get_best_config result ? |
Just run a test by printing out m,k,n with the code's hash and the combination is ~ grep 'hash:' result.log | awk -F',' '{print $2","$3","$4}' |sort -n |uniq
n:2112, k:7168, hash:ff553c76cf29a593aa64d3978324b464
n:6144, k:1536, hash:7fcb2cba293de06aed40e104a84329ed
n:7168, k:4096, hash:cd6abf55b6918d860e44b6eda2227393
n:7168, k:4608, hash:2162a9ac1ef3ec57bcb61026f14a22c1
n:8192, k:512, hash:178c4989723bf2cf6b216774fe798a1f
n:9216, k:7168, hash:d7b3bb3a047b35716ae4dc8a3f73b678 For original file, check the attachment result.log. And the code diff to print out the result.log is --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
+++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
@@ -8,6 +8,7 @@
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
+#include "../../utils/hash.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
@@ -142,6 +143,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
+ std::cout << fmt::format("m:{}, n:{}, k:{}, hash:{}", m, n, k, get_hex_digest(code)) << std::endl;
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
} |
No description provided.