Skip to content

Commit 18d4127

Browse files
authored
Add elu and kron op (#482)
1 parent 6c0c3cb commit 18d4127

File tree

7 files changed

+259
-2
lines changed

7 files changed

+259
-2
lines changed

benchmark/test_special_perf.py

+1
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def diagonal_backward_input_fn(shape, dtype, device):
381381
bench.run()
382382

383383

384+
@pytest.mark.skipif(vendor_name == "kunlunxin", reason="RESULT TODOFIX")
384385
@pytest.mark.kron
385386
def test_perf_kron():
386387
class KronBenchmark(GenericBenchmark2DOnly):

benchmark/test_unary_pointwise_perf.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from .attri_util import BOOL_DTYPES, DEFAULT_METRICS, FLOAT_DTYPES, INT_DTYPES
7-
from .performance_utils import Benchmark, generate_tensor_input
7+
from .performance_utils import Benchmark, generate_tensor_input, vendor_name
88

99

1010
class UnaryPointwiseBenchmark(Benchmark):
@@ -74,6 +74,8 @@ def get_tflops(self, op, *args, **kwargs):
7474
],
7575
)
7676
def test_general_unary_pointwise_perf(op_name, torch_op, dtypes):
77+
if vendor_name == "kunlunxin" and op_name == "elu":
78+
pytest.skip("RUNTIME TODOFIX")
7779
bench = UnaryPointwiseBenchmark(op_name=op_name, torch_op=torch_op, dtypes=dtypes)
7880
bench.run()
7981

