Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ repos:
language: system
types: [python]
pass_filenames: false
- id: mypy
name: MyPy type checking
entry: bash -c 'mypy --version >&2; mypy "$@"' --
language: system
types_or: [python, pyi]
args: [--show-error-codes, --show-error-context]
files: ^(python|examples|tests)/
require_serial: true
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version
rev: v0.12.3
Expand All @@ -30,17 +38,3 @@ repos:
- id: ruff-format
name: Ruff formatter
types_or: [python, pyi]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
name: MyPy type checking
# Uses [tool.mypy] configuration from pyproject.toml
additional_dependencies: [
types-tabulate,
types-tqdm,
]
# Check files in the python package, examples, and tests folders
files: ^(python|examples|tests)/
# Exclude hidet extensions (they have special mypy overrides in pyproject.toml)
exclude: ^python/tilus/extensions/hidet/
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"python.analysis.supportDocstringTemplate": true
}
2 changes: 1 addition & 1 deletion examples/attention/flash_attention_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def store_back(
shape=[num_q_blocks, batch_size, self.num_heads, self.block_q],
requires_clean=False,
)
semaphore = ~semaphores[self.blockIdx.x, bs, head]
semaphore = semaphores[self.blockIdx.x, bs, head].item_ptr()

sm = self.shared_tensor(dtype=f32, shape=[self.block_q])
sl = self.shared_tensor(dtype=f32, shape=[self.block_q])
Expand Down
4 changes: 2 additions & 2 deletions examples/flash_attention_decode/tilus_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __call__(
dims=[2, 3],
)
else:
state_idx = -1
state_idx = -1 # type: ignore
r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0)

# H' = alpha * H : [K, BV] = [] * [K, BV]
Expand Down Expand Up @@ -388,7 +388,7 @@ def __call__(
dims=[2, 3],
)
else:
state_idx = -1
state_idx = -1 # type: ignore
r_h = self.register_tensor(dtype=float32, shape=[K, self.BV], init=0.0)

# Apply gating to hidden state: H' = alpha * H
Expand Down
22 changes: 13 additions & 9 deletions examples/quantization/matmul_a16wx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
int32,
uint8,
)
from tilus.ir.layout.ops import concat, local, reduce, spatial
from tilus.ir.layout.ops import concat, local, reduce, shared_row_major_swizzle, spatial
from tilus.utils import benchmark_func, cdiv, dtype_to_torch, gcd
from torch import nn

