Skip to content

Commit 2eeabf4

Browse files
committed
add limits to get max/min for integer types and inf/-inf for floating point types, fix max/min/argmax/argmin/cummin
1 parent 06a1845 commit 2eeabf4

File tree

7 files changed

+126
-70
lines changed

7 files changed

+126
-70
lines changed

src/flag_gems/ops/argmax.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..runtime import torch_device_fn
1010
from ..utils import libentry
1111
from ..utils import triton_lang_extension as tle
12+
from ..utils.limits import get_dtype_min
1213

1314

1415
@libentry()
@@ -24,7 +25,8 @@ def argmax_kernel_1(
2425
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
2526
inp_ptrs = inp + offset
2627
mask = offset < M
27-
inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
28+
min_value = get_dtype_min(inp.type.element_ty)
29+
inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
2830
max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
2931
max_index = max_index + pid * BLOCK_SIZE
3032
mid_value_ptr = mid_value + pid
@@ -39,7 +41,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr
3941
offset = tl.arange(0, BLOCK_MID)
4042
mid_ptrs = mid_value + offset
4143
mask = offset < mid_size
42-
mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
44+
min_value = get_dtype_min(mid_value.type.element_ty)
45+
mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
4346
index_val = tl.argmax(mid_val, axis=0)
4447
mid_index_ptrs = mid_index + index_val
4548
out_val = tl.load(mid_index_ptrs)
@@ -63,14 +66,17 @@ def argmax_kernel(
6366
pid_k = tle.program_id(1)
6467
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
6568

66-
max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf"))
69+
dtype = inp.type.element_ty
70+
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
71+
min_value = get_dtype_min(dtype)
72+
max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
6773
argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
6874
for start_n in range(0, N, BLOCK_N):
6975
n_offset = start_n + tl.arange(0, BLOCK_N)
7076
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
7177
mask = m_offset[:, None] < M and n_offset[None, :] < N
7278
inp_ptrs = inp + offset
73-
inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
79+
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
7480
local_max, local_argmax = tl.max(
7581
inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
7682
)

src/flag_gems/ops/argmin.py

+12-27
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,7 @@
99
from ..runtime import torch_device_fn
1010
from ..utils import libentry
1111
from ..utils import triton_lang_extension as tle
12-
13-
torch_dtype_to_tl_dtype_and_max_value = {
14-
torch.int16: (tl.int16, torch.iinfo(torch.int16).max),
15-
torch.int32: (tl.int32, torch.iinfo(torch.int32).max),
16-
torch.float16: (tl.float16, torch.finfo(torch.float16).max),
17-
torch.float32: (tl.float32, torch.finfo(torch.float32).max),
18-
torch.bfloat16: (tl.float32, torch.finfo(torch.float32).max),
19-
}
12+
from ..utils.limits import get_dtype_max
2013

2114

2215
@libentry()
@@ -27,13 +20,14 @@ def argmin_kernel_1(
2720
mid_index,
2821
M,
2922
BLOCK_SIZE: tl.constexpr,
30-
dtype_max_value: tl.constexpr,
3123
):
3224
pid = tle.program_id(0)
3325
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
3426
inp_ptrs = inp + offset
3527
mask = offset < M
36-
inp_val = tl.load(inp_ptrs, mask=mask, other=dtype_max_value)
28+
29+
max_value = get_dtype_max(inp.type.element_ty)
30+
inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
3731
min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
3832
min_index = min_index + pid * BLOCK_SIZE
3933
mid_value_ptr = mid_value + pid
@@ -50,12 +44,12 @@ def argmin_kernel_2(
5044
out,
5145
mid_size,
5246
BLOCK_MID: tl.constexpr,
53-
dtype_max_value: tl.constexpr,
5447
):
5548
offset = tl.arange(0, BLOCK_MID)
5649
mid_ptrs = mid_value + offset
5750
mask = offset < mid_size
58-
mid_val = tl.load(mid_ptrs, mask=mask, other=dtype_max_value)
51+
max_value = get_dtype_max(mid_value.type.element_ty)
52+
mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
5953
index_val = tl.argmin(mid_val, axis=0)
6054
mid_index_ptrs = mid_index + index_val
6155
out_val = tl.load(mid_index_ptrs)
@@ -75,8 +69,6 @@ def argmin_kernel(
7569
M,
7670
N,
7771
K,
78-
tl_dtype: tl.constexpr,
79-
dtype_max_value: tl.constexpr,
8072
BLOCK_M: tl.constexpr,
8173
BLOCK_N: tl.constexpr,
8274
):
@@ -85,18 +77,18 @@ def argmin_kernel(
8577
pid_k = tle.program_id(1)
8678
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
8779

88-
# min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
89-
if tl_dtype is tl.int16:
90-
tl_dtype = tl.int32
91-
min_values = tl.full([BLOCK_M], dtype=tl_dtype, value=dtype_max_value)
80+
dtype = inp.type.element_ty
81+
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
82+
max_value = get_dtype_max(dtype)
83+
min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
9284
argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
9385
for start_n in range(0, N, BLOCK_N):
9486
n_offset = start_n + tl.arange(0, BLOCK_N)
9587
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
9688
mask = m_offset[:, None] < M and n_offset[None, :] < N
9789
inp_ptrs = inp + offset
98-
# inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
99-
inp_vals = tl.load(inp_ptrs, mask=mask, other=dtype_max_value)
90+
inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
91+
# tl.bfloat is promoted to tl.float32 by tl.min
10092
local_min, local_argmin = tl.min(
10193
inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
10294
)
@@ -132,23 +124,20 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
132124
else:
133125
out = torch.empty([], dtype=torch.int64, device=inp.device)
134126

135-
tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]
136127
with torch_device_fn.device(inp.device):
137128
argmin_kernel_1[(mid_size, 1, 1)](
138129
inp,
139130
mid_value,
140131
mid_index,
141132
M,
142133
block_size,
143-
dtype_max_value,
144134
)
145135
argmin_kernel_2[(1, 1, 1)](
146136
mid_value,
147137
mid_index,
148138
out,
149139
mid_size,
150140
block_mid,
151-
dtype_max_value,
152141
)
153142
return out
154143
else:
@@ -167,8 +156,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
167156
if not keepdim:
168157
out_index = torch.squeeze(out_index, dim)
169158

170-
tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]
171-
172159
grid = lambda meta: (
173160
triton.cdiv(M, meta["BLOCK_M"]),
174161
K,
@@ -180,8 +167,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
180167
M,
181168
N,
182169
K,
183-
tl_dtype,
184-
dtype_max_value,
185170
)
186171

187172
return out_index

src/flag_gems/ops/cummin.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ..runtime import torch_device_fn
99
from ..utils import libentry
1010
from ..utils import triton_lang_extension as tle
11+
from ..utils.limits import get_dtype_max
1112

1213

1314
@triton.jit
@@ -76,8 +77,9 @@ def scan_part_min_kernel(
7677
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
7778
mask = offset < n_elements
7879

80+
max_value = get_dtype_max(inp.type.element_ty)
7981
inp_ptrs = inp + offset
80-
inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
82+
inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
8183
if (
8284
tl.constexpr(inp_vals.dtype.is_int64())
8385
or tl.constexpr(inp_vals.dtype.is_uint64())
@@ -169,7 +171,8 @@ def scan_part_min_abc_kernel(
169171

170172
mask = b_idx < B
171173
inp_ptrs = inp + offset
172-
inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
174+
max_value = get_dtype_max(inp.type.element_ty)
175+
inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
173176
if (
174177
tl.constexpr(inp_vals.dtype.is_int64())
175178
or tl.constexpr(inp_vals.dtype.is_uint64())

src/flag_gems/ops/max.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..runtime import torch_device_fn
1111
from ..utils import libentry
1212
from ..utils import triton_lang_extension as tle
13+
from ..utils.limits import get_dtype_min
1314

1415

1516
@libentry()
@@ -18,26 +19,27 @@ def max_kernel_1(
1819
inp,
1920
mid,
2021
M,
21-
DTYPE_MIN,
2222
BLOCK_SIZE: tl.constexpr,
2323
):
2424
pid = tle.program_id(0)
2525
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
2626
inp_ptrs = inp + offset
2727
mask = offset < M
28-
inp_val = tl.load(inp_ptrs, mask=mask, other=DTYPE_MIN)
28+
min_value = get_dtype_min(inp.type.element_ty)
29+
inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
2930
max_val = tl.max(inp_val)
3031
mid_ptr = mid + pid
3132
tl.store(mid_ptr, max_val)
3233

3334

3435
@libentry()
3536
@triton.jit
36-
def max_kernel_2(mid, out, mid_size, DTYPE_MIN, BLOCK_MID: tl.constexpr):
37+
def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
3738
offset = tl.arange(0, BLOCK_MID)
3839
mid_ptrs = mid + offset
3940
mask = offset < mid_size
40-
mid_val = tl.load(mid_ptrs, mask=mask, other=DTYPE_MIN)
41+
min_value = get_dtype_min(mid.type.element_ty)
42+
mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
4143
max_val = tl.max(mid_val)
4244
tl.store(out, max_val)
4345

@@ -62,23 +64,26 @@ def max_kernel(
6264
M,
6365
N,
6466
K,
65-
DTYPE_MIN,
6667
BLOCK_M: tl.constexpr,
6768
BLOCK_N: tl.constexpr,
6869
):
6970
# set offset
7071
pid_m = tle.program_id(0)
7172
pid_k = tle.program_id(1)
7273
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
73-
result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
74+
75+
dtype = inp.type.element_ty
76+
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
77+
min_value = get_dtype_min(dtype)
78+
result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type)
7479
result_index = tl.zeros([BLOCK_M], dtype=tl.int64)
7580
for i in range(0, N, BLOCK_N):
7681
n_offset = i + tl.arange(0, BLOCK_N)
7782
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
7883
# set mask
7984
mask = m_offset[:, None] < M and n_offset[None, :] < N
8085
inp_ptrs = inp + offset
81-
inp_vals = tl.load(inp_ptrs, mask=mask, other=DTYPE_MIN)
86+
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
8287
max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True)
8388
update_mask = max_value > result_value
8489
result_value = tl.where(update_mask, max_value, result_value)
@@ -104,13 +109,9 @@ def max(inp):
104109
mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
105110
out = torch.empty([], dtype=dtype, device=inp.device)
106111

107-
if torch.is_floating_point(inp):
108-
dtype_min = torch.finfo(inp.dtype).min
109-
else:
110-
dtype_min = torch.iinfo(inp.dtype).min
111112
with torch_device_fn.device(inp.device):
112-
max_kernel_1[(mid_size, 1, 1)](inp, mid, M, dtype_min, block_size)
113-
max_kernel_2[(1, 1, 1)](mid, out, mid_size, dtype_min, block_mid)
113+
max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
114+
max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
114115
return out
115116

116117

@@ -134,17 +135,12 @@ def max_dim(inp, dim=None, keepdim=False):
134135
out_value = torch.squeeze(out_value, dim)
135136
out_index = torch.squeeze(out_index, dim)
136137

137-
if torch.is_floating_point(inp):
138-
dtype_min = torch.finfo(inp.dtype).min
139-
else:
140-
dtype_min = torch.iinfo(inp.dtype).min
141-
142138
grid = lambda meta: (
143139
triton.cdiv(M, meta["BLOCK_M"]),
144140
K,
145141
)
146142
with torch_device_fn.device(inp.device):
147-
max_kernel[grid](inp, out_value, out_index, M, N, K, dtype_min)
143+
max_kernel[grid](inp, out_value, out_index, M, N, K)
148144
Max_out = namedtuple("max", ["values", "indices"])
149145
out = Max_out(values=out_value, indices=out_index)
150146
return out

0 commit comments

Comments
 (0)