src/flag_gems/runtime/backend/_kunlunxin/ops/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .diagonal import diagonal_backward
3333
from .div import div_mode, floor_divide, remainder, true_divide
3434
from .dropout import native_dropout
35+
from .elu import elu
3536
from .embedding import embedding
3637
from .eq import eq, eq_scalar
3738
from .erf import erf
@@ -56,6 +57,7 @@
5657
from .isin import isin
5758
from .isinf import isinf
5859
from .isnan import isnan
60+
from .kron import kron
5961
from .layernorm import layer_norm
6062
from .le import le, le_scalar
6163
from .log_sigmoid import log_sigmoid
@@ -170,6 +172,7 @@
170172
"diag",
171173
"diag_embed",
172174
"diagonal_backward",
175+
"elu",
173176
"pad",
174177
"constant_pad_nd",
175178
"cummin",
@@ -305,6 +308,7 @@
305308
"logical_xor",
306309
"logical_not",
307310
"sort",
311+
"kron",
308312
"nll_loss_forward",
309313
"nll_loss_backward",
310314
"nll_loss2d_forward",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import logging
2+
3+
import triton
4+
import triton.language as tl
5+
6+
from ..utils.pointwise_dynamic import pointwise_dynamic
7+
8+
9+
@pointwise_dynamic(
10+
is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")]
11+
)
12+
@triton.jit
13+
def elu_forward_kernel(x, alpha, scale, input_scale):
14+
return tl.where(
15+
x.to(tl.float32) > 0,
16+
scale * input_scale * x.to(tl.float32),
17+
scale * alpha * (tl.exp(x.to(tl.float32) * input_scale) - 1),
18+
)
19+
20+
21+
def elu(A, alpha=1.0, scale=1.0, input_scale=1.0):
22+
logging.debug("GEMS ELU")
23+
return elu_forward_kernel(A, alpha, scale, input_scale)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import math
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
# from flag_gems import runtime
8+
from flag_gems.runtime import torch_device_fn
9+
from flag_gems.utils import triton_lang_extension as tle
10+
11+
12+
def prepare_tensor_for_kron(tensor_a, tensor_b):
13+
a_shape = list(tensor_a.shape)
14+
b_shape = list(tensor_b.shape)
15+
16+
if tensor_a.numel() == 0 or tensor_b.numel() == 0:
17+
if not a_shape:
18+
a_shape = [0]
19+
if not b_shape:
20+
b_shape = [0]
21+
22+
if len(a_shape) > len(b_shape):
23+
b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
24+
elif len(b_shape) > len(a_shape):
25+
a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
26+
27+
out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
28+
return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
29+
30+
if len(a_shape) < 2:
31+
a_shape = [1] * (2 - len(a_shape)) + a_shape
32+
if len(b_shape) < 2:
33+
b_shape = [1] * (2 - len(b_shape)) + b_shape
34+
35+
if len(a_shape) > len(b_shape):
36+
b_shape = [1] * (len(a_shape) - len(b_shape)) + b_shape
37+
elif len(b_shape) > len(a_shape):
38+
a_shape = [1] * (len(b_shape) - len(a_shape)) + a_shape
39+
40+
out_shape = tuple(a * b for a, b in zip(a_shape, b_shape))
41+
return tensor_a.reshape(*a_shape), tensor_b.reshape(*b_shape), out_shape
42+
43+
44+
def calculate_indices(batch_idx, shape_a, shape_b):
45+
a_batch_dims = shape_a[:-2] or (1,)
46+
b_batch_dims = shape_b[:-2] or (1,)
47+
out_batch_dims = tuple(a * b for a, b in zip(a_batch_dims, b_batch_dims))
48+
49+
out_indices = []
50+
remaining = batch_idx
51+
for dim_size in out_batch_dims[::-1]:
52+
out_indices.insert(0, remaining % dim_size)
53+
remaining //= dim_size
54+
55+
a_idx = b_idx = 0
56+
for out_idx, (a_dim, b_dim) in zip(out_indices, zip(a_batch_dims, b_batch_dims)):
57+
a_idx = a_idx * a_dim + (out_idx // b_dim)
58+
b_idx = b_idx * b_dim + (out_idx % b_dim)
59+
60+
return a_idx, b_idx
61+
62+
63+
def heur_block_n(args):
64+
import builtins
65+
66+
return builtins.min(args["N"], 8192)
67+
68+
69+
def heur_block_m(args):
70+
return triton.next_power_of_2(triton.cdiv(args["M"], 12))
71+
72+
73+
# @triton.autotune(configs=runtime.get_tuned_config("kron"), key=["M", "N"])
74+
@triton.heuristics(
75+
{
76+
"BLOCK_M": heur_block_m,
77+
"BLOCK_N": heur_block_n,
78+
}
79+
)
80+
@triton.jit
81+
def kron_kernel(
82+
a_ptr,
83+
b_ptr,
84+
c_ptr,
85+
map_ptr,
86+
batch_size: tl.int64,
87+
M: tl.int64,
88+
N: tl.int64,
89+
M1: tl.int64,
90+
M2: tl.int64,
91+
N1: tl.int64,
92+
N2: tl.int64,
93+
a_stride_0: tl.int64,
94+
a_stride_1: tl.int64,
95+
b_stride_0: tl.int64,
96+
b_stride_1: tl.int64,
97+
c_stride_0: tl.int64,
98+
c_stride_1: tl.int64,
99+
a_batch_stride: tl.int64,
100+
b_batch_stride: tl.int64,
101+
c_batch_stride: tl.int64,
102+
BLOCK_M: tl.constexpr,
103+
BLOCK_N: tl.constexpr,
104+
):
105+
pid = tle.program_id(0)
106+
num_blocks_n = tl.cdiv(N, BLOCK_N)
107+
num_blocks_m = tl.cdiv(M, BLOCK_M)
108+
num_blocks_per_batch = num_blocks_m * num_blocks_n
109+
110+
batch_id = pid // num_blocks_per_batch
111+
local_pid = pid % num_blocks_per_batch
112+
block_m = local_pid // num_blocks_n
113+
block_n = local_pid % num_blocks_n
114+
115+
offs_m = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
116+
offs_n = block_n * BLOCK_N + tl.arange(0, BLOCK_N)
117+
118+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) & (batch_id < batch_size)
119+
120+
offset = batch_id * 2
121+
is_valid = batch_id < batch_size
122+
a_batch_idx = tl.load(map_ptr + offset, mask=is_valid)
123+
b_batch_idx = tl.load(map_ptr + offset + 1, mask=is_valid)
124+
125+
a_row = offs_m[:, None] // M2
126+
a_col = offs_n[None, :] // N2
127+
b_row = offs_m[:, None] % M2
128+
b_col = offs_n[None, :] % N2
129+
130+
a_idx = a_batch_idx * a_batch_stride + a_row * a_stride_0 + a_col * a_stride_1
131+
b_idx = b_batch_idx * b_batch_stride + b_row * b_stride_0 + b_col * b_stride_1
132+
133+
a = tl.load(a_ptr + a_idx, mask=mask)
134+
b = tl.load(b_ptr + b_idx, mask=mask)
135+
c = a * b
136+
137+
c_idx = (
138+
batch_id * c_batch_stride
139+
+ offs_m[:, None] * c_stride_0
140+
+ offs_n[None, :] * c_stride_1
141+
)
142+
tl.store(c_ptr + c_idx, c, mask=mask)
143+
144+
145+
def kron(A, B):
146+
if A.dim() == 0 and B.dim() == 0:
147+
return A * B
148+
149+
if A.numel() == 0 or B.numel() == 0:
150+
A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
151+
output_dtype = torch.promote_types(A.dtype, B.dtype)
152+
return torch.empty(out_shape, device=A.device, dtype=output_dtype)
153+
154+
if A.dim() == 0:
155+
return A.unsqueeze(0) * B
156+
if B.dim() == 0:
157+
return A * B.unsqueeze(0)
158+
159+
A_prepared, B_prepared, out_shape = prepare_tensor_for_kron(A, B)
160+
M1, N1 = A_prepared.shape[-2:]
161+
M2, N2 = B_prepared.shape[-2:]
162+
M, N = M1 * M2, N1 * N2
163+
164+
batch_size = math.prod(out_shape[:-2]) if out_shape[:-2] else 1
165+
166+
output_dtype = torch.promote_types(A.dtype, B.dtype)
167+
C = torch.empty(out_shape, device=A.device, dtype=output_dtype)
168+
169+
C_reshaped = C.view(-1, M, N)
170+
A_view = A_prepared.reshape(-1, M1, N1)
171+
B_view = B_prepared.reshape(-1, M2, N2)
172+
173+
if not A_view.is_contiguous():
174+
A_view = A_view.contiguous()
175+
if not B_view.is_contiguous():
176+
B_view = B_view.contiguous()
177+
178+
batch_indices = torch.empty(batch_size * 2, device=A.device, dtype=torch.int64)
179+
for i in range(batch_size):
180+
a_idx, b_idx = calculate_indices(i, A_prepared.shape, B_prepared.shape)
181+
batch_indices[i * 2] = a_idx
182+
batch_indices[i * 2 + 1] = b_idx
183+
184+
a_batch_stride = M1 * N1
185+
b_batch_stride = M2 * N2
186+
c_batch_stride = M * N
187+
with torch_device_fn.device(A.device):
188+
grid = lambda meta: (
189+
batch_size
190+
* triton.cdiv(M, meta["BLOCK_M"])
191+
* triton.cdiv(N, meta["BLOCK_N"]),
192+
)
193+
194+
kron_kernel[grid](
195+
A_view,
196+
B_view,
197+
C_reshaped,
198+
batch_indices,
199+
batch_size,
200+
M,
201+
N,
202+
M1,
203+
M2,
204+
N1,
205+
N2,
206+
A_view.stride(1),
207+
A_view.stride(2),
208+
B_view.stride(1),
209+
B_view.stride(2),
210+
C_reshaped.stride(1),
211+
C_reshaped.stride(2),
212+
a_batch_stride,
213+
b_batch_stride,
214+
c_batch_stride,
215+
)
216+
217+
if A.dim() <= 1 and B.dim() <= 1:
218+
return C.reshape(-1)
219+
220+
return C

