Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import os

import torch
import triton
import triton.language as tl
from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda
from triton.testing import do_bench

# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)


@triton.jit
def _moe_sum_reduce_kernel(
Expand Down Expand Up @@ -38,7 +46,6 @@ def _moe_sum_reduce_kernel(
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]

accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)

for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tile = tl.load(
base_ptrs + i * input_stride_1,
Expand Down Expand Up @@ -110,7 +117,7 @@ def compute_sum_scaled_compiled(
return out


def get_benchmark():
def get_benchmark(dtype=torch.bfloat16):
num_tokens_range = [2**i for i in range(0, 13)]

@triton.testing.perf_report(
Expand All @@ -122,7 +129,7 @@ def get_benchmark():
line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")],
ylabel="us",
plot_name="sum_scaled_performance",
plot_name=f"sum_scaled_performance_{str(dtype).split('.')[-1]}",
args={},
)
)
Expand Down Expand Up @@ -174,8 +181,8 @@ def benchmark(num_tokens, version):
return benchmark


def verify_correctness(num_tokens=1024):
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
def verify_correctness(num_tokens=1024, dtype=torch.bfloat16):
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=dtype)
scaling_factor = 0.3

out_baseline = torch.empty_like(x[:, 0])
Expand All @@ -184,33 +191,60 @@ def verify_correctness(num_tokens=1024):
out_compiled = torch.empty_like(out_baseline)
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)

out_triton = torch.empty_like(out_baseline)
moe_sum_reduce_triton(x, out_triton, scaling_factor)

out_cuda = torch.empty_like(out_baseline)
moe_sum_reduce_cuda(x, out_cuda, scaling_factor)

if (
torch.allclose(out_baseline, out_compiled, atol=1e-2, rtol=1e-2)
and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2)
and torch.allclose(out_baseline, out_cuda, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
triton_skipped = dtype == torch.float64
if not triton_skipped:
out_triton = torch.empty_like(out_baseline)
moe_sum_reduce_triton(x, out_triton, scaling_factor)

if dtype == torch.float64:
atol, rtol = 1e-12, 1e-12
elif dtype == torch.float32:
atol, rtol = 1e-6, 1e-6
else: # bfloat16 / float16
atol, rtol = 1e-2, 1e-2

ok_compiled = torch.allclose(out_baseline, out_compiled, atol=atol, rtol=rtol)
ok_cuda = torch.allclose(out_baseline, out_cuda, atol=atol, rtol=rtol)
ok_triton = (
True
if triton_skipped
else torch.allclose(out_baseline, out_triton, atol=atol, rtol=rtol)
)

if ok_compiled and ok_triton and ok_cuda:
msg = "✅ All implementations match"
if triton_skipped:
msg += " (Triton skipped for float64)"
print(msg)
else:
print("❌ Implementations differ")
print(
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
)
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
if not triton_skipped:
print(
f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}"
)
print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}")


if __name__ == "__main__":
print("Running correctness verification...")
verify_correctness()
print("Running correctness verification for bfloat16...")
verify_correctness(dtype=torch.bfloat16)

# CI environment uses simplified parameters
if not IS_CI:
print("Running correctness verification for float64...")
verify_correctness(dtype=torch.float64)

print("Running correctness verification for float64...")
verify_correctness(dtype=torch.float64)

print("\nRunning performance benchmark...")
benchmark = get_benchmark()
print("\nRunning performance benchmark for bfloat16...")
benchmark = get_benchmark(dtype=torch.bfloat16)
benchmark.run(
print_data=True,
# save_path="./configs/benchmark_ops/sum_scaled/"
Expand Down
Loading
Loading