Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix max/min other value when int dtype and add int dtype tests #494

Merged
merged 8 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/flag_gems/ops/amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..runtime import torch_device_fn
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle
from ..utils.limits import get_dtype_min


@libentry()
Expand All @@ -24,7 +25,8 @@ def amax_kernel_1(
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
min_value = get_dtype_min(inp.type.element_ty)
inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
amax_val = tl.max(inp_val)
mid_ptr = mid + pid
tl.store(mid_ptr, amax_val)
Expand All @@ -36,7 +38,8 @@ def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
mid_ptrs = mid + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
min_value = get_dtype_min(mid.type.element_ty)
mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
amax_val = tl.max(mid_val)
tl.store(out, amax_val)

Expand All @@ -52,20 +55,23 @@ def amax_kernel(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
dtype = inp.type.element_ty
min_value = get_dtype_min(dtype)

# Map the program id to the row of inp it should compute.
pid = tle.program_id(0)
rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
inp = inp + rows * N
out = out + rows
row_mask = rows < M

_all = tl.full([BLOCK_M, BLOCK_N], value=-float("inf"), dtype=tl.float32)
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
_all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type)
for off in range(0, N, BLOCK_N):
cols = off + tl.arange(0, BLOCK_N)[None, :]
col_mask = cols < N
mask = row_mask and col_mask

a = tl.load(inp + cols, mask, other=-float("inf")).to(tl.float32)
a = tl.load(inp + cols, mask, other=min_value)
_all = tl.maximum(_all, a)
all = tl.max(_all, axis=1)[:, None]
tl.store(out, all, row_mask)
Expand Down
14 changes: 10 additions & 4 deletions src/flag_gems/ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle
from ..utils.limits import get_dtype_min


@libentry()
Expand All @@ -24,7 +25,8 @@ def argmax_kernel_1(
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
min_value = get_dtype_min(inp.type.element_ty)
inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
max_index = max_index + pid * BLOCK_SIZE
mid_value_ptr = mid_value + pid
Expand All @@ -39,7 +41,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr
offset = tl.arange(0, BLOCK_MID)
mid_ptrs = mid_value + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
min_value = get_dtype_min(mid_value.type.element_ty)
mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
index_val = tl.argmax(mid_val, axis=0)
mid_index_ptrs = mid_index + index_val
out_val = tl.load(mid_index_ptrs)
Expand All @@ -63,14 +66,17 @@ def argmax_kernel(
pid_k = tle.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf"))
dtype = inp.type.element_ty
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
min_value = get_dtype_min(dtype)
max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
for start_n in range(0, N, BLOCK_N):
n_offset = start_n + tl.arange(0, BLOCK_N)
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
mask = m_offset[:, None] < M and n_offset[None, :] < N
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
local_max, local_argmax = tl.max(
inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
)
Expand Down
39 changes: 12 additions & 27 deletions src/flag_gems/ops/argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle

torch_dtype_to_tl_dtype_and_max_value = {
torch.int16: (tl.int16, torch.iinfo(torch.int16).max),
torch.int32: (tl.int32, torch.iinfo(torch.int32).max),
torch.float16: (tl.float16, torch.finfo(torch.float16).max),
torch.float32: (tl.float32, torch.finfo(torch.float32).max),
torch.bfloat16: (tl.float32, torch.finfo(torch.float32).max),
}
from ..utils.limits import get_dtype_max


@libentry()
Expand All @@ -27,13 +20,14 @@ def argmin_kernel_1(
mid_index,
M,
BLOCK_SIZE: tl.constexpr,
dtype_max_value: tl.constexpr,
):
pid = tle.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=dtype_max_value)

max_value = get_dtype_max(inp.type.element_ty)
inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
min_index = min_index + pid * BLOCK_SIZE
mid_value_ptr = mid_value + pid
Expand All @@ -50,12 +44,12 @@ def argmin_kernel_2(
out,
mid_size,
BLOCK_MID: tl.constexpr,
dtype_max_value: tl.constexpr,
):
offset = tl.arange(0, BLOCK_MID)
mid_ptrs = mid_value + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=dtype_max_value)
max_value = get_dtype_max(mid_value.type.element_ty)
mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
index_val = tl.argmin(mid_val, axis=0)
mid_index_ptrs = mid_index + index_val
out_val = tl.load(mid_index_ptrs)
Expand All @@ -75,8 +69,6 @@ def argmin_kernel(
M,
N,
K,
tl_dtype: tl.constexpr,
dtype_max_value: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
Expand All @@ -85,18 +77,18 @@ def argmin_kernel(
pid_k = tle.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

# min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
if tl_dtype is tl.int16:
tl_dtype = tl.int32
min_values = tl.full([BLOCK_M], dtype=tl_dtype, value=dtype_max_value)
dtype = inp.type.element_ty
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
max_value = get_dtype_max(dtype)
min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
for start_n in range(0, N, BLOCK_N):
n_offset = start_n + tl.arange(0, BLOCK_N)
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
mask = m_offset[:, None] < M and n_offset[None, :] < N
inp_ptrs = inp + offset
# inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
inp_vals = tl.load(inp_ptrs, mask=mask, other=dtype_max_value)
inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
# tl.bfloat is promoted to tl.float32 by tl.min
local_min, local_argmin = tl.min(
inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
)
Expand Down Expand Up @@ -132,23 +124,20 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
else:
out = torch.empty([], dtype=torch.int64, device=inp.device)

tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]
with torch_device_fn.device(inp.device):
argmin_kernel_1[(mid_size, 1, 1)](
inp,
mid_value,
mid_index,
M,
block_size,
dtype_max_value,
)
argmin_kernel_2[(1, 1, 1)](
mid_value,
mid_index,
out,
mid_size,
block_mid,
dtype_max_value,
)
return out
else:
Expand All @@ -167,8 +156,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
if not keepdim:
out_index = torch.squeeze(out_index, dim)

tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
K,
Expand All @@ -180,8 +167,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
M,
N,
K,
tl_dtype,
dtype_max_value,
)

return out_index
7 changes: 5 additions & 2 deletions src/flag_gems/ops/cummin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle
from ..utils.limits import get_dtype_max


@triton.jit
Expand Down Expand Up @@ -76,8 +77,9 @@ def scan_part_min_kernel(
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements

max_value = get_dtype_max(inp.type.element_ty)
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
if (
tl.constexpr(inp_vals.dtype.is_int64())
or tl.constexpr(inp_vals.dtype.is_uint64())
Expand Down Expand Up @@ -169,7 +171,8 @@ def scan_part_min_abc_kernel(

mask = b_idx < B
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
max_value = get_dtype_max(inp.type.element_ty)
inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
if (
tl.constexpr(inp_vals.dtype.is_int64())
or tl.constexpr(inp_vals.dtype.is_uint64())
Expand Down
5 changes: 3 additions & 2 deletions src/flag_gems/ops/index_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
code.writeline("from flag_gems.utils import libentry")
code.writeline("from flag_gems import runtime")
code.writeline("from flag_gems.utils.shape_utils import volume")
code.writeline("from flag_gems.utils import triton_lang_extension as tle")

code.newline()
code.newline()
Expand Down Expand Up @@ -74,8 +75,8 @@ def generate_index_put_kernel(
code.writeline("):")

with code.indent():
code.writeline("pid0 = tl.program_id(axis=0)")
code.writeline("pid1 = tl.program_id(axis=1)")
code.writeline("pid0 = tle.program_id(axis=0)")
code.writeline("pid1 = tle.program_id(axis=1)")
code.writeline(
"offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]"
)
Expand Down
15 changes: 11 additions & 4 deletions src/flag_gems/ops/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle
from ..utils.limits import get_dtype_min


@libentry()
Expand All @@ -24,7 +25,8 @@ def max_kernel_1(
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
min_value = get_dtype_min(inp.type.element_ty)
inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
max_val = tl.max(inp_val)
mid_ptr = mid + pid
tl.store(mid_ptr, max_val)
Expand All @@ -36,7 +38,8 @@ def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
mid_ptrs = mid + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
min_value = get_dtype_min(mid.type.element_ty)
mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
max_val = tl.max(mid_val)
tl.store(out, max_val)

Expand Down Expand Up @@ -68,15 +71,19 @@ def max_kernel(
pid_m = tle.program_id(0)
pid_k = tle.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)

dtype = inp.type.element_ty
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
min_value = get_dtype_min(dtype)
result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type)
result_index = tl.zeros([BLOCK_M], dtype=tl.int64)
for i in range(0, N, BLOCK_N):
n_offset = i + tl.arange(0, BLOCK_N)
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
# set mask
mask = m_offset[:, None] < M and n_offset[None, :] < N
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True)
update_mask = max_value > result_value
result_value = tl.where(update_mask, max_value, result_value)
Expand Down
15 changes: 11 additions & 4 deletions src/flag_gems/ops/min.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle
from ..utils.limits import get_dtype_max


@libentry()
Expand All @@ -24,7 +25,8 @@ def min_kernel_1(
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf"))
max_value = get_dtype_max(inp.type.element_ty)
inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
min_val = tl.min(inp_val)
mid_ptr = mid + pid
tl.store(mid_ptr, min_val)
Expand All @@ -36,7 +38,8 @@ def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
mid_ptrs = mid + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf"))
max_value = get_dtype_max(mid.type.element_ty)
mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
min_val = tl.min(mid_val)
tl.store(out, min_val)

Expand Down Expand Up @@ -69,14 +72,18 @@ def min_kernel(
pid_k = tle.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
dtype = inp.type.element_ty
# you just cannot create a function that return a tl.dtype in triton lang
acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
max_value = get_dtype_max(dtype)
min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
for start_n in range(0, N, BLOCK_N):
n_offset = start_n + tl.arange(0, BLOCK_N)
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
mask = m_offset[:, None] < M and n_offset[None, :] < N
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True)
# if return indices is not supported, call a tl.argmax in addition
# local_argmin = tl.argmin(inp_vals, 1)
Expand Down
Loading