Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ repos:
language: system
types: [python]
pass_filenames: false
- id: mypy
name: MyPy type checking
entry: bash -c 'mypy --version >&2; mypy "$@"' --
language: system
types_or: [python, pyi]
args: [--show-error-codes, --show-error-context]
files: ^(python|examples|tests)/
require_serial: true
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version
rev: v0.12.3
Expand All @@ -30,17 +38,3 @@ repos:
- id: ruff-format
name: Ruff formatter
types_or: [python, pyi]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
name: MyPy type checking
# Uses [tool.mypy] configuration from pyproject.toml
additional_dependencies: [
types-tabulate,
types-tqdm,
]
# Check files in the python package, examples, and tests folders
files: ^(python|examples|tests)/
# Exclude hidet extensions (they have special mypy overrides in pyproject.toml)
exclude: ^python/tilus/extensions/hidet/
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"python.analysis.supportDocstringTemplate": true
}
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ It also includes automatic tuning, caching, and a Pythonic interface for ease of

Tilus is pronounced as tie-lus, /ˈtaɪləs/.

## Status

Tilus supports Ampere architecture, and we are actively working on the support of Hopper/Blackwell GPUs (see the [roadmap](https://github.com/NVIDIA/tilus/issues/49)). If you want to contribute to Tilus (documentation, examples, tutorial, or add new instruction), please open an issue and let us know.

## Getting Started

### Installation
Expand Down
2 changes: 1 addition & 1 deletion examples/attention/flash_attention_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def store_back(
shape=[num_q_blocks, batch_size, self.num_heads, self.block_q],
requires_clean=False,
)
semaphore = ~semaphores[self.blockIdx.x, bs, head]
semaphore = semaphores[self.blockIdx.x, bs, head].item_ptr()

sm = self.shared_tensor(dtype=f32, shape=[self.block_q])
sl = self.shared_tensor(dtype=f32, shape=[self.block_q])
Expand Down
4 changes: 2 additions & 2 deletions examples/flash_attention_decode/tilus_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __call__(
dims=[2, 3],
)
else:
state_idx = -1
state_idx = -1 # type: ignore
r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0)

# H' = alpha * H : [K, BV] = [] * [K, BV]
Expand Down Expand Up @@ -388,7 +388,7 @@ def __call__(
dims=[2, 3],
)
else:
state_idx = -1
state_idx = -1 # type: ignore
r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0)

# Apply gating to hidden state: H' = alpha * H
Expand Down
4 changes: 3 additions & 1 deletion examples/hopper_matmul/matmul_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def main():
headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
workloads = [
[4096, 4096, 4096],
# [128, 48, 16],
[4096, 4096, 14336],
[8192, 8192, 8192],
[10240, 10240, 10240],
]

rows = []
Expand Down
151 changes: 151 additions & 0 deletions examples/hopper_matmul/matmul_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import math

import pandas
import tilus
import torch
from tilus import float16, float32, int32, uint32
from tilus.utils import benchmark_func, cdiv


@tilus.autotune("num_stages", [2, 3, 4, 5, 6, 7])
@tilus.autotune(
"block_m, block_n", [[128, 64], [128, 128], [128, 256], [256, 128], [256, 256]]
)
@tilus.autotune("block_k", [16, 32, 64])
class MatmulWGMMAV3(tilus.Script):
def __init__(
self,
num_stages,
block_m,
block_n,
block_k,
):
super().__init__()
self.num_stages = num_stages
self.block_m = block_m
self.block_n = block_n
self.block_k = block_k

def __call__(
self,
m_size: int32,
n_size: int,
k_size: int,
a_ptr: ~float16,
b_ptr: ~float16,
c_ptr: ~float16,
):
self.attrs.blocks = [
cdiv(m_size, self.block_m),
cdiv(n_size, self.block_n),
]
self.attrs.warps = 5

block_m, block_n, block_k = self.block_m, self.block_n, self.block_k
offset_m: int32 = block_m * self.blockIdx.x
offset_n: int32 = block_n * self.blockIdx.y

ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
gb = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k])
sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_n, block_k])
acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0)

consumer_barriers = self.mbarrier.alloc(count=[2 for _ in range(self.num_stages)])
producer_barriers = self.mbarrier.alloc(
count=[128 for _ in range(self.num_stages)]
)

