diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index abe35007..cb99c2e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,14 @@ repos: language: system types: [python] pass_filenames: false + - id: mypy + name: MyPy type checking + entry: bash -c 'mypy --version >&2; mypy "$@"' -- + language: system + types_or: [python, pyi] + args: [--show-error-codes, --show-error-context] + files: ^(python|examples|tests)/ + require_serial: true - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version rev: v0.12.3 @@ -30,17 +38,3 @@ repos: - id: ruff-format name: Ruff formatter types_or: [python, pyi] -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 - hooks: - - id: mypy - name: MyPy type checking - # Uses [tool.mypy] configuration from pyproject.toml - additional_dependencies: [ - types-tabulate, - types-tqdm, - ] - # Check files in the python package, examples, and tests folders - files: ^(python|examples|tests)/ - # Exclude hidet extensions (they have special mypy overrides in pyproject.toml) - exclude: ^python/tilus/extensions/hidet/ diff --git a/.vscode/settings.json b/.vscode/settings.json index c41362fd..870ed23e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,6 @@ "tests" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "python.analysis.supportDocstringTemplate": true } diff --git a/examples/attention/flash_attention_v3.py b/examples/attention/flash_attention_v3.py index 9a44e758..d79ddf87 100644 --- a/examples/attention/flash_attention_v3.py +++ b/examples/attention/flash_attention_v3.py @@ -293,7 +293,7 @@ def store_back( shape=[num_q_blocks, batch_size, self.num_heads, self.block_q], requires_clean=False, ) - semaphore = ~semaphores[self.blockIdx.x, bs, head] + semaphore = semaphores[self.blockIdx.x, bs, head].item_ptr() sm = self.shared_tensor(dtype=f32, shape=[self.block_q]) sl = self.shared_tensor(dtype=f32, shape=[self.block_q]) diff --git a/examples/flash_attention_decode/tilus_kernel.py b/examples/flash_attention_decode/tilus_kernel.py index c0885326..bf670c35 100644 --- a/examples/flash_attention_decode/tilus_kernel.py +++ b/examples/flash_attention_decode/tilus_kernel.py @@ -152,7 +152,7 @@ def __call__( dims=[2, 3], ) else: - state_idx = -1 + state_idx = -1 # type: ignore r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0) # H' = alpha * H : [K, BV] = [] * [K, BV] @@ -388,7 +388,7 @@ def __call__( dims=[2, 3], ) else: - state_idx = -1 + state_idx = -1 # type: ignore r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0) # Apply gating to hidden state: H' = alpha * H diff --git a/examples/quantization/matmul_a16wx.py b/examples/quantization/matmul_a16wx.py index cb1deb77..aeb54684 100644 --- a/examples/quantization/matmul_a16wx.py +++ b/examples/quantization/matmul_a16wx.py @@ -22,7 +22,7 @@ int32, uint8, ) -from tilus.ir.layout.ops import concat, local, reduce, spatial +from tilus.ir.layout.ops import concat, local, reduce, shared_row_major_swizzle, spatial from tilus.utils import benchmark_func, cdiv, dtype_to_torch, gcd from torch import nn @@ -197,8 +197,8 @@ def __init__( self.block_k = self.atomic_mma.k * wrk self.num_warps = wsm * wsn - k_tiles = wrk // tk - n_tiles = wsn * wrn // tn + self.k_tiles = wrk // tk + self.n_tiles = wsn * wrn // tn # we make sure that each weight_tile will be loaded by one warp assert wrk * self.atomic_mma.k % weight_tile[0] == 0 @@ -217,14 +217,15 @@ def __init__( ) self.layout_rs = reduce(self.mma.lb, dims=[0], keepdims=True) - self.layout_sa = self.cuda.swizzled_shared_layout( - self.a_dtype, shape=[num_stages, self.block_m, self.block_k] + self.layout_sa = shared_row_major_swizzle( + dtype_nbytes=self.a_dtype.nbytes, + shape=[num_stages, self.block_m, self.block_k], ) self.layout_sb = self.cuda.shared_layout( - shape=[self.num_stages, k_tiles, n_tiles, self.tile_bytes] + shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes] ) - self.layout_sc = self.cuda.swizzled_shared_layout( - self.a_dtype, shape=[self.block_m, self.block_n] + self.layout_sc = shared_row_major_swizzle( + dtype_nbytes=self.a_dtype.nbytes, shape=[self.block_m, self.block_n] ) self.layout_ss = self.cuda.shared_layout(shape=[self.num_stages, 1, self.block_n]) @@ -275,7 +276,10 @@ def __call__( sa = self.shared_tensor( dtype=self.a_dtype, shape=[self.num_stages, block_m, block_k] ) - sb = self.shared_tensor(dtype=uint8, shape=self.layout_sb.shape) + sb = self.shared_tensor( + dtype=uint8, + shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes], + ) ss = self.shared_tensor(dtype=self.a_dtype, shape=[self.num_stages, 1, block_n]) acc = self.register_tensor( dtype=float32, diff --git a/pyproject.toml b/pyproject.toml index 36376f74..ce8125d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ Documentation = "https://nvidia.github.io/tilus" "" = "python" [tool.setuptools.package-data] +"tilus" = ["py.typed"] "tilus.extensions.hidet" = ["include/**/*.h"] [tool.ruff] @@ -106,6 +107,11 @@ disallow_incomplete_defs = true allow_redefinition = true strict_optional = false explicit_package_bases = true +mypy_path = "python" +namespace_packages = true +disable_error_code = [ + "import-untyped", # we used some untyped third-party libraries +] [[tool.mypy.overrides]] module = ["tilus.*"] @@ -120,9 +126,13 @@ disable_error_code = [ "override", ] -[[tool.mypy.overrides]] -module = ["examples.*", "tests.*"] -disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef"] +[[tool.mypy.overrides]] # disable type checking in these modules that might define tilus scripts +module = [ + "examples.*", + "tests.*", + "tilus.lang.classes.pipeline" +] +disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef", "assignment", "import-not-found"] [[tool.mypy.overrides]] module = ["hidet.*"] diff --git a/python/tilus/backends/codegen.py b/python/tilus/backends/codegen.py index aa66394c..129dc6a9 100644 --- a/python/tilus/backends/codegen.py +++ b/python/tilus/backends/codegen.py @@ -143,15 +143,15 @@ def check_emitter_existence(self) -> None: if failed_instructions: rows = [f"Failed to find emitter for the following instructions (target: {get_current_target()}):"] required_targets: list[str] = [] - for inst in failed_instructions: + for inst_cls in failed_instructions: for registry_inst_cls, emitter_classes in BaseInstEmitter.REGISTRY.items(): - if issubclass(inst, registry_inst_cls): + if issubclass(inst_cls, registry_inst_cls): required_targets.extend([str(target) for target in emitter_classes.keys()]) break if not required_targets: - rows.append(f" - {inst.__name__} (no registered emitters)") + rows.append(f" - {inst_cls.__name__} (no registered emitters)") else: - rows.append(f" - {inst.__name__} (registered targets: {', '.join(required_targets)})") + rows.append(f" - {inst_cls.__name__} (registered targets: {', '.join(required_targets)})") raise CodeGenerationFailed("\n".join(rows)) def launch_kernel(self, kernel_func: HidetFunction) -> None: diff --git a/python/tilus/backends/emitters/cuda/cp_async.py b/python/tilus/backends/emitters/cuda/cp_async.py index 3d1976af..2ae2955e 100644 --- a/python/tilus/backends/emitters/cuda/cp_async.py +++ b/python/tilus/backends/emitters/cuda/cp_async.py @@ -47,7 +47,8 @@ def emit(self, inst: CopyAsyncGenericInst) -> None: # get shared, global, and mask info inst_mask = inst.mask if inst.mask is not None else boolean.true - shared_info: TensorInfo = analyze_grid(shape=shape, axes=layout.axes, analysis=analysis, expr=layout.offset) + axes, offset = layout.as_axes_mapping() + shared_info: TensorInfo = analyze_grid(shape=shape, axes=axes, analysis=analysis, expr=offset) mask_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst_mask) global_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst.offset) diff --git a/python/tilus/backends/emitters/cuda/cp_async_tensor.py b/python/tilus/backends/emitters/cuda/cp_async_tensor.py index 98629b66..2f7e4006 100644 --- a/python/tilus/backends/emitters/cuda/cp_async_tensor.py +++ b/python/tilus/backends/emitters/cuda/cp_async_tensor.py @@ -202,12 +202,9 @@ def resolve_shared_tensor_info(self, shared_tensor: SharedTensor) -> SharedTenso range_indices: list[np.ndarray] = [] for dim, extent in enumerate(shared_tensor.shape): range_indices.append(np.arange(extent, dtype=np.int32)) - grid = np.meshgrid(*range_indices, indexing="ij") layout: SharedLayout = shared_tensor.layout - offset_grid: np.ndarray = vectorized_evaluate( - expr=layout.offset, var2value={axis: grid[i] for i, axis in enumerate(layout.axes)} - ) + offset_grid: np.ndarray = layout.as_numpy_grid() for swizzle in [ TensorMapSwizzle.NONE, TensorMapSwizzle.B32, diff --git a/python/tilus/backends/emitters/cuda/tcgen05/copy.py b/python/tilus/backends/emitters/cuda/tcgen05/copy.py index ab0d4ae8..0d64253c 100644 --- a/python/tilus/backends/emitters/cuda/tcgen05/copy.py +++ b/python/tilus/backends/emitters/cuda/tcgen05/copy.py @@ -188,17 +188,12 @@ def generate_instructions( raise ValueError("No valid instructions generated") - def check_warp_group(self) -> None: - begin = self.current_thread_group_begin - end = self.current_thread_group_end - if begin % 128 != 0 or end - begin != 128: - raise ValueError("The number of threads in the current thread group must be 128") - def emit(self, inst: Tcgen05CopyInst) -> None: shared_tensor = inst.inputs[1].as_shared_tensor() tmem_tensor = inst.inputs[0].as_tmemory_tensor() - self.check_warp_group() + if self.current_num_threads != 1: + raise ValueError("Tcgen05CopyInst can only be emitted in thread group with a single thread") if len(shared_tensor.shape) != 2: raise ValueError("The shared tensor must be a 2D tensor, got shape {}".format(shared_tensor.shape)) diff --git a/python/tilus/backends/emitters/cuda/wgmma.py b/python/tilus/backends/emitters/cuda/wgmma.py index 4435543f..d2e891a3 100644 --- a/python/tilus/backends/emitters/cuda/wgmma.py +++ b/python/tilus/backends/emitters/cuda/wgmma.py @@ -180,7 +180,7 @@ def emit_wgmma(self, inst: WgmmaMmaSSInst) -> None: a_desc.encoded(), d_register_addr + d_offset, b_desc.encoded(), - trans_a=0, - trans_b=0, + trans_a=0, # type: ignore + trans_b=0, # type: ignore ) ) diff --git a/python/tilus/backends/emitters/reduce.py b/python/tilus/backends/emitters/reduce.py index 32ac8dfc..a2aba1f3 100644 --- a/python/tilus/backends/emitters/reduce.py +++ b/python/tilus/backends/emitters/reduce.py @@ -284,8 +284,8 @@ def inter_warp_reduce(self, inst: ReduceInst) -> None: smem_ctx = self.contexts.smem_alloc_ctx smem_buf = self.declare_var( "smem_buf", - tensor_pointer_type(dtype=dst.dtype, shape=[shared_layout.size]), - init=cast(smem_ctx.request_shared_workspace(dst.dtype.nbytes * shared_layout.size), ~dst.dtype), + tensor_pointer_type(dtype=dst.dtype, shape=[shared_layout.count_size()]), + init=cast(smem_ctx.request_shared_workspace(dst.dtype.nbytes * shared_layout.count_size()), ~dst.dtype), ) reduced_mode_shape = [ diff --git a/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py b/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py index 81a988a1..180a08e0 100644 --- a/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py +++ b/python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py @@ -207,9 +207,8 @@ def mbarrier_arrive_shared(mbarrier_addr: Expr, count: Expr | int) -> Expr: -------- mbarrier.arrive : PTX ISA documentation section 9.7.13.15.13 """ - if isinstance(count, int): - count = u32(count) - return call_primitive_func("cuda_mbarrier_arrive_shared", args=[mbarrier_addr, count]) + count_expr = count if isinstance(count, Expr) else u32(count) + return call_primitive_func("cuda_mbarrier_arrive_shared", args=[mbarrier_addr, count_expr]) def mbarrier_arrive_remote_shared( diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index 6b362f17..d50f37c1 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -165,21 +165,15 @@ def __init__( iter_vars: List[Var], extents: List[Expr], unrolls: List[Optional[int]], - unwrap: bool = False, ): super().__init__(vb) self.iter_vars: List[Var] = iter_vars self.extents: List[Expr] = extents self.unrolls: List[Optional[int]] = unrolls - self.unwrap: bool = unwrap - def __enter__(self): + def __enter__(self) -> List[Var]: self.enter() - if self.unwrap: - assert len(self.iter_vars) == 1 - return self.iter_vars[0] - else: - return self.iter_vars + return self.iter_vars def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: @@ -193,6 +187,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.append(body) +class ForRangeContext(ForContext): + def __enter__(self): + iter_vars = super().__enter__() + assert len(iter_vars) == 1 + return iter_vars[0] + + class IfContext(StmtContext): def __init__(self, vb: StmtBuilderCore, cond: Expr): super().__init__(vb) @@ -315,11 +316,13 @@ def is_empty(self): def for_range( self, extent: Union[Expr, int], iter_name_hint: str = "i", unroll_factor: Optional[int] = None - ) -> ForContext: + ) -> ForRangeContext: iter_var = Var(iter_name_hint, type=int32) - return ForContext(self, [iter_var], [as_expr(extent)], [unroll_factor], unwrap=True) + return ForRangeContext(self, [iter_var], [as_expr(extent)], [unroll_factor]) - def for_grid(self, extents: List[Union[Expr, int]], iter_name_hints: Optional[List[str]] = None) -> ForContext: + def for_grid( + self, extents: Sequence[Union[Expr, int]], iter_name_hints: Optional[Sequence[str]] = None + ) -> ForContext: expr_extents = [as_expr(extent) for extent in extents] if iter_name_hints is None: names = "ijkpqrstuvw" @@ -1231,6 +1234,7 @@ def wait_barrier(self, barrier: Expr | RegisterTensor, phase: Expr | int | Regis phase = self.tensor_item_value(phase) elif isinstance(phase, int): phase = uint32(phase) + assert isinstance(phase, Expr) inst = WaitBarrierInst.create(barrier=barrier, phase=phase) self.append(inst) @@ -1245,6 +1249,7 @@ def cluster_launch_control_try_cancel( mbarrier = self.tensor_item_value(mbarrier) if isinstance(multicast, bool): multicast = boolean(multicast) + assert isinstance(multicast, Expr) inst = ClusterLaunchControlTryCancelInst.create(response=response, mbarrier=mbarrier, multicast=multicast) self.append(inst) diff --git a/python/tilus/ir/functors/functor.py b/python/tilus/ir/functors/functor.py index dcc3ec16..aa77f669 100644 --- a/python/tilus/ir/functors/functor.py +++ b/python/tilus/ir/functors/functor.py @@ -435,11 +435,7 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> RegisterLayout: return layout def visit_SharedLayout(self, layout: SharedLayout) -> SharedLayout: - offset = self.visit(layout.offset) - if offset is layout.offset: - return layout - else: - return SharedLayout(shape=layout.shape, size=layout.size, axes=layout.axes, offset=offset) + return layout def visit_GlobalLayout(self, layout: GlobalLayout) -> GlobalLayout: shape = self.visit(layout.shape) @@ -573,7 +569,7 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> None: pass def visit_SharedLayout(self, layout: SharedLayout) -> None: - self.visit(layout.offset) + pass def visit_GlobalLayout(self, layout: GlobalLayout) -> None: self.visit(layout.shape) diff --git a/python/tilus/ir/layout/cuda/tcgen05/smem.py b/python/tilus/ir/layout/cuda/tcgen05/smem.py index c4f788c9..c58daf1a 100644 --- a/python/tilus/ir/layout/cuda/tcgen05/smem.py +++ b/python/tilus/ir/layout/cuda/tcgen05/smem.py @@ -18,13 +18,10 @@ from typing import Literal, Optional, Sequence, cast import numpy as np -from hidet.ir.expr import Expr, Var from hidet.ir.type import DataType -from hidet.utils.py import prod from tilus.extensions.hidet.ir.primitives.cuda.tcgen05 import Tcgen05SwizzleMode from tilus.ir.layout.shared_layout import SharedLayout from tilus.ir.layout.utils.cute import CuteLayout, CuteSwizzle, IntTuple, SwizzledCuteLayout, cute_layout, tuple_product -from tilus.ir.utils.veceval import meshgrid, vectorized_evaluate from tilus.utils import floor_log2 # class Tcgen05SwizzleMode(Enum): @@ -163,11 +160,7 @@ def _generate_atom_grid(major_kind: Literal["MN", "K"], swizzle_mode: Tcgen05Swi major_kind=major_kind, swizzle_mode=swizzle_mode, SBO=0, LBO=0, m=1, k=1, T=t ) atom_layout = get_shared_layout_from_canonical(canonical_layout) - grid_axes = meshgrid(atom_layout.shape) - atom_grid = vectorized_evaluate( - expr=atom_layout.offset, var2value={axis: grid_axes[i] for i, axis in enumerate(atom_layout.axes)} - ) - return atom_grid + return atom_layout.as_numpy_grid() def canonicalize_shared_layout(shared_layout: SharedLayout, dtype: DataType) -> Optional[CanonicalSharedLayout]: @@ -188,6 +181,7 @@ def canonicalize_shared_layout(shared_layout: SharedLayout, dtype: DataType) -> ret: Optional[CanonicalSharedLayout] The canonical form if found, None otherwise """ + # todo: simplify the implementation of this function since we used a similar layout system as cute now if len(shared_layout.shape) != 2: return None @@ -197,10 +191,7 @@ def canonicalize_shared_layout(shared_layout: SharedLayout, dtype: DataType) -> T = 128 // dtype.nbits # Create meshgrid for the entire layout - grid_axes = meshgrid(shared_layout.shape) - entire_grid = vectorized_evaluate( - expr=shared_layout.offset, var2value={axis: grid_axes[i] for i, axis in enumerate(shared_layout.axes)} - ) + entire_grid = shared_layout.as_numpy_grid() entire_shape = shared_layout.shape # Try each swizzle mode and majorness using direct pattern analysis @@ -303,17 +294,14 @@ def get_shared_layout_from_canonical(canonical_layout: CanonicalSharedLayout) -> else: raise ValueError(f"Unsupported major_kind: {canonical_layout.major_kind}") - def f_offset(axes: Sequence[Var]) -> Expr | int: - nbytes = 16 // canonical_layout.T - swizzle = CuteSwizzle(bbits=bbits, mbase=mbase - floor_log2(nbytes), sshift=sshift) - return swizzle(layout(*axes)) + nbytes = 16 // canonical_layout.T + swizzle = CuteSwizzle(bbits=bbits, mbase=mbase - floor_log2(nbytes), sshift=sshift) + swizzled_cute_layout = SwizzledCuteLayout(layout, swizzle) - if not isinstance(layout.shape, Sequence): - smem_shape = [int(layout.shape)] - else: - smem_shape = [int(tuple_product(item)) for item in layout.shape] + assert isinstance(layout.shape, Sequence) + shape = [int(tuple_product(item)) for item in layout.shape] - return SharedLayout.create(shape=smem_shape, size=prod(smem_shape), f_offset=f_offset) + return swizzled_cute_layout.as_shared_layout(shape) def generate_canonical_layout( diff --git a/python/tilus/ir/layout/inference/inference.py b/python/tilus/ir/layout/inference/inference.py index fc9bb454..6f63a355 100644 --- a/python/tilus/ir/layout/inference/inference.py +++ b/python/tilus/ir/layout/inference/inference.py @@ -278,6 +278,12 @@ def pair_sort_key(pair: tuple[Instruction, Type[LayoutInferenceRule]]) -> tuple[ SharedTensor | RegisterTensor | TMemoryTensor, SharedTensor | RegisterTensor | TMemoryTensor ] = {} for tensor, layout in mapping.items(): + assert isinstance(tensor, (RegisterTensor, SharedTensor, TMemoryTensor)), ( + f"Invalid tensor type {type(tensor)} for rule {rule.__name__} " + ) + assert isinstance(layout, (RegisterLayout, SharedLayout, TMemoryLayout)), ( + f"Invalid layout type {type(layout)} for rule {rule.__name__} " + ) assert same_list(tensor.shape, layout.shape), ( f"Layout shape does not match tensor shape: {tensor.shape} vs {layout.shape} for rule {rule.__name__} " ) diff --git a/python/tilus/ir/layout/inference/inference_rules/load_shared.py b/python/tilus/ir/layout/inference/inference_rules/load_shared.py index e9c36fed..ce09e96d 100644 --- a/python/tilus/ir/layout/inference/inference_rules/load_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/load_shared.py @@ -19,6 +19,7 @@ from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixConfig from tilus.ir.layout import LayoutOperationError, ops from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule +from tilus.ir.layout.ops import shared_row_major_swizzle from tilus.utils import gcd @@ -45,9 +46,8 @@ def inference( continue # use swizzle layout since we are using ldmatrix instruction - from tilus.lang.modules.cuda import cuda - return {a: cuda.swizzled_shared_layout(dtype=a.dtype, shape=a.shape)} + return {a: shared_row_major_swizzle(dtype_nbytes=a.dtype.nbytes, shape=a.shape)} return {} @@ -85,8 +85,7 @@ def inference( if not (shared.has_layout() and not register.has_layout()): return {} - axes = shared.layout.axes - offset = shared.layout.offset + axes, offset = shared.layout.as_axes_mapping() info = analyze_grid( shape=shared.shape, diff --git a/python/tilus/ir/layout/inference/inference_rules/store_shared.py b/python/tilus/ir/layout/inference/inference_rules/store_shared.py index be450b65..5d6d61b7 100644 --- a/python/tilus/ir/layout/inference/inference_rules/store_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/store_shared.py @@ -18,6 +18,7 @@ from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixConfig from tilus.ir.layout import LayoutOperationError, ops from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule +from tilus.ir.layout.ops import shared_row_major_swizzle @register_rule(StoreSharedGenericInst) @@ -42,8 +43,7 @@ def inference( continue # use swizzle layout since we are using ldmatrix instruction - from tilus.lang.modules.cuda import cuda - return {a: cuda.swizzled_shared_layout(dtype=a.dtype, shape=a.shape)} + return {a: shared_row_major_swizzle(dtype_nbytes=a.dtype.nbytes, shape=a.shape)} return {} diff --git a/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py b/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py index 60243e1f..e35e6142 100644 --- a/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py +++ b/python/tilus/ir/layout/inference/inference_rules/tcgen05/mma.py @@ -61,7 +61,7 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05MmaSSInst) -> dict[Share a_layout_canonical = generate_canonical_layout( shape=(m, k), dtype=a_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[a_tensor] = a_layout_canonical.as_shared_layout().simplify() + ret[a_tensor] = a_layout_canonical.as_shared_layout() except ValueError: continue else: @@ -77,7 +77,7 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05MmaSSInst) -> dict[Share b_layout_canonical = generate_canonical_layout( shape=(n, k), dtype=b_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]).simplify() + ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]) except ValueError: continue else: @@ -123,7 +123,6 @@ def inference(ctx: LayoutInferenceContext, inst: Tcgen05MmaTSInst) -> dict[Share ) .as_shared_layout() .permute(dims=[1, 0]) - .simplify() ) ret[b_tensor] = b_layout except ValueError: diff --git a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py index 172d8627..98244436 100644 --- a/python/tilus/ir/layout/inference/inference_rules/transform_shared.py +++ b/python/tilus/ir/layout/inference/inference_rules/transform_shared.py @@ -28,13 +28,20 @@ def inference(ctx: LayoutInferenceContext, inst: SliceSharedInst) -> dict[Shared if a.optional_layout is not None and b.optional_layout is not None: return {} elif a.optional_layout is not None: - return {b: a.layout.slice(offsets=inst.offsets, slice_dims=inst.dims, slice_shape=b.shape).simplify()} + if inst.dims is None: + dims = list(range(len(a.shape))) + else: + dims = list(inst.dims) + return {b: a.layout.slice(retain_dims=dims)} elif b.optional_layout is not None: b_layout = b.layout.unsqueeze(dims=range(len(a.shape) - len(b.shape))) outer_shape = [] for i in range(len(a.shape)): outer_shape.append(a.shape[i] // b_layout.shape[i]) - return {a: shared_compose(shared_row_major(*outer_shape), b_layout).simplify()} + layout = shared_compose(shared_row_major(*outer_shape), b_layout) + if b.layout.optional_swizzle is not None: + layout = layout.apply_swizzle(b.layout.swizzle) + return {a: layout} else: return {} diff --git a/python/tilus/ir/layout/inference/inference_rules/wgmma.py b/python/tilus/ir/layout/inference/inference_rules/wgmma.py index 630c99f6..f887fc34 100644 --- a/python/tilus/ir/layout/inference/inference_rules/wgmma.py +++ b/python/tilus/ir/layout/inference/inference_rules/wgmma.py @@ -45,7 +45,9 @@ def generate_wgmma_register_layout(m: int, n: int, inst_m: int, inst_n: int, ins @register_rule(WgmmaMmaSSInst) class WgmmaMmaSSRule(LayoutInferenceRule): @staticmethod - def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedTensor, SharedLayout]: + def inference( + ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst + ) -> dict[SharedTensor | RegisterTensor, SharedLayout | RegisterLayout]: a_tensor: SharedTensor = inst.inputs[0].as_shared_tensor() b_tensor: SharedTensor = inst.inputs[1].as_shared_tensor() d_tensor: RegisterTensor = inst.inputs[2].as_register_tensor() @@ -64,7 +66,7 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT ) m, n, k = d_shape[0], d_shape[1], a_shape[1] - ret = {} + ret: dict[SharedTensor | RegisterTensor, SharedLayout | RegisterLayout] = {} if not a_tensor.has_layout(): for swizzle_mode in [ Tcgen05SwizzleMode.B128_SWIZZLE, @@ -76,7 +78,7 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT a_layout_canonical = generate_canonical_layout( shape=(m, k), dtype=a_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[a_tensor] = a_layout_canonical.as_shared_layout().simplify() + ret[a_tensor] = a_layout_canonical.as_shared_layout() except ValueError: continue else: @@ -92,7 +94,7 @@ def inference(ctx: LayoutInferenceContext, inst: WgmmaMmaSSInst) -> dict[SharedT b_layout_canonical = generate_canonical_layout( shape=(n, k), dtype=b_tensor.dtype, major_kind="K", swizzle_mode=swizzle_mode ) - ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]).simplify() + ret[b_tensor] = b_layout_canonical.as_shared_layout().permute(dims=[1, 0]) except ValueError: continue else: diff --git a/python/tilus/ir/layout/ops/__init__.py b/python/tilus/ir/layout/ops/__init__.py index 3bb326ba..db421638 100644 --- a/python/tilus/ir/layout/ops/__init__.py +++ b/python/tilus/ir/layout/ops/__init__.py @@ -38,12 +38,7 @@ squeeze, unsqueeze, ) -from .shared_ops import ( - shared_column_major, - shared_compose, - shared_permute, - shared_row_major, -) +from .shared_ops import shared_column_major, shared_compose, shared_permute, shared_row_major, shared_row_major_swizzle from .tmemory_ops import ( tmemory_row_major, tmemory_slice, diff --git a/python/tilus/ir/layout/ops/shared_ops.py b/python/tilus/ir/layout/ops/shared_ops.py index 31bb7bc8..bf3cba51 100644 --- a/python/tilus/ir/layout/ops/shared_ops.py +++ b/python/tilus/ir/layout/ops/shared_ops.py @@ -17,42 +17,33 @@ from typing import List, Sequence import tabulate -from hidet.ir.dtypes import int32 -from hidet.ir.expr import Expr, Var -from hidet.utils import prod +from hidet.utils import gcd, prod -from tilus.extensions.hidet.ir.utils.index_transform import vector_mul -from tilus.ir.layout.ops.utils import LayoutOperationError -from tilus.ir.layout.shared_layout import SharedLayout -from tilus.ir.utils.veceval import meshgrid, vectorized_evaluate +from tilus.ir.layout.ops.utils import LayoutOperationError, get_mode_groups +from tilus.ir.layout.shared_layout import SharedLayout, Swizzle, shared_layout -def _generic_repeat(shape: List[int], ranks: List[int]) -> SharedLayout: - assert len(shape) == len(ranks) - assert len(ranks) == len(set(ranks)) and all(0 <= d < len(shape) for d in ranks) - strides: List[int] = [prod([s for j, s in enumerate(shape) if ranks[j] > ranks[i]]) for i in range(len(shape))] - - def f_offset(axes: Sequence[Var]) -> Expr: - return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero) - - return SharedLayout.create(shape=shape, size=prod(shape), f_offset=f_offset) - - -def _shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout: - assert len(lhs.shape) == len(rhs.shape) - ndims = len(lhs.shape) - - def f_offset(axes: Sequence[Var]) -> Expr: - lhs_axes = [axes[i] // rhs.shape[i] for i in range(ndims)] - rhs_axes = [axes[i] % rhs.shape[i] for i in range(ndims)] - lhs_offset = lhs(*lhs_axes) - rhs_offset = rhs(*rhs_axes) - return lhs_offset * rhs.size + rhs_offset +def strides_from_ranks(shape: Sequence[int], ranks: Sequence[int]) -> list[int]: + """ + Compute the strides from the ranks of each dimension. - shape = vector_mul(lhs.shape, rhs.shape) - size = lhs.size * rhs.size + Parameters + ---------- + shape: Sequence[int] + The shape of the tensor. + ranks: Sequence[int] + The ranks of each dimension. The length of ranks must be equal to the length of shape + and all elements in ranks must be unique and in the range [0, len(shape)). - return SharedLayout.create(shape=shape, size=size, f_offset=f_offset) + Returns + ------- + ret: list[int] + The strides of each dimension. + """ + assert len(shape) == len(ranks) + assert len(ranks) == len(set(ranks)) and all(0 <= d < len(shape) for d in ranks) + strides: list[int] = [prod([s for j, s in enumerate(shape) if ranks[j] > ranks[i]]) for i in range(len(shape))] + return strides def shared_row_major(*shape: int) -> SharedLayout: @@ -68,7 +59,9 @@ def shared_row_major(*shape: int) -> SharedLayout: ret: SharedLayout A shared layout with the specified shape in row-major order. """ - return _generic_repeat(shape=list(shape), ranks=list(range(len(shape)))) + mode_shape = shape + mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(range(len(mode_shape)))) + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None) def shared_column_major(*shape: int) -> SharedLayout: @@ -84,10 +77,12 @@ def shared_column_major(*shape: int) -> SharedLayout: ret: SharedLayout A shared layout with the specified shape in column-major order. """ - return _generic_repeat(shape=list(shape), ranks=list(reversed(range(len(shape))))) + mode_shape = shape + mode_strides = strides_from_ranks(shape=mode_shape, ranks=list(reversed(range(len(mode_shape))))) + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None) -def shared_compose(lhs: SharedLayout, rhs: SharedLayout, *others: SharedLayout) -> SharedLayout: +def shared_compose(lhs: SharedLayout, rhs: SharedLayout) -> SharedLayout: """Compose multiple shared layouts together. Parameters @@ -96,18 +91,34 @@ def shared_compose(lhs: SharedLayout, rhs: SharedLayout, *others: SharedLayout) The first shared layout to compose. rhs: SharedLayout The second shared layout to compose. - others: Sequence[SharedLayout] - The additional shared layouts to compose with the first two. It can be empty. Returns ------- ret: SharedLayout The composed shared layout. """ - if len(others) == 0: - return _shared_compose(lhs, rhs) - else: - return shared_compose(_shared_compose(lhs, rhs), *others) + assert len(lhs.shape) == len(rhs.shape) + ndims = len(lhs.shape) + + # shape + shape = tuple(lhs.shape[i] * rhs.shape[i] for i in range(ndims)) + + # mode shape + lhs_mode_groups = get_mode_groups(lhs.shape, lhs.mode_shape) + rhs_mode_groups = get_mode_groups(rhs.shape, rhs.mode_shape) + mode_shape: list[int] = [] + for lhs_group, rhs_group in zip(lhs_mode_groups, rhs_mode_groups): + mode_shape.extend([lhs.mode_shape[i] for i in lhs_group]) + mode_shape.extend([rhs.mode_shape[i] for i in rhs_group]) + + # mode strides + mode_strides: list[int] = [] + rhs_size = rhs.count_size() + for lhs_group, rhs_group in zip(lhs_mode_groups, rhs_mode_groups): + mode_strides.extend([stride * rhs_size for stride in (lhs.mode_strides[i] for i in lhs_group)]) + mode_strides.extend([rhs.mode_strides[i] for i in rhs_group]) + + return shared_layout(shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=None) def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: @@ -127,11 +138,245 @@ def shared_permute(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: ret: SharedLayout The permuted layout. """ - if set(dims) != set(range(len(layout.shape))): - raise LayoutOperationError("Dims must be a permutation of {}, got {}".format(range(len(layout.shape)), dims)) + assert len(dims) == len(layout.shape) and set(dims) == set(range(len(layout.shape))) + + # shape shape = tuple(layout.shape[d] for d in dims) - axes = tuple(layout.axes[d] for d in dims) - return SharedLayout(shape=shape, size=layout.size, axes=axes, offset=layout.offset) + + # mode shape and mode strides + layout_mode_groups = get_mode_groups(layout.shape, layout.mode_shape) + mode_shape: list[int] = [] + mode_strides: list[int] = [] + for d in dims: + mode_shape.extend([layout.mode_shape[i] for i in layout_mode_groups[d]]) + mode_strides.extend([layout.mode_strides[i] for i in layout_mode_groups[d]]) + + return shared_layout( + shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=layout.optional_swizzle + ) + + +def shared_slice(layout: SharedLayout, retain_dims: Sequence[int]) -> SharedLayout: + """Slice the shared layout by removing specified dimensions. + + Parameters + ---------- + layout: SharedLayout + The layout to slice. + dims: Sequence[int] + The dimensions to slice. Each dimension should be in the range [0, len(layout.shape)). The dimensions will + be kept in the output layout. + + Returns + ------- + ret: SharedLayout + The sliced layout. + """ + assert all(0 <= d < len(layout.shape) for d in retain_dims) and len(retain_dims) == len(set(retain_dims)) + shape: List[int] = [] + mode_shape: List[int] = [] + mode_strides: List[int] = [] + layout_mode_groups = get_mode_groups(layout.shape, layout.mode_shape) + for i in retain_dims: + shape.append(layout.shape[i]) + mode_shape.extend([layout.mode_shape[j] for j in layout_mode_groups[i]]) + mode_strides.extend([layout.mode_strides[j] for j in layout_mode_groups[i]]) + + return shared_layout( + shape=shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + optional_swizzle=layout.optional_swizzle, + ) + + +def shared_unsqueeze(layout: SharedLayout, dims: Sequence[int]) -> SharedLayout: + """Unsqueeze the shared layout by adding new dimensions of size 1. + + Parameters + ---------- + layout: SharedLayout + The layout to unsqueeze. + dims: Sequence[int] + The dimensions to unsqueeze. Each dimension should be in the range [0, len(layout.shape)]. + + Returns + ------- + ret: SharedLayout + The unsqueezed layout. + """ + assert all(0 <= d <= len(layout.shape) for d in dims) and len(dims) == len(set(dims)) + shape: List[int] = list(layout.shape) + for d in sorted(dims): + shape.insert(d, 1) + return shared_layout( + shape=shape, + mode_shape=layout.mode_shape, + mode_strides=layout.mode_strides, + optional_swizzle=layout.optional_swizzle, + ) + + +def shared_row_major_swizzle(shape: Sequence[int], dtype_nbytes: int) -> SharedLayout: + """ + Generate a shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. + + Both m and n must be a multiple of 8. + + We will divide each row into bank groups, and bank group has 16 bytes (16 x uint8, 8 x fp16, or 4 x fp32, etc.). + They correspond to 4 banks in shared memory. For example, if we have m = n = 8 and dtype=fp16, we can represent + bank groups as + + 0 # bank group 0, banks from 0 to 3 + 1 # bank group 1, banks from 4 to 7 + 2 # ... + 3 + 4 + 5 + 6 + 7 # bank groups 7, banks from 28 to 31 + + Given m, and n, we need to find a proper way to organize the m x (n / 8) bank groups in shared memory, so that + 1) each row has different bank groups + 2) each column has different bank groups + + When we have m = 8 and n = 64, we have 8 x 8 bank groups. If we store the elements in row-major order, we will + have the bank groups as + + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + 0 1 2 3 4 5 6 7 + + If we use ldmatrix to load the above 8 x 64 shared memory, we will need 8 ldmatrix.v1 instructions. Each instruction + loads one column (8 x 8 elements, or 8 x 1 bank groups). Since each instruction will access the same bank group, + severe bank conflicts will occur. Thus, we need to change the layout of shared memory to avoid bank conflicts. + + Let layout(i, j) be the shared memory address of logical elements (each element has 16 bytes) when we use + a specific `layout`. For example, the row-major layout row-major(i, j) = i * n + j * 8 (we assume the dtype has 2 + bytes). If we use the swizzled layout swizzled(i, j) = row-major(i, j ^ i) = i * n + (j ^ i) * 8, we can have the + following bank groups in shared memory. + + 0 1 2 3 4 5 6 7 + 1 0 3 2 5 4 7 6 + 2 3 0 1 6 7 4 5 + 3 2 1 0 7 6 5 4 + 4 5 6 7 0 1 2 3 + 5 4 7 6 1 0 3 2 + 6 7 4 5 2 3 0 1 + 7 6 5 4 3 2 1 0 + + (reader may need some time to figure out the above layout...) + + This layout has two benefits: + 1) Each row has different bank groups. In above example, we have 32 banks per row. + 2) Each column has different bank groups. In above example, we have 32 banks per column. + + The benefit 1 makes sure that when we load data from global memory to shared memory, we can store efficiently. + The benefit 2 makes sure that when we load data from shared memory to register memory, we can load efficiently. + + We can always generate the swizzled layout for arbitrary m and n as long as they are multiple of 8. See the + implementation for more details. + + Parameters + ---------- + shape: Sequence[int] + The shape of the shared memory. The shape must have at least two dimensions. + + dtype_nbytes: int + The element data type size in bytes. + + Returns + ------- + shared_layout: SharedLayout + The shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. + """ + if len(shape) < 2: + raise ValueError("The shape of swizzled shared layout must have at least two dimensions.") + head, m, n = tuple(shape[:-2]), shape[-2], shape[-1] + + if m % 8 != 0 or n * dtype_nbytes % 16 != 0: + raise ValueError("m must be a multiple of 8, and n * dtype_nbytes must be a multiple of 16.") + + n_vector_size: int = gcd(n, 128 // dtype_nbytes) + n_num_vectors: int = n // n_vector_size + + mode_shape = head + (m, n_num_vectors, n_vector_size) + + # use the order of head, columns_vectors, rows, columns_vec_size to compute the strides + ranks = list(range(len(head))) + [len(head) + 1, len(head), len(head) + 2] + mode_strides = strides_from_ranks(shape=mode_shape, ranks=ranks) + + log2 = { + 1: 0, + 2: 1, + 4: 2, + 8: 3, + 16: 4, + } + + if n_vector_size * dtype_nbytes == 128: + """ + (each number represents a 16-byte group of elements) + 0 1 2 3 4 5 6 7 + 1 0 3 2 5 4 7 6 + 2 3 0 1 6 7 4 5 + 3 2 1 0 7 6 5 4 + 4 5 6 7 0 1 2 3 + 5 4 7 6 1 0 3 2 + 6 7 4 5 2 3 0 1 + 7 6 5 4 3 2 1 0 + """ + swizzle = Swizzle(base=log2[16 // dtype_nbytes], bits=3, shift=3) + elif n_vector_size * dtype_nbytes == 64: + """ + 0 1 2 3 + 4 5 6 7 + 1 0 3 2 + 5 4 7 6 + 2 3 0 1 + 6 7 4 5 + 3 2 1 0 + 7 6 5 4 + """ + swizzle = Swizzle(base=log2[16 // dtype_nbytes], bits=2, shift=3) + elif n_vector_size * dtype_nbytes == 32: + """ + 0 1 + 2 3 + 4 5 + 6 7 + 1 0 + 3 2 + 5 4 + 7 6 + """ + swizzle = Swizzle(base=log2[16 // dtype_nbytes], bits=1, shift=3) + elif n_vector_size * dtype_nbytes == 16: + """ + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + """ + swizzle = None + else: + assert False + + return shared_layout( + shape=shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + optional_swizzle=swizzle, + ) def visualize_layout(layout: SharedLayout, tablefmt: str = "simple_grid") -> str: @@ -169,12 +414,11 @@ def visualize_layout(layout: SharedLayout, tablefmt: str = "simple_grid") -> str if len(layout.shape) != 2: raise LayoutOperationError(f"Shared layout with shape {layout.shape} is not supported for visualization.") - grid = meshgrid(layout.shape) - offset_grid = vectorized_evaluate(layout.offset, var2value={axis: grid[i] for i, axis in enumerate(layout.axes)}) + grid = layout.as_numpy_grid() table = [] for i in range(layout.shape[0]): row = [] for j in range(layout.shape[1]): - row.append(f"{offset_grid[i, j]}") + row.append(f"{grid[i, j]}") table.append(row) return head + "\n" + tabulate.tabulate(table, tablefmt=tablefmt) diff --git a/python/tilus/ir/layout/ops/utils.py b/python/tilus/ir/layout/ops/utils.py index 4e0699f8..bd34fb43 100644 --- a/python/tilus/ir/layout/ops/utils.py +++ b/python/tilus/ir/layout/ops/utils.py @@ -25,6 +25,12 @@ def get_mode_groups(shape: Sequence[int], mode_shape: Sequence[int]) -> list[lis """ Get the groups of modes based on the shape and mode_shape. + Example: + >>> shape = [64, 32] + >>> mode_shape = [8, 8, 16, 2] + >>> get_mode_groups(shape, mode_shape) + [[0, 1], [2, 3]] + Parameters ---------- shape: Sequence[int] diff --git a/python/tilus/ir/layout/shared_layout.py b/python/tilus/ir/layout/shared_layout.py index 3636c59b..c9c8dbe2 100644 --- a/python/tilus/ir/layout/shared_layout.py +++ b/python/tilus/ir/layout/shared_layout.py @@ -15,40 +15,78 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, Dict, List, Sequence +from typing import Optional, Sequence +import numpy as np from hidet.ir.expr import Expr, Var, as_expr +from hidet.ir.utils.index_transform import index_deserialize +from hidet.utils import prod from tilus.extensions.hidet.ir.expr import index_vars from tilus.ir.node import IRNode +from tilus.ir.utils.veceval import vectorized_evaluate + + +@dataclass(frozen=True, eq=True) +class Swizzle: + """ + A swizzle function. + + 0xxxxYYYxxZZZxxxx + z_mask = ((1 << self.bbits) - 1) << self.mbase + y_mask = ((1 << self.bbits) - 1) << (self.mbase + self.sshift) + return offset ^ ((offset & y_mask) >> self.sshift) + """ + + base: int + bits: int + shift: int + + def __call__(self, index: Expr) -> Expr: + # we use a primitive function to here + # todo: use general computation after refactor to cute-like shared layout + from tilus.extensions.hidet.ir.primitives.swizzle import swizzle + + if self.bits == 0: + return index + return swizzle(index, self.base, self.bits, self.shift) + + def __str__(self): + return f"Swizzle(base={self.base}, bits={self.bits}, shift={self.shift})" @dataclass(frozen=True, eq=False) class SharedLayout(IRNode): """The layout for shared tensor. + We use three components to describe a shared tensor layout: the shape, the mode shape, and the mode strides. + + The mode shape and mode strides are used to describe how to split each dimension into multiple sub-dimensions (modes), + and the strides of each mode. + + For example, consider a shape of (64, 32), we can split the first dimension into two sub-dimensions (modes) of size 8 and 8, + and the second dimension into two sub-dimensions (modes) of size 16 and 2. The mode shape would be (8, 8, 16, 2). We can + have strides for each mode, for example, (256, 2, 16, 1). Then given the indices (i, j), we can compute the indices in the + sub-dimensions (i1, i2, j1, j2) where i1 = i // 8, i2 = i % 8, j1 = j // 2, j2 = j % 2. The offset can be computed as: + offset = i1 * 256 + i2 * 2 + j1 * 16 + j2 * 1. To get the final offset in the shared tensor, we can use the formula: + (i, j) => ((i // 8) * 256) + ((i % 8) * 2) + ((j // 2) * 16) + ((j % 2) * 1). + Attributes ---------- shape: tuple[int, ...] The shape of the shared tensor. Each dimension is a constant integer. - size: int - The storage size of the shared tensor, in number of elements. If the layout is a `compact` layout, size - should be equal to the product of the shape dimensions. Otherwise, it can be either larger (in case of padding) - or smaller (in case of sharing data for different elements) than the product of the shape dimensions. The - size must be a constant integer. - axes: tuple[Var, ...] - The axes of the shared tensor. Each axis is a variable that represents the index of the corresponding dimension. - It should have the same length as the shape. - offset: Expr - The offset expression of the shared tensor based on the axes. It is an expression that computes the offset - of the shared tensor based on the axes. Only the axes and variables that are invariant in the lifetime of the - given corresponding shared tensor with this layout can be used in the expression. + mode_shape: tuple[int, ...] + We can split each dimension into multiple sub-dimensions (modes). + mode_strides: tuple[int, ...] + The strides of each mode. + swizzle: Optional[Swizzle] + The swizzle function to apply on the final offset. If None, no swizzling is applied. """ shape: tuple[int, ...] - size: int - axes: tuple[Var, ...] - offset: Expr + mode_shape: tuple[int, ...] + mode_strides: tuple[int, ...] + optional_swizzle: Optional[Swizzle] def __call__(self, *indices: Expr) -> Expr: """Compute the offset on given indices. @@ -65,89 +103,122 @@ def __call__(self, *indices: Expr) -> Expr: ret: Expr The computed offset of the shared tensor element at the given indices. """ - assert len(indices) == len(self.axes) - from hidet.ir.tools import rewrite + from tilus.ir.layout.ops.utils import get_mode_groups - return rewrite(self.offset, rewrite_map={axis: index for axis, index in zip(self.axes, indices)}) + # get the stride-based index + group_modes = get_mode_groups(self.shape, self.mode_shape) + mode_indices: list[Expr] = [] + for index, modes in zip(indices, group_modes): + mode_indices.extend(index_deserialize(index, shape=[self.mode_shape[m] for m in modes])) + total_index: Expr = as_expr(sum(index * stride for index, stride in zip(mode_indices, self.mode_strides))) - @staticmethod - def create(shape: Sequence[int], size: int, f_offset: Callable[[Sequence[Var]], Expr | int]) -> SharedLayout: - """Create a shared layout. + # apply swizzle if exists + if self.optional_swizzle is not None: + total_index = self.optional_swizzle(total_index) + + return total_index + + @property + def swizzle(self) -> Swizzle: + if self.optional_swizzle is None: + raise ValueError("No swizzle is applied on this layout.") + return self.optional_swizzle - This method creates a shared layout with the given shape, size, and a function to compute the offset based on - the axes. The shape must be a sequence of constant integers, and the size must be a constant integer that is - larger than the maximum possible offset computed by the `f_offset` function. + @staticmethod + def create( + shape: Sequence[int], + mode_shape: Sequence[int], + mode_strides: Sequence[int], + optional_swizzle: Optional[Swizzle], + ) -> SharedLayout: + """ + Create a SharedLayout from shape, mode_shape, and mode_strides. Parameters ---------- shape: Sequence[int] - The shape of the shared tensor. Each dimension is a constant integer. - size: int - The storage size of the shared tensor, in number of elements. - f_offset: Callable[[Sequence[Var]], Expr] - The function that computes the offset of the shared tensor based on the axes. It takes a sequence of - axes (variables) and returns an expression that computes the offset. The function must ensure that the - size is larger than the maximum possible offset computed by this function. + The shape of the shared tensor. + mode_shape: Sequence[int] + The mode shape of the shared tensor. + mode_strides: Sequence[int] + The mode strides of the shared tensor. + swizzle: Optional[Swizzle] + The swizzle function to apply on the final offset. If None, no swizzling is applied. Returns ------- ret: SharedLayout - A shared layout with the specified shape, size, axes, and offset. + The created SharedLayout. """ - axes: List[Var] = index_vars(num_vars=len(shape)) - return SharedLayout(shape=tuple(shape), size=size, axes=tuple(axes), offset=as_expr(f_offset(axes))) - - def slice(self, offsets: Sequence[Expr], slice_dims: Sequence[int], slice_shape: Sequence[int]) -> SharedLayout: - assert len(set(slice_dims)) == len(slice_dims), "slice_dims must be unique" - assert len(slice_shape) == len(slice_dims), "slice_dims and slice_shape must have the same length" - assert len(slice_dims) <= len(self.shape), "slice_dims must be less than or equal to the number of dimensions" - - def f_offset(axes: Sequence[Var]) -> Expr: - indices: List[Expr] = list(offsets) - for dim, axis in zip(slice_dims, axes): - indices[dim] = indices[dim] + axis - return self(*indices) - self(*offsets) - - return SharedLayout.create(shape=slice_shape, size=self.size, f_offset=f_offset) - - def simplify(self) -> SharedLayout: - from tilus.extensions.hidet.ir.tools import simplify_expr - from tilus.extensions.hidet.transforms.rule_based_simplifier import BoundInfo, RuleBasedSimplifier - - var2bound: Dict[Var, BoundInfo] = { - axis: BoundInfo(min_value=0, max_value=extent - 1) for axis, extent in zip(self.axes, self.shape) - } - simplifier = RuleBasedSimplifier(var2bound=var2bound) + if any(s < 1 for s in shape): + raise ValueError("All dimensions in shape must be positive integers.") + if len(mode_shape) != len(mode_strides): + raise ValueError("mode_shape and mode_strides must have the same length.") + if prod(mode_shape) != prod(shape): + raise ValueError("The product of mode_shape must equal to the product of shape.") return SharedLayout( - shape=self.shape, size=self.size, axes=self.axes, offset=simplify_expr(simplifier(self.offset)) + shape=tuple(shape), + mode_shape=tuple(mode_shape), + mode_strides=tuple(mode_strides), + optional_swizzle=optional_swizzle, ) - def swizzle(self, dim: int, regards_dim: int, log_step: int) -> SharedLayout: - ndims = len(self.shape) - assert 0 <= dim < ndims and 0 <= regards_dim < ndims and dim != regards_dim + def as_numpy_grid(self) -> np.ndarray: + grid_axes = np.meshgrid(*[np.arange(extent) for extent in self.shape], indexing="ij") + axes = index_vars(num_vars=len(self.shape)) + offset = self(*axes) + atom_grid = vectorized_evaluate(expr=offset, var2value={axis: grid_axes[i] for i, axis in enumerate(axes)}) + return atom_grid - def get_xor_index(indices: Sequence[Expr]) -> Expr: - indices = list(indices) # copy - step = 2**log_step - regards_index = indices[regards_dim] // step - regards_extent = self.shape[regards_dim] // step - if regards_extent > self.shape[dim]: - regards_index = regards_index % self.shape[dim] - return regards_index + def as_axes_mapping(self) -> tuple[list[Var], Expr]: + axes = index_vars(num_vars=len(self.shape)) + offset = self(*axes) + return axes, offset - def f_offset(axes: Sequence[Var]) -> Expr: - swizzled_indices: List[Expr] = [axis for axis in axes] - swizzled_indices[dim] = swizzled_indices[dim] ^ get_xor_index(axes) - return self(*swizzled_indices) + def count_size(self) -> int: + """Count the total size of the shared layout. - return SharedLayout.create(shape=self.shape, size=self.size, f_offset=f_offset) + It is the minimum number of elements required to store the tensor in shared memory. - def prepend_dim(self, extent: int) -> SharedLayout: - def f_offset(axes: Sequence[Var]) -> Expr: - tile_offset = axes[0] * self.size - return tile_offset + self(*axes[1:]) + Returns + ------- + ret: int + The total size of the shared layout. + """ + indices = [extent - 1 for extent in self.mode_shape] + max_index = sum(a * b for a, b in zip(indices, self.mode_strides)) + return max_index + 1 + + def slice(self, retain_dims: Sequence[int]) -> SharedLayout: + from tilus.ir.layout.ops.shared_ops import shared_slice + + return shared_slice(self, retain_dims) + + def apply_swizzle(self, swizzle: Swizzle) -> SharedLayout: + if self.optional_swizzle is not None: + raise RuntimeError("Chained swizzle is not supported.") + return SharedLayout.create( + shape=self.shape, + mode_shape=self.mode_shape, + mode_strides=self.mode_strides, + optional_swizzle=swizzle, + ) - return SharedLayout.create(shape=(extent,) + self.shape, size=extent * self.size, f_offset=f_offset) + def prepend_dim(self, extent: int) -> SharedLayout: + shape = (extent,) + self.shape + if extent > 1: + mode_shape = (extent,) + self.mode_shape + mode_strides = (self.count_size(),) + self.mode_strides + else: + mode_shape = self.mode_shape + mode_strides = self.mode_strides + + return SharedLayout.create( + shape=shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + optional_swizzle=self.optional_swizzle, + ) def transpose(self) -> SharedLayout: assert len(self.shape) == 2 @@ -159,22 +230,55 @@ def permute(self, dims: Sequence[int]) -> SharedLayout: return shared_permute(self, dims) def unsqueeze(self, dims: Sequence[int]) -> SharedLayout: - shape = [] - cur_dim = 0 - for i in range(len(self.shape) + len(dims)): - if i in dims: - shape.append(1) - else: - shape.append(self.shape[cur_dim]) - cur_dim += 1 - - def f_offset(axes: Sequence[Var]) -> Expr: - base_axes = [axis for i, axis in enumerate(axes) if i not in dims] - return self(*base_axes) + from tilus.ir.layout.ops.shared_ops import shared_unsqueeze - return SharedLayout.create(shape=shape, size=self.size, f_offset=f_offset) + return shared_unsqueeze(self, dims) def visualize(self, tablefmt: str = "simple_grid") -> str: from tilus.ir.layout.ops.shared_ops import visualize_layout return visualize_layout(self, tablefmt=tablefmt) + + +def shared_layout( + shape: Sequence[int], + mode_shape: Sequence[int], + mode_strides: Sequence[int], + optional_swizzle: Optional[Swizzle] = None, +) -> SharedLayout: + """Create a SharedLayout from shape, mode_shape, and mode_strides. + + Parameters + ---------- + shape: Sequence[int] + The shape of the shared tensor. + mode_shape: Sequence[int] + The mode shape of the shared tensor. + mode_strides: Sequence[int] + The mode strides of the shared tensor. + swizzle: Optional[Swizzle] + The swizzle function to apply on the final offset. If None, no swizzling is applied. + + Returns + ------- + ret: SharedLayout + The created SharedLayout. + """ + # canonicalize mode shape: clean up mode_shape and mode_strides by removing size 1 modes + if any(s <= 1 for s in mode_shape): + updated_mode_shape = [] + updated_mode_strides = [] + for ms, stride in zip(mode_shape, mode_strides): + if ms > 1: + updated_mode_shape.append(ms) + updated_mode_strides.append(stride) + mode_shape = updated_mode_shape + mode_strides = updated_mode_strides + + # canonicalize swizzle: if swizzle has 0 bits, set it to None (both mean no swizzle) + if optional_swizzle is not None and optional_swizzle.bits == 0: + optional_swizzle = None + + return SharedLayout.create( + shape=shape, mode_shape=mode_shape, mode_strides=mode_strides, optional_swizzle=optional_swizzle + ) diff --git a/python/tilus/ir/layout/utils/cute.py b/python/tilus/ir/layout/utils/cute.py index 6093c6b1..09327b16 100644 --- a/python/tilus/ir/layout/utils/cute.py +++ b/python/tilus/ir/layout/utils/cute.py @@ -21,6 +21,7 @@ from hidet.utils import prod from tilus.extensions.hidet.ir.primitives.swizzle import swizzle from tilus.extensions.hidet.ir.utils.index_transform import index_deserialize +from tilus.ir.layout.shared_layout import SharedLayout, Swizzle, shared_layout Int = Union[Expr, int] IntTuple = Int | Sequence[Union[Int, "IntTuple"]] @@ -79,8 +80,7 @@ def __str__(self) -> str: return f"{self.shape}:{self.strides}" def __call__(self, *coords: IntTuple) -> Int: - coords = specialize(coords, self.shape) - ret = tuple_sum(tuple_multiply(coords, self.strides)) + ret = tuple_sum(tuple_multiply(specialize(coords, self.shape), self.strides)) return ret @property @@ -110,6 +110,9 @@ def __call__(self, offset: Int) -> Int: # return offset ^ ((offset & y_mask) >> self.sshift) return swizzle(int32(offset), self.mbase, self.bbits, self.sshift) + def as_swizzle(self) -> Swizzle: + return Swizzle(base=self.mbase, bits=self.bbits, shift=self.sshift) + class SwizzledCuteLayout: def __init__(self, layout: CuteLayout, swizzle: CuteSwizzle): @@ -122,6 +125,43 @@ def __str__(self) -> str: def __call__(self, *coords: IntTuple) -> Int: return self.swizzle(self.layout(*coords)) + def as_shared_layout(self, tensor_shape: Sequence[int]) -> SharedLayout: + # since cute layout use column-major order when splitting modes, we need to reverse the shape and strides + def reverse_int_tuple(t: IntTuple) -> IntTuple: + if isinstance(t, Sequence): + return tuple(reverse_int_tuple(item) for item in reversed(t)) + else: + return t + + assert isinstance(self.layout.shape, Sequence) + assert isinstance(self.layout.strides, Sequence) + + rev_shape = [reverse_int_tuple(item) for item in self.layout.shape] + rev_strides = [reverse_int_tuple(item) for item in self.layout.strides] + + # then, we flatten them into 1D lists + def flatten_int_tuple(t: IntTuple) -> list[Int]: + if isinstance(t, Sequence): + result = [] + for item in t: + result.extend(flatten_int_tuple(item)) + return result + else: + return [t] + + flat_shape = flatten_int_tuple(rev_shape) + flat_strides = flatten_int_tuple(rev_strides) + + mode_shape = [int(s) for s in flat_shape] + mode_strides = [int(s) for s in flat_strides] + + return shared_layout( + shape=tensor_shape, + mode_shape=mode_shape, + mode_strides=mode_strides, + optional_swizzle=self.swizzle.as_swizzle(), + ) + def cute_layout(shape: IntTuple, strides: IntTuple) -> CuteLayout: return CuteLayout(shape, strides) diff --git a/python/tilus/ir/tensor.py b/python/tilus/ir/tensor.py index 7361edd3..10d1dc16 100644 --- a/python/tilus/ir/tensor.py +++ b/python/tilus/ir/tensor.py @@ -371,7 +371,23 @@ def __eq__(self, other): """ raise RuntimeError("tensor == tensor could only be used in Tilus Script.") - def __xor__(self, other): + def __ne__(self, value): + """ + Not equal to comparison. + + Parameters + ---------- + value: RegisterTensor | int | float | Expr + The tensor or scalar to compare with this tensor. + + Returns + ------- + ret: RegisterTensor + A new tensor that is the result of the comparison. + """ + raise RuntimeError("tensor != tensor could only be used in Tilus Script.") + + def __xor__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """Bitwise XOR operation. Parameters @@ -419,8 +435,23 @@ def __rsub__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor """ raise RuntimeError("tensor - tensor could only be used in Tilus Script.") + def __rtruediv__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: + """Perform right-side division with another tensor or a scalar. + + Parameters + ---------- + other: RegisterTensor | int | float | Expr + The tensor or scalar to divide this tensor by. + + Returns + ------- + ret: RegisterTensor + A new tensor that is the result of the division. + """ + raise RuntimeError("tensor / tensor could only be used in Tilus Script.") + # i-version of operator - def __iadd__(self, other: RegisterTensor | int | float | Expr) -> None: + def __iadd__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place addition operation. Parameters @@ -430,7 +461,7 @@ def __iadd__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor += tensor could only be used in Tilus Script.") - def __isub__(self, other: RegisterTensor | int | float | Expr) -> None: + def __isub__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place subtraction operation. Parameters @@ -440,7 +471,7 @@ def __isub__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor -= tensor could only be used in Tilus Script.") - def __imul__(self, other: RegisterTensor | int | float | Expr) -> None: + def __imul__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place multiplication operation. Parameters @@ -450,7 +481,7 @@ def __imul__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor *= tensor could only be used in Tilus Script.") - def __itruediv__(self, other: RegisterTensor | int | float | Expr) -> None: + def __itruediv__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place division operation. Parameters @@ -460,7 +491,7 @@ def __itruediv__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor /= tensor could only be used in Tilus Script.") - def __imod__(self, other: RegisterTensor | int | float | Expr) -> None: + def __imod__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place modulus operation. Parameters @@ -470,7 +501,7 @@ def __imod__(self, other: RegisterTensor | int | float | Expr) -> None: """ raise RuntimeError("tensor %= tensor could only be used in Tilus Script.") - def __ixor__(self, other: RegisterTensor | int | float | Expr) -> None: + def __ixor__(self, other: RegisterTensor | int | float | Expr) -> RegisterTensor: """In-place bitwise XOR operation. Parameters @@ -509,7 +540,7 @@ def transpose(self) -> RegisterTensor: def to(self, dtype: DataType) -> RegisterTensor: raise RuntimeError("tensor.to(...) could only be used in Tilus Script.") - def tolist(self) -> Expr | list: + def tolist(self) -> list: raise RuntimeError("tensor.tolist() could only be used in Tilus Script.") @@ -584,7 +615,7 @@ def size(self) -> int: ret: int The size of the SharedTensor, which is the number of elements it contains. """ - return self.layout.size + return self.layout.count_size() @property def nbytes(self) -> int: diff --git a/python/tilus/ir/tools/printer.py b/python/tilus/ir/tools/printer.py index fe51c446..d94d33d9 100644 --- a/python/tilus/ir/tools/printer.py +++ b/python/tilus/ir/tools/printer.py @@ -114,7 +114,7 @@ def get_tensor_type(self, tensor: Tensor) -> Doc: doc = Text("shared, ") doc += self.printer(tensor.dtype) + "[" + self.visit(tensor.shape) + "]" if tensor.optional_layout is not None: - doc += ", size={}".format(tensor.layout.size) + doc += ", size={}".format(tensor.size) doc += ", {}".format(self.visit(tensor.layout)) return doc elif isinstance(tensor, GlobalTensor): @@ -448,12 +448,11 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> Doc: return self.add_key_comment("layout", doc) def visit_SharedLayout(self, node: SharedLayout) -> Doc: - for i, axis in enumerate(node.axes): - self.set_var_name(axis, "u" + str(i)) items = [ "shape=[" + self(node.shape) + "]", - "axes=[" + self(node.axes) + "]", - "offset=" + self(node.offset), + "mode_shape=[" + self(node.mode_shape) + "]", + "mode_strides=[" + self(node.mode_strides) + "]", + "swizzle=" + (str(node.swizzle) if node.optional_swizzle is not None else "None"), ] doc = Text("SharedLayout(") + doc_join(items, ", ") + ")" return self.add_key_comment("shared_layout", doc) diff --git a/python/tilus/lang/instructions/root.py b/python/tilus/lang/instructions/root.py index ea9cccda..77898a8c 100644 --- a/python/tilus/lang/instructions/root.py +++ b/python/tilus/lang/instructions/root.py @@ -34,12 +34,12 @@ class RootInstructionGroup(InstructionGroup): @property def blockIdx(self) -> Dim3: """Get the block index of the current thread block.""" - return Dim3(blockIdx.x, blockIdx.y, blockIdx.z) + return Dim3(blockIdx.x, blockIdx.y, blockIdx.z) # type: ignore[attr-defined] @property def gridDim(self) -> Dim3: """Get the grid dimension of the kernel.""" - return Dim3(gridDim.x, gridDim.y, gridDim.z) + return Dim3(gridDim.x, gridDim.y, gridDim.z) # type: ignore[attr-defined] @property def current_thread_begin(self) -> int: @@ -1627,7 +1627,7 @@ def print_tensor(self, msg: str, tensor: Tensor, fmt: Optional[str] = None) -> N """ self._builder.print_tensor(msg=msg, tensor=tensor, fmt=fmt) - def printf(self, fstring: str, *args: Expr | int | float) -> None: + def printf(self, fstring: str, *args: Expr | int | float | str) -> None: """Print a formatted string. This instruction prints a formatted string to the standard output. The `fstring` parameter is a format string diff --git a/python/tilus/lang/instructions/tcgen05.py b/python/tilus/lang/instructions/tcgen05.py index 187aa816..a1aa770d 100644 --- a/python/tilus/lang/instructions/tcgen05.py +++ b/python/tilus/lang/instructions/tcgen05.py @@ -53,6 +53,7 @@ def alloc( "The thread group used to allocate with initialization must start at a multiple of 128 " "and have at least 128 threads." ) + ctx: contextlib.AbstractContextManager if thread_end - thread_begin == 128: ctx = contextlib.nullcontext() else: diff --git a/python/tilus/lang/modules/cuda.py b/python/tilus/lang/modules/cuda.py index 2d2e13bf..64b47949 100644 --- a/python/tilus/lang/modules/cuda.py +++ b/python/tilus/lang/modules/cuda.py @@ -18,14 +18,13 @@ import cuda.bindings.runtime as cudart from hidet.ir.dtypes import DataType, bfloat16, float16, float32, int8, int32 -from hidet.ir.expr import as_expr from tilus import RegisterLayout from tilus.backends.emitters.cuda.mma_dot import AtomicMmaConfig from tilus.ir.layout import SharedLayout -from tilus.ir.layout.ops import auto_local_spatial, reduce, shared_compose, shared_row_major, spatial +from tilus.ir.layout.ops import auto_local_spatial, reduce, shared_row_major, spatial from tilus.ir.utils import vector -from tilus.utils import gcd, idiv, prod +from tilus.utils import gcd, prod @dataclass(frozen=True, eq=False) @@ -206,186 +205,6 @@ def shared_layout(shape: Sequence[int]) -> SharedLayout: """ return shared_row_major(*shape) - @staticmethod - def swizzled_shared_layout(dtype: DataType, *, shape: Sequence[int]) -> SharedLayout: - """ - Generate a shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. - - Both m and n must be a multiple of 8. - - We will divide each row into bank groups, and bank group has 16 bytes (16 x uint8, 8 x fp16, or 4 x fp32, etc.). - They correspond to 4 banks in shared memory. For example, if we have m = n = 8, we can represent bank groups as - - 0 # bank group 0, banks from 0 to 3 - 1 # bank group 1, banks from 4 to 7 - 2 # ... - 3 - 4 - 5 - 6 - 7 # bank groups 7, banks from 28 to 31 - - Given m, and n, we need to find a proper way to organize the m x (n / 8) bank groups in shared memory, so that - 1) each row has different bank groups - 2) each column has different bank groups - - When we have m = 8 and n = 64, we have 8 x 8 bank groups. If we store the elements in row-major order, we will - have the bank groups as - - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - 0 1 2 3 4 5 6 7 - - If we use ldmatrix to load the above 8 x 64 shared memory, we will need 8 ldmatrix.v1 instructions. Each instruction - loads one column (8 x 8 elements, or 8 x 1 bank groups). Since each instruction will access the same bank group, - severe bank conflicts will occur. Thus, we need to change the layout of shared memory to avoid bank conflicts. - - Let layout(i, j) be the shared memory address of logical elements (each element has 16 bytes) when we use - a specific `layout`. For example, the row-major layout row-major(i, j) = i * n + j * 8 (we assume the dtype has 2 - bytes). If we use the swizzled layout swizzled(i, j) = row-major(i, j ^ i) = i * n + (j ^ i) * 8, we can have the - following bank groups in shared memory. - - 0 1 2 3 4 5 6 7 - 1 0 3 2 5 4 7 6 - 2 3 0 1 6 7 4 5 - 3 2 1 0 7 6 5 4 - 4 5 6 7 0 1 2 3 - 5 4 7 6 1 0 3 2 - 6 7 4 5 2 3 0 1 - 7 6 5 4 3 2 1 0 - - (reader may need some time to figure out the above layout...) - - This layout has two benefits: - 1) Each row has different bank groups. In above example, we have 32 banks per row. - 2) Each column has different bank groups. In above example, we have 32 banks per column. - - The benefit 1 makes sure that when we load data from global memory to shared memory, we can store efficiently. - The benefit 2 makes sure that when we load data from shared memory to register memory, we can load efficiently. - - We can always generate the swizzled layout for arbitrary m and n as long as they are multiple of 8. See the - implementation for more details. - - Parameters - ---------- - dtype: DataType - The element data type for both the shared memory and the register memory. - - shape: Sequence[int] - The shape of the shared memory. The shape must have at least two dimensions. - - Returns - ------- - shared_layout: SharedLayout - The shared layout that could be used to generate ldmatrix instruction when using LoadSharedInst. - """ - return cuda._swizzled_shared_layout_new(dtype, shape=tuple(shape)) - - @staticmethod - @functools.lru_cache - def _swizzled_shared_layout(dtype: DataType, shape: tuple[int, ...]) -> SharedLayout: - if len(shape) < 2: - raise ValueError("The shape of swizzled shared layout must have at least two dimensions.") - m, n = shape[-2:] - group_elements = idiv(16, dtype.nbytes) - if m % 8 != 0 or n % group_elements != 0: - raise ValueError("m must be a multiple of 8, and n must be a multiple of dtype.nbytes * 8.") - rows = m - columns = n // group_elements - - if columns % 8 == 0: - """ - 0 1 2 3 4 5 6 7 - 1 0 3 2 5 4 7 6 - 2 3 0 1 6 7 4 5 - 3 2 1 0 7 6 5 4 - 4 5 6 7 0 1 2 3 - 5 4 7 6 1 0 3 2 - 6 7 4 5 2 3 0 1 - 7 6 5 4 3 2 1 0 - """ - core = shared_row_major(rows, columns).swizzle(dim=1, regards_dim=0, log_step=0) - elif columns % 4 == 0: - """ - 0 1 2 3 - 4 5 6 7 - 1 0 3 2 - 5 4 7 6 - 2 3 0 1 - 6 7 4 5 - 3 2 1 0 - 7 6 5 4 - """ - core = shared_row_major(rows, 4).swizzle(dim=1, regards_dim=0, log_step=1) - elif columns % 2 == 0: - """ - 0 1 - 2 3 - 4 5 - 6 7 - 1 0 - 3 2 - 5 4 - 7 6 - """ - core = shared_row_major(rows, 2).swizzle(dim=1, regards_dim=0, log_step=2) - else: - """ - 0 - 1 - 2 - 3 - 4 - 5 - 6 - 7 - """ - core = shared_row_major(rows, 1) - layout = shared_compose(core, shared_row_major(1, group_elements)) - if m > layout.shape[0] or n > layout.shape[1]: - layout = shared_compose(shared_row_major(m // layout.shape[0], n // layout.shape[1]), layout) - if len(shape) > 2: - for extent in reversed(shape[:-2]): - layout = layout.prepend_dim(extent=extent) - return layout - - @staticmethod - @functools.lru_cache - def _swizzled_shared_layout_new(dtype: DataType, shape: tuple[int, ...]) -> SharedLayout: - if len(shape) < 2: - raise ValueError("The shape of swizzled shared layout must have at least two dimensions.") - m, n = shape[-2:] - group_elements = idiv(16, dtype.nbytes) - if m % 8 != 0 or n % group_elements != 0: - raise ValueError("m must be a multiple of 8, and n must be a multiple of dtype.nbytes * 8.") - - def f_offset(axes): - strides: list[int] = [prod(shape[i + 1 :]) for i in range(len(shape))] - - columns: int = n // group_elements - columns_vec_size: int = gcd(columns, 8) - - i, j = axes[-2:] - if columns_vec_size == 8: - i, j = i, j ^ ((i % 8) * group_elements) - elif columns_vec_size == 4: - i, j = i, j ^ (i // 2 % 4 * group_elements) - elif columns_vec_size == 2: - i, j = i, j ^ (i // 4 % 2 * group_elements) - else: - i, j = i, j - swizzled_axes = axes[:-2] + [i, j] - offset = as_expr(sum(axis * stride for axis, stride in zip(swizzled_axes, strides))) - return offset - - layout = SharedLayout.create(shape=shape, size=prod(shape), f_offset=f_offset).simplify() - return layout - @staticmethod def default_register_layout( num_warps: int, dtype: DataType, shape: Sequence[int], vector_size: Optional[int] = None diff --git a/python/tilus/lang/script.py b/python/tilus/lang/script.py index 63657abf..9f7e4893 100644 --- a/python/tilus/lang/script.py +++ b/python/tilus/lang/script.py @@ -43,7 +43,7 @@ class Script(InstructionInterface): # specify the schedule used for debugging. it will override any autotune space debug_schedule: Optional[dict[str, Any]] = None - def __new__(cls, *args, **kwargs) -> InstantiatedScript: # type: ignore[no-untyped-def] + def __new__(cls, *args, **kwargs) -> InstantiatedScript: # type: ignore[no-untyped-def, misc] from tilus.lang.instantiated_script import InstantiatedScriptCache instantiated_script: InstantiatedScript = InstantiatedScriptCache.get( diff --git a/python/tilus/lang/transpiler/transpiler.py b/python/tilus/lang/transpiler/transpiler.py index 3b2aec13..09fcf678 100644 --- a/python/tilus/lang/transpiler/transpiler.py +++ b/python/tilus/lang/transpiler/transpiler.py @@ -203,7 +203,7 @@ def transpile( metadata = Metadata( grid_blocks=normalize_grid_blocks(script.attrs.blocks), cluster_blocks=normalize_cluster_blocks(script.attrs.cluster_blocks), - block_indices=(blockIdx.x, blockIdx.y, blockIdx.z), + block_indices=(blockIdx.x, blockIdx.y, blockIdx.z), # type: ignore[attr-defined] num_warps=script.attrs.warps, param2divisibility=frozendict(param2divisibility), analysis=None, @@ -686,7 +686,7 @@ def visit_Call(self, expr: ast.Call) -> Any: # case 4 class_cls: Type[Class] = func obj = object.__new__(class_cls) - self.transpile_call(obj.__init__, args, kwargs) + self.transpile_call(obj.__init__, args, kwargs) # type: ignore ret = obj elif func is super: # case 5 diff --git a/python/tilus/py.typed b/python/tilus/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index c284d40a..28881c0a 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -23,7 +23,7 @@ from typing import Optional import pytest -from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90, nvgpu_sm100a +from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90a, nvgpu_sm100a # Get the project root directory PROJECT_ROOT = Path(__file__).parent.parent.parent @@ -55,9 +55,9 @@ ("blackwell_matmul", "matmul_v4.py", nvgpu_sm100a), ("blackwell_matmul", "matmul_v5.py", nvgpu_sm100a), # hopper matmul example (SM 9.0) - ("hopper_matmul", "matmul_v0.py", nvgpu_sm90), - ("hopper_matmul", "matmul_v1.py", nvgpu_sm90), - ("hopper_matmul", "matmul_v2.py", nvgpu_sm90), + ("hopper_matmul", "matmul_v0.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v1.py", nvgpu_sm90a), + ("hopper_matmul", "matmul_v2.py", nvgpu_sm90a), # quantization examples (SM 8.0+) ("quantization", "matmul_a16wx.py", nvgpu_sm80), # flash attention decode examples (SM 8.0+) diff --git a/tests/instructions/test_tcgen05_copy.py b/tests/instructions/test_tcgen05_copy.py index 6426fd82..4b1383cc 100644 --- a/tests/instructions/test_tcgen05_copy.py +++ b/tests/instructions/test_tcgen05_copy.py @@ -58,8 +58,9 @@ def __call__(self, m_size: int, n_size: int, x_ptr: ~int32, y_ptr: ~int32): self.sync() # copy x from shared to tmem - self.tcgen05.copy(src=s_x, dst=t_x) - self.tcgen05.commit(mbarrier=barriers[0]) + with self.single_thread(): + self.tcgen05.copy(src=s_x, dst=t_x) + self.tcgen05.commit(mbarrier=barriers[0]) self.mbarrier.wait(barriers[0], phase=0) # load y from tmem to register diff --git a/tests/instructions/test_tcgen05_mma.py b/tests/instructions/test_tcgen05_mma.py index f8db7bdf..c249c511 100644 --- a/tests/instructions/test_tcgen05_mma.py +++ b/tests/instructions/test_tcgen05_mma.py @@ -78,8 +78,9 @@ def __call__(self, a_ptr: void_p, b_ptr: void_p, d_ptr: void_p) -> None: self.mbarrier.wait(tma_mbarrier, phase=0) # perform mma - self.tcgen05.mma(a=s_a, b=s_b.transpose(), d=t_d) - self.tcgen05.commit(mma_mbarrier) + with self.single_thread(): + self.tcgen05.mma(a=s_a, b=s_b.transpose(), d=t_d) + self.tcgen05.commit(mma_mbarrier) self.mbarrier.wait(mma_mbarrier, phase=0) # store d from t_d to global diff --git a/tests/ir/tools/verifier/test_verify_load_shared.py b/tests/ir/tools/verifier/test_verify_load_shared.py index 693671ae..0a90c567 100644 --- a/tests/ir/tools/verifier/test_verify_load_shared.py +++ b/tests/ir/tools/verifier/test_verify_load_shared.py @@ -41,7 +41,7 @@ def __call__(self): ) def test_verify_load_shared(shared_shape, register_shape, success): script = DemoLoadShared(shared_shape=shared_shape, register_shape=register_shape) - program = script._jit_instance_for().transpiled_programs[0] + program = script._jit_instance_for().transpiled_programs[0] # type: ignore if success: verify(program)