Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] index_add optimized #427

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions benchmark/test_select_and_slice_perf.py
Original file line number Diff line number Diff line change
@@ -254,23 +254,31 @@ def select_scatter_input_fn(shape, dtype, device):
bench.run()


def index_add_gbps(bench_fn_args, latency):
index = bench_fn_args[2]
src = bench_fn_args[3]
io_amount = sum([shape_utils.size_in_bytes(item) for item in [index, src, src]])
return io_amount * 1e-9 / (latency * 1e-3)


@pytest.mark.index_add
def test_index_add_perf():
def index_add_input_fn(shape, dtype, device):
inp = torch.randn(shape, dtype=dtype, device="cuda")
dim = 0
inp = torch.randn(shape, dtype=dtype, device=device)
dim = 0 if len(shape) == 1 else 1
src_shape = list(inp.shape)
index_max = src_shape[dim]
index_len = index_max // 2
index = torch.randint(0, index_max, (index_len,), device="cuda")
index_len = index_max // 2 if index_max >= 2 else 1
index = torch.randperm(index_len, device=device)
src_shape[dim] = index_len
src = torch.randn(src_shape, dtype=dtype, device="cuda")
src = torch.randn(src_shape, dtype=dtype, device=device)
yield inp, dim, index, src

bench = TensorSelectBenchmark(
op_name="index_add",
torch_op=torch.index_add,
input_fn=index_add_input_fn,
dtypes=FLOAT_DTYPES,
dtypes=[torch.float16, torch.float32],
get_gbps=index_add_gbps,
)
bench.run()
315 changes: 241 additions & 74 deletions src/flag_gems/ops/index_add.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,231 @@
import importlib
import logging
import os
from typing import Any, Callable, List, Mapping, Tuple

import torch
import triton
import triton.language as tl

from ..utils import dim_compress, libentry


def cfggen():
block_m = [1, 2, 4]
block_n = [128, 1024, 2048, 4096]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)
for m in block_m
for n in block_n
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def index_add_kernel(
inp,
out,
index,
src,
M,
N,
alpha,
inp_len,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)

rows_mask = rows_offsets < M
index_mask = cols_offsets < N
block_mask = rows_mask and index_mask

cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0)
inp_off = rows_offsets * inp_len + cur_indices[None, :]
cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0)
src_off = rows_offsets * N + cols_offsets[None, :]
cur_src = tl.load(src + src_off, mask=block_mask, other=0.0)
cur_inp += alpha * cur_src

tl.store(out + inp_off, cur_inp, mask=block_mask)

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace


def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
code.writeline("import triton")
code.writeline("import triton.language as tl")
code.writeline("from flag_gems.utils import libentry")

code.newline()
code.newline()

return code


def generate_index_add_kernel(
rank: int,
kernel_name: str,
code: IndentedBuffer,
) -> IndentedBuffer:
# the decorators
code.writeline("@libentry()")
code.writeline("@triton.jit")

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
if rank > 0:
code.writeline("index,")
code.writeline("src,")
code.writeline("out,")
code.writeline("N,")
code.writeline("inp_numel,")
code.writeline("inp_stride_dim,")
code.writeline("inp_shape_dim,")
code.writeline("src_shape_dim,")
code.writeline("delta,")
code.writeline("alpha,")

function_ns.create_name("index")
function_ns.create_name("src")
function_ns.create_name("out")

for i in range(rank):
function_ns.create_name(f"src_stride_{i}")
stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for src")

for i in range(rank):
function_ns.create_name(f"src_shape_{i}")
shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank))
code.writeline(f"{shape_args}, # shape for src")

code.writeline("BLOCK_SIZE: tl.constexpr,")

code.writeline("):")

# Kernel Code
with code.indent():
code.writeline("pid = tl.program_id(axis=0)")
code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
code.writeline("mask = offsets < N")

# 1. Calculate src_offsets
code.writeline("src_idx = tl.zeros((BLOCK_SIZE, ), dtype=tl.int64)")
# snippets
for i in range(rank):
code.writeline(f"mod = offsets % src_shape_{i}")
code.writeline(f"src_idx += mod * src_stride_{i}")
if i != (rank - 1):
code.writeline(f"offsets = offsets // src_shape_{i}")

# 2. index add
code.writeline(
"pre_idx = (src_idx // (inp_stride_dim * src_shape_dim)).to(tl.int64)"
)
code.writeline(
"dim_idx = (src_idx % (inp_stride_dim * src_shape_dim) // inp_stride_dim).to(tl.int64)"
)
code.writeline(
"src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
)
code.writeline(
'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
)
code.writeline(
"input_idx = (src_idx + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
)