with self.thread_group(thread_begin=128, num_threads=32):
stage: int32 = 0
producer_phases = self.register_tensor(
dtype=uint32, shape=[self.num_stages], init=1
)
for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages):
self.mbarrier.wait(producer_barriers[stage], phase=producer_phases[stage])
producer_phases[stage] ^= 1
with self.single_thread():
self.tma.global_to_shared(
src=ga,
dst=sa[stage],
offsets=[offset_m, offset_k],
mbarrier=consumer_barriers[stage],
)
self.tma.global_to_shared(
src=gb,
dst=sb[stage],
offsets=[offset_n, offset_k],
mbarrier=consumer_barriers[stage],
)
stage = (stage + 1) % self.num_stages

for _ in self.range(min(self.num_stages, cdiv(k_size, self.block_k))):
self.mbarrier.wait(
producer_barriers[stage], phase=producer_phases[stage]
) # wait until the stage is ready to be filled
producer_phases[stage] ^= 1
stage = (stage + 1) % self.num_stages

with self.thread_group(thread_begin=0, num_threads=128):
consumer_phases = self.register_tensor(
dtype=uint32, shape=[self.num_stages], init=0
)
stage: int32 = 0
for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages):
self.mbarrier.wait(consumer_barriers[stage], phase=consumer_phases[stage])
consumer_phases[stage] ^= 1
self.wgmma.fence()
self.wgmma.mma(sa[stage], sb[stage].transpose(), acc)
self.wgmma.commit_group()
self.wgmma.wait_group(0)
self.mbarrier.arrive(producer_barriers[stage])
stage = (stage + 1) % self.num_stages
self.sync()
casted_acc = self.cast(acc, dtype=float16)
gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
self.store_global(gc, casted_acc, offsets=[offset_m, offset_n])


def main():
headers = ["m", "n", "k", "name", "latency (ms)", "tflops"]
workloads = [
[4096, 4096, 4096],
[4096, 4096, 14336],
[8192, 8192, 8192],
[10240, 10240, 10240],
]

rows = []
for m, n, k in workloads:
matmul = MatmulWGMMAV3()

a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
b = (torch.rand(n, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k)
c_actual = torch.empty(m, n, dtype=torch.float16).cuda()
c_expect = a @ b.T
matmul(m, n, k, a, b, c_actual)
torch.cuda.synchronize()

# check correctness
torch.testing.assert_close(c_expect, c_actual)

# benchmark
for name, func in [
("torch", lambda: torch.matmul(a, b.T, out=c_expect)),
("tilus", lambda: matmul(m, n, k, a, b, c_actual)),
]:
latency = benchmark_func(func, warmup=5, repeat=20)
tflops = 2 * m * n * k / latency * 1e-9
rows.append([m, n, k, name, latency, tflops])

df = pandas.DataFrame(rows, columns=headers)
print(df)


# %%

if __name__ == "__main__":
main()
22 changes: 13 additions & 9 deletions examples/quantization/matmul_a16wx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
int32,
uint8,
)
from tilus.ir.layout.ops import concat, local, reduce, spatial
from tilus.ir.layout.ops import concat, local, reduce, shared_row_major_swizzle, spatial
from tilus.utils import benchmark_func, cdiv, dtype_to_torch, gcd
from torch import nn

