Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add index select backward #359

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
51e16fd
init-flaggems
ph0375 Dec 11, 2024
f5dd979
init-ops
ph0375 Dec 11, 2024
0afb6bc
index_select.py
ph0375 Dec 11, 2024
00acd52
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 11, 2024
d347788
update
ph0375 Dec 11, 2024
f47272b
init_update
ph0375 Dec 11, 2024
0caa0a8
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 11, 2024
414fe76
Update __init__.py
ph0375 Dec 11, 2024
3aa39b3
autotune
ph0375 Dec 11, 2024
4b6ad02
Update tune_configs.yaml
ph0375 Dec 11, 2024
fa976f0
index_select_backward unit test
shuailong616 Dec 12, 2024
a2e6e01
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 12, 2024
3ef71f4
Update test_select_and_slice_perf.py
ph0375 Dec 12, 2024
1ac8878
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 12, 2024
1376320
change autotune config
shuailong616 Dec 13, 2024
b685f3c
code format change
shuailong616 Dec 13, 2024
6dfb8b7
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 16, 2024
378cecf
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 17, 2024
084e409
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 17, 2024
acb1e59
add_index_select_backward
henghengxiedaima Dec 17, 2024
1f94774
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 18, 2024
a19634d
add_index_select_backward
henghengxiedaima Dec 18, 2024
80777b3
add_index_select_backward
henghengxiedaima Dec 18, 2024
d8b1441
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 18, 2024
3d59b87
solve conflict
shuailong616 Dec 19, 2024
3a27f33
add_index_select_backward
henghengxiedaima Dec 20, 2024
c3fd06f
add_index_select_backward
henghengxiedaima Dec 20, 2024
3c64858
add_index_select_backward
henghengxiedaima Dec 20, 2024
7770b1a
add_index_select
henghengxiedaima Dec 24, 2024
bafa7ef
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Dec 30, 2024
6515c14
add_index_select_backward
henghengxiedaima Dec 30, 2024
815360e
add_index_select_backward
henghengxiedaima Dec 31, 2024
08992ec
add_index_select_backward
henghengxiedaima Jan 14, 2025
386fbdd
Merge branch 'FlagOpen:master' into add_index_select_backward
AdvancedCompiler Jan 22, 2025
58b624a
Merge branch 'master' into add_index_select_backward
AdvancedCompiler Mar 5, 2025
150c55e
Update tune_configs.yaml
AdvancedCompiler Mar 5, 2025
5294358
Update __init__.py
AdvancedCompiler Mar 5, 2025
d27e4bc
Update __init__.py
AdvancedCompiler Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions benchmark/test_select_and_slice_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ def index_select_gbps(bench_fn_args, latency):
return io_amount * 1e-9 / (latency * 1e-3)


def index_select_backward_gbps(bench_fn_args, latency):
inp = bench_fn_args[0]
dim = bench_fn_args[1]
index = bench_fn_args[2]
index_unique = torch.unique(index)
io_amount = (
shape_utils.size_in_bytes(inp)
* (index.size(0) + index_unique.size(0))
// inp.size(dim)
)
return io_amount * 1e-9 / (latency * 1e-3)


@pytest.mark.parametrize(
"op_name, torch_op, input_fn, gbps_fn, dtypes",
[
Expand Down Expand Up @@ -272,6 +285,32 @@ def select_scatter_input_fn(shape, dtype, device):
bench.run()


@pytest.mark.skipif(vendor_name == "kunlunxin", reason="RESULT TODOFIX")
@pytest.mark.index_select_backward
def test_perf_index_select_backward():
def index_select_backward_input_fn(shape, dtype, device):
inp = generate_tensor_input(shape, dtype, device)
threshold = 0.1
dim = 0
index_size = inp.size(dim)
from math import floor

index = torch.randint(
0, index_size, [floor(index_size * threshold)], device=device
)
yield inp, dim, index

bench = TensorSelectBenchmark(
input_fn=index_select_backward_input_fn,
op_name="index_select_backward",
torch_op=torch.index_select,
dtypes=FLOAT_DTYPES,
get_gbps=index_select_backward_gbps,
is_backward=True,
)
bench.run()


@pytest.mark.skipif(vendor_name == "kunlunxin", reason="RESULT TODOFIX")
@pytest.mark.index_add
def test_index_add_perf():
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("slice_scatter", slice_scatter, Autograd.disable),
("select_scatter", select_scatter, Autograd.disable),
("index_select", index_select, Autograd.disable),
("index_select_backward", index_select_backward, Autograd.disable),
("tile", tile, Autograd.disable),
("masked_fill.Tensor", masked_fill, Autograd.disable),
("masked_fill.Scalar", masked_fill, Autograd.disable),
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .hstack import hstack
from .index_add import index_add
from .index_put import index_put
from .index_select import index_select
from .index_select import index_select, index_select_backward
from .instancenorm import instance_norm
from .isclose import allclose, isclose
from .isfinite import isfinite
Expand Down Expand Up @@ -203,6 +203,7 @@
"gt",
"gt_scalar",
"index_select",
"index_select_backward",
"instance_norm",
"isclose",
"isfinite",
Expand Down
87 changes: 87 additions & 0 deletions src/flag_gems/ops/index_select.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math

import torch
import triton
Expand Down Expand Up @@ -62,3 +63,89 @@ def index_select(inp, dim, index):
return out.permute(order).contiguous()
else:
return out


def dim_compress_backward(inp, dims):
if isinstance(dims, int):
dims = [dims]
dim = inp.ndim
stride = inp.stride()
batch_dim = [i for i in range(dim) if i not in dims]
sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True)
order = sorted_reduction_dim + batch_dim
return inp.permute(order).contiguous()


# kernel
@libentry()
@triton.autotune(
configs=runtime.get_tuned_config("index_select_backward"),
key=["M", "N"],
reset_to_zero=["out"],
)
@triton.jit
def index_select_backward_kernel(
grad,
out,
M,
N,
num_blocks_per_CTA,
index,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_x = tle.program_id(axis=0)
pid_y = tle.program_id(axis=1)
rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)
cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)

