From d578631cda45b8dcf8be39fbfb75d6169e73bee5 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 2 Dec 2025 17:45:19 -0800 Subject: [PATCH 01/16] upd Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_tma.py | 118 +++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 examples/hopper_matmul/matmul_tma.py diff --git a/examples/hopper_matmul/matmul_tma.py b/examples/hopper_matmul/matmul_tma.py new file mode 100644 index 00000000..2af4015c --- /dev/null +++ b/examples/hopper_matmul/matmul_tma.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pandas +import tilus +import torch +from tilus import float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +tilus.option.cache_dir("./cache") +tilus.option.debug.dump_ir(True) + +@tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) +@tilus.autotune("block_k", [16, 32, 16]) +class MatmulTMA(tilus.Script): + def __init__( + self, + block_m, + block_n, + block_k, + ): + super().__init__() + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + cdiv(m_size, self.block_m), + cdiv(n_size, self.block_n), + ] + self.attrs.warps = 4 + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + sa = self.shared_tensor(dtype=float16, shape=[block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[block_n, block_k]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + tma_barrier = self.mbarrier.alloc(count=[2]) + phase: uint32 = 0 + + for offset_k in range(0, k_size, block_k): + # issue asynchronous copy instructions to load tiles of A and B + with self.single_thread(): + self.tma.global_to_shared(src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier) + self.tma.global_to_shared(src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier) + self.mbarrier.wait(tma_barrier, phase=phase) + + # synchronize threads in the block to ensure data is available in shared memory + self.sync() + + # a = self.load_shared(sa) + # b = self.load_shared(sb) + self.dot(sa, sb.transpose(), acc, out=acc) + self.sync() + phase ^= 1 + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulTMA() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual) + + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +# %% + +if __name__ == "__main__": + main() From 5ae6d16c56c5dba53d54185df4db02ceb3aba85f Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 2 Dec 2025 17:53:00 -0800 Subject: [PATCH 02/16] add instructions Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_wgmma.py | 121 +++++++++++++++++++++ python/tilus/ir/builders/stmt_builder.py | 21 ++++ python/tilus/lang/instructions/__init__.py | 2 + python/tilus/lang/instructions/wgmma.py | 41 +++++++ 4 files changed, 185 insertions(+) create mode 100644 examples/hopper_matmul/matmul_wgmma.py create mode 100644 python/tilus/lang/instructions/wgmma.py diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py new file mode 100644 index 00000000..c1da9fa7 --- /dev/null +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pandas +import tilus +import torch +from tilus import float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +tilus.option.cache_dir("./cache") +tilus.option.debug.dump_ir(True) + +@tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) +@tilus.autotune("block_k", [16, 32, 16]) +class MatmulTMA(tilus.Script): + def __init__( + self, + block_m, + block_n, + block_k, + ): + super().__init__() + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + cdiv(m_size, self.block_m), + cdiv(n_size, self.block_n), + ] + self.attrs.warps = 4 + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + sa = self.shared_tensor(dtype=float16, shape=[block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[block_n, block_k]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + tma_barrier = self.mbarrier.alloc(count=[2]) + phase: uint32 = 0 + + for offset_k in range(0, k_size, block_k): + # issue asynchronous copy instructions to load tiles of A and B + with self.single_thread(): + self.tma.global_to_shared(src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier) + self.tma.global_to_shared(src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier) + self.mbarrier.wait(tma_barrier, phase=phase) + + # synchronize threads in the block to ensure data is available in shared memory + self.sync() + + # a = self.load_shared(sa) + # b = self.load_shared(sb) + self.wgmma.fence() + self.wgmma.mma(sa, sb, acc) + self.wgmma.commit_group() + self.wgmma.wait_group(1) + self.sync() + phase ^= 1 + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulTMA() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual) + + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +# %% + +if __name__ == "__main__": + main() diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 9bd24804..9ee0e745 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -1324,6 +1324,27 @@ def tcgen05_mma_ts(self, a: TMemoryTensor, b: SharedTensor, d: TMemoryTensor) -> inst = Tcgen05MmaTSInst.create(a=a, b=b, d=d) self.append(inst) + # wgmma + def wgmma_fence(self) -> None: + inst = WgmmaFenceInst.create() + self.append(inst) + + def wgmma_commit_group(self) -> None: + inst = WgmmaCommitGroupInst.create() + self.append(inst) + + def wgmma_wait_group(self, n: Expr | int) -> None: + inst = WgmmaWaitGroupInst.create(n=n) + self.append(inst) + + def wgmma_mma_ss(self, a: SharedTensor, b: SharedTensor, d: RegisterTensor) -> None: + inst = WgmmaMmaSSInst.create(a=a, b=b, d=d) + self.append(inst) + + def wgmma_mma_rs(self, a: RegisterTensor, b: SharedTensor, d: RegisterTensor) -> None: + inst = WgmmaMmaRSInst.create(a=a, b=b, d=d) + self.append(inst) + # annotations def annotate_layout(self, tensor: RegisterTensor | SharedTensor, layout: RegisterLayout | SharedLayout) -> None: inst = AnnotateLayoutInst.create(tensor=tensor, layout=layout) diff --git a/python/tilus/lang/instructions/__init__.py b/python/tilus/lang/instructions/__init__.py index 0d7dafe0..801e1ba9 100644 --- a/python/tilus/lang/instructions/__init__.py +++ b/python/tilus/lang/instructions/__init__.py @@ -22,6 +22,7 @@ from .root import RootInstructionGroup from .tcgen05 import Tcgen05InstructionGroup from .tma import TmaInstructionGroup +from .wgmma import WgmmaInstructionGroup class InstructionInterface(RootInstructionGroup): @@ -30,3 +31,4 @@ class InstructionInterface(RootInstructionGroup): mbarrier = BarrierInstructionGroup() clc = ClusterLaunchControlInstructionGroup() cluster = BlockClusterInstructionGroup() + wgmma = WgmmaInstructionGroup() \ No newline at end of file diff --git a/python/tilus/lang/instructions/wgmma.py b/python/tilus/lang/instructions/wgmma.py new file mode 100644 index 00000000..5dbeb3a1 --- /dev/null +++ b/python/tilus/lang/instructions/wgmma.py @@ -0,0 +1,41 @@ +import contextlib +from typing import Optional, Sequence + +from hidet.ir.expr import Expr +from hidet.ir.type import DataType + +from tilus.ir.inst import InstructionError +from tilus.ir.tensor import RegisterTensor, SharedTensor + +from .root import InstructionGroup + + +class WgmmaInstructionGroup(InstructionGroup): + def fence(self) -> None: + self._builder.wgmma_fence() + + def commit_group(self) -> None: + self._builder.wgmma_commit_group() + + def wait_group(self, n: Union[Expr, int]) -> None: + self._builder.wgmma_wait_group(n) + + def mma(self, a: SharedTensor | RegisterTensor, b: SharedTensor, d: RegisterTensor) -> None: + if isinstance(a, SharedTensor): + if len(a.shape) != 2: + raise InstructionError("mma requires 2D shared tensors, got shape {}".format(a.shape)) + if len(b.shape) != 2: + raise InstructionError("mma requires 2D shared tensors, got shape {}".format(b.shape)) + if len(d.shape) != 2: + raise InstructionError("mma requires 2D register tensors, got shape {}".format(d.shape)) + self._builder.wgmma_mma_ss(a, b, d) + elif isinstance(a, RegisterTensor): + if len(a.shape) != 2: + raise InstructionError("mma requires 2D register tensors, got shape {}".format(a.shape)) + if len(b.shape) != 2: + raise InstructionError("mma requires 2D shared tensors, got shape {}".format(b.shape)) + if len(d.shape) != 2: + raise InstructionError("mma requires 2D register tensors, got shape {}".format(d.shape)) + self._builder.wgmma_mma_rs(a, b, d) + else: + raise InstructionError("Invalid type of a: {}, expected SharedTensor or RegisterTensor".format(type(a))) \ No newline at end of file From 6ab05f8cd10871a06ea81f4f89c5265c6404f2b1 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 2 Dec 2025 18:21:32 -0800 Subject: [PATCH 03/16] add inference rules Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_wgmma.py | 2 +- python/tilus/ir/builders/stmt_builder.py | 11 ++++++++++- .../ir/layout/inference/inference_rules/__init__.py | 1 + python/tilus/ir/layout/inference/order.py | 2 ++ .../ir/layout/inference/validation_rules/always_ok.py | 4 +++- python/tilus/lang/instructions/wgmma.py | 2 +- 6 files changed, 18 insertions(+), 4 deletions(-) diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py index c1da9fa7..99ebfc62 100644 --- a/examples/hopper_matmul/matmul_wgmma.py +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -68,7 +68,7 @@ def __call__( # a = self.load_shared(sa) # b = self.load_shared(sb) self.wgmma.fence() - self.wgmma.mma(sa, sb, acc) + self.wgmma.mma(sa, sb.transpose(), acc) self.wgmma.commit_group() self.wgmma.wait_group(1) self.sync() diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 9ee0e745..b15d0f33 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -69,6 +69,13 @@ Tcgen05ViewInst, Tcgen05WaitInst, ) +from tilus.ir.instructions.cuda.wgmma import ( + WgmmaFenceInst, + WgmmaCommitGroupInst, + WgmmaWaitGroupInst, + WgmmaMmaSSInst, + WgmmaMmaRSInst, +) from tilus.ir.instructions.generic import ( AddInst, AllocateGlobalInst, @@ -1333,7 +1340,9 @@ def wgmma_commit_group(self) -> None: inst = WgmmaCommitGroupInst.create() self.append(inst) - def wgmma_wait_group(self, n: Expr | int) -> None: + def wgmma_wait_group(self, n: Union[Expr, int]) -> None: + if isinstance(n, int): + n = as_expr(n) inst = WgmmaWaitGroupInst.create(n=n) self.append(inst) diff --git a/python/tilus/ir/layout/inference/inference_rules/__init__.py b/python/tilus/ir/layout/inference/inference_rules/__init__.py index dfac471a..2b4fa6b1 100644 --- a/python/tilus/ir/layout/inference/inference_rules/__init__.py +++ b/python/tilus/ir/layout/inference/inference_rules/__init__.py @@ -31,4 +31,5 @@ transform_shared, transpose, where, + wgmma, ) diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index bfcb5ef3..61fe7e62 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -40,6 +40,7 @@ from .inference_rules.tcgen05.ldst import Tcgen05LoadRule, Tcgen05StoreRule from .inference_rules.tcgen05.mma import Tcgen05MmaSSRule, Tcgen05MmaTSRule from .inference_rules.tcgen05.slice import Tcgen05SliceRule +from .inference_rules.wgmma import WgmmaMmaSSRule #, WgmmaMmaRSRule from .inference_rules.transform import SqueezeRule, UnsqueezeRule from .inference_rules.transform_shared import PermuteSharedRule, SharedSliceRule from .inference_rules.transpose import TransposeRule @@ -65,6 +66,7 @@ # shared memory rules [LoadSharedInferSwizzledSharedRule, StoreSharedSwizzleRule], [SharedSliceRule, PermuteSharedRule], + [WgmmaMmaSSRule], #, WgmmaMmaRSRule], [CopyAsyncRule], [LoadSharedInferRegisterRule], [LoadSharedInferRowMajorSharedRule], diff --git a/python/tilus/ir/layout/inference/validation_rules/always_ok.py b/python/tilus/ir/layout/inference/validation_rules/always_ok.py index aebe5069..4a360cae 100644 --- a/python/tilus/ir/layout/inference/validation_rules/always_ok.py +++ b/python/tilus/ir/layout/inference/validation_rules/always_ok.py @@ -49,8 +49,10 @@ from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixInst from tilus.ir.instructions.cuda.tcgen05 import Tcgen05CopyInst, Tcgen05LoadInst, Tcgen05MmaSSInst, Tcgen05StoreInst from tilus.ir.layout.inference.rule import LayoutValidationRule, register_rule +from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst, WgmmaMmaRSInst - +@register_rule(WgmmaMmaSSInst) # todo: should have its own rule +@register_rule(WgmmaMmaRSInst) # todo: should have its own rule @register_rule(Tcgen05LoadInst) # todo: should have its own rule @register_rule(Tcgen05StoreInst) # todo: should have its own rule @register_rule(Tcgen05CopyInst) # todo: should have its own rule diff --git a/python/tilus/lang/instructions/wgmma.py b/python/tilus/lang/instructions/wgmma.py index 5dbeb3a1..318d759d 100644 --- a/python/tilus/lang/instructions/wgmma.py +++ b/python/tilus/lang/instructions/wgmma.py @@ -1,5 +1,5 @@ import contextlib -from typing import Optional, Sequence +from typing import Optional, Sequence, Union from hidet.ir.expr import Expr from hidet.ir.type import DataType From f2cfb81fce99bbcd337fc4903e760bd8bf461713 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 2 Dec 2025 19:25:34 -0800 Subject: [PATCH 04/16] matrix d layout Signed-off-by: Qidong Su --- python/tilus/ir/instructions/cuda/wgmma.py | 41 +++++++++ .../layout/inference/inference_rules/wgmma.py | 83 +++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 python/tilus/ir/instructions/cuda/wgmma.py create mode 100644 python/tilus/ir/layout/inference/inference_rules/wgmma.py diff --git a/python/tilus/ir/instructions/cuda/wgmma.py b/python/tilus/ir/instructions/cuda/wgmma.py new file mode 100644 index 00000000..df7a3f41 --- /dev/null +++ b/python/tilus/ir/instructions/cuda/wgmma.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Sequence + +from hidet.ir.expr import Constant, Expr +from hidet.ir.type import DataType + +from tilus.ir.inst import Instruction +from tilus.ir.tensor import RegisterTensor, SharedTensor + +@dataclass(frozen=True, eq=False) +class WgmmaFenceInst(Instruction): + @staticmethod + def create() -> WgmmaFenceInst: + return WgmmaFenceInst(output=None, inputs=()) + +@dataclass(frozen=True, eq=False) +class WgmmaCommitGroupInst(Instruction): + @staticmethod + def create() -> WgmmaCommitGroupInst: + return WgmmaCommitGroupInst(output=None, inputs=()) + +@dataclass(frozen=True, eq=False) +class WgmmaWaitGroupInst(Instruction): + n: Expr + @staticmethod + def create(n: Expr) -> WgmmaWaitGroupInst: + return WgmmaWaitGroupInst(output=None, inputs=(), n=n) + +@dataclass(frozen=True, eq=False) +class WgmmaMmaSSInst(Instruction): + @staticmethod + def create(a: SharedTensor, b: SharedTensor, d: RegisterTensor) -> WgmmaMmaSSInst: + return WgmmaMmaSSInst(output=None, inputs=(a, b, d)) + +@dataclass(frozen=True, eq=False) +class WgmmaMmaRSInst(Instruction): + @staticmethod + def create(a: RegisterTensor, b: SharedTensor, d: RegisterTensor) -> WgmmaMmaRSInst: + return WgmmaMmaRSInst(output=None, inputs=(a, b, d)) \ No newline at end of file diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py new file mode 100644 index 00000000..58a20dfb --- /dev/null +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -0,0 +1,83 @@ +from tilus.ir.instructions.cuda.tcgen05 import Tcgen05MmaSSInst, Tcgen05MmaTSInst +from tilus.ir.layout import SharedLayout +from tilus.ir.layout.cuda.tcgen05.smem import ( + Tcgen05SwizzleMode, + generate_canonical_layout, +) +from tilus.ir.layout.inference.rule import ( + LayoutInferenceContext, + LayoutInferenceError, + LayoutInferenceRule, + register_rule, +) +from tilus.ir.tensor import SharedTensor, RegisterTensor +from tilus.ir.layout import RegisterLayout + +from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst, WgmmaMmaRSInst + +from tilus.ir.layout.ops.register_ops import spatial, local, column_spatial, column_local + + +def generate_wgmma_register_layout(num_column, dtype) -> RegisterLayout: # Same for A and D + T = 32 // dtype.nbits + return column_spatial(4).column_local(2, num_column // 8).spatial(8, 4).local(T) + +@register_rule(WgmmaMmaSSInst) +class WgmmaMmaSSRule(LayoutInferenceRule): + @staticmethod + def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedTensor, SharedLayout]: + a_tensor: SharedTensor = inst.inputs[0].as_shared_tensor() + b_tensor: SharedTensor = inst.inputs[1].as_shared_tensor() + d_tensor: RegisterTensor = inst.inputs[2].as_register_tensor() + + a_shape = a_tensor.shape + b_shape = b_tensor.shape + d_shape = d_tensor.shape + + if not len(a_shape) == len(b_shape) == len(d_shape) == 2: + raise LayoutInferenceError( + f"A, B, and D must have 2 dimensions, but got {len(a_shape)}, {len(b_shape)}, and {len(d_shape)}." + ) + if a_shape[1] != b_shape[0] or a_shape[0] != d_shape[0] or b_shape[1] != d_shape[1]: + raise LayoutInferenceError( + f"A, B, and D must have compatible shapes, but got {a_tensor.shape}, {b_tensor.shape}, and {d_tensor.shape}." + ) + m, n, k = d_shape[0], d_shape[1], a_shape[1] + + ret = {} + if not a_tensor.has_layout(): + for swizzle_mode in [ + Tcgen05SwizzleMode.B128_SWIZZLE, + Tcgen05SwizzleMode.B64_SWIZZLE, + Tcgen05SwizzleMode.B32_SWIZZLE, + Tcgen05SwizzleMode.NO_SWIZZLE, + ]: + try: + a_layout_canonical = generate_canonical_layout( + shape=(m, k), dtype=a_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode + ) + ret[a_tensor] = a_layout_canonical.as_shared_layout().simplify() + except ValueError: + continue + else: + break + if not b_tensor.has_layout(): + for swizzle_mode in [ + Tcgen05SwizzleMode.B128_SWIZZLE, + Tcgen05SwizzleMode.B64_SWIZZLE, + Tcgen05SwizzleMode.B32_SWIZZLE, + Tcgen05SwizzleMode.NO_SWIZZLE, + ]: + try: + b_layout_canonical = generate_canonical_layout( + shape=(n, k), dtype=b_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode + ) + ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]).simplify() + except ValueError: + continue + else: + break + if not d_tensor.has_layout(): + d_layout = generate_wgmma_register_layout(n, d_tensor.dtype) + ret[d_tensor] = d_layout + return ret \ No newline at end of file From 0d2d84b439c35d3850c407ed1da4fa4547e5bb5a Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 2 Dec 2025 22:22:01 -0800 Subject: [PATCH 05/16] upd Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_wgmma.py | 5 +++-- python/tilus/backends/emitters/cuda/__init__.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py index 99ebfc62..1bdf9342 100644 --- a/examples/hopper_matmul/matmul_wgmma.py +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -13,8 +13,9 @@ tilus.option.cache_dir("./cache") tilus.option.debug.dump_ir(True) -@tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) -@tilus.autotune("block_k", [16, 32, 16]) +# @tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) +@tilus.autotune("block_m, block_n", [(64, 128),]) +@tilus.autotune("block_k", [16,]) class MatmulTMA(tilus.Script): def __init__( self, diff --git a/python/tilus/backends/emitters/cuda/__init__.py b/python/tilus/backends/emitters/cuda/__init__.py index d5af74f1..df3ed155 100644 --- a/python/tilus/backends/emitters/cuda/__init__.py +++ b/python/tilus/backends/emitters/cuda/__init__.py @@ -24,4 +24,5 @@ semaphore, simt_dot, tcgen05, + wgmma, ) From 9b289957a445caa309350785a6403221d3eb4429 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 3 Dec 2025 13:35:43 -0800 Subject: [PATCH 06/16] runnable Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_wgmma.py | 12 ++++++------ .../ir/layout/inference/inference_rules/wgmma.py | 6 +++--- python/tilus/ir/layout/inference/order.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py index 1bdf9342..5a9a9337 100644 --- a/examples/hopper_matmul/matmul_wgmma.py +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -12,11 +12,12 @@ tilus.option.cache_dir("./cache") tilus.option.debug.dump_ir(True) +torch.set_printoptions(precision=4, sci_mode=False) # @tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) @tilus.autotune("block_m, block_n", [(64, 128),]) @tilus.autotune("block_k", [16,]) -class MatmulTMA(tilus.Script): +class MatmulWGMMA(tilus.Script): def __init__( self, block_m, @@ -66,8 +67,6 @@ def __call__( # synchronize threads in the block to ensure data is available in shared memory self.sync() - # a = self.load_shared(sa) - # b = self.load_shared(sb) self.wgmma.fence() self.wgmma.mma(sa, sb.transpose(), acc) self.wgmma.commit_group() @@ -87,11 +86,12 @@ def main(): headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] workloads = [ [4096, 4096, 4096], + # [128, 16, 32], ] rows = [] for m, n, k in workloads: - matmul = MatmulTMA() + matmul = MatmulWGMMA() a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) @@ -99,13 +99,13 @@ def main(): c_expect = a @ b.T matmul(m, n, k, a, b, c_actual) torch.cuda.synchronize() - + # check correctness torch.testing.assert_close(c_expect, c_actual) # benchmark for name, func in [ - ("torch", lambda: torch.matmul(a, b, out=c_expect)), + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), ]: latency = benchmark_func(func, warmup=5, repeat=20) diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py index 58a20dfb..065e6dcd 100644 --- a/python/tilus/ir/layout/inference/inference_rules/wgmma.py +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -18,9 +18,9 @@ from tilus.ir.layout.ops.register_ops import spatial, local, column_spatial, column_local -def generate_wgmma_register_layout(num_column, dtype) -> RegisterLayout: # Same for A and D - T = 32 // dtype.nbits - return column_spatial(4).column_local(2, num_column // 8).spatial(8, 4).local(T) +def generate_wgmma_register_layout(num_column, dtype) -> RegisterLayout: + T = 64 // dtype.nbits + return column_spatial(4, 1).column_local(2, num_column // T // 4).spatial(8, 4).local(T) @register_rule(WgmmaMmaSSInst) class WgmmaMmaSSRule(LayoutInferenceRule): diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index 61fe7e62..383c9a78 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -52,6 +52,7 @@ # register layout rules [SliceRegisterRule, SliceAssignRule, AllocBarrierRule], [MmaDotRule], + [WgmmaMmaSSRule], #, WgmmaMmaRSRule], [Tcgen05LoadRule, Tcgen05StoreRule], [Tcgen05CopyRule], [BinaryRule, UnaryRule], @@ -66,7 +67,6 @@ # shared memory rules [LoadSharedInferSwizzledSharedRule, StoreSharedSwizzleRule], [SharedSliceRule, PermuteSharedRule], - [WgmmaMmaSSRule], #, WgmmaMmaRSRule], [CopyAsyncRule], [LoadSharedInferRegisterRule], [LoadSharedInferRowMajorSharedRule], From b81c9c2b2aba962c8640e16aa7047c456251a45a Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Thu, 4 Dec 2025 00:10:19 -0800 Subject: [PATCH 07/16] upd Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_wgmma.py | 19 ++++++++++----- python/tilus/ir/instructions/cuda/wgmma.py | 24 +++++++++++++++++++ .../layout/inference/inference_rules/wgmma.py | 9 ++++--- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py index 5a9a9337..5f3fa746 100644 --- a/examples/hopper_matmul/matmul_wgmma.py +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -12,11 +12,11 @@ tilus.option.cache_dir("./cache") tilus.option.debug.dump_ir(True) -torch.set_printoptions(precision=4, sci_mode=False) +torch.set_printoptions(precision=3, sci_mode=False, linewidth=160) # @tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) -@tilus.autotune("block_m, block_n", [(64, 128),]) -@tilus.autotune("block_k", [16,]) +@tilus.autotune("block_m, block_n", [(128, 16),]) +@tilus.autotune("block_k", [16]) class MatmulWGMMA(tilus.Script): def __init__( self, @@ -70,8 +70,9 @@ def __call__( self.wgmma.fence() self.wgmma.mma(sa, sb.transpose(), acc) self.wgmma.commit_group() - self.wgmma.wait_group(1) + self.wgmma.wait_group(0) self.sync() + self.wgmma.fence() phase ^= 1 self.free_shared(sa) @@ -85,8 +86,8 @@ def __call__( def main(): headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] workloads = [ - [4096, 4096, 4096], - # [128, 16, 32], + # [4096, 4096, 4096], + [128, 16, 16], ] rows = [] @@ -99,6 +100,12 @@ def main(): c_expect = a @ b.T matmul(m, n, k, a, b, c_actual) torch.cuda.synchronize() + + v = 8 + print(c_actual[:v, :v]) + print(c_expect[:v, :v]) + print(c_actual[64:][:v, :v]) + print(c_expect[64:][:v, :v]) # check correctness torch.testing.assert_close(c_expect, c_actual) diff --git a/python/tilus/ir/instructions/cuda/wgmma.py b/python/tilus/ir/instructions/cuda/wgmma.py index df7a3f41..dd5b1160 100644 --- a/python/tilus/ir/instructions/cuda/wgmma.py +++ b/python/tilus/ir/instructions/cuda/wgmma.py @@ -5,9 +5,12 @@ from hidet.ir.expr import Constant, Expr from hidet.ir.type import DataType +from hidet.ir.dtypes import f16, bf16, tf32, f8e4m3, f8e5m2, i8, u8, u1 + from tilus.ir.inst import Instruction from tilus.ir.tensor import RegisterTensor, SharedTensor +from tilus.utils import gcd @dataclass(frozen=True, eq=False) class WgmmaFenceInst(Instruction): @@ -30,6 +33,27 @@ def create(n: Expr) -> WgmmaWaitGroupInst: @dataclass(frozen=True, eq=False) class WgmmaMmaSSInst(Instruction): + + @staticmethod + def get_inst_mnk(m: int, n: int, k: int, a_dtype: DataType, b_dtype: DataType, d_dtype: DataType) -> tuple[int, int, int]: + inst_m = 64 + inst_n = gcd(n, 256) # why? + if a_dtype == b_dtype == f16: + inst_k = 16 + elif a_dtype == b_dtype == bf16: + inst_k = 16 + elif a_dtype == b_dtype == tf32: + inst_k = 8 + elif a_dtype in (f8e4m3, f8e5m2) and b_dtype in (f8e4m3, f8e5m2): + inst_k = 32 + elif a_dtype in (i8, u8) and b_dtype in (i8, u8): + inst_k = 32 + elif a_dtype == d_dtype == u1: + inst_k = 256 + else: + raise ValueError(f"Unsupported data types for MMA: a_dtype={a_dtype}, b_dtype={b_dtype}") + return inst_m, inst_n, inst_k + @staticmethod def create(a: SharedTensor, b: SharedTensor, d: RegisterTensor) -> WgmmaMmaSSInst: return WgmmaMmaSSInst(output=None, inputs=(a, b, d)) diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py index 065e6dcd..2577dbbc 100644 --- a/python/tilus/ir/layout/inference/inference_rules/wgmma.py +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -18,9 +18,11 @@ from tilus.ir.layout.ops.register_ops import spatial, local, column_spatial, column_local -def generate_wgmma_register_layout(num_column, dtype) -> RegisterLayout: +def generate_wgmma_register_layout(num_row, num_column, inst_m, inst_n, dtype) -> RegisterLayout: T = 64 // dtype.nbits - return column_spatial(4, 1).column_local(2, num_column // T // 4).spatial(8, 4).local(T) + repeat_m = num_row // inst_m + repeat_n = num_column // inst_n + return local(repeat_m, repeat_n).column_spatial(inst_m // 16, 1).column_local(2, inst_n // T // 4).spatial(8, 4).local(T) @register_rule(WgmmaMmaSSInst) class WgmmaMmaSSRule(LayoutInferenceRule): @@ -78,6 +80,7 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT else: break if not d_tensor.has_layout(): - d_layout = generate_wgmma_register_layout(n, d_tensor.dtype) + inst_m, inst_n, inst_k = WgmmaMmaSSInst.get_inst_mnk(m, n, k, a_tensor.dtype, b_tensor.dtype, d_tensor.dtype) + d_layout = generate_wgmma_register_layout(m, n, inst_m, inst_n, d_tensor.dtype) ret[d_tensor] = d_layout return ret \ No newline at end of file From 0119b66aa3ba7e741514f6fa1f1692820017316d Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Thu, 4 Dec 2025 00:16:20 -0800 Subject: [PATCH 08/16] upd Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_wgmma.py | 19 ++- python/tilus/backends/emitters/cuda/wgmma.py | 152 ++++++++++++++++++ .../hidet/ir/primitives/cuda/wgmma.py | 49 ++++++ 3 files changed, 210 insertions(+), 10 deletions(-) create mode 100644 python/tilus/backends/emitters/cuda/wgmma.py create mode 100644 python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py index 5f3fa746..0add34d8 100644 --- a/examples/hopper_matmul/matmul_wgmma.py +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -15,8 +15,8 @@ torch.set_printoptions(precision=3, sci_mode=False, linewidth=160) # @tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) -@tilus.autotune("block_m, block_n", [(128, 16),]) -@tilus.autotune("block_k", [16]) +@tilus.autotune("block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)]) +@tilus.autotune("block_k", [16, 32, 64]) class MatmulWGMMA(tilus.Script): def __init__( self, @@ -72,7 +72,6 @@ def __call__( self.wgmma.commit_group() self.wgmma.wait_group(0) self.sync() - self.wgmma.fence() phase ^= 1 self.free_shared(sa) @@ -86,8 +85,8 @@ def __call__( def main(): headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] workloads = [ - # [4096, 4096, 4096], - [128, 16, 16], + [4096, 4096, 4096], + # [128, 48, 16], ] rows = [] @@ -101,11 +100,11 @@ def main(): matmul(m, n, k, a, b, c_actual) torch.cuda.synchronize() - v = 8 - print(c_actual[:v, :v]) - print(c_expect[:v, :v]) - print(c_actual[64:][:v, :v]) - print(c_expect[64:][:v, :v]) + # v = 8 + # print(c_actual[:v, :v]) + # print(c_expect[:v, :v]) + # print(c_actual[64:][:v, :v]) + # print(c_expect[64:][:v, :v]) # check correctness torch.testing.assert_close(c_expect, c_actual) diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py new file mode 100644 index 00000000..19e41e5f --- /dev/null +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from enum import Enum +from dataclasses import dataclass + +from hidet.ir.dtypes import uint32, DataType +from hidet.ir.expr import Expr, cast, if_then_else, Var +from hidet.ir.primitives.cuda.wgmma import wgmma_fence, wgmma_commit_group, wgmma_wait_group, wgmma_async, WgmmaConfig + + +from tilus.backends.emitter import BaseInstEmitter, register_emitter +from tilus.ir.layout import LayoutOperationError, RegisterLayout +from tilus.ir.tensor import RegisterTensor, SharedTensor +from tilus.target import nvgpu_sm90a +from tilus.ir.layout.utils.cute import CuteLayout + +from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst, WgmmaMmaRSInst, WgmmaFenceInst, WgmmaCommitGroupInst, WgmmaWaitGroupInst +from tilus.ir.layout.cuda.tcgen05.smem import canonicalize_shared_layout + +from tilus.extensions.hidet.ir.primitives.cuda.wgmma import wgmma_encode_smem_descriptor +from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode + +def encode_swizzle_mode(swizzle_mode: Tcgen05SwizzleMode) -> int: + swizzle_mode_map = { + Tcgen05SwizzleMode.NO_SWIZZLE: 0, + Tcgen05SwizzleMode.B32_SWIZZLE: 3, + Tcgen05SwizzleMode.B64_SWIZZLE: 2, + Tcgen05SwizzleMode.B128_SWIZZLE: 1, + } + return swizzle_mode_map[swizzle_mode] + + +@dataclass +class SharedMatrixDescriptor: + addr: Expr | int + lbo: int + sbo: int + base_offset: int + swizzle_mode: int + + def encoded(self) -> Expr: + return wgmma_encode_smem_descriptor( + self.addr >> 4, + self.lbo >> 4, + self.sbo >> 4, + self.base_offset, + self.swizzle_mode, + ) + + @staticmethod + def decode(encoded: int) -> SharedMatrixDescriptor: + return SharedMatrixDescriptor( + addr=(encoded & 0x3FFF) << 4, + lbo=((encoded >> 16) & 0x3FFF) << 4, + sbo=((encoded >> 32) & 0x3FFF) << 4, + base_offset=(encoded >> 49) & 0x7, + swizzle_mode=(encoded >> 62) & 0x3, + ) + + + +@register_emitter(WgmmaFenceInst, target=nvgpu_sm90a) +class WgmmaFenceEmitter(BaseInstEmitter): + def emit(self, inst: WgmmaFenceInst) -> None: + self.append(wgmma_fence()) + +@register_emitter(WgmmaCommitGroupInst, target=nvgpu_sm90a) +class WgmmaCommitGroupEmitter(BaseInstEmitter): + def emit(self, inst: WgmmaCommitGroupInst) -> None: + self.append(wgmma_commit_group()) + +@register_emitter(WgmmaWaitGroupInst, target=nvgpu_sm90a) +class WgmmaWaitGroupEmitter(BaseInstEmitter): + def emit(self, inst: WgmmaWaitGroupInst) -> None: + self.append(wgmma_wait_group(inst.n)) + +@register_emitter(WgmmaMmaSSInst, target=nvgpu_sm90a) +class WgmmaMmaSSEmitter(BaseInstEmitter): + + def emit(self, inst: WgmmaMmaSSInst) -> None: + a, b, d = inst.inputs + a_tensor: SharedTensor = a.as_shared_tensor() + b_tensor: SharedTensor = b.as_shared_tensor() + d_tensor: RegisterTensor = d.as_register_tensor() + + a_shape = a_tensor.shape + b_shape = b_tensor.shape + d_shape = d_tensor.shape + + if len(a_shape) != 2 or len(b_shape) != 2 or len(d_shape) != 2: + raise ValueError(f"MMA requires 2D tensors, but got shapes: a={a_shape}, b={b_shape}, d={d_shape}") + if a_shape[1] != b_shape[0] or a_shape[0] != d_shape[0] or b_shape[1] != d_shape[1]: + raise ValueError(f"Incompatible shapes for MMA: a={a_shape}, b={b_shape}, d={d_shape}") + m, n, k = d_shape[0], d_shape[1], a_shape[1] + + a_dtype = a_tensor.dtype + b_dtype = b_tensor.dtype + d_dtype = d_tensor.dtype + + inst_m, inst_n, inst_k = inst.get_inst_mnk(m, n, k, a_dtype, b_dtype, d_dtype) + print(f"inst_m: {inst_m}, inst_n: {inst_n}, inst_k: {inst_k}") + wgmma_config = WgmmaConfig.get(inst_m, inst_n, inst_k, a_dtype.short_name, b_dtype.short_name, d_dtype.short_name) + + repeat_m = m // inst_m + repeat_n = n // inst_n + repeat_k = k // inst_k + print(f"repeat_m: {repeat_m}, repeat_n: {repeat_n}, repeat_k: {repeat_k}") + + a_canonical = canonicalize_shared_layout(a_tensor.layout, dtype=a_dtype) + b_canonical = canonicalize_shared_layout(b_tensor.layout.transpose(), dtype=b_dtype) + + if a_canonical is None: + raise ValueError(f"Cannot canonicalize the layout of a tensor: {a_tensor.layout}.") + if b_canonical is None: + raise ValueError(f"Cannot canonicalize the layout of b tensor: {b_tensor.layout}.") + + a_cute_layout: CuteLayout = a_canonical.swizzled_cute_layout.layout + b_cute_layout: CuteLayout = b_canonical.swizzled_cute_layout.layout + + a_shared_addr: Var = self.shared_tensor_shared_space_addr[a_tensor] + b_shared_addr: Var = self.shared_tensor_shared_space_addr[b_tensor] + d_register_addr: Var = ~(self.tensor2var[d_tensor][0]) + + d_local_stride = d_tensor.layout.local_size // repeat_m // repeat_n + + for k in range(repeat_k): + for i in range(repeat_m): + for j in range(repeat_n): + a_offset = a_cute_layout(i * inst_m, k * inst_k) + b_offset = b_cute_layout(j * inst_n, k * inst_k) + a_desc = SharedMatrixDescriptor( + addr=a_shared_addr + a_offset * a_tensor.dtype.nbytes, + lbo=a_canonical.LBO * a_tensor.dtype.nbytes, + sbo=a_canonical.SBO * a_tensor.dtype.nbytes, + base_offset=0, + swizzle_mode=encode_swizzle_mode(a_canonical.swizzle_mode), + ) + b_desc = SharedMatrixDescriptor( + addr=b_shared_addr + b_offset * b_tensor.dtype.nbytes, + lbo=b_canonical.LBO * b_tensor.dtype.nbytes, + sbo=b_canonical.SBO * b_tensor.dtype.nbytes, + base_offset=0, + swizzle_mode=encode_swizzle_mode(b_canonical.swizzle_mode), + ) + d_offset = (i * repeat_n + j) * d_local_stride + self.append(wgmma_async(wgmma_config, + a_desc.encoded(), + d_register_addr + d_offset, + b_desc.encoded(), + trans_a=0, + trans_b=0, + )) \ No newline at end of file diff --git a/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py new file mode 100644 index 00000000..2e53a9cd --- /dev/null +++ b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py @@ -0,0 +1,49 @@ +from enum import Enum +from typing import Optional, Sequence, no_type_check + +from hidet.ir.dtypes import int32, uint8, uint32, uint64 +from hidet.ir.expr import Expr, as_expr +from hidet.ir.primitives.func import call_primitive_func +from hidet.ir.stmt import asm +from hidet.utils import initialize + +from tilus.extensions.hidet.ir.primitives.utils import register_primitive_function_decorator + +@initialize() +def register_wgmma_instructions(): + from hidet.lang import attrs, meta + + from tilus.extensions.hidet.lang import script + + @register_primitive_function_decorator + @no_type_check + @script + def wgmma_encode_smem_descriptor( + smem_addr: uint32, # 14 bits + lbo: uint32, # 14 bits + sbo: uint32, # 14 bits + mbo: uint8, # 3 bits + swizzle_mode: uint8, # 2 bits + ) -> uint64: + attrs.func_name = "cuda_wgmma_encode_smem_descriptor" + attrs.func_kind = "cuda_internal" + desc: uint64 = uint64(0) + desc = desc | uint64(lbo & uint32(0x3FFF)) << 16 + desc = desc | uint64(sbo & uint32(0x3FFF)) << 32 + desc = desc | uint64(mbo & uint8(0b111)) << 49 + desc = desc | uint64(swizzle_mode & uint8(0b11)) << 62 + desc = desc | uint64(smem_addr & uint32(0x3FFF)) + return desc + + +def wgmma_encode_smem_descriptor( + smem_addr: Expr | int, + lbo: Expr | int, + sbo: Expr | int, + mbo: Expr | int, + swizzle_mode: Expr | int, +) -> Expr: + func_name = "cuda_wgmma_encode_smem_descriptor" + return call_primitive_func( + func_name, [uint32(smem_addr), uint32(lbo), uint32(sbo), uint8(mbo), uint8(swizzle_mode)] + ) \ No newline at end of file From 01237e55f14a5b7799bf5d969d296bc8553718b7 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Thu, 4 Dec 2025 11:50:05 -0800 Subject: [PATCH 09/16] upd Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_wgmma.py | 7 ------ python/tilus/backends/emitters/cuda/wgmma.py | 14 +++++++++++ .../hidet/ir/primitives/cuda/wgmma.py | 14 +++++++++++ python/tilus/ir/instructions/cuda/wgmma.py | 14 +++++++++++ .../layout/inference/inference_rules/wgmma.py | 25 +++++++++++++++---- python/tilus/lang/instructions/wgmma.py | 14 +++++++++++ 6 files changed, 76 insertions(+), 12 deletions(-) diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py index 0add34d8..4860ffda 100644 --- a/examples/hopper_matmul/matmul_wgmma.py +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -14,7 +14,6 @@ tilus.option.debug.dump_ir(True) torch.set_printoptions(precision=3, sci_mode=False, linewidth=160) -# @tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) @tilus.autotune("block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)]) @tilus.autotune("block_k", [16, 32, 64]) class MatmulWGMMA(tilus.Script): @@ -100,12 +99,6 @@ def main(): matmul(m, n, k, a, b, c_actual) torch.cuda.synchronize() - # v = 8 - # print(c_actual[:v, :v]) - # print(c_expect[:v, :v]) - # print(c_actual[64:][:v, :v]) - # print(c_expect[64:][:v, :v]) - # check correctness torch.testing.assert_close(c_expect, c_actual) diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py index 19e41e5f..57a435c2 100644 --- a/python/tilus/backends/emitters/cuda/wgmma.py +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations from enum import Enum diff --git a/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py index 2e53a9cd..1a8b4e22 100644 --- a/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py +++ b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from enum import Enum from typing import Optional, Sequence, no_type_check diff --git a/python/tilus/ir/instructions/cuda/wgmma.py b/python/tilus/ir/instructions/cuda/wgmma.py index dd5b1160..efbd265e 100644 --- a/python/tilus/ir/instructions/cuda/wgmma.py +++ b/python/tilus/ir/instructions/cuda/wgmma.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations from dataclasses import dataclass diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py index 2577dbbc..b3a65049 100644 --- a/python/tilus/ir/layout/inference/inference_rules/wgmma.py +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from tilus.ir.instructions.cuda.tcgen05 import Tcgen05MmaSSInst, Tcgen05MmaTSInst from tilus.ir.layout import SharedLayout from tilus.ir.layout.cuda.tcgen05.smem import ( @@ -18,10 +32,11 @@ from tilus.ir.layout.ops.register_ops import spatial, local, column_spatial, column_local -def generate_wgmma_register_layout(num_row, num_column, inst_m, inst_n, dtype) -> RegisterLayout: - T = 64 // dtype.nbits - repeat_m = num_row // inst_m - repeat_n = num_column // inst_n +def generate_wgmma_register_layout(m, n, inst_m, inst_n, inst_k) -> RegisterLayout: + # See also: https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-register-fragment + T = inst_k // 8 + repeat_m = m // inst_m + repeat_n = n // inst_n return local(repeat_m, repeat_n).column_spatial(inst_m // 16, 1).column_local(2, inst_n // T // 4).spatial(8, 4).local(T) @register_rule(WgmmaMmaSSInst) @@ -81,6 +96,6 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT break if not d_tensor.has_layout(): inst_m, inst_n, inst_k = WgmmaMmaSSInst.get_inst_mnk(m, n, k, a_tensor.dtype, b_tensor.dtype, d_tensor.dtype) - d_layout = generate_wgmma_register_layout(m, n, inst_m, inst_n, d_tensor.dtype) + d_layout = generate_wgmma_register_layout(m, n, inst_m, inst_n, inst_k) ret[d_tensor] = d_layout return ret \ No newline at end of file diff --git a/python/tilus/lang/instructions/wgmma.py b/python/tilus/lang/instructions/wgmma.py index 318d759d..5f457893 100644 --- a/python/tilus/lang/instructions/wgmma.py +++ b/python/tilus/lang/instructions/wgmma.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import contextlib from typing import Optional, Sequence, Union From 088a2b0d07e59c0089cd0abeb4152889e5858d05 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Thu, 4 Dec 2025 14:04:17 -0800 Subject: [PATCH 10/16] format Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_tma.py | 10 ++-- examples/hopper_matmul/matmul_wgmma.py | 14 +++-- python/tilus/backends/emitters/cuda/wgmma.py | 52 +++++++++++-------- .../hidet/ir/primitives/cuda/wgmma.py | 13 +++-- python/tilus/ir/builders/stmt_builder.py | 6 +-- python/tilus/ir/instructions/cuda/wgmma.py | 23 ++++---- .../inference/inference_rules/__init__.py | 2 +- .../layout/inference/inference_rules/wgmma.py | 29 ++++++----- python/tilus/ir/layout/inference/order.py | 4 +- .../inference/validation_rules/always_ok.py | 7 +-- python/tilus/lang/instructions/__init__.py | 2 +- python/tilus/lang/instructions/wgmma.py | 6 +-- 12 files changed, 96 insertions(+), 72 deletions(-) diff --git a/examples/hopper_matmul/matmul_tma.py b/examples/hopper_matmul/matmul_tma.py index 2af4015c..106de870 100644 --- a/examples/hopper_matmul/matmul_tma.py +++ b/examples/hopper_matmul/matmul_tma.py @@ -9,10 +9,10 @@ from tilus import float16, float32, int32, uint32 from tilus.utils import benchmark_func, cdiv - tilus.option.cache_dir("./cache") tilus.option.debug.dump_ir(True) + @tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) @tilus.autotune("block_k", [16, 32, 16]) class MatmulTMA(tilus.Script): @@ -58,8 +58,12 @@ def __call__( for offset_k in range(0, k_size, block_k): # issue asynchronous copy instructions to load tiles of A and B with self.single_thread(): - self.tma.global_to_shared(src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier) - self.tma.global_to_shared(src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier) + self.tma.global_to_shared( + src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier + ) + self.tma.global_to_shared( + src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier + ) self.mbarrier.wait(tma_barrier, phase=phase) # synchronize threads in the block to ensure data is available in shared memory diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_wgmma.py index 4860ffda..9e2d9fd9 100644 --- a/examples/hopper_matmul/matmul_wgmma.py +++ b/examples/hopper_matmul/matmul_wgmma.py @@ -9,12 +9,14 @@ from tilus import float16, float32, int32, uint32 from tilus.utils import benchmark_func, cdiv - tilus.option.cache_dir("./cache") tilus.option.debug.dump_ir(True) torch.set_printoptions(precision=3, sci_mode=False, linewidth=160) -@tilus.autotune("block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)]) + +@tilus.autotune( + "block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)] +) @tilus.autotune("block_k", [16, 32, 64]) class MatmulWGMMA(tilus.Script): def __init__( @@ -59,8 +61,12 @@ def __call__( for offset_k in range(0, k_size, block_k): # issue asynchronous copy instructions to load tiles of A and B with self.single_thread(): - self.tma.global_to_shared(src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier) - self.tma.global_to_shared(src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier) + self.tma.global_to_shared( + src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier + ) + self.tma.global_to_shared( + src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier + ) self.mbarrier.wait(tma_barrier, phase=phase) # synchronize threads in the block to ensure data is available in shared memory diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py index 57a435c2..48fa33f4 100644 --- a/python/tilus/backends/emitters/cuda/wgmma.py +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -14,25 +14,25 @@ # limitations under the License. from __future__ import annotations -from enum import Enum from dataclasses import dataclass -from hidet.ir.dtypes import uint32, DataType -from hidet.ir.expr import Expr, cast, if_then_else, Var -from hidet.ir.primitives.cuda.wgmma import wgmma_fence, wgmma_commit_group, wgmma_wait_group, wgmma_async, WgmmaConfig - +from hidet.ir.expr import Expr, Var +from hidet.ir.primitives.cuda.wgmma import WgmmaConfig, wgmma_async, wgmma_commit_group, wgmma_fence, wgmma_wait_group from tilus.backends.emitter import BaseInstEmitter, register_emitter -from tilus.ir.layout import LayoutOperationError, RegisterLayout +from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode +from tilus.extensions.hidet.ir.primitives.cuda.wgmma import wgmma_encode_smem_descriptor +from tilus.ir.instructions.cuda.wgmma import ( + WgmmaCommitGroupInst, + WgmmaFenceInst, + WgmmaMmaSSInst, + WgmmaWaitGroupInst, +) +from tilus.ir.layout.cuda.tcgen05.smem import canonicalize_shared_layout +from tilus.ir.layout.utils.cute import CuteLayout from tilus.ir.tensor import RegisterTensor, SharedTensor from tilus.target import nvgpu_sm90a -from tilus.ir.layout.utils.cute import CuteLayout - -from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst, WgmmaMmaRSInst, WgmmaFenceInst, WgmmaCommitGroupInst, WgmmaWaitGroupInst -from tilus.ir.layout.cuda.tcgen05.smem import canonicalize_shared_layout -from tilus.extensions.hidet.ir.primitives.cuda.wgmma import wgmma_encode_smem_descriptor -from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode def encode_swizzle_mode(swizzle_mode: Tcgen05SwizzleMode) -> int: swizzle_mode_map = { @@ -72,25 +72,26 @@ def decode(encoded: int) -> SharedMatrixDescriptor: ) - @register_emitter(WgmmaFenceInst, target=nvgpu_sm90a) class WgmmaFenceEmitter(BaseInstEmitter): def emit(self, inst: WgmmaFenceInst) -> None: self.append(wgmma_fence()) + @register_emitter(WgmmaCommitGroupInst, target=nvgpu_sm90a) class WgmmaCommitGroupEmitter(BaseInstEmitter): def emit(self, inst: WgmmaCommitGroupInst) -> None: self.append(wgmma_commit_group()) + @register_emitter(WgmmaWaitGroupInst, target=nvgpu_sm90a) class WgmmaWaitGroupEmitter(BaseInstEmitter): def emit(self, inst: WgmmaWaitGroupInst) -> None: self.append(wgmma_wait_group(inst.n)) + @register_emitter(WgmmaMmaSSInst, target=nvgpu_sm90a) class WgmmaMmaSSEmitter(BaseInstEmitter): - def emit(self, inst: WgmmaMmaSSInst) -> None: a, b, d = inst.inputs a_tensor: SharedTensor = a.as_shared_tensor() @@ -113,8 +114,10 @@ def emit(self, inst: WgmmaMmaSSInst) -> None: inst_m, inst_n, inst_k = inst.get_inst_mnk(m, n, k, a_dtype, b_dtype, d_dtype) print(f"inst_m: {inst_m}, inst_n: {inst_n}, inst_k: {inst_k}") - wgmma_config = WgmmaConfig.get(inst_m, inst_n, inst_k, a_dtype.short_name, b_dtype.short_name, d_dtype.short_name) - + wgmma_config = WgmmaConfig.get( + inst_m, inst_n, inst_k, a_dtype.short_name, b_dtype.short_name, d_dtype.short_name + ) + repeat_m = m // inst_m repeat_n = n // inst_n repeat_k = k // inst_k @@ -157,10 +160,13 @@ def emit(self, inst: WgmmaMmaSSInst) -> None: swizzle_mode=encode_swizzle_mode(b_canonical.swizzle_mode), ) d_offset = (i * repeat_n + j) * d_local_stride - self.append(wgmma_async(wgmma_config, - a_desc.encoded(), - d_register_addr + d_offset, - b_desc.encoded(), - trans_a=0, - trans_b=0, - )) \ No newline at end of file + self.append( + wgmma_async( + wgmma_config, + a_desc.encoded(), + d_register_addr + d_offset, + b_desc.encoded(), + trans_a=0, + trans_b=0, + ) + ) diff --git a/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py index 1a8b4e22..248c9717 100644 --- a/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py +++ b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py @@ -12,20 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum -from typing import Optional, Sequence, no_type_check +from typing import no_type_check -from hidet.ir.dtypes import int32, uint8, uint32, uint64 -from hidet.ir.expr import Expr, as_expr +from hidet.ir.dtypes import uint8, uint32, uint64 +from hidet.ir.expr import Expr from hidet.ir.primitives.func import call_primitive_func -from hidet.ir.stmt import asm from hidet.utils import initialize from tilus.extensions.hidet.ir.primitives.utils import register_primitive_function_decorator + @initialize() def register_wgmma_instructions(): - from hidet.lang import attrs, meta + from hidet.lang import attrs from tilus.extensions.hidet.lang import script @@ -60,4 +59,4 @@ def wgmma_encode_smem_descriptor( func_name = "cuda_wgmma_encode_smem_descriptor" return call_primitive_func( func_name, [uint32(smem_addr), uint32(lbo), uint32(sbo), uint8(mbo), uint8(swizzle_mode)] - ) \ No newline at end of file + ) diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index b15d0f33..6b362f17 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -70,11 +70,11 @@ Tcgen05WaitInst, ) from tilus.ir.instructions.cuda.wgmma import ( - WgmmaFenceInst, WgmmaCommitGroupInst, - WgmmaWaitGroupInst, - WgmmaMmaSSInst, + WgmmaFenceInst, WgmmaMmaRSInst, + WgmmaMmaSSInst, + WgmmaWaitGroupInst, ) from tilus.ir.instructions.generic import ( AddInst, diff --git a/python/tilus/ir/instructions/cuda/wgmma.py b/python/tilus/ir/instructions/cuda/wgmma.py index efbd265e..9b6306ce 100644 --- a/python/tilus/ir/instructions/cuda/wgmma.py +++ b/python/tilus/ir/instructions/cuda/wgmma.py @@ -15,43 +15,47 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Sequence -from hidet.ir.expr import Constant, Expr +from hidet.ir.dtypes import bf16, f8e4m3, f8e5m2, f16, i8, tf32, u1, u8 +from hidet.ir.expr import Expr from hidet.ir.type import DataType -from hidet.ir.dtypes import f16, bf16, tf32, f8e4m3, f8e5m2, i8, u8, u1 - from tilus.ir.inst import Instruction -from tilus.ir.tensor import RegisterTensor, SharedTensor +from tilus.ir.tensor import RegisterTensor, SharedTensor from tilus.utils import gcd + @dataclass(frozen=True, eq=False) class WgmmaFenceInst(Instruction): @staticmethod def create() -> WgmmaFenceInst: return WgmmaFenceInst(output=None, inputs=()) + @dataclass(frozen=True, eq=False) class WgmmaCommitGroupInst(Instruction): @staticmethod def create() -> WgmmaCommitGroupInst: return WgmmaCommitGroupInst(output=None, inputs=()) + @dataclass(frozen=True, eq=False) class WgmmaWaitGroupInst(Instruction): n: Expr + @staticmethod def create(n: Expr) -> WgmmaWaitGroupInst: return WgmmaWaitGroupInst(output=None, inputs=(), n=n) + @dataclass(frozen=True, eq=False) class WgmmaMmaSSInst(Instruction): - @staticmethod - def get_inst_mnk(m: int, n: int, k: int, a_dtype: DataType, b_dtype: DataType, d_dtype: DataType) -> tuple[int, int, int]: + def get_inst_mnk( + m: int, n: int, k: int, a_dtype: DataType, b_dtype: DataType, d_dtype: DataType + ) -> tuple[int, int, int]: inst_m = 64 - inst_n = gcd(n, 256) # why? + inst_n = gcd(n, 256) # why? if a_dtype == b_dtype == f16: inst_k = 16 elif a_dtype == b_dtype == bf16: @@ -72,8 +76,9 @@ def get_inst_mnk(m: int, n: int, k: int, a_dtype: DataType, b_dtype: DataType, d def create(a: SharedTensor, b: SharedTensor, d: RegisterTensor) -> WgmmaMmaSSInst: return WgmmaMmaSSInst(output=None, inputs=(a, b, d)) + @dataclass(frozen=True, eq=False) class WgmmaMmaRSInst(Instruction): @staticmethod def create(a: RegisterTensor, b: SharedTensor, d: RegisterTensor) -> WgmmaMmaRSInst: - return WgmmaMmaRSInst(output=None, inputs=(a, b, d)) \ No newline at end of file + return WgmmaMmaRSInst(output=None, inputs=(a, b, d)) diff --git a/python/tilus/ir/layout/inference/inference_rules/__init__.py b/python/tilus/ir/layout/inference/inference_rules/__init__.py index 2b4fa6b1..c35543f7 100644 --- a/python/tilus/ir/layout/inference/inference_rules/__init__.py +++ b/python/tilus/ir/layout/inference/inference_rules/__init__.py @@ -30,6 +30,6 @@ transform, transform_shared, transpose, - where, wgmma, + where, ) diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py index b3a65049..630c99f6 100644 --- a/python/tilus/ir/layout/inference/inference_rules/wgmma.py +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tilus.ir.instructions.cuda.tcgen05 import Tcgen05MmaSSInst, Tcgen05MmaTSInst -from tilus.ir.layout import SharedLayout +from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst +from tilus.ir.layout import RegisterLayout, SharedLayout from tilus.ir.layout.cuda.tcgen05.smem import ( Tcgen05SwizzleMode, generate_canonical_layout, @@ -24,20 +24,23 @@ LayoutInferenceRule, register_rule, ) -from tilus.ir.tensor import SharedTensor, RegisterTensor -from tilus.ir.layout import RegisterLayout +from tilus.ir.layout.ops.register_ops import local +from tilus.ir.tensor import RegisterTensor, SharedTensor -from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst, WgmmaMmaRSInst -from tilus.ir.layout.ops.register_ops import spatial, local, column_spatial, column_local - - -def generate_wgmma_register_layout(m, n, inst_m, inst_n, inst_k) -> RegisterLayout: +def generate_wgmma_register_layout(m: int, n: int, inst_m: int, inst_n: int, inst_k: int) -> RegisterLayout: # See also: https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-register-fragment T = inst_k // 8 repeat_m = m // inst_m repeat_n = n // inst_n - return local(repeat_m, repeat_n).column_spatial(inst_m // 16, 1).column_local(2, inst_n // T // 4).spatial(8, 4).local(T) + return ( + local(repeat_m, repeat_n) + .column_spatial(inst_m // 16, 1) + .column_local(2, inst_n // T // 4) + .spatial(8, 4) + .local(T) + ) + @register_rule(WgmmaMmaSSInst) class WgmmaMmaSSRule(LayoutInferenceRule): @@ -95,7 +98,9 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT else: break if not d_tensor.has_layout(): - inst_m, inst_n, inst_k = WgmmaMmaSSInst.get_inst_mnk(m, n, k, a_tensor.dtype, b_tensor.dtype, d_tensor.dtype) + inst_m, inst_n, inst_k = WgmmaMmaSSInst.get_inst_mnk( + m, n, k, a_tensor.dtype, b_tensor.dtype, d_tensor.dtype + ) d_layout = generate_wgmma_register_layout(m, n, inst_m, inst_n, inst_k) ret[d_tensor] = d_layout - return ret \ No newline at end of file + return ret diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index 383c9a78..ef9eb628 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -40,10 +40,10 @@ from .inference_rules.tcgen05.ldst import Tcgen05LoadRule, Tcgen05StoreRule from .inference_rules.tcgen05.mma import Tcgen05MmaSSRule, Tcgen05MmaTSRule from .inference_rules.tcgen05.slice import Tcgen05SliceRule -from .inference_rules.wgmma import WgmmaMmaSSRule #, WgmmaMmaRSRule from .inference_rules.transform import SqueezeRule, UnsqueezeRule from .inference_rules.transform_shared import PermuteSharedRule, SharedSliceRule from .inference_rules.transpose import TransposeRule +from .inference_rules.wgmma import WgmmaMmaSSRule # , WgmmaMmaRSRule from .inference_rules.where import WhereRule inference_order: list[list[Type[LayoutInferenceRule]]] = [ @@ -52,7 +52,7 @@ # register layout rules [SliceRegisterRule, SliceAssignRule, AllocBarrierRule], [MmaDotRule], - [WgmmaMmaSSRule], #, WgmmaMmaRSRule], + [WgmmaMmaSSRule], # , WgmmaMmaRSRule], [Tcgen05LoadRule, Tcgen05StoreRule], [Tcgen05CopyRule], [BinaryRule, UnaryRule], diff --git a/python/tilus/ir/layout/inference/validation_rules/always_ok.py b/python/tilus/ir/layout/inference/validation_rules/always_ok.py index 4a360cae..6d2c1248 100644 --- a/python/tilus/ir/layout/inference/validation_rules/always_ok.py +++ b/python/tilus/ir/layout/inference/validation_rules/always_ok.py @@ -48,11 +48,12 @@ ) from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixInst from tilus.ir.instructions.cuda.tcgen05 import Tcgen05CopyInst, Tcgen05LoadInst, Tcgen05MmaSSInst, Tcgen05StoreInst +from tilus.ir.instructions.cuda.wgmma import WgmmaMmaRSInst, WgmmaMmaSSInst from tilus.ir.layout.inference.rule import LayoutValidationRule, register_rule -from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst, WgmmaMmaRSInst -@register_rule(WgmmaMmaSSInst) # todo: should have its own rule -@register_rule(WgmmaMmaRSInst) # todo: should have its own rule + +@register_rule(WgmmaMmaSSInst) # todo: should have its own rule +@register_rule(WgmmaMmaRSInst) # todo: should have its own rule @register_rule(Tcgen05LoadInst) # todo: should have its own rule @register_rule(Tcgen05StoreInst) # todo: should have its own rule @register_rule(Tcgen05CopyInst) # todo: should have its own rule diff --git a/python/tilus/lang/instructions/__init__.py b/python/tilus/lang/instructions/__init__.py index 801e1ba9..58c7e6eb 100644 --- a/python/tilus/lang/instructions/__init__.py +++ b/python/tilus/lang/instructions/__init__.py @@ -31,4 +31,4 @@ class InstructionInterface(RootInstructionGroup): mbarrier = BarrierInstructionGroup() clc = ClusterLaunchControlInstructionGroup() cluster = BlockClusterInstructionGroup() - wgmma = WgmmaInstructionGroup() \ No newline at end of file + wgmma = WgmmaInstructionGroup() diff --git a/python/tilus/lang/instructions/wgmma.py b/python/tilus/lang/instructions/wgmma.py index 5f457893..632ebce0 100644 --- a/python/tilus/lang/instructions/wgmma.py +++ b/python/tilus/lang/instructions/wgmma.py @@ -12,11 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -from typing import Optional, Sequence, Union +from typing import Union from hidet.ir.expr import Expr -from hidet.ir.type import DataType from tilus.ir.inst import InstructionError from tilus.ir.tensor import RegisterTensor, SharedTensor @@ -52,4 +50,4 @@ def mma(self, a: SharedTensor | RegisterTensor, b: SharedTensor, d: RegisterTens raise InstructionError("mma requires 2D register tensors, got shape {}".format(d.shape)) self._builder.wgmma_mma_rs(a, b, d) else: - raise InstructionError("Invalid type of a: {}, expected SharedTensor or RegisterTensor".format(type(a))) \ No newline at end of file + raise InstructionError("Invalid type of a: {}, expected SharedTensor or RegisterTensor".format(type(a))) From 1356e8990adbe6c73dc3f23d0bae029d4aba7383 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Thu, 4 Dec 2025 14:08:14 -0800 Subject: [PATCH 11/16] remove print Signed-off-by: Qidong Su --- python/tilus/backends/emitters/cuda/wgmma.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py index 48fa33f4..f0b4cfcc 100644 --- a/python/tilus/backends/emitters/cuda/wgmma.py +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -113,7 +113,6 @@ def emit(self, inst: WgmmaMmaSSInst) -> None: d_dtype = d_tensor.dtype inst_m, inst_n, inst_k = inst.get_inst_mnk(m, n, k, a_dtype, b_dtype, d_dtype) - print(f"inst_m: {inst_m}, inst_n: {inst_n}, inst_k: {inst_k}") wgmma_config = WgmmaConfig.get( inst_m, inst_n, inst_k, a_dtype.short_name, b_dtype.short_name, d_dtype.short_name ) @@ -121,7 +120,6 @@ def emit(self, inst: WgmmaMmaSSInst) -> None: repeat_m = m // inst_m repeat_n = n // inst_n repeat_k = k // inst_k - print(f"repeat_m: {repeat_m}, repeat_n: {repeat_n}, repeat_k: {repeat_k}") a_canonical = canonicalize_shared_layout(a_tensor.layout, dtype=a_dtype) b_canonical = canonicalize_shared_layout(b_tensor.layout.transpose(), dtype=b_dtype) From 7ea0aba074ac127b5f77dfd31d90c9ca31f936a8 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Thu, 4 Dec 2025 16:14:08 -0800 Subject: [PATCH 12/16] add test Signed-off-by: Qidong Su --- tests/examples/test_examples.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 84bdfd5b..86a4c2e0 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -23,7 +23,7 @@ from typing import Optional import pytest -from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm100a +from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90, nvgpu_sm100a # Get the project root directory PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -54,6 +54,9 @@ ("blackwell_matmul", "matmul_v3.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v4.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v5.py", nvgpu_sm100a), + # hopper matmul example (SM 9.0) + ("hopper_matmul", "matmul_tma.py", nvgpu_sm90), + ("hopper_matmul", "matmul_wgmma.py", nvgpu_sm90), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+) From 097a6c3bfced50ba0317a0c43508893c6a0ac83b Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Thu, 4 Dec 2025 22:39:43 -0800 Subject: [PATCH 13/16] fix Signed-off-by: Qidong Su --- .../{matmul_tma.py => matmul_v0.py} | 16 ++++++++-------- .../{matmul_wgmma.py => matmul_v1.py} | 0 python/tilus/ir/layout/inference/order.py | 4 ++-- python/tilus/lang/instructions/wgmma.py | 14 ++------------ 4 files changed, 12 insertions(+), 22 deletions(-) rename examples/hopper_matmul/{matmul_tma.py => matmul_v0.py} (90%) rename examples/hopper_matmul/{matmul_wgmma.py => matmul_v1.py} (100%) diff --git a/examples/hopper_matmul/matmul_tma.py b/examples/hopper_matmul/matmul_v0.py similarity index 90% rename from examples/hopper_matmul/matmul_tma.py rename to examples/hopper_matmul/matmul_v0.py index 106de870..18f60e87 100644 --- a/examples/hopper_matmul/matmul_tma.py +++ b/examples/hopper_matmul/matmul_v0.py @@ -9,12 +9,12 @@ from tilus import float16, float32, int32, uint32 from tilus.utils import benchmark_func, cdiv -tilus.option.cache_dir("./cache") -tilus.option.debug.dump_ir(True) -@tilus.autotune("block_m, block_n", [(128, 128), (128, 256), (128, 64)]) -@tilus.autotune("block_k", [16, 32, 16]) +@tilus.autotune( + "block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)] +) +@tilus.autotune("block_k", [16, 32, 64]) class MatmulTMA(tilus.Script): def __init__( self, @@ -69,9 +69,9 @@ def __call__( # synchronize threads in the block to ensure data is available in shared memory self.sync() - # a = self.load_shared(sa) - # b = self.load_shared(sb) - self.dot(sa, sb.transpose(), acc, out=acc) + a = self.load_shared(sa) + b = self.load_shared(sb) + self.dot(a, b.transpose(), acc, out=acc) self.sync() phase ^= 1 @@ -105,7 +105,7 @@ def main(): # benchmark for name, func in [ - ("torch", lambda: torch.matmul(a, b, out=c_expect)), + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), ]: latency = benchmark_func(func, warmup=5, repeat=20) diff --git a/examples/hopper_matmul/matmul_wgmma.py b/examples/hopper_matmul/matmul_v1.py similarity index 100% rename from examples/hopper_matmul/matmul_wgmma.py rename to examples/hopper_matmul/matmul_v1.py diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index ef9eb628..6a793247 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -43,7 +43,7 @@ from .inference_rules.transform import SqueezeRule, UnsqueezeRule from .inference_rules.transform_shared import PermuteSharedRule, SharedSliceRule from .inference_rules.transpose import TransposeRule -from .inference_rules.wgmma import WgmmaMmaSSRule # , WgmmaMmaRSRule +from .inference_rules.wgmma import WgmmaMmaSSRule from .inference_rules.where import WhereRule inference_order: list[list[Type[LayoutInferenceRule]]] = [ @@ -52,7 +52,7 @@ # register layout rules [SliceRegisterRule, SliceAssignRule, AllocBarrierRule], [MmaDotRule], - [WgmmaMmaSSRule], # , WgmmaMmaRSRule], + [WgmmaMmaSSRule], [Tcgen05LoadRule, Tcgen05StoreRule], [Tcgen05CopyRule], [BinaryRule, UnaryRule], diff --git a/python/tilus/lang/instructions/wgmma.py b/python/tilus/lang/instructions/wgmma.py index 632ebce0..ebf7fa98 100644 --- a/python/tilus/lang/instructions/wgmma.py +++ b/python/tilus/lang/instructions/wgmma.py @@ -33,21 +33,11 @@ def wait_group(self, n: Union[Expr, int]) -> None: self._builder.wgmma_wait_group(n) def mma(self, a: SharedTensor | RegisterTensor, b: SharedTensor, d: RegisterTensor) -> None: + if any(len(tensor.shape) != 2 for tensor in (a, b, d)): + raise InstructionError("mma requires 2D tensors, got shapes {}".format([tensor.shape for tensor in (a, b, d)])) if isinstance(a, SharedTensor): - if len(a.shape) != 2: - raise InstructionError("mma requires 2D shared tensors, got shape {}".format(a.shape)) - if len(b.shape) != 2: - raise InstructionError("mma requires 2D shared tensors, got shape {}".format(b.shape)) - if len(d.shape) != 2: - raise InstructionError("mma requires 2D register tensors, got shape {}".format(d.shape)) self._builder.wgmma_mma_ss(a, b, d) elif isinstance(a, RegisterTensor): - if len(a.shape) != 2: - raise InstructionError("mma requires 2D register tensors, got shape {}".format(a.shape)) - if len(b.shape) != 2: - raise InstructionError("mma requires 2D shared tensors, got shape {}".format(b.shape)) - if len(d.shape) != 2: - raise InstructionError("mma requires 2D register tensors, got shape {}".format(d.shape)) self._builder.wgmma_mma_rs(a, b, d) else: raise InstructionError("Invalid type of a: {}, expected SharedTensor or RegisterTensor".format(type(a))) From 2c4e787a384cfba9002873fbfdc6a65beffe2496 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 5 Dec 2025 09:05:41 -0800 Subject: [PATCH 14/16] format Signed-off-by: Qidong Su --- examples/hopper_matmul/matmul_v0.py | 1 - python/tilus/lang/instructions/wgmma.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/hopper_matmul/matmul_v0.py b/examples/hopper_matmul/matmul_v0.py index 18f60e87..891f62b9 100644 --- a/examples/hopper_matmul/matmul_v0.py +++ b/examples/hopper_matmul/matmul_v0.py @@ -10,7 +10,6 @@ from tilus.utils import benchmark_func, cdiv - @tilus.autotune( "block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)] ) diff --git a/python/tilus/lang/instructions/wgmma.py b/python/tilus/lang/instructions/wgmma.py index ebf7fa98..bd414db7 100644 --- a/python/tilus/lang/instructions/wgmma.py +++ b/python/tilus/lang/instructions/wgmma.py @@ -34,7 +34,9 @@ def wait_group(self, n: Union[Expr, int]) -> None: def mma(self, a: SharedTensor | RegisterTensor, b: SharedTensor, d: RegisterTensor) -> None: if any(len(tensor.shape) != 2 for tensor in (a, b, d)): - raise InstructionError("mma requires 2D tensors, got shapes {}".format([tensor.shape for tensor in (a, b, d)])) + raise InstructionError( + "mma requires 2D tensors, got shapes {}".format([tensor.shape for tensor in (a, b, d)]) + ) if isinstance(a, SharedTensor): self._builder.wgmma_mma_ss(a, b, d) elif isinstance(a, RegisterTensor): From e5b11ec6330989e5e921f5d315699494c86c8e90 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 5 Dec 2025 09:54:38 -0800 Subject: [PATCH 15/16] upd Signed-off-by: Qidong Su --- python/tilus/backends/emitters/cuda/wgmma.py | 32 +++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py index f0b4cfcc..4435543f 100644 --- a/python/tilus/backends/emitters/cuda/wgmma.py +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -22,6 +22,7 @@ from tilus.backends.emitter import BaseInstEmitter, register_emitter from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode from tilus.extensions.hidet.ir.primitives.cuda.wgmma import wgmma_encode_smem_descriptor +from tilus.ir.inst import Instruction from tilus.ir.instructions.cuda.wgmma import ( WgmmaCommitGroupInst, WgmmaFenceInst, @@ -72,27 +73,42 @@ def decode(encoded: int) -> SharedMatrixDescriptor: ) +class WgmmaBaseEmitter(BaseInstEmitter): + def check_warp_group(self) -> None: + begin = self.current_thread_group_begin + end = self.current_thread_group_end + if begin % 128 != 0 or end - begin != 128: + raise ValueError("The number of threads in the current thread group must be 128") + + def emit(self, inst: Instruction) -> None: + self.check_warp_group() + self.emit_wgmma(inst) + + def emit_wgmma(self, inst: Instruction) -> None: + raise NotImplementedError("Subclasses must implement this method") + + @register_emitter(WgmmaFenceInst, target=nvgpu_sm90a) -class WgmmaFenceEmitter(BaseInstEmitter): - def emit(self, inst: WgmmaFenceInst) -> None: +class WgmmaFenceEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaFenceInst) -> None: self.append(wgmma_fence()) @register_emitter(WgmmaCommitGroupInst, target=nvgpu_sm90a) -class WgmmaCommitGroupEmitter(BaseInstEmitter): - def emit(self, inst: WgmmaCommitGroupInst) -> None: +class WgmmaCommitGroupEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaCommitGroupInst) -> None: self.append(wgmma_commit_group()) @register_emitter(WgmmaWaitGroupInst, target=nvgpu_sm90a) -class WgmmaWaitGroupEmitter(BaseInstEmitter): - def emit(self, inst: WgmmaWaitGroupInst) -> None: +class WgmmaWaitGroupEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaWaitGroupInst) -> None: self.append(wgmma_wait_group(inst.n)) @register_emitter(WgmmaMmaSSInst, target=nvgpu_sm90a) -class WgmmaMmaSSEmitter(BaseInstEmitter): - def emit(self, inst: WgmmaMmaSSInst) -> None: +class WgmmaMmaSSEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaMmaSSInst) -> None: a, b, d = inst.inputs a_tensor: SharedTensor = a.as_shared_tensor() b_tensor: SharedTensor = b.as_shared_tensor() From 5c01c3e3f0d33daeaecc93e162d279d166d45a4a Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 5 Dec 2025 10:55:20 -0800 Subject: [PATCH 16/16] fix Signed-off-by: Qidong Su --- tests/examples/test_examples.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 86a4c2e0..79eb38c3 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -55,8 +55,8 @@ ("blackwell_matmul", "matmul_v4.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v5.py", nvgpu_sm100a), # hopper matmul example (SM 9.0) - ("hopper_matmul", "matmul_tma.py", nvgpu_sm90), - ("hopper_matmul", "matmul_wgmma.py", nvgpu_sm90), + ("hopper_matmul", "matmul_v0.py", nvgpu_sm90), + ("hopper_matmul", "matmul_v1.py", nvgpu_sm90), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+)