diff --git a/examples/hopper_matmul/matmul_v2.py b/examples/hopper_matmul/matmul_v2.py index b1bac5c6..789f515a 100644 --- a/examples/hopper_matmul/matmul_v2.py +++ b/examples/hopper_matmul/matmul_v2.py @@ -118,7 +118,9 @@ def main(): headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] workloads = [ [4096, 4096, 4096], - # [128, 48, 16], + [4096, 4096, 14336], + [8192, 8192, 8192], + [10240, 10240, 10240], ] rows = [] diff --git a/examples/hopper_matmul/matmul_v3.py b/examples/hopper_matmul/matmul_v3.py new file mode 100644 index 00000000..9122540d --- /dev/null +++ b/examples/hopper_matmul/matmul_v3.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pandas +import tilus +import torch +from tilus import float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +@tilus.autotune("num_stages", [2, 3, 4, 5, 6, 7]) +@tilus.autotune( + "block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]] +) +@tilus.autotune("block_k", [16, 32, 64]) +class MatmulWGMMAV3(tilus.Script): + def __init__( + self, + num_stages, + block_m, + block_n, + block_k, + ): + super().__init__() + self.num_stages = num_stages + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + cdiv(m_size, self.block_m), + cdiv(n_size, self.block_n), + ] + self.attrs.warps = 5 + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_n, block_k]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + consumer_barriers = self.mbarrier.alloc(count=[2 for _ in range(self.num_stages)]) + producer_barriers = self.mbarrier.alloc( + count=[128 for _ in range(self.num_stages)] + ) + + with self.thread_group(thread_begin=128, num_threads=32): + stage: int32 = 0 + producer_phases = self.register_tensor( + dtype=uint32, shape=[self.num_stages], init=1 + ) + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + self.mbarrier.wait(producer_barriers[stage], phase=producer_phases[stage]) + producer_phases[stage] ^= 1 + with self.single_thread(): + self.tma.global_to_shared( + src=ga, + dst=sa[stage], + offsets=[offset_m, offset_k], + mbarrier=consumer_barriers[stage], + ) + self.tma.global_to_shared( + src=gb, + dst=sb[stage], + offsets=[offset_n, offset_k], + mbarrier=consumer_barriers[stage], + ) + stage = (stage + 1) % self.num_stages + + for _ in self.range(min(self.num_stages, cdiv(k_size, self.block_k))): + self.mbarrier.wait( + producer_barriers[stage], phase=producer_phases[stage] + ) # wait until the stage is ready to be filled + producer_phases[stage] ^= 1 + stage = (stage + 1) % self.num_stages + + with self.thread_group(thread_begin=0, num_threads=128): + consumer_phases = self.register_tensor( + dtype=uint32, shape=[self.num_stages], init=0 + ) + stage: int32 = 0 + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage]) + consumer_phases[stage] ^= 1 + self.wgmma.fence() + self.wgmma.mma(sa[stage], sb[stage].transpose(), acc) + self.wgmma.commit_group() + self.wgmma.wait_group(0) + self.mbarrier.arrive(producer_barriers[stage]) + stage = (stage + 1) % self.num_stages + self.sync() + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + [4096, 4096, 14336], + [8192, 8192, 8192], + [10240, 10240, 10240], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulWGMMAV3() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual) + + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +# %% + +if __name__ == "__main__": + main() diff --git a/python/tilus/backends/codegen.py b/python/tilus/backends/codegen.py index 129dc6a9..46a8d574 100644 --- a/python/tilus/backends/codegen.py +++ b/python/tilus/backends/codegen.py @@ -278,7 +278,6 @@ def visit_ForStmt(self, stmt: ForStmt) -> None: def visit_ThreadGroupStmt(self, stmt: ThreadGroupStmt) -> None: # check the validity of the thread group parent_num_threads = self.thread_group_stack.num_threads[-1] - assert parent_num_threads % stmt.num_threads == 0 assert 0 <= stmt.thread_begin and stmt.thread_begin + stmt.num_threads <= parent_num_threads self.builder.comment( diff --git a/python/tilus/ir/utils/thread_group_stack.py b/python/tilus/ir/utils/thread_group_stack.py index a5ddcb28..5c4389ff 100644 --- a/python/tilus/ir/utils/thread_group_stack.py +++ b/python/tilus/ir/utils/thread_group_stack.py @@ -43,8 +43,6 @@ def push(self, thread_begin: int, num_threads: int) -> None: depth = self.stack_depth() if depth > 0: parent_num_threads = self.num_threads[-1] - if parent_num_threads % num_threads != 0: - raise ValueError("group_size must be a divisor of the parent group_size") if thread_begin < 0 or thread_begin + num_threads > parent_num_threads: raise ValueError( "thread_begin must be in [0, parent_num_threads - num_threads), got thread_begin={}, num_threads={}, parent_num_threads={}".format( diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 28881c0a..76560ef8 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -58,6 +58,7 @@ ("hopper_matmul", "matmul_v0.py", nvgpu_sm90a), ("hopper_matmul", "matmul_v1.py", nvgpu_sm90a), ("hopper_matmul", "matmul_v2.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v3.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+)