Expand Down Expand Up @@ -197,8 +197,8 @@ def __init__(
self.block_k = self.atomic_mma.k * wrk
self.num_warps = wsm * wsn

k_tiles = wrk // tk
n_tiles = wsn * wrn // tn
self.k_tiles = wrk // tk
self.n_tiles = wsn * wrn // tn

# we make sure that each weight_tile will be loaded by one warp
assert wrk * self.atomic_mma.k % weight_tile[0] == 0
Expand All @@ -217,14 +217,15 @@ def __init__(
)
self.layout_rs = reduce(self.mma.lb, dims=[0], keepdims=True)

self.layout_sa = self.cuda.swizzled_shared_layout(
self.a_dtype, shape=[num_stages, self.block_m, self.block_k]
self.layout_sa = shared_row_major_swizzle(
dtype_nbytes=self.a_dtype.nbytes,
shape=[num_stages, self.block_m, self.block_k],
)
self.layout_sb = self.cuda.shared_layout(
shape=[self.num_stages, k_tiles, n_tiles, self.tile_bytes]
shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes]
)
self.layout_sc = self.cuda.swizzled_shared_layout(
self.a_dtype, shape=[self.block_m, self.block_n]
self.layout_sc = shared_row_major_swizzle(
dtype_nbytes=self.a_dtype.nbytes, shape=[self.block_m, self.block_n]
)
self.layout_ss = self.cuda.shared_layout(shape=[self.num_stages, 1, self.block_n])

Expand Down Expand Up @@ -275,7 +276,10 @@ def __call__(
sa = self.shared_tensor(
dtype=self.a_dtype, shape=[self.num_stages, block_m, block_k]
)
sb = self.shared_tensor(dtype=uint8, shape=self.layout_sb.shape)
sb = self.shared_tensor(
dtype=uint8,
shape=[self.num_stages, self.k_tiles, self.n_tiles, self.tile_bytes],
)
ss = self.shared_tensor(dtype=self.a_dtype, shape=[self.num_stages, 1, block_n])
acc = self.register_tensor(
dtype=float32,
Expand Down
16 changes: 13 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Documentation = "https://nvidia.github.io/tilus"
"" = "python"

[tool.setuptools.package-data]
"tilus" = ["py.typed"]
"tilus.extensions.hidet" = ["include/**/*.h"]

[tool.ruff]
Expand Down Expand Up @@ -106,6 +107,11 @@ disallow_incomplete_defs = true
allow_redefinition = true
strict_optional = false
explicit_package_bases = true
mypy_path = "python"
namespace_packages = true
disable_error_code = [
"import-untyped", # we used some untyped third-party libraries
]

[[tool.mypy.overrides]]
module = ["tilus.*"]
Expand All @@ -120,9 +126,13 @@ disable_error_code = [
"override",
]

[[tool.mypy.overrides]]
module = ["examples.*", "tests.*"]
disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef"]
[[tool.mypy.overrides]] # disable type checking in these modules that might define tilus scripts
module = [
"examples.*",
"tests.*",
"tilus.lang.classes.pipeline"
]
disable_error_code = ["override", "valid-type", "call-arg", "no-untyped-def", "has-type", "no-redef", "assignment", "import-not-found"]

[[tool.mypy.overrides]]
module = ["hidet.*"]
Expand Down
8 changes: 4 additions & 4 deletions python/tilus/backends/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def check_emitter_existence(self) -> None:
if failed_instructions:
rows = [f"Failed to find emitter for the following instructions (target: {get_current_target()}):"]
required_targets: list[str] = []
for inst in failed_instructions:
for inst_cls in failed_instructions:
for registry_inst_cls, emitter_classes in BaseInstEmitter.REGISTRY.items():
if issubclass(inst, registry_inst_cls):
if issubclass(inst_cls, registry_inst_cls):
required_targets.extend([str(target) for target in emitter_classes.keys()])
break
if not required_targets:
rows.append(f" - {inst.__name__} (no registered emitters)")
rows.append(f" - {inst_cls.__name__} (no registered emitters)")
else:
rows.append(f" - {inst.__name__} (registered targets: {', '.join(required_targets)})")
rows.append(f" - {inst_cls.__name__} (registered targets: {', '.join(required_targets)})")
raise CodeGenerationFailed("\n".join(rows))

def launch_kernel(self, kernel_func: HidetFunction) -> None:
Expand Down
3 changes: 2 additions & 1 deletion python/tilus/backends/emitters/cuda/cp_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def emit(self, inst: CopyAsyncGenericInst) -> None:

# get shared, global, and mask info
inst_mask = inst.mask if inst.mask is not None else boolean.true
shared_info: TensorInfo = analyze_grid(shape=shape, axes=layout.axes, analysis=analysis, expr=layout.offset)
axes, offset = layout.as_axes_mapping()
shared_info: TensorInfo = analyze_grid(shape=shape, axes=axes, analysis=analysis, expr=offset)
mask_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst_mask)
global_info: TensorInfo = analyze_grid(shape=shape, axes=inst.axes, analysis=analysis, expr=inst.offset)

Expand Down
5 changes: 1 addition & 4 deletions python/tilus/backends/emitters/cuda/cp_async_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,9 @@ def resolve_shared_tensor_info(self, shared_tensor: SharedTensor) -> SharedTenso
range_indices: list[np.ndarray] = []
for dim, extent in enumerate(shared_tensor.shape):
range_indices.append(np.arange(extent, dtype=np.int32))
grid = np.meshgrid(*range_indices, indexing="ij")
layout: SharedLayout = shared_tensor.layout

offset_grid: np.ndarray = vectorized_evaluate(
expr=layout.offset, var2value={axis: grid[i] for i, axis in enumerate(layout.axes)}
)
offset_grid: np.ndarray = layout.as_numpy_grid()
for swizzle in [
TensorMapSwizzle.NONE,
TensorMapSwizzle.B32,
Expand Down
9 changes: 2 additions & 7 deletions python/tilus/backends/emitters/cuda/tcgen05/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,12 @@ def generate_instructions(

raise ValueError("No valid instructions generated")

def check_warp_group(self) -> None:
begin = self.current_thread_group_begin
end = self.current_thread_group_end
if begin % 128 != 0 or end - begin != 128:
raise ValueError("The number of threads in the current thread group must be 128")

def emit(self, inst: Tcgen05CopyInst) -> None:
shared_tensor = inst.inputs[1].as_shared_tensor()
tmem_tensor = inst.inputs[0].as_tmemory_tensor()

self.check_warp_group()
if self.current_num_threads != 1:
raise ValueError("Tcgen05CopyInst can only be emitted in thread group with a single thread")

if len(shared_tensor.shape) != 2:
raise ValueError("The shared tensor must be a 2D tensor, got shape {}".format(shared_tensor.shape))
Expand Down
4 changes: 2 additions & 2 deletions python/tilus/backends/emitters/cuda/wgmma.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def emit_wgmma(self, inst: WgmmaMmaSSInst) -> None:
a_desc.encoded(),
d_register_addr + d_offset,
b_desc.encoded(),
trans_a=0,
trans_b=0,
trans_a=0, # type: ignore
trans_b=0, # type: ignore
)
)
4 changes: 2 additions & 2 deletions python/tilus/backends/emitters/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def inter_warp_reduce(self, inst: ReduceInst) -> None:
smem_ctx = self.contexts.smem_alloc_ctx
smem_buf = self.declare_var(
"smem_buf",
tensor_pointer_type(dtype=dst.dtype, shape=[shared_layout.size]),
init=cast(smem_ctx.request_shared_workspace(dst.dtype.nbytes * shared_layout.size), ~dst.dtype),
tensor_pointer_type(dtype=dst.dtype, shape=[shared_layout.count_size()]),
init=cast(smem_ctx.request_shared_workspace(dst.dtype.nbytes * shared_layout.count_size()), ~dst.dtype),
)

reduced_mode_shape = [
Expand Down
5 changes: 2 additions & 3 deletions python/tilus/extensions/hidet/ir/primitives/cuda/mbarrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,8 @@ def mbarrier_arrive_shared(mbarrier_addr: Expr, count: Expr | int) -> Expr:
--------
mbarrier.arrive : PTX ISA documentation section 9.7.13.15.13
"""
if isinstance(count, int):
count = u32(count)
return call_primitive_func("cuda_mbarrier_arrive_shared", args=[mbarrier_addr, count])
count_expr = count if isinstance(count, Expr) else u32(count)
return call_primitive_func("cuda_mbarrier_arrive_shared", args=[mbarrier_addr, count_expr])


def mbarrier_arrive_remote_shared(
Expand Down
27 changes: 16 additions & 11 deletions python/tilus/ir/builders/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,15 @@ def __init__(
iter_vars: List[Var],
extents: List[Expr],
unrolls: List[Optional[int]],
unwrap: bool = False,
):
super().__init__(vb)
self.iter_vars: List[Var] = iter_vars
self.extents: List[Expr] = extents
self.unrolls: List[Optional[int]] = unrolls
self.unwrap: bool = unwrap

def __enter__(self):
def __enter__(self) -> List[Var]:
self.enter()
if self.unwrap:
assert len(self.iter_vars) == 1
return self.iter_vars[0]
else:
return self.iter_vars
return self.iter_vars

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
Expand All @@ -193,6 +187,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.append(body)


class ForRangeContext(ForContext):
def __enter__(self):
iter_vars = super().__enter__()
assert len(iter_vars) == 1
return iter_vars[0]


class IfContext(StmtContext):
def __init__(self, vb: StmtBuilderCore, cond: Expr):
super().__init__(vb)
Expand Down Expand Up @@ -315,11 +316,13 @@ def is_empty(self):

def for_range(
self, extent: Union[Expr, int], iter_name_hint: str = "i", unroll_factor: Optional[int] = None
) -> ForContext:
) -> ForRangeContext:
iter_var = Var(iter_name_hint, type=int32)
return ForContext(self, [iter_var], [as_expr(extent)], [unroll_factor], unwrap=True)
return ForRangeContext(self, [iter_var], [as_expr(extent)], [unroll_factor])

def for_grid(self, extents: List[Union[Expr, int]], iter_name_hints: Optional[List[str]] = None) -> ForContext:
def for_grid(
self, extents: Sequence[Union[Expr, int]], iter_name_hints: Optional[Sequence[str]] = None
) -> ForContext:
expr_extents = [as_expr(extent) for extent in extents]
if iter_name_hints is None:
names = "ijkpqrstuvw"
Expand Down Expand Up @@ -1231,6 +1234,7 @@ def wait_barrier(self, barrier: Expr | RegisterTensor, phase: Expr | int | Regis
phase = self.tensor_item_value(phase)
elif isinstance(phase, int):
phase = uint32(phase)
assert isinstance(phase, Expr)
inst = WaitBarrierInst.create(barrier=barrier, phase=phase)
self.append(inst)

Expand All @@ -1245,6 +1249,7 @@ def cluster_launch_control_try_cancel(
mbarrier = self.tensor_item_value(mbarrier)
if isinstance(multicast, bool):
multicast = boolean(multicast)
assert isinstance(multicast, Expr)
inst = ClusterLaunchControlTryCancelInst.create(response=response, mbarrier=mbarrier, multicast=multicast)
self.append(inst)

Expand Down
8 changes: 2 additions & 6 deletions python/tilus/ir/functors/functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,7 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> RegisterLayout:
return layout

def visit_SharedLayout(self, layout: SharedLayout) -> SharedLayout:
offset = self.visit(layout.offset)
if offset is layout.offset:
return layout
else:
return SharedLayout(shape=layout.shape, size=layout.size, axes=layout.axes, offset=offset)
return layout

def visit_GlobalLayout(self, layout: GlobalLayout) -> GlobalLayout:
shape = self.visit(layout.shape)
Expand Down Expand Up @@ -573,7 +569,7 @@ def visit_RegisterLayout(self, layout: RegisterLayout) -> None:
pass

def visit_SharedLayout(self, layout: SharedLayout) -> None:
self.visit(layout.offset)
pass

def visit_GlobalLayout(self, layout: GlobalLayout) -> None:
self.visit(layout.shape)
Expand Down
Loading
Loading