grad_mask = (rows_offsets < M) and (cols_offsets < N)
indices = tl.load(index + rows_offsets, mask=(rows_offsets < M), other=0)

for i in range(0, num_blocks_per_CTA):
grad_off = (pid_x * num_blocks_per_CTA + i) * N + cols_offsets
out_off = (indices * num_blocks_per_CTA + i) * N + cols_offsets
selected = tl.load(grad + grad_off, mask=grad_mask, other=0.0)
tl.atomic_add(out + out_off, selected, mask=grad_mask)


# function
def index_select_backward(grad, self_sizes, dim, index):
logging.debug("GEMS INDEX SELECT BACKWARD")
assert dim >= -len(self_sizes) and dim < len(self_sizes), "Invalid dim"
assert index.ndim <= 1, "Index should have dimension 1 or 0"
if index.ndim == 0:
index = index.unsqueeze(0)
grad_init = grad.ndim
if grad_init == 1:
grad = grad.unsqueeze(1)
index_shape = list(index.shape)
dim = dim % len(self_sizes)
grad_shape = list(grad.shape)
assert grad_shape[dim] == index_shape[0], "Index out of range"
grad = dim_compress_backward(grad, dim)
grad_shape = list(grad.shape)
out_shape = list(grad.shape)
shape_for_block_counting = list(grad.shape[1:])
shape_for_block_counting[-1] = 1
num_blocks_per_CTA = math.prod(shape_for_block_counting)
N = grad_shape[grad.ndim - 1]
M = grad.numel() // N // num_blocks_per_CTA
out_shape[0] = self_sizes[dim]
grad_type = grad.dtype
grad = grad.to(torch.float32)
out = torch.zeros(out_shape, dtype=torch.float32, device=grad.device)
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
index_select_backward_kernel[grid](grad, out, M, N, num_blocks_per_CTA, index)
out = out.to(grad_type)
if grad_init == 1:
out = out.squeeze(1)
if dim != 0:
order = [i for i in range(1, out.ndim)]
order.insert(dim, 0)
return out.permute(order)
else:
return out
13 changes: 12 additions & 1 deletion src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,18 @@ batch_norm:
- 8
- 16
- 32

index_select_backward:
- gen: true
param_map:
META:
BLOCK_M: block_m
BLOCK_N: block_n
num_warps: 4
block_m:
- 1
block_n:
- 256
- 512
index_put:
- gen: true
param_map:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,36 @@ def test_accuracy_index_select(shape, dim, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX")
@pytest.mark.index_select_backward
@pytest.mark.parametrize("shape", REDUCTION_SMALL_SHAPES)
@pytest.mark.parametrize("dim", DIM_LIST)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_index_select_backward(shape, dim, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
ref_inp = to_reference(inp)
from math import floor

index_size = inp.size(dim)
index = torch.randint(0, index_size, [floor(index_size * 0.1)], device="cuda")
if len(index) == 0:
pass
else:
ref_index = to_reference(index)
ref_out = torch.index_select(ref_inp, dim, ref_index)
with flag_gems.use_gems():
res_out = torch.index_select(inp, dim, index)
out_grad = torch.randn_like(res_out)
ref_grad = to_reference(out_grad)
(ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad)
with flag_gems.use_gems():
(res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
res_out = to_reference(res_out)
res_in_grad = to_reference(res_in_grad)
gems_assert_close(res_out, ref_out, dtype)
gems_assert_close(res_in_grad, ref_in_grad, dtype)


@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX")
@pytest.mark.masked_select
@pytest.mark.parametrize("threshold, shape", THRESHOLD_SHAPE)
Expand Down
Loading