-
Notifications
You must be signed in to change notification settings - Fork 476
[REQ] Add wp.tile_empty() builtin function #1312
Description
wp.tile_empty() -- Uninitialized Tile Allocation
Summary
Add a new builtin wp.tile_empty() that allocates a tile without initializing its elements. This mirrors the relationship between numpy.zeros() and numpy.empty() -- when a tile will be immediately overwritten (e.g., by a tile_load() or element-wise assignment), zero-initialization is wasted work.
Motivation
wp.tile_zeros() is currently the standard way to allocate a tile for use as an accumulator or scratch space. However, many common patterns overwrite the tile contents immediately after allocation:
@wp.kernel
def my_kernel(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float)):
i, j = wp.tid()
# Tile is zeroed, then immediately overwritten by tile_load
a = wp.tile_zeros(shape=(TILE_M, TILE_K), dtype=float)
a = wp.tile_load(A, shape=(TILE_M, TILE_K))In these cases, the zero-initialization is redundant. On GPU, zeroing shared memory or registers has a non-trivial cost, especially for large tiles. tile_empty() lets the user opt out of initialization when it is not needed.
API
wp.tile_empty(shape: tuple[int, ...], dtype: type = float, storage: str = "register") -> TileParameters are identical to wp.tile_zeros():
shape-- Compile-time constant tile dimensions.dtype-- Element data type (defaultfloat).storage--"register"(default) or"shared".
Returns an uninitialized tile. Reading from the tile before writing produces undefined values.
Example
import warp as wp
TILE_M = 64
TILE_K = 32
TILE_N = 64
@wp.kernel
def gemm_kernel(
A: wp.array2d(dtype=float),
B: wp.array2d(dtype=float),
C: wp.array2d(dtype=float),
):
i, j = wp.tid()
sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=float) # Accumulator -- must be zeroed
for k in range(A.shape[1] // TILE_K):
a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i, k))
b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k, j))
wp.tile_matmul(a, b, sum)
wp.tile_store(C, sum, offset=(i, j))In the kernel above, sum must be zero-initialized because it is used as a matmul accumulator. However, if we had a scratch tile that was loaded before use, tile_empty() would be appropriate:
scratch = wp.tile_empty(shape=(TILE_M, TILE_K), dtype=float)
scratch = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i, 0))
# scratch is now fully defined -- safe to readImplementation Notes
- The native C++ template (
tile_empty<T, Shape...>()) can simply declare the tile storage without writing to it. - For register tiles, this means skipping the zero-fill loop. For shared memory tiles, this means advancing the shared memory allocator without a
memset. - The builtin registration in
builtins.pycan reusetile_zeros_value_funcandtile_zeros_dispatch_func(or trivially adapted copies), since the signature is identical. tile_empty()is not differentiable (same astile_zeros()).