Expand Down Expand Up @@ -197,8 +197,8 @@ def __init__(
self.block_k = self.atomic_mma.k * wrk
self.num_warps = wsm * wsn

k_tiles = wrk // tk
n_tiles = wsn * wrn // tn
self.k_tiles = wrk // tk
self.n_tiles = wsn * wrn // tn

# we make sure that each weight_tile will be loaded by one warp
assert wrk * self.atomic_mma.k % weight_tile[0] == 0
Expand All @@ -217,14 +217,15 @@ def __init__(
)
self.layout_rs = reduce(self.mma.lb, dims=[0], keepdims=True)

self.layout_sa = self.cuda.swizzled_shared_layout(
self.a_dtype, shape=[num_stages, self.block_m, self.block_k]
self.layout_sa = shared_row_major_swizzle(
dtype_nbytes=self.a_dtype.nbytes,
shape=[num_stages, self.block_m, self.block_k],
)
self.layout_sb = self.cuda.shared_layout(
shape=[self.num_stages, k_tiles, n_tiles, self.tile_bytes]
shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes]
)
self.layout_sc = self.cuda.swizzled_shared_layout(
self.a_dtype, shape=[self.block_m, self.block_n]
self.layout_sc = shared_row_major_swizzle(
dtype_nbytes=self.a_dtype.nbytes, shape=[self.block_m, self.block_n]
)
self.layout_ss = self.cuda.shared_layout(shape=[self.num_stages, 1, self.block_n])

Expand Down Expand Up @@ -275,7 +276,10 @@ def __call__(
sa = self.shared_tensor(
dtype=self.a_dtype, shape=[self.num_stages, block_m, block_k]
)
sb = self.shared_tensor(dtype=uint8, shape=self.layout_sb.shape)
sb = self.shared_tensor(
dtype=uint8,
shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes],
)
ss = self.shared_tensor(dtype=self.a_dtype, shape=[self.num_stages, 1, block_n])
acc = self.register_tensor(
dtype=float32,
Expand Down
16 changes: 13 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Documentation = "https://nvidia.github.io/tilus"
"" = "python"

[tool.setuptools.package-data]
"tilus" = ["py.typed"]
"tilus.extensions.hidet" = ["include/**/*.h"]

[tool.ruff]
Expand Down Expand Up @@ -106,6 +107,11 @@ disallow_incomplete_defs = true
allow_redefinition = true
strict_optional = false
explicit_package_bases = true
mypy_path = "python"
namespace_packages = true
disable_error_code = [
"import-untyped", # we used some untyped third-party libraries
]

[[tool.mypy.overrides]]
module = ["tilus.*"]
Expand All @@ -120,9 +126,13 @@ disable_error_code = [
"override",
]

[[tool.mypy.overrides]]
module = ["examples.*", "tests.*"]
disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef"]
[[tool.mypy.overrides]] # disable type checking in these modules that might define tilus scripts
module = [
"examples.*",
"tests.*",
"tilus.lang.classes.pipeline"
]
disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef", "assignment", "import-not-found"]

[[tool.mypy.overrides]]
module = ["hidet.*"]
Expand Down
9 changes: 4 additions & 5 deletions python/tilus/backends/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def check_emitter_existence(self) -> None:
if failed_instructions:
rows = [f"Failed to find emitter for the following instructions (target: {get_current_target()}):"]
required_targets: list[str] = []
for inst in failed_instructions:
for inst_cls in failed_instructions:
for registry_inst_cls, emitter_classes in BaseInstEmitter.REGISTRY.items():
if issubclass(inst, registry_inst_cls):
if issubclass(inst_cls, registry_inst_cls):
required_targets.extend([str(target) for target in emitter_classes.keys()])
break
if not required_targets:
rows.append(f" - {inst.__name__} (no registered emitters)")
rows.append(f" - {inst_cls.__name__} (no registered emitters)")
else:
rows.append(f" - {inst.__name__} (registered targets: {', '.join(required_targets)})")
rows.append(f" - {inst_cls.__name__} (registered targets: {', '.join(required_targets)})")
raise CodeGenerationFailed("\n".join(rows))

def launch_kernel(self, kernel_func: HidetFunction) -> None:
Expand Down Expand Up @@ -278,7 +278,6 @@ def visit_ForStmt(self, stmt: ForStmt) -> None:
def visit_ThreadGroupStmt(self, stmt: ThreadGroupStmt) -> None:
# check the validity of the thread group
parent_num_threads = self.thread_group_stack.num_threads[-1]
assert parent_num_threads % stmt.num_threads == 0
assert 0 <= stmt.thread_begin and stmt.thread_begin + stmt.num_threads <= parent_num_threads

self.builder.comment(
Expand Down
3 changes: 2 additions & 1 deletion python/tilus/backends/emitters/cuda/cp_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def emit(self, inst: CopyAsyncGenericInst) -> None:

# get shared, global, and mask info
inst_mask = inst.mask if inst.mask is not None else boolean.true
shared_info: TensorInfo = analyze_grid(shape=shape, axes=layout.axes, analysis=analysis, expr=layout.offset)
axes, offset = layout.as_axes_mapping()
shared_info: TensorInfo = analyze_grid(shape=shape, axes=axes, analysis=analysis, expr=offset)
mask_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst_mask)
global_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst.offset)

Expand Down
5 changes: 1 addition & 4 deletions python/tilus/backends/emitters/cuda/cp_async_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,9 @@ def resolve_shared_tensor_info(self, shared_tensor: SharedTensor) -> SharedTenso
range_indices: list[np.ndarray] = []
for dim, extent in enumerate(shared_tensor.shape):
range_indices.append(np.arange(extent, dtype=np.int32))
grid = np.meshgrid(*range_indices, indexing="ij")
layout: SharedLayout = shared_tensor.layout

offset_grid: np.ndarray = vectorized_evaluate(
expr=layout.offset, var2value={axis: grid[i] for i, axis in enumerate(layout.axes)}
)
offset_grid: np.ndarray = layout.as_numpy_grid()
for swizzle in [
TensorMapSwizzle.NONE,
TensorMapSwizzle.B32,
Expand Down
Loading
Loading