Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] index_add optimized #427

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions benchmark/test_select_and_slice_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,23 +254,31 @@ def select_scatter_input_fn(shape, dtype, device):
bench.run()


def index_add_gbps(bench_fn_args, latency):
index = bench_fn_args[2]
src = bench_fn_args[3]
io_amount = sum([shape_utils.size_in_bytes(item) for item in [index, src, src]])
return io_amount * 1e-9 / (latency * 1e-3)


@pytest.mark.index_add
def test_index_add_perf():
def index_add_input_fn(shape, dtype, device):
inp = torch.randn(shape, dtype=dtype, device="cuda")
dim = 0
inp = torch.randn(shape, dtype=dtype, device=device)
dim = 0 if len(shape) == 1 else 1
src_shape = list(inp.shape)
index_max = src_shape[dim]
index_len = index_max // 2
index = torch.randint(0, index_max, (index_len,), device="cuda")
index_len = index_max // 2 if index_max >= 2 else 1
index = torch.randperm(index_len, device=device)
src_shape[dim] = index_len
src = torch.randn(src_shape, dtype=dtype, device="cuda")
src = torch.randn(src_shape, dtype=dtype, device=device)
yield inp, dim, index, src

bench = TensorSelectBenchmark(
op_name="index_add",
torch_op=torch.index_add,
input_fn=index_add_input_fn,
dtypes=FLOAT_DTYPES,
dtypes=[torch.float16, torch.float32],
get_gbps=index_add_gbps,
)
bench.run()
317 changes: 243 additions & 74 deletions src/flag_gems/ops/index_add.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,233 @@
import importlib
import logging
import os
from typing import Any, Callable, List, Mapping, Tuple

import torch
import triton
import triton.language as tl

from ..utils import dim_compress, libentry


def cfggen():
block_m = [1, 2, 4]
block_n = [128, 1024, 2048, 4096]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)
for m in block_m
for n in block_n
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def index_add_kernel(
inp,
out,
index,
src,
M,
N,
alpha,
inp_len,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)
rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)

rows_mask = rows_offsets < M
index_mask = cols_offsets < N
block_mask = rows_mask and index_mask

cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0)
inp_off = rows_offsets * inp_len + cur_indices[None, :]
cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0)
src_off = rows_offsets * N + cols_offsets[None, :]
cur_src = tl.load(src + src_off, mask=block_mask, other=0.0)
cur_inp += alpha * cur_src

tl.store(out + inp_off, cur_inp, mask=block_mask)

from flag_gems.utils.code_cache import code_cache_dir
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace


def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
code.writeline("import triton")
code.writeline("import triton.language as tl")
code.writeline("from flag_gems.utils import libentry")

code.newline()
code.newline()

return code


def generate_index_add_kernel(
rank: int,
kernel_name: str,
code: IndentedBuffer,
) -> IndentedBuffer:
# the decorators
code.writeline("@libentry()")
code.writeline("@triton.jit")

# signature
code.writeline(f"def {kernel_name}(")
function_ns = NameSpace()
with code.indent():
if rank > 0:
code.writeline("index,")
code.writeline("src,")
code.writeline("out,")
code.writeline("N,")
code.writeline("inp_numel,")
code.writeline("inp_stride_dim,")
code.writeline("inp_shape_dim,")
code.writeline("src_shape_dim,")
code.writeline("delta,")
code.writeline("alpha,")

function_ns.create_name("index")
function_ns.create_name("src")
function_ns.create_name("out")

for i in range(rank):
function_ns.create_name(f"src_stride_{i}")
stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
code.writeline(f"{stride_args}, # stride for src")

for i in range(rank):
function_ns.create_name(f"src_shape_{i}")
shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank))
code.writeline(f"{shape_args}, # shape for src")

code.writeline("BLOCK_SIZE: tl.constexpr,")

code.writeline("):")

# Kernel Code
with code.indent():
code.writeline("pid = tl.program_id(axis=0)")
code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
code.writeline("mask = offsets < N")

# 1. Calculate src_offsets
code.writeline("src_idx = tl.zeros((BLOCK_SIZE, ), dtype=tl.int64)")
# snippets
for i in range(rank):
code.writeline(f"mod = offsets % src_shape_{i}")
code.writeline(f"src_idx += mod * src_stride_{i}")
if i != (rank - 1):
code.writeline(f"offsets = offsets // src_shape_{i}")

# 2. index add
code.writeline(
"pre_idx = (src_idx // (inp_stride_dim * src_shape_dim)).to(tl.int64)"
)
code.writeline(
"dim_idx = (src_idx % (inp_stride_dim * src_shape_dim) // inp_stride_dim).to(tl.int64)"
)
code.writeline(
"src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
)
code.writeline(
'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
)
code.writeline(
"input_idx = (src_idx + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
)