code.writeline("input_mask = input_idx < inp_numel")
code.writeline(
"add_on = tl.load(src + src_idx, mask=mask, other=0) * alpha"
)
code.writeline("tl.atomic_add(out + input_idx, add_on, mask=input_mask)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safe to set sem='relaxed'

# TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe.
# code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)")
# code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)")

code.newline()
code.newline()
return code


def parameter_for_wrapper() -> str:
# out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha
parameters: List[str] = []
parameters.append("out")
parameters.append("index")
parameters.append("src")
parameters.append("dim")
parameters.append("inp_stride_dim")
parameters.append("inp_shape_dim")
parameters.append("src_shape_dim")
parameters.append("delta")
parameters.append("N")
parameters.append("inp_numel")
parameters.append("alpha")

return ", ".join(parameters)


def generate_destination_passing_wrapper(
rank: int,
wrapper_name: str,
kernel_name: str,
code: IndentedBuffer,
) -> IndentedBuffer:
parameters: str = parameter_for_wrapper()
wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
code.writeline(wrapper_signature)

with code.indent():
code.writeline("src_strides = list(src.stride())")
code.writeline("src_shapes = list(src.shape)")

# kernel launch
code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting
code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)")
kernel_launch: str = f"{kernel_name}[grid]("
code.writeline(kernel_launch)
with code.indent():
code.writeline(
"index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, "
)
if rank > 0:
s = ", ".join(f"src_strides[{i}]" for i in range(rank))
code.writeline(f"{s},")

s = ", ".join(f"src_shapes[{i}]" for i in range(rank))
code.writeline(f"{s},")
code.writeline("BLOCK_SIZE=BLOCK_SIZE")
code.writeline(")")
code.writeline("return out")

return code


def generate_code(
inputs: Tuple[Any],
wrapper_name: str,
kernel_name: str,
code: IndentedBuffer,
) -> IndentedBuffer:
# inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha]
shape = inputs[2].shape
rank = len(shape)

code = generate_imports(code)
code = generate_index_add_kernel(rank, kernel_name, code)
code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
return code


class IndexAddFunction:
def __init__(self):
self.pid = os.getpid()
self.overloads: Mapping[str, Callable] = {}

def __call__(self, *args, **kwargs):
key = f"{self.arg_key(*args)}"
if key in self.overloads:
overload = self.overloads[key]
else:
code = IndentedBuffer()
code = generate_code(
args,
"_index_add_wrapper",
"_index_add_jit_function",
code,
)

file_name = f"index_add_rank_{key}_pid_{self.pid}.py"

with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
f.write(code.getvalue())

# load
spec = importlib.util.spec_from_file_location(
f"_gen_module_rank_{key}_pid_{self.pid}",
f.name,
)

m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
overload = getattr(m, "_index_add_wrapper")
self.overloads[key] = overload

return overload(*args, **kwargs)

def arg_key(self, *args):
tensors = [item for item in args if torch.is_tensor(item)]
max_rank = max(item.ndim for item in tensors)
return max_rank


_index_add_func = IndexAddFunction()


def index_add(inp, dim, index, src, alpha=1):
logging.debug("GEMS INDEX ADD")
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
), "0 <= index < self.size(dim)"
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
assert index.numel() == src.size(
dim
@@ -68,28 +237,26 @@ def index_add(inp, dim, index, src, alpha=1):
((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
), "src.size(d) == self.size(d) for all dimensions d != dim"

inp = inp.contiguous()
index = index.contiguous()
src = src.contiguous()

dim = dim % inp.ndim
inp_len = inp.size(dim)
N = index.numel()
M = src.numel() // N
fine_dim = inp.ndim - 1
if dim != fine_dim:
inp = dim_compress(inp, dim)
src = dim_compress(src, dim)
out = inp.clone()

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
dim %= inp.ndim
inp_stride_dim = inp.stride(dim)
src_shape_dim = src.size(dim)
inp_shape_dim = inp.size(dim)
delta = inp.size(dim) - src_shape_dim
N = src.numel()

_index_add_func(
out,
index,
src,
dim,
inp_stride_dim,
inp_shape_dim,
src_shape_dim,
delta,
N,
inp.numel(),
alpha,
)
index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len)
if dim != fine_dim:
order = [i for i in range(out.ndim - 1)]
order.insert(dim, fine_dim)
return out.permute(order).contiguous()
else:
return out
return out
2 changes: 1 addition & 1 deletion src/flag_gems/utils/shape_utils.py
Original file line number Diff line number Diff line change
@@ -341,7 +341,7 @@ def offset_calculator(inp, idx, strides, dim, isInp):
return offsets if not isInp else (offsets - idx_dim)


def offsetCalculator(inp, idx, strides, dim, isInp):
def offsetCalculator(inp, idx, strides, dim, isInp=False):
ndim = inp.ndim
shape = list(inp.shape)
offsets = 0
12 changes: 5 additions & 7 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
@@ -714,20 +714,18 @@ def test_accuracy_slice_scatter_with_self_overlapping_input():
gems_assert_equal(res_out, ref_out)


# TODO: failed at (200, 40999, 3)
@pytest.mark.index_add
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dim", DIM_LIST)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_accuracy_index_add(shape, dim, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda")

inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
src_shape = list(inp.shape)
index_max = src_shape[dim]
index_len = index_max
index = torch.randperm(index_len, device="cuda")
index_len = index_max // 2 if index_max >= 2 else 1
index = torch.randperm(index_len, device=flag_gems.device)
src_shape[dim] = index_len
src = torch.randn(src_shape, dtype=dtype, device="cuda")
src = torch.randn(src_shape, dtype=dtype, device=flag_gems.device)
alpha = 2

ref_inp = to_reference(inp)