diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py b/sgl-kernel/benchmark/bench_sum_scale.py similarity index 75% rename from benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py rename to sgl-kernel/benchmark/bench_sum_scale.py index ec6b2f2f219..ad9621ee1f1 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py +++ b/sgl-kernel/benchmark/bench_sum_scale.py @@ -1,9 +1,17 @@ +import os + import torch import triton import triton.language as tl from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda from triton.testing import do_bench +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + @triton.jit def _moe_sum_reduce_kernel( @@ -38,7 +46,6 @@ def _moe_sum_reduce_kernel( base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :] accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32) - for i in tl.range(0, topk_num, num_stages=NUM_STAGE): tile = tl.load( base_ptrs + i * input_stride_1, @@ -110,7 +117,7 @@ def compute_sum_scaled_compiled( return out -def get_benchmark(): +def get_benchmark(dtype=torch.bfloat16): num_tokens_range = [2**i for i in range(0, 13)] @triton.testing.perf_report( @@ -122,7 +129,7 @@ def get_benchmark(): line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"], styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")], ylabel="us", - plot_name="sum_scaled_performance", + plot_name=f"sum_scaled_performance_{str(dtype).split('.')[-1]}", args={}, ) ) @@ -174,8 +181,8 @@ def benchmark(num_tokens, version): return benchmark -def verify_correctness(num_tokens=1024): - x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16) +def verify_correctness(num_tokens=1024, dtype=torch.bfloat16): + x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=dtype) scaling_factor = 0.3 out_baseline = torch.empty_like(x[:, 0]) @@ -184,33 +191,60 @@ def verify_correctness(num_tokens=1024): out_compiled = torch.empty_like(out_baseline) compute_sum_scaled_compiled(x, out_compiled, scaling_factor) - out_triton = torch.empty_like(out_baseline) - moe_sum_reduce_triton(x, out_triton, scaling_factor) - out_cuda = torch.empty_like(out_baseline) moe_sum_reduce_cuda(x, out_cuda, scaling_factor) - if ( - torch.allclose(out_baseline, out_compiled, atol=1e-2, rtol=1e-2) - and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2) - and torch.allclose(out_baseline, out_cuda, atol=1e-2, rtol=1e-2) - ): - print("✅ All implementations match") + triton_skipped = dtype == torch.float64 + if not triton_skipped: + out_triton = torch.empty_like(out_baseline) + moe_sum_reduce_triton(x, out_triton, scaling_factor) + + if dtype == torch.float64: + atol, rtol = 1e-12, 1e-12 + elif dtype == torch.float32: + atol, rtol = 1e-6, 1e-6 + else: # bfloat16 / float16 + atol, rtol = 1e-2, 1e-2 + + ok_compiled = torch.allclose(out_baseline, out_compiled, atol=atol, rtol=rtol) + ok_cuda = torch.allclose(out_baseline, out_cuda, atol=atol, rtol=rtol) + ok_triton = ( + True + if triton_skipped + else torch.allclose(out_baseline, out_triton, atol=atol, rtol=rtol) + ) + + if ok_compiled and ok_triton and ok_cuda: + msg = "✅ All implementations match" + if triton_skipped: + msg += " (Triton skipped for float64)" + print(msg) else: print("❌ Implementations differ") print( f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" ) - print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}") + if not triton_skipped: + print( + f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}" + ) print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}") if __name__ == "__main__": - print("Running correctness verification...") - verify_correctness() + print("Running correctness verification for bfloat16...") + verify_correctness(dtype=torch.bfloat16) + + # CI environment uses simplified parameters + if not IS_CI: + print("Running correctness verification for float64...") + verify_correctness(dtype=torch.float64) + + print("Running correctness verification for float64...") + verify_correctness(dtype=torch.float64) - print("\nRunning performance benchmark...") - benchmark = get_benchmark() + print("\nRunning performance benchmark for bfloat16...") + benchmark = get_benchmark(dtype=torch.bfloat16) benchmark.run( print_data=True, # save_path="./configs/benchmark_ops/sum_scaled/" diff --git a/sgl-kernel/csrc/moe/moe_sum_reduce.cu b/sgl-kernel/csrc/moe/moe_sum_reduce.cu index 6e5454336f2..791ce620b29 100644 --- a/sgl-kernel/csrc/moe/moe_sum_reduce.cu +++ b/sgl-kernel/csrc/moe/moe_sum_reduce.cu @@ -1,3 +1,4 @@ +#include #include #include #include @@ -12,25 +13,36 @@ #include "utils.h" template -__device__ __forceinline__ float to_float(T x) { - return static_cast(x); -} +using opmath_t = at::opmath_type; -template <> -__device__ __forceinline__ float to_float(half x) { - return __half2float(x); +template +__device__ __forceinline__ opmath_t to_acc(T x) { + return static_cast>(x); } template -__device__ __forceinline__ T from_float(float x) { +__device__ __forceinline__ T from_acc(opmath_t x) { return static_cast(x); } template <> -__device__ __forceinline__ half from_float(float x) { +__device__ __forceinline__ opmath_t to_acc(at::Half x) { + return __half2float(__nv_half(x)); +} +template <> +__device__ __forceinline__ at::Half from_acc(opmath_t x) { return __float2half_rn(x); } +template <> +__device__ __forceinline__ opmath_t to_acc(at::BFloat16 x) { + return __bfloat162float(__nv_bfloat16(x)); +} +template <> +__device__ __forceinline__ at::BFloat16 from_acc(opmath_t x) { + return __float2bfloat16_rn(x); +} + template __device__ __forceinline__ T ldg_cg(const T* p) { return __ldg(p); @@ -111,22 +123,22 @@ __global__ void moe_sum_reduce_kernel_warp_token_topk( const int64_t stride_token, const int64_t stride_topk, const int64_t out_stride_token, - const float scale) { + const opmath_t scale) { const int warp_id = threadIdx.x / 32; const int lane = threadIdx.x % 32; const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id; if (t >= token_num) return; for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) { - float acc = 0.f; + opmath_t acc = opmath_t(0); const int64_t base = t * stride_token + d; #pragma unroll for (int k = 0; k < TOPK; ++k) { - acc += to_float(ldg_cg(&x[base + (int64_t)k * stride_topk])); + acc += to_acc(x[base + (int64_t)k * stride_topk]); } acc *= scale; - y[t * out_stride_token + d] = from_float(acc); + y[t * out_stride_token + d] = from_acc(acc); } } @@ -139,23 +151,79 @@ __global__ void moe_sum_reduce_kernel( const int64_t stride_token, const int64_t stride_topk, const int64_t out_stride_token, - const float scale) { + const opmath_t scale) { for (int t = blockIdx.y; t < token_num; t += gridDim.y) { for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) { const int64_t base = t * stride_token + d; - float acc = 0.f; + opmath_t acc = opmath_t(0); #pragma unroll for (int k = 0; k < TOPK; ++k) { - acc += to_float(x[base + (int64_t)k * stride_topk]); + acc += to_acc(x[base + (int64_t)k * stride_topk]); } acc *= scale; - y[t * out_stride_token + d] = from_float(acc); + y[t * out_stride_token + d] = from_acc(acc); } } } +// -------------------- general-topk fallback kernels -------------------- +// small-token +template +__global__ void moe_sum_reduce_kernel_general( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int64_t token_num, + const int64_t hidden_dim, + const int64_t stride_token, + const int64_t stride_topk, + const int64_t out_stride_token, + const int topk_num, + const opmath_t scale) { + for (int t = blockIdx.y; t < token_num; t += gridDim.y) { + for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) { + const int64_t base = t * stride_token + d; + opmath_t acc = opmath_t(0); +#pragma unroll 1 + for (int k = 0; k < topk_num; ++k) { + acc += to_acc(x[base + (int64_t)k * stride_topk]); + } + acc *= scale; + y[t * out_stride_token + d] = from_acc(acc); + } + } +} + +// warp-per-token +template +__global__ void moe_sum_reduce_kernel_warp_token_general( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int64_t token_num, + const int64_t hidden_dim, + const int64_t stride_token, + const int64_t stride_topk, + const int64_t out_stride_token, + const int topk_num, + const opmath_t scale) { + const int warp_id = threadIdx.x / 32; + const int lane = threadIdx.x % 32; + const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (t >= token_num) return; + + for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) { + opmath_t acc = opmath_t(0); + const int64_t base = t * stride_token + d; +#pragma unroll 1 + for (int k = 0; k < topk_num; ++k) { + acc += to_acc(x[base + (int64_t)k * stride_topk]); + } + acc *= scale; + y[t * out_stride_token + d] = from_acc(acc); + } +} + void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor) { TORCH_CHECK(input.is_cuda(), "input must be CUDA tensor"); TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor"); @@ -175,8 +243,6 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling const int64_t in_stride_topk = input.stride(1); const int64_t out_stride_token = output.stride(0); - const float scale = static_cast(routed_scaling_factor); - auto stream = at::cuda::getCurrentCUDAStream(); const bool fast_bf16_vec_ok = (input.scalar_type() == at::kBFloat16) && (token_num > 256) && (hidden_dim % 8 == 0); @@ -198,6 +264,7 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling auto stream = at::cuda::getCurrentCUDAStream(); + const float scale = static_cast(routed_scaling_factor); moe_sum_reduce_warp_per_token_vec_kernel<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(output.data_ptr()), @@ -209,32 +276,12 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling out_stride_token, scale); - TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel launch failed"); + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (bf16 vec) launch failed"); return; } const bool per_token_use_one_warp = (token_num > 128); - auto dispatch_topk = [&](auto&& launch_kernel) { - switch (topk_num) { - case 2: - launch_kernel(std::integral_constant{}); - break; - case 4: - launch_kernel(std::integral_constant{}); - break; - case 8: - launch_kernel(std::integral_constant{}); - break; - case 9: - launch_kernel(std::integral_constant{}); - break; - default: - launch_kernel(std::integral_constant{}); - break; - } - }; - if (!per_token_use_one_warp) { // ---------- small-token ---------- const int block_size = 256; @@ -245,28 +292,55 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling dim3 block(block_size); dim3 grid(static_cast(grid_x), static_cast(grid_y)); +#define LAUNCH_SMALL_TOKEN_KERNEL(TOPK) \ + moe_sum_reduce_kernel<<>>( \ + input.data_ptr(), \ + output.data_ptr(), \ + token_num, \ + hidden_dim, \ + in_stride_token, \ + in_stride_topk, \ + out_stride_token, \ + scale); + AT_DISPATCH_FLOATING_TYPES_AND2( at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_small_token", [&] { using scalar_t_ = scalar_t; - - auto lauch_small_token_kernel = [&](auto topk_c) { - constexpr int TK = decltype(topk_c)::value; - - moe_sum_reduce_kernel<<>>( - input.data_ptr(), - output.data_ptr(), - token_num, - hidden_dim, - in_stride_token, - in_stride_topk, - out_stride_token, - scale); - }; - dispatch_topk(lauch_small_token_kernel); + using acc_t_ = opmath_t; + const acc_t_ scale = static_cast(routed_scaling_factor); + + switch (topk_num) { + case 2: + LAUNCH_SMALL_TOKEN_KERNEL(2); + break; + case 4: + LAUNCH_SMALL_TOKEN_KERNEL(4); + break; + case 8: + LAUNCH_SMALL_TOKEN_KERNEL(8); + break; + case 9: + LAUNCH_SMALL_TOKEN_KERNEL(9); + break; + default: // launch general kernel + moe_sum_reduce_kernel_general<<>>( + input.data_ptr(), + output.data_ptr(), + token_num, + hidden_dim, + in_stride_token, + in_stride_topk, + out_stride_token, + static_cast(topk_num), + scale); + } }); +#undef LAUNCH_SMALL_TOKEN_KERNEL + + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (small-token) launch failed"); } else { - // ---------- warp-token ---------- + // ---------- warp-per-token ---------- constexpr int WARPS_PER_BLOCK = 4; constexpr int THREADS = WARPS_PER_BLOCK * 32; @@ -279,25 +353,51 @@ void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling dim3 block(THREADS); dim3 grid(static_cast(gx), static_cast(gy)); +#define LAUNCH_WARP_PER_TOKEN_KERNEL(TOPK) \ + moe_sum_reduce_kernel_warp_token_topk<<>>( \ + input.data_ptr(), \ + output.data_ptr(), \ + token_num, \ + hidden_dim, \ + in_stride_token, \ + in_stride_topk, \ + out_stride_token, \ + scale); + AT_DISPATCH_FLOATING_TYPES_AND2( at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_large_token", [&] { using scalar_t_ = scalar_t; - - auto launch_large_token_kernel = [&](auto topk_c) { - constexpr int TK = decltype(topk_c)::value; - - moe_sum_reduce_kernel_warp_token_topk<<>>( - input.data_ptr(), - output.data_ptr(), - token_num, - hidden_dim, - in_stride_token, - in_stride_topk, - out_stride_token, - scale); - }; - dispatch_topk(launch_large_token_kernel); + using acc_t_ = opmath_t; + const acc_t_ scale = static_cast(routed_scaling_factor); + + switch (topk_num) { + case 2: + LAUNCH_WARP_PER_TOKEN_KERNEL(2); + break; + case 4: + LAUNCH_WARP_PER_TOKEN_KERNEL(4); + break; + case 8: + LAUNCH_WARP_PER_TOKEN_KERNEL(8); + break; + case 9: + LAUNCH_WARP_PER_TOKEN_KERNEL(9); + break; + default: // launch general kernel + moe_sum_reduce_kernel_warp_token_general<<>>( + input.data_ptr(), + output.data_ptr(), + token_num, + hidden_dim, + in_stride_token, + in_stride_topk, + out_stride_token, + static_cast(topk_num), + scale); + } }); +#undef LAUNCH_WARP_PER_TOKEN_KERNEL + + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel (warp-token) launch failed"); } - TORCH_CHECK(cudaGetLastError() == cudaSuccess, "CUDA kernel launch failed"); }