code.writeline("input_mask = input_idx < inp_numel")
code.writeline(
"add_on = tl.load(src + src_idx, mask=mask, other=0) * alpha"
)
code.writeline(
"tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')"
)
# TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe.
# code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)")
# code.writeline("tl.store(out + input_idx, cur_out + add_on, mask=input_mask)")

code.newline()
code.newline()
return code


def parameter_for_wrapper() -> str:
# out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha
parameters: List[str] = []
parameters.append("out")
parameters.append("index")
parameters.append("src")
parameters.append("dim")
parameters.append("inp_stride_dim")
parameters.append("inp_shape_dim")
parameters.append("src_shape_dim")
parameters.append("delta")
parameters.append("N")
parameters.append("inp_numel")
parameters.append("alpha")

return ", ".join(parameters)


def generate_destination_passing_wrapper(
rank: int,
wrapper_name: str,
kernel_name: str,
code: IndentedBuffer,
) -> IndentedBuffer:
parameters: str = parameter_for_wrapper()
wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
code.writeline(wrapper_signature)

with code.indent():
code.writeline("src_strides = list(src.stride())")
code.writeline("src_shapes = list(src.shape)")

# kernel launch
code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting
code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)")
kernel_launch: str = f"{kernel_name}[grid]("
code.writeline(kernel_launch)
with code.indent():
code.writeline(
"index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, "
)
if rank > 0:
s = ", ".join(f"src_strides[{i}]" for i in range(rank))
code.writeline(f"{s},")

s = ", ".join(f"src_shapes[{i}]" for i in range(rank))
code.writeline(f"{s},")
code.writeline("BLOCK_SIZE=BLOCK_SIZE")
code.writeline(")")
code.writeline("return out")

return code


def generate_code(
inputs: Tuple[Any],
wrapper_name: str,
kernel_name: str,
code: IndentedBuffer,
) -> IndentedBuffer:
# inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha]
shape = inputs[2].shape
rank = len(shape)

code = generate_imports(code)
code = generate_index_add_kernel(rank, kernel_name, code)
code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
return code


class IndexAddFunction:
def __init__(self):
self.pid = os.getpid()
self.overloads: Mapping[str, Callable] = {}

def __call__(self, *args, **kwargs):
key = f"{self.arg_key(*args)}"
if key in self.overloads:
overload = self.overloads[key]
else:
code = IndentedBuffer()
code = generate_code(
args,
"_index_add_wrapper",
"_index_add_jit_function",
code,
)

file_name = f"index_add_rank_{key}_pid_{self.pid}.py"

with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
f.write(code.getvalue())

# load
spec = importlib.util.spec_from_file_location(
f"_gen_module_rank_{key}_pid_{self.pid}",
f.name,
)

m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
overload = getattr(m, "_index_add_wrapper")
self.overloads[key] = overload

return overload(*args, **kwargs)

def arg_key(self, *args):
tensors = [item for item in args if torch.is_tensor(item)]
max_rank = max(item.ndim for item in tensors)
return max_rank


_index_add_func = IndexAddFunction()


def index_add(inp, dim, index, src, alpha=1):
logging.debug("GEMS INDEX ADD")
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
), "0 <= index < self.size(dim)"
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
assert index.numel() == src.size(
dim
Expand All @@ -68,28 +239,26 @@ def index_add(inp, dim, index, src, alpha=1):
((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
), "src.size(d) == self.size(d) for all dimensions d != dim"

inp = inp.contiguous()
index = index.contiguous()
src = src.contiguous()

dim = dim % inp.ndim
inp_len = inp.size(dim)
N = index.numel()
M = src.numel() // N
fine_dim = inp.ndim - 1
if dim != fine_dim:
inp = dim_compress(inp, dim)
src = dim_compress(src, dim)
out = inp.clone()

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
dim %= inp.ndim
inp_stride_dim = inp.stride(dim)
src_shape_dim = src.size(dim)
inp_shape_dim = inp.size(dim)
delta = inp.size(dim) - src_shape_dim
N = src.numel()

_index_add_func(
out,
index,
src,
dim,
inp_stride_dim,
inp_shape_dim,
src_shape_dim,
delta,
N,
inp.numel(),
alpha,
)
index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len)
if dim != fine_dim:
order = [i for i in range(out.ndim - 1)]
order.insert(dim, fine_dim)
return out.permute(order).contiguous()
else:
return out
return out
2 changes: 1 addition & 1 deletion src/flag_gems/utils/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def offset_calculator(inp, idx, strides, dim, isInp):
return offsets if not isInp else (offsets - idx_dim)


def offsetCalculator(inp, idx, strides, dim, isInp):
def offsetCalculator(inp, idx, strides, dim, isInp=False):
ndim = inp.ndim
shape = list(inp.shape)
offsets = 0
Expand Down
Loading
Loading