diff --git a/benchmark/test_select_and_slice_perf.py b/benchmark/test_select_and_slice_perf.py index f2ea75a3e..8663dff26 100644 --- a/benchmark/test_select_and_slice_perf.py +++ b/benchmark/test_select_and_slice_perf.py @@ -254,23 +254,31 @@ def select_scatter_input_fn(shape, dtype, device): bench.run() +def index_add_gbps(bench_fn_args, latency): + index = bench_fn_args[2] + src = bench_fn_args[3] + io_amount = sum([shape_utils.size_in_bytes(item) for item in [index, src, src]]) + return io_amount * 1e-9 / (latency * 1e-3) + + @pytest.mark.index_add def test_index_add_perf(): def index_add_input_fn(shape, dtype, device): - inp = torch.randn(shape, dtype=dtype, device="cuda") - dim = 0 + inp = torch.randn(shape, dtype=dtype, device=device) + dim = 0 if len(shape) == 1 else 1 src_shape = list(inp.shape) index_max = src_shape[dim] - index_len = index_max // 2 - index = torch.randint(0, index_max, (index_len,), device="cuda") + index_len = index_max // 2 if index_max >= 2 else 1 + index = torch.randperm(index_len, device=device) src_shape[dim] = index_len - src = torch.randn(src_shape, dtype=dtype, device="cuda") + src = torch.randn(src_shape, dtype=dtype, device=device) yield inp, dim, index, src bench = TensorSelectBenchmark( op_name="index_add", torch_op=torch.index_add, input_fn=index_add_input_fn, - dtypes=FLOAT_DTYPES, + dtypes=[torch.float16, torch.float32], + get_gbps=index_add_gbps, ) bench.run() diff --git a/src/flag_gems/ops/index_add.py b/src/flag_gems/ops/index_add.py index 611afe77d..cce78182c 100644 --- a/src/flag_gems/ops/index_add.py +++ b/src/flag_gems/ops/index_add.py @@ -1,62 +1,233 @@ +import importlib import logging +import os +from typing import Any, Callable, List, Mapping, Tuple import torch -import triton -import triton.language as tl - -from ..utils import dim_compress, libentry - - -def cfggen(): - block_m = [1, 2, 4] - block_n = [128, 1024, 2048, 4096] - configs = [ - triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4) - for m in block_m - for n in block_n - ] - return configs - - -@libentry() -@triton.autotune(configs=cfggen(), key=["M", "N"]) -@triton.jit -def index_add_kernel( - inp, - out, - index, - src, - M, - N, - alpha, - inp_len, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - pid_x = tl.program_id(axis=0) - pid_y = tl.program_id(axis=1) - rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] - cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) - - rows_mask = rows_offsets < M - index_mask = cols_offsets < N - block_mask = rows_mask and index_mask - - cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0) - inp_off = rows_offsets * inp_len + cur_indices[None, :] - cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0) - src_off = rows_offsets * N + cols_offsets[None, :] - cur_src = tl.load(src + src_off, mask=block_mask, other=0.0) - cur_inp += alpha * cur_src - - tl.store(out + inp_off, cur_inp, mask=block_mask) + +from flag_gems.utils.code_cache import code_cache_dir +from flag_gems.utils.code_utils import IndentedBuffer, NameSpace + + +def generate_imports(code: IndentedBuffer) -> IndentedBuffer: + code.writeline("import triton") + code.writeline("import triton.language as tl") + code.writeline("from flag_gems.utils import libentry") + + code.newline() + code.newline() + + return code + + +def generate_index_add_kernel( + rank: int, + kernel_name: str, + code: IndentedBuffer, +) -> IndentedBuffer: + # the decorators + code.writeline("@libentry()") + code.writeline("@triton.jit") + + # signature + code.writeline(f"def {kernel_name}(") + function_ns = NameSpace() + with code.indent(): + if rank > 0: + code.writeline("index,") + code.writeline("src,") + code.writeline("out,") + code.writeline("N,") + code.writeline("inp_numel,") + code.writeline("inp_stride_dim,") + code.writeline("inp_shape_dim,") + code.writeline("src_shape_dim,") + code.writeline("delta,") + code.writeline("alpha,") + + function_ns.create_name("index") + function_ns.create_name("src") + function_ns.create_name("out") + + for i in range(rank): + function_ns.create_name(f"src_stride_{i}") + stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) + code.writeline(f"{stride_args}, # stride for src") + + for i in range(rank): + function_ns.create_name(f"src_shape_{i}") + shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank)) + code.writeline(f"{shape_args}, # shape for src") + + code.writeline("BLOCK_SIZE: tl.constexpr,") + + code.writeline("):") + + # Kernel Code + with code.indent(): + code.writeline("pid = tl.program_id(axis=0)") + code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") + code.writeline("mask = offsets < N") + + # 1. Calculate src_offsets + code.writeline("src_idx = tl.zeros((BLOCK_SIZE, ), dtype=tl.int64)") + # snippets + for i in range(rank): + code.writeline(f"mod = offsets % src_shape_{i}") + code.writeline(f"src_idx += mod * src_stride_{i}") + if i != (rank - 1): + code.writeline(f"offsets = offsets // src_shape_{i}") + + # 2. index add + code.writeline( + "pre_idx = (src_idx // (inp_stride_dim * src_shape_dim)).to(tl.int64)" + ) + code.writeline( + "dim_idx = (src_idx % (inp_stride_dim * src_shape_dim) // inp_stride_dim).to(tl.int64)" + ) + code.writeline( + "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)" + ) + code.writeline( + 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"' + ) + code.writeline( + "input_idx = (src_idx + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)" + ) + + code.writeline("input_mask = input_idx < inp_numel") + code.writeline( + "add_on = tl.load(src + src_idx, mask=mask, other=0) * alpha" + ) + code.writeline( + "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')" + ) + # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe. + # code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)") + # code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)") + + code.newline() + code.newline() + return code + + +def parameter_for_wrapper() -> str: + # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha + parameters: List[str] = [] + parameters.append("out") + parameters.append("index") + parameters.append("src") + parameters.append("dim") + parameters.append("inp_stride_dim") + parameters.append("inp_shape_dim") + parameters.append("src_shape_dim") + parameters.append("delta") + parameters.append("N") + parameters.append("inp_numel") + parameters.append("alpha") + + return ", ".join(parameters) + + +def generate_destination_passing_wrapper( + rank: int, + wrapper_name: str, + kernel_name: str, + code: IndentedBuffer, +) -> IndentedBuffer: + parameters: str = parameter_for_wrapper() + wrapper_signature: str = f"def {wrapper_name} ({parameters}):" + code.writeline(wrapper_signature) + + with code.indent(): + code.writeline("src_strides = list(src.stride())") + code.writeline("src_shapes = list(src.shape)") + + # kernel launch + code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting + code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)") + kernel_launch: str = f"{kernel_name}[grid](" + code.writeline(kernel_launch) + with code.indent(): + code.writeline( + "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, " + ) + if rank > 0: + s = ", ".join(f"src_strides[{i}]" for i in range(rank)) + code.writeline(f"{s},") + + s = ", ".join(f"src_shapes[{i}]" for i in range(rank)) + code.writeline(f"{s},") + code.writeline("BLOCK_SIZE=BLOCK_SIZE") + code.writeline(")") + code.writeline("return out") + + return code + + +def generate_code( + inputs: Tuple[Any], + wrapper_name: str, + kernel_name: str, + code: IndentedBuffer, +) -> IndentedBuffer: + # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha] + shape = inputs[2].shape + rank = len(shape) + + code = generate_imports(code) + code = generate_index_add_kernel(rank, kernel_name, code) + code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) + return code + + +class IndexAddFunction: + def __init__(self): + self.pid = os.getpid() + self.overloads: Mapping[str, Callable] = {} + + def __call__(self, *args, **kwargs): + key = f"{self.arg_key(*args)}" + if key in self.overloads: + overload = self.overloads[key] + else: + code = IndentedBuffer() + code = generate_code( + args, + "_index_add_wrapper", + "_index_add_jit_function", + code, + ) + + file_name = f"index_add_rank_{key}_pid_{self.pid}.py" + + with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: + f.write(code.getvalue()) + + # load + spec = importlib.util.spec_from_file_location( + f"_gen_module_rank_{key}_pid_{self.pid}", + f.name, + ) + + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + overload = getattr(m, "_index_add_wrapper") + self.overloads[key] = overload + + return overload(*args, **kwargs) + + def arg_key(self, *args): + tensors = [item for item in args if torch.is_tensor(item)] + max_rank = max(item.ndim for item in tensors) + return max_rank + + +_index_add_func = IndexAddFunction() def index_add(inp, dim, index, src, alpha=1): logging.debug("GEMS INDEX ADD") - assert ((0 <= index) * (index < inp.size(dim))).equal( - torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda") - ), "0 <= index < self.size(dim)" assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" assert index.numel() == src.size( dim @@ -68,28 +239,26 @@ def index_add(inp, dim, index, src, alpha=1): ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) ), "src.size(d) == self.size(d) for all dimensions d != dim" - inp = inp.contiguous() - index = index.contiguous() - src = src.contiguous() - - dim = dim % inp.ndim - inp_len = inp.size(dim) - N = index.numel() - M = src.numel() // N - fine_dim = inp.ndim - 1 - if dim != fine_dim: - inp = dim_compress(inp, dim) - src = dim_compress(src, dim) out = inp.clone() - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - triton.cdiv(N, meta["BLOCK_N"]), + dim %= inp.ndim + inp_stride_dim = inp.stride(dim) + src_shape_dim = src.size(dim) + inp_shape_dim = inp.size(dim) + delta = inp.size(dim) - src_shape_dim + N = src.numel() + + _index_add_func( + out, + index, + src, + dim, + inp_stride_dim, + inp_shape_dim, + src_shape_dim, + delta, + N, + inp.numel(), + alpha, ) - index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len) - if dim != fine_dim: - order = [i for i in range(out.ndim - 1)] - order.insert(dim, fine_dim) - return out.permute(order).contiguous() - else: - return out + return out diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index 7653f1621..e4c5d20ec 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -341,7 +341,7 @@ def offset_calculator(inp, idx, strides, dim, isInp): return offsets if not isInp else (offsets - idx_dim) -def offsetCalculator(inp, idx, strides, dim, isInp): +def offsetCalculator(inp, idx, strides, dim, isInp=False): ndim = inp.ndim shape = list(inp.shape) offsets = 0 diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 5847d1eb4..29986b6da 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -714,20 +714,18 @@ def test_accuracy_slice_scatter_with_self_overlapping_input(): gems_assert_equal(res_out, ref_out) -# TODO: failed at (200, 40999, 3) @pytest.mark.index_add @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dim", DIM_LIST) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_accuracy_index_add(shape, dim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) src_shape = list(inp.shape) index_max = src_shape[dim] - index_len = index_max - index = torch.randperm(index_len, device="cuda") + index_len = index_max // 2 if index_max >= 2 else 1 + index = torch.randperm(index_len, device=flag_gems.device) src_shape[dim] = index_len - src = torch.randn(src_shape, dtype=dtype, device="cuda") + src = torch.randn(src_shape, dtype=dtype, device=flag_gems.device) alpha = 2 ref_inp = to_reference(inp)