tests/accuracy_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from .conftest import QUICK_MODE, TO_CPU
99

10+
fp64_is_supported = flag_gems.runtime.device.support_fp64
11+
1012

1113
def SkipVersion(module_name, skip_pattern):
1214
cmp = skip_pattern[0]
@@ -146,7 +148,7 @@ def SkipVersion(module_name, skip_pattern):
146148
]
147149
# Add some test cases with zeor-dimensional tensor and zero-sized tensors.
148150
FLOAT_DTYPES = [torch.float16, torch.float32, torch.bfloat16]
149-
ALL_FLOAT_DTYPES = FLOAT_DTYPES + [torch.float64]
151+
ALL_FLOAT_DTYPES = FLOAT_DTYPES + [torch.float64] if fp64_is_supported else FLOAT_DTYPES
150152
INT_DTYPES = [torch.int16, torch.int32]
151153
ALL_INT_DTYPES = INT_DTYPES + [torch.int64]
152154
BOOL_TYPES = [torch.bool]

tests/test_special_ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,11 @@ def test_accuracy_kron(shape, dtype):
10751075
inp1 = torch.randint(0, 2, size=shape[0], dtype=dtype, device=flag_gems.device)
10761076
inp2 = torch.randint(0, 2, size=shape[1], dtype=dtype, device=flag_gems.device)
10771077

1078+
if flag_gems.vendor_name == "kunlunxin" and dtype == torch.bfloat16:
1079+
# Pytorch 2.0.1 Bfloat16 CPU Backend Precision Failed
1080+
inp1 = torch.randn(shape[0], dtype=torch.float32, device=flag_gems.device)
1081+
inp2 = torch.randn(shape[1], dtype=torch.float32, device=flag_gems.device)
1082+
10781083
ref_inp1 = to_reference(inp1)
10791084
ref_inp2 = to_reference(inp2)
10801085

0 commit comments

Comments
 (0)