diff --git a/examples/hopper_matmul/matmul_v0.py b/examples/hopper_matmul/matmul_v0.py new file mode 100644 index 00000000..891f62b9 --- /dev/null +++ b/examples/hopper_matmul/matmul_v0.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pandas +import tilus +import torch +from tilus import float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + + +@tilus.autotune( + "block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)] +) +@tilus.autotune("block_k", [16, 32, 64]) +class MatmulTMA(tilus.Script): + def __init__( + self, + block_m, + block_n, + block_k, + ): + super().__init__() + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + cdiv(m_size, self.block_m), + cdiv(n_size, self.block_n), + ] + self.attrs.warps = 4 + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + sa = self.shared_tensor(dtype=float16, shape=[block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[block_n, block_k]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + tma_barrier = self.mbarrier.alloc(count=[2]) + phase: uint32 = 0 + + for offset_k in range(0, k_size, block_k): + # issue asynchronous copy instructions to load tiles of A and B + with self.single_thread(): + self.tma.global_to_shared( + src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier + ) + self.tma.global_to_shared( + src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier + ) + self.mbarrier.wait(tma_barrier, phase=phase) + + # synchronize threads in the block to ensure data is available in shared memory + self.sync() + + a = self.load_shared(sa) + b = self.load_shared(sb) + self.dot(a, b.transpose(), acc, out=acc) + self.sync() + phase ^= 1 + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulTMA() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual) + + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +# %% + +if __name__ == "__main__": + main() diff --git a/examples/hopper_matmul/matmul_v1.py b/examples/hopper_matmul/matmul_v1.py new file mode 100644 index 00000000..9e2d9fd9 --- /dev/null +++ b/examples/hopper_matmul/matmul_v1.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pandas +import tilus +import torch +from tilus import float16, float32, int32, uint32 +from tilus.utils import benchmark_func, cdiv + +tilus.option.cache_dir("./cache") +tilus.option.debug.dump_ir(True) +torch.set_printoptions(precision=3, sci_mode=False, linewidth=160) + + +@tilus.autotune( + "block_m, block_n", [(64, 128), (128, 128), (128, 256), (256, 128), (256, 256)] +) +@tilus.autotune("block_k", [16, 32, 64]) +class MatmulWGMMA(tilus.Script): + def __init__( + self, + block_m, + block_n, + block_k, + ): + super().__init__() + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + cdiv(m_size, self.block_m), + cdiv(n_size, self.block_n), + ] + self.attrs.warps = 4 + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size]) + sa = self.shared_tensor(dtype=float16, shape=[block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[block_n, block_k]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + tma_barrier = self.mbarrier.alloc(count=[2]) + phase: uint32 = 0 + + for offset_k in range(0, k_size, block_k): + # issue asynchronous copy instructions to load tiles of A and B + with self.single_thread(): + self.tma.global_to_shared( + src=ga, dst=sa, offsets=[offset_m, offset_k], mbarrier=tma_barrier + ) + self.tma.global_to_shared( + src=gb, dst=sb, offsets=[offset_n, offset_k], mbarrier=tma_barrier + ) + self.mbarrier.wait(tma_barrier, phase=phase) + + # synchronize threads in the block to ensure data is available in shared memory + self.sync() + + self.wgmma.fence() + self.wgmma.mma(sa, sb.transpose(), acc) + self.wgmma.commit_group() + self.wgmma.wait_group(0) + self.sync() + phase ^= 1 + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + # [128, 48, 16], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulWGMMA() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b.T + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual) + + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b.T, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +# %% + +if __name__ == "__main__": + main() diff --git a/python/tilus/backends/emitters/cuda/__init__.py b/python/tilus/backends/emitters/cuda/__init__.py index d5af74f1..df3ed155 100644 --- a/python/tilus/backends/emitters/cuda/__init__.py +++ b/python/tilus/backends/emitters/cuda/__init__.py @@ -24,4 +24,5 @@ semaphore, simt_dot, tcgen05, + wgmma, ) diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py new file mode 100644 index 00000000..4435543f --- /dev/null +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass + +from hidet.ir.expr import Expr, Var +from hidet.ir.primitives.cuda.wgmma import WgmmaConfig, wgmma_async, wgmma_commit_group, wgmma_fence, wgmma_wait_group + +from tilus.backends.emitter import BaseInstEmitter, register_emitter +from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode +from tilus.extensions.hidet.ir.primitives.cuda.wgmma import wgmma_encode_smem_descriptor +from tilus.ir.inst import Instruction +from tilus.ir.instructions.cuda.wgmma import ( + WgmmaCommitGroupInst, + WgmmaFenceInst, + WgmmaMmaSSInst, + WgmmaWaitGroupInst, +) +from tilus.ir.layout.cuda.tcgen05.smem import canonicalize_shared_layout +from tilus.ir.layout.utils.cute import CuteLayout +from tilus.ir.tensor import RegisterTensor, SharedTensor +from tilus.target import nvgpu_sm90a + + +def encode_swizzle_mode(swizzle_mode: Tcgen05SwizzleMode) -> int: + swizzle_mode_map = { + Tcgen05SwizzleMode.NO_SWIZZLE: 0, + Tcgen05SwizzleMode.B32_SWIZZLE: 3, + Tcgen05SwizzleMode.B64_SWIZZLE: 2, + Tcgen05SwizzleMode.B128_SWIZZLE: 1, + } + return swizzle_mode_map[swizzle_mode] + + +@dataclass +class SharedMatrixDescriptor: + addr: Expr | int + lbo: int + sbo: int + base_offset: int + swizzle_mode: int + + def encoded(self) -> Expr: + return wgmma_encode_smem_descriptor( + self.addr >> 4, + self.lbo >> 4, + self.sbo >> 4, + self.base_offset, + self.swizzle_mode, + ) + + @staticmethod + def decode(encoded: int) -> SharedMatrixDescriptor: + return SharedMatrixDescriptor( + addr=(encoded & 0x3FFF) << 4, + lbo=((encoded >> 16) & 0x3FFF) << 4, + sbo=((encoded >> 32) & 0x3FFF) << 4, + base_offset=(encoded >> 49) & 0x7, + swizzle_mode=(encoded >> 62) & 0x3, + ) + + +class WgmmaBaseEmitter(BaseInstEmitter): + def check_warp_group(self) -> None: + begin = self.current_thread_group_begin + end = self.current_thread_group_end + if begin % 128 != 0 or end - begin != 128: + raise ValueError("The number of threads in the current thread group must be 128") + + def emit(self, inst: Instruction) -> None: + self.check_warp_group() + self.emit_wgmma(inst) + + def emit_wgmma(self, inst: Instruction) -> None: + raise NotImplementedError("Subclasses must implement this method") + + +@register_emitter(WgmmaFenceInst, target=nvgpu_sm90a) +class WgmmaFenceEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaFenceInst) -> None: + self.append(wgmma_fence()) + + +@register_emitter(WgmmaCommitGroupInst, target=nvgpu_sm90a) +class WgmmaCommitGroupEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaCommitGroupInst) -> None: + self.append(wgmma_commit_group()) + + +@register_emitter(WgmmaWaitGroupInst, target=nvgpu_sm90a) +class WgmmaWaitGroupEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaWaitGroupInst) -> None: + self.append(wgmma_wait_group(inst.n)) + + +@register_emitter(WgmmaMmaSSInst, target=nvgpu_sm90a) +class WgmmaMmaSSEmitter(WgmmaBaseEmitter): + def emit_wgmma(self, inst: WgmmaMmaSSInst) -> None: + a, b, d = inst.inputs + a_tensor: SharedTensor = a.as_shared_tensor() + b_tensor: SharedTensor = b.as_shared_tensor() + d_tensor: RegisterTensor = d.as_register_tensor() + + a_shape = a_tensor.shape + b_shape = b_tensor.shape + d_shape = d_tensor.shape + + if len(a_shape) != 2 or len(b_shape) != 2 or len(d_shape) != 2: + raise ValueError(f"MMA requires 2D tensors, but got shapes: a={a_shape}, b={b_shape}, d={d_shape}") + if a_shape[1] != b_shape[0] or a_shape[0] != d_shape[0] or b_shape[1] != d_shape[1]: + raise ValueError(f"Incompatible shapes for MMA: a={a_shape}, b={b_shape}, d={d_shape}") + m, n, k = d_shape[0], d_shape[1], a_shape[1] + + a_dtype = a_tensor.dtype + b_dtype = b_tensor.dtype + d_dtype = d_tensor.dtype + + inst_m, inst_n, inst_k = inst.get_inst_mnk(m, n, k, a_dtype, b_dtype, d_dtype) + wgmma_config = WgmmaConfig.get( + inst_m, inst_n, inst_k, a_dtype.short_name, b_dtype.short_name, d_dtype.short_name + ) + + repeat_m = m // inst_m + repeat_n = n // inst_n + repeat_k = k // inst_k + + a_canonical = canonicalize_shared_layout(a_tensor.layout, dtype=a_dtype) + b_canonical = canonicalize_shared_layout(b_tensor.layout.transpose(), dtype=b_dtype) + + if a_canonical is None: + raise ValueError(f"Cannot canonicalize the layout of a tensor: {a_tensor.layout}.") + if b_canonical is None: + raise ValueError(f"Cannot canonicalize the layout of b tensor: {b_tensor.layout}.") + + a_cute_layout: CuteLayout = a_canonical.swizzled_cute_layout.layout + b_cute_layout: CuteLayout = b_canonical.swizzled_cute_layout.layout + + a_shared_addr: Var = self.shared_tensor_shared_space_addr[a_tensor] + b_shared_addr: Var = self.shared_tensor_shared_space_addr[b_tensor] + d_register_addr: Var = ~(self.tensor2var[d_tensor][0]) + + d_local_stride = d_tensor.layout.local_size // repeat_m // repeat_n + + for k in range(repeat_k): + for i in range(repeat_m): + for j in range(repeat_n): + a_offset = a_cute_layout(i * inst_m, k * inst_k) + b_offset = b_cute_layout(j * inst_n, k * inst_k) + a_desc = SharedMatrixDescriptor( + addr=a_shared_addr + a_offset * a_tensor.dtype.nbytes, + lbo=a_canonical.LBO * a_tensor.dtype.nbytes, + sbo=a_canonical.SBO * a_tensor.dtype.nbytes, + base_offset=0, + swizzle_mode=encode_swizzle_mode(a_canonical.swizzle_mode), + ) + b_desc = SharedMatrixDescriptor( + addr=b_shared_addr + b_offset * b_tensor.dtype.nbytes, + lbo=b_canonical.LBO * b_tensor.dtype.nbytes, + sbo=b_canonical.SBO * b_tensor.dtype.nbytes, + base_offset=0, + swizzle_mode=encode_swizzle_mode(b_canonical.swizzle_mode), + ) + d_offset = (i * repeat_n + j) * d_local_stride + self.append( + wgmma_async( + wgmma_config, + a_desc.encoded(), + d_register_addr + d_offset, + b_desc.encoded(), + trans_a=0, + trans_b=0, + ) + ) diff --git a/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py new file mode 100644 index 00000000..248c9717 --- /dev/null +++ b/python/tilus/extensions/hidet/ir/primitives/cuda/wgmma.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import no_type_check + +from hidet.ir.dtypes import uint8, uint32, uint64 +from hidet.ir.expr import Expr +from hidet.ir.primitives.func import call_primitive_func +from hidet.utils import initialize + +from tilus.extensions.hidet.ir.primitives.utils import register_primitive_function_decorator + + +@initialize() +def register_wgmma_instructions(): + from hidet.lang import attrs + + from tilus.extensions.hidet.lang import script + + @register_primitive_function_decorator + @no_type_check + @script + def wgmma_encode_smem_descriptor( + smem_addr: uint32, # 14 bits + lbo: uint32, # 14 bits + sbo: uint32, # 14 bits + mbo: uint8, # 3 bits + swizzle_mode: uint8, # 2 bits + ) -> uint64: + attrs.func_name = "cuda_wgmma_encode_smem_descriptor" + attrs.func_kind = "cuda_internal" + desc: uint64 = uint64(0) + desc = desc | uint64(lbo & uint32(0x3FFF)) << 16 + desc = desc | uint64(sbo & uint32(0x3FFF)) << 32 + desc = desc | uint64(mbo & uint8(0b111)) << 49 + desc = desc | uint64(swizzle_mode & uint8(0b11)) << 62 + desc = desc | uint64(smem_addr & uint32(0x3FFF)) + return desc + + +def wgmma_encode_smem_descriptor( + smem_addr: Expr | int, + lbo: Expr | int, + sbo: Expr | int, + mbo: Expr | int, + swizzle_mode: Expr | int, +) -> Expr: + func_name = "cuda_wgmma_encode_smem_descriptor" + return call_primitive_func( + func_name, [uint32(smem_addr), uint32(lbo), uint32(sbo), uint8(mbo), uint8(swizzle_mode)] + ) diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 9bd24804..6b362f17 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -69,6 +69,13 @@ Tcgen05ViewInst, Tcgen05WaitInst, ) +from tilus.ir.instructions.cuda.wgmma import ( + WgmmaCommitGroupInst, + WgmmaFenceInst, + WgmmaMmaRSInst, + WgmmaMmaSSInst, + WgmmaWaitGroupInst, +) from tilus.ir.instructions.generic import ( AddInst, AllocateGlobalInst, @@ -1324,6 +1331,29 @@ def tcgen05_mma_ts(self, a: TMemoryTensor, b: SharedTensor, d: TMemoryTensor) -> inst = Tcgen05MmaTSInst.create(a=a, b=b, d=d) self.append(inst) + # wgmma + def wgmma_fence(self) -> None: + inst = WgmmaFenceInst.create() + self.append(inst) + + def wgmma_commit_group(self) -> None: + inst = WgmmaCommitGroupInst.create() + self.append(inst) + + def wgmma_wait_group(self, n: Union[Expr, int]) -> None: + if isinstance(n, int): + n = as_expr(n) + inst = WgmmaWaitGroupInst.create(n=n) + self.append(inst) + + def wgmma_mma_ss(self, a: SharedTensor, b: SharedTensor, d: RegisterTensor) -> None: + inst = WgmmaMmaSSInst.create(a=a, b=b, d=d) + self.append(inst) + + def wgmma_mma_rs(self, a: RegisterTensor, b: SharedTensor, d: RegisterTensor) -> None: + inst = WgmmaMmaRSInst.create(a=a, b=b, d=d) + self.append(inst) + # annotations def annotate_layout(self, tensor: RegisterTensor | SharedTensor, layout: RegisterLayout | SharedLayout) -> None: inst = AnnotateLayoutInst.create(tensor=tensor, layout=layout) diff --git a/python/tilus/ir/instructions/cuda/wgmma.py b/python/tilus/ir/instructions/cuda/wgmma.py new file mode 100644 index 00000000..9b6306ce --- /dev/null +++ b/python/tilus/ir/instructions/cuda/wgmma.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass + +from hidet.ir.dtypes import bf16, f8e4m3, f8e5m2, f16, i8, tf32, u1, u8 +from hidet.ir.expr import Expr +from hidet.ir.type import DataType + +from tilus.ir.inst import Instruction +from tilus.ir.tensor import RegisterTensor, SharedTensor +from tilus.utils import gcd + + +@dataclass(frozen=True, eq=False) +class WgmmaFenceInst(Instruction): + @staticmethod + def create() -> WgmmaFenceInst: + return WgmmaFenceInst(output=None, inputs=()) + + +@dataclass(frozen=True, eq=False) +class WgmmaCommitGroupInst(Instruction): + @staticmethod + def create() -> WgmmaCommitGroupInst: + return WgmmaCommitGroupInst(output=None, inputs=()) + + +@dataclass(frozen=True, eq=False) +class WgmmaWaitGroupInst(Instruction): + n: Expr + + @staticmethod + def create(n: Expr) -> WgmmaWaitGroupInst: + return WgmmaWaitGroupInst(output=None, inputs=(), n=n) + + +@dataclass(frozen=True, eq=False) +class WgmmaMmaSSInst(Instruction): + @staticmethod + def get_inst_mnk( + m: int, n: int, k: int, a_dtype: DataType, b_dtype: DataType, d_dtype: DataType + ) -> tuple[int, int, int]: + inst_m = 64 + inst_n = gcd(n, 256) # why? + if a_dtype == b_dtype == f16: + inst_k = 16 + elif a_dtype == b_dtype == bf16: + inst_k = 16 + elif a_dtype == b_dtype == tf32: + inst_k = 8 + elif a_dtype in (f8e4m3, f8e5m2) and b_dtype in (f8e4m3, f8e5m2): + inst_k = 32 + elif a_dtype in (i8, u8) and b_dtype in (i8, u8): + inst_k = 32 + elif a_dtype == d_dtype == u1: + inst_k = 256 + else: + raise ValueError(f"Unsupported data types for MMA: a_dtype={a_dtype}, b_dtype={b_dtype}") + return inst_m, inst_n, inst_k + + @staticmethod + def create(a: SharedTensor, b: SharedTensor, d: RegisterTensor) -> WgmmaMmaSSInst: + return WgmmaMmaSSInst(output=None, inputs=(a, b, d)) + + +@dataclass(frozen=True, eq=False) +class WgmmaMmaRSInst(Instruction): + @staticmethod + def create(a: RegisterTensor, b: SharedTensor, d: RegisterTensor) -> WgmmaMmaRSInst: + return WgmmaMmaRSInst(output=None, inputs=(a, b, d)) diff --git a/python/tilus/ir/layout/inference/inference_rules/__init__.py b/python/tilus/ir/layout/inference/inference_rules/__init__.py index dfac471a..c35543f7 100644 --- a/python/tilus/ir/layout/inference/inference_rules/__init__.py +++ b/python/tilus/ir/layout/inference/inference_rules/__init__.py @@ -30,5 +30,6 @@ transform, transform_shared, transpose, + wgmma, where, ) diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py new file mode 100644 index 00000000..630c99f6 --- /dev/null +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tilus.ir.instructions.cuda.wgmma import WgmmaMmaSSInst +from tilus.ir.layout import RegisterLayout, SharedLayout +from tilus.ir.layout.cuda.tcgen05.smem import ( + Tcgen05SwizzleMode, + generate_canonical_layout, +) +from tilus.ir.layout.inference.rule import ( + LayoutInferenceContext, + LayoutInferenceError, + LayoutInferenceRule, + register_rule, +) +from tilus.ir.layout.ops.register_ops import local +from tilus.ir.tensor import RegisterTensor, SharedTensor + + +def generate_wgmma_register_layout(m: int, n: int, inst_m: int, inst_n: int, inst_k: int) -> RegisterLayout: + # See also: https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-register-fragment + T = inst_k // 8 + repeat_m = m // inst_m + repeat_n = n // inst_n + return ( + local(repeat_m, repeat_n) + .column_spatial(inst_m // 16, 1) + .column_local(2, inst_n // T // 4) + .spatial(8, 4) + .local(T) + ) + + +@register_rule(WgmmaMmaSSInst) +class WgmmaMmaSSRule(LayoutInferenceRule): + @staticmethod + def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedTensor, SharedLayout]: + a_tensor: SharedTensor = inst.inputs[0].as_shared_tensor() + b_tensor: SharedTensor = inst.inputs[1].as_shared_tensor() + d_tensor: RegisterTensor = inst.inputs[2].as_register_tensor() + + a_shape = a_tensor.shape + b_shape = b_tensor.shape + d_shape = d_tensor.shape + + if not len(a_shape) == len(b_shape) == len(d_shape) == 2: + raise LayoutInferenceError( + f"A, B, and D must have 2 dimensions, but got {len(a_shape)}, {len(b_shape)}, and {len(d_shape)}." + ) + if a_shape[1] != b_shape[0] or a_shape[0] != d_shape[0] or b_shape[1] != d_shape[1]: + raise LayoutInferenceError( + f"A, B, and D must have compatible shapes, but got {a_tensor.shape}, {b_tensor.shape}, and {d_tensor.shape}." + ) + m, n, k = d_shape[0], d_shape[1], a_shape[1] + + ret = {} + if not a_tensor.has_layout(): + for swizzle_mode in [ + Tcgen05SwizzleMode.B128_SWIZZLE, + Tcgen05SwizzleMode.B64_SWIZZLE, + Tcgen05SwizzleMode.B32_SWIZZLE, + Tcgen05SwizzleMode.NO_SWIZZLE, + ]: + try: + a_layout_canonical = generate_canonical_layout( + shape=(m, k), dtype=a_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode + ) + ret[a_tensor] = a_layout_canonical.as_shared_layout().simplify() + except ValueError: + continue + else: + break + if not b_tensor.has_layout(): + for swizzle_mode in [ + Tcgen05SwizzleMode.B128_SWIZZLE, + Tcgen05SwizzleMode.B64_SWIZZLE, + Tcgen05SwizzleMode.B32_SWIZZLE, + Tcgen05SwizzleMode.NO_SWIZZLE, + ]: + try: + b_layout_canonical = generate_canonical_layout( + shape=(n, k), dtype=b_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode + ) + ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]).simplify() + except ValueError: + continue + else: + break + if not d_tensor.has_layout(): + inst_m, inst_n, inst_k = WgmmaMmaSSInst.get_inst_mnk( + m, n, k, a_tensor.dtype, b_tensor.dtype, d_tensor.dtype + ) + d_layout = generate_wgmma_register_layout(m, n, inst_m, inst_n, inst_k) + ret[d_tensor] = d_layout + return ret diff --git a/python/tilus/ir/layout/inference/order.py b/python/tilus/ir/layout/inference/order.py index bfcb5ef3..6a793247 100644 --- a/python/tilus/ir/layout/inference/order.py +++ b/python/tilus/ir/layout/inference/order.py @@ -43,6 +43,7 @@ from .inference_rules.transform import SqueezeRule, UnsqueezeRule from .inference_rules.transform_shared import PermuteSharedRule, SharedSliceRule from .inference_rules.transpose import TransposeRule +from .inference_rules.wgmma import WgmmaMmaSSRule from .inference_rules.where import WhereRule inference_order: list[list[Type[LayoutInferenceRule]]] = [ @@ -51,6 +52,7 @@ # register layout rules [SliceRegisterRule, SliceAssignRule, AllocBarrierRule], [MmaDotRule], + [WgmmaMmaSSRule], [Tcgen05LoadRule, Tcgen05StoreRule], [Tcgen05CopyRule], [BinaryRule, UnaryRule], diff --git a/python/tilus/ir/layout/inference/validation_rules/always_ok.py b/python/tilus/ir/layout/inference/validation_rules/always_ok.py index aebe5069..6d2c1248 100644 --- a/python/tilus/ir/layout/inference/validation_rules/always_ok.py +++ b/python/tilus/ir/layout/inference/validation_rules/always_ok.py @@ -48,9 +48,12 @@ ) from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixInst from tilus.ir.instructions.cuda.tcgen05 import Tcgen05CopyInst, Tcgen05LoadInst, Tcgen05MmaSSInst, Tcgen05StoreInst +from tilus.ir.instructions.cuda.wgmma import WgmmaMmaRSInst, WgmmaMmaSSInst from tilus.ir.layout.inference.rule import LayoutValidationRule, register_rule +@register_rule(WgmmaMmaSSInst) # todo: should have its own rule +@register_rule(WgmmaMmaRSInst) # todo: should have its own rule @register_rule(Tcgen05LoadInst) # todo: should have its own rule @register_rule(Tcgen05StoreInst) # todo: should have its own rule @register_rule(Tcgen05CopyInst) # todo: should have its own rule diff --git a/python/tilus/lang/instructions/__init__.py b/python/tilus/lang/instructions/__init__.py index 0d7dafe0..58c7e6eb 100644 --- a/python/tilus/lang/instructions/__init__.py +++ b/python/tilus/lang/instructions/__init__.py @@ -22,6 +22,7 @@ from .root import RootInstructionGroup from .tcgen05 import Tcgen05InstructionGroup from .tma import TmaInstructionGroup +from .wgmma import WgmmaInstructionGroup class InstructionInterface(RootInstructionGroup): @@ -30,3 +31,4 @@ class InstructionInterface(RootInstructionGroup): mbarrier = BarrierInstructionGroup() clc = ClusterLaunchControlInstructionGroup() cluster = BlockClusterInstructionGroup() + wgmma = WgmmaInstructionGroup() diff --git a/python/tilus/lang/instructions/wgmma.py b/python/tilus/lang/instructions/wgmma.py new file mode 100644 index 00000000..bd414db7 --- /dev/null +++ b/python/tilus/lang/instructions/wgmma.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from hidet.ir.expr import Expr + +from tilus.ir.inst import InstructionError +from tilus.ir.tensor import RegisterTensor, SharedTensor + +from .root import InstructionGroup + + +class WgmmaInstructionGroup(InstructionGroup): + def fence(self) -> None: + self._builder.wgmma_fence() + + def commit_group(self) -> None: + self._builder.wgmma_commit_group() + + def wait_group(self, n: Union[Expr, int]) -> None: + self._builder.wgmma_wait_group(n) + + def mma(self, a: SharedTensor | RegisterTensor, b: SharedTensor, d: RegisterTensor) -> None: + if any(len(tensor.shape) != 2 for tensor in (a, b, d)): + raise InstructionError( + "mma requires 2D tensors, got shapes {}".format([tensor.shape for tensor in (a, b, d)]) + ) + if isinstance(a, SharedTensor): + self._builder.wgmma_mma_ss(a, b, d) + elif isinstance(a, RegisterTensor): + self._builder.wgmma_mma_rs(a, b, d) + else: + raise InstructionError("Invalid type of a: {}, expected SharedTensor or RegisterTensor".format(type(a))) diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 84bdfd5b..79eb38c3 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -23,7 +23,7 @@ from typing import Optional import pytest -from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm100a +from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90, nvgpu_sm100a # Get the project root directory PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -54,6 +54,9 @@ ("blackwell_matmul", "matmul_v3.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v4.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v5.py", nvgpu_sm100a), + # hopper matmul example (SM 9.0) + ("hopper_matmul", "matmul_v0.py", nvgpu_sm90), + ("hopper_matmul", "matmul_v1.py", nvgpu_sm90), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+)