From e053ae0b58635be744134340645423a10ee2a6d7 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Mon, 17 Mar 2025 17:57:12 +0800 Subject: [PATCH 1/8] Fix max other value --- src/flag_gems/ops/max.py | 28 +++++++++++++++++++-------- tests/test_general_reduction_ops.py | 30 +++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index ba0b25d5f..34f9a7ed8 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -18,25 +18,27 @@ def max_kernel_1( inp, mid, M, + DTYPE_MIN, BLOCK_SIZE: tl.constexpr, ): pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + inp_val = tl.load(inp_ptrs, mask=mask, other=DTYPE_MIN) max_val = tl.max(inp_val) + # tl.device_print("max_val: ", max_val) mid_ptr = mid + pid tl.store(mid_ptr, max_val) @libentry() @triton.jit -def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): +def max_kernel_2(mid, out, mid_size, DTYPE_MIN, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) + mid_val = tl.load(mid_ptrs, mask=mask, other=DTYPE_MIN) max_val = tl.max(mid_val) tl.store(out, max_val) @@ -61,6 +63,7 @@ def max_kernel( M, N, K, + DTYPE_MIN, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -68,7 +71,7 @@ def max_kernel( pid_m = tle.program_id(0) pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) + result_value = tl.full([BLOCK_M], value=DTYPE_MIN, dtype=tl.float32) result_index = tl.zeros([BLOCK_M], dtype=tl.int64) for i in range(0, N, BLOCK_N): n_offset = i + tl.arange(0, BLOCK_N) @@ -76,7 +79,7 @@ def max_kernel( # set mask mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + inp_vals = tl.load(inp_ptrs, mask=mask, other=DTYPE_MIN) max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True) update_mask = max_value > result_value result_value = tl.where(update_mask, max_value, result_value) @@ -102,9 +105,13 @@ def max(inp): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) + if torch.is_floating_point(inp): + dtype_min = torch.finfo(inp.dtype).min + else: + dtype_min = torch.iinfo(inp.dtype).min with torch_device_fn.device(inp.device): - max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) - max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) + max_kernel_1[(mid_size, 1, 1)](inp, mid, M, dtype_min, block_size) + max_kernel_2[(1, 1, 1)](mid, out, mid_size, dtype_min, block_mid) return out @@ -128,12 +135,17 @@ def max_dim(inp, dim=None, keepdim=False): out_value = torch.squeeze(out_value, dim) out_index = torch.squeeze(out_index, dim) + if torch.is_floating_point(inp): + dtype_min = torch.finfo(inp.dtype).min + else: + dtype_min = torch.iinfo(inp.dtype).min + grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_M"]), K, ) with torch_device_fn.device(inp.device): - max_kernel[grid](inp, out_value, out_index, M, N, K) + max_kernel[grid](inp, out_value, out_index, M, N, K, dtype_min) Max_out = namedtuple("max", ["values", "indices"]) out = Max_out(values=out_value, indices=out_index) return out diff --git a/tests/test_general_reduction_ops.py b/tests/test_general_reduction_ops.py index 9fa5c9070..af4020789 100644 --- a/tests/test_general_reduction_ops.py +++ b/tests/test_general_reduction_ops.py @@ -112,9 +112,12 @@ def test_accuracy_any_dims(shape, dim, keepdim, dtype, kind): @pytest.mark.max @pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.max(ref_inp) @@ -144,9 +147,14 @@ def test_accuracy_max_int(shape, dtype): @pytest.mark.max @pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_without_dim_uncontiguous(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)[::2, ::2] + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)[::2, ::2] + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device)[ + ::2, ::2 + ] ref_inp = to_reference(inp) ref_out = torch.max(ref_inp) @@ -160,9 +168,12 @@ def test_accuracy_max_without_dim_uncontiguous(shape, dtype): @pytest.mark.max @pytest.mark.parametrize("shape", REDUCTION_SMALL_SHAPES) @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.max(ref_inp, dim=dim, keepdim=keepdim) @@ -176,9 +187,12 @@ def test_accuracy_max_dim(shape, dim, keepdim, dtype): @pytest.mark.max @pytest.mark.parametrize("shape", [(4, 1048577, 4)]) @pytest.mark.parametrize("keepdim, dim", [(True, 1), (False, 1)]) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_dim_big_shape(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.max(ref_inp, dim=dim, keepdim=keepdim) From 6b8120d8204ab3f34db49a88fd952370246574ba Mon Sep 17 00:00:00 2001 From: 0x45f Date: Mon, 17 Mar 2025 17:59:59 +0800 Subject: [PATCH 2/8] Format code --- src/flag_gems/ops/max.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 34f9a7ed8..5695bf3fa 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -27,7 +27,6 @@ def max_kernel_1( mask = offset < M inp_val = tl.load(inp_ptrs, mask=mask, other=DTYPE_MIN) max_val = tl.max(inp_val) - # tl.device_print("max_val: ", max_val) mid_ptr = mid + pid tl.store(mid_ptr, max_val) From f5bd9d6fba0bd6508c80059241e54421b751705d Mon Sep 17 00:00:00 2001 From: 0x45f Date: Mon, 17 Mar 2025 18:01:45 +0800 Subject: [PATCH 3/8] Fix code --- src/flag_gems/ops/max.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 5695bf3fa..695a9f3ae 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -70,7 +70,7 @@ def max_kernel( pid_m = tle.program_id(0) pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - result_value = tl.full([BLOCK_M], value=DTYPE_MIN, dtype=tl.float32) + result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) result_index = tl.zeros([BLOCK_M], dtype=tl.int64) for i in range(0, N, BLOCK_N): n_offset = i + tl.arange(0, BLOCK_N) From 1236c980e9b88dabf3a093ac4ff3fb59da3b3aa0 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Tue, 18 Mar 2025 10:03:55 +0800 Subject: [PATCH 4/8] Fix min --- src/flag_gems/ops/min.py | 24 +++++++++++++++++------- tests/test_general_reduction_ops.py | 14 ++++++++++---- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index 9426853b7..378a60ec9 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -18,13 +18,14 @@ def min_kernel_1( inp, mid, M, + DTYPE_MAX, BLOCK_SIZE: tl.constexpr, ): pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf")) + inp_val = tl.load(inp_ptrs, mask=mask, other=DTYPE_MAX) min_val = tl.min(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, min_val) @@ -32,11 +33,11 @@ def min_kernel_1( @libentry() @triton.jit -def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): +def min_kernel_2(mid, out, mid_size, DTYPE_MAX, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf")) + mid_val = tl.load(mid_ptrs, mask=mask, other=DTYPE_MAX) min_val = tl.min(mid_val) tl.store(out, min_val) @@ -61,6 +62,7 @@ def min_kernel( M, N, K, + DTYPE_MAX, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -76,7 +78,7 @@ def min_kernel( offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) + inp_vals = tl.load(inp_ptrs, mask=mask, other=DTYPE_MAX) local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) # if return indices is not supported, call a tl.argmax in addition # local_argmin = tl.argmin(inp_vals, 1) @@ -103,9 +105,13 @@ def min(inp): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) + if torch.is_floating_point(inp): + dtype_max = torch.finfo(inp.dtype).max + else: + dtype_max = torch.iinfo(inp.dtype).max with torch_device_fn.device(inp.device): - min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) - min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) + min_kernel_1[(mid_size, 1, 1)](inp, mid, M, dtype_max, block_size) + min_kernel_2[(1, 1, 1)](mid, out, mid_size, dtype_max, block_mid) return out @@ -133,8 +139,12 @@ def min_dim(inp, dim=None, keepdim=False): triton.cdiv(M, meta["BLOCK_M"]), K, ) + if torch.is_floating_point(inp): + dtype_max = torch.finfo(inp.dtype).max + else: + dtype_max = torch.iinfo(inp.dtype).max with torch_device_fn.device(inp.device): - min_kernel[grid](inp, out_value, out_index, M, N, K) + min_kernel[grid](inp, out_value, out_index, M, N, K, dtype_max) Min_out = namedtuple("min", ["values", "indices"]) out = Min_out(values=out_value, indices=out_index) return out diff --git a/tests/test_general_reduction_ops.py b/tests/test_general_reduction_ops.py index af4020789..2a5e2d63e 100644 --- a/tests/test_general_reduction_ops.py +++ b/tests/test_general_reduction_ops.py @@ -234,9 +234,12 @@ def test_accuracy_mean_dim(shape, dim, keepdim, dtype): @pytest.mark.min @pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_min_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.min(ref_inp) @@ -250,9 +253,12 @@ def test_accuracy_min_without_dim(shape, dtype): @pytest.mark.min @pytest.mark.parametrize("shape", REDUCTION_SMALL_SHAPES) @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_min_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.min(ref_inp, dim=dim, keepdim=keepdim) From 13ccb4747d9ed87bb0a33c4fc2d4780be0c5e7df Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Tue, 18 Mar 2025 22:48:44 +0800 Subject: [PATCH 5/8] add limits to get max/min for integer types and inf/-inf for floating point types, fix max/min/argmax/argmin/cummin --- src/flag_gems/ops/argmax.py | 14 ++++++++--- src/flag_gems/ops/argmin.py | 39 +++++++++-------------------- src/flag_gems/ops/cummin.py | 7 ++++-- src/flag_gems/ops/max.py | 34 +++++++++++-------------- src/flag_gems/ops/min.py | 33 +++++++++++------------- src/flag_gems/utils/limits.py | 36 ++++++++++++++++++++++++++ tests/test_general_reduction_ops.py | 33 ++++++++++++++++++++++++ 7 files changed, 126 insertions(+), 70 deletions(-) create mode 100644 src/flag_gems/utils/limits.py diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index 86e2b6bd8..01bb00b4e 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -9,6 +9,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_min @libentry() @@ -24,7 +25,8 @@ def argmax_kernel_1( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) max_index = max_index + pid * BLOCK_SIZE mid_value_ptr = mid_value + pid @@ -39,7 +41,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid_value + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(mid_value.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) index_val = tl.argmax(mid_val, axis=0) mid_index_ptrs = mid_index + index_val out_val = tl.load(mid_index_ptrs) @@ -63,14 +66,17 @@ def argmax_kernel( pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf")) + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + min_value = get_dtype_min(dtype) + max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value) argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) for start_n in range(0, N, BLOCK_N): n_offset = start_n + tl.arange(0, BLOCK_N) offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")).to(tl.float32) + inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) local_max, local_argmax = tl.max( inp_vals, 1, return_indices=True, return_indices_tie_break_left=True ) diff --git a/src/flag_gems/ops/argmin.py b/src/flag_gems/ops/argmin.py index 2136c6cf2..566d0fbe5 100644 --- a/src/flag_gems/ops/argmin.py +++ b/src/flag_gems/ops/argmin.py @@ -9,14 +9,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle - -torch_dtype_to_tl_dtype_and_max_value = { - torch.int16: (tl.int16, torch.iinfo(torch.int16).max), - torch.int32: (tl.int32, torch.iinfo(torch.int32).max), - torch.float16: (tl.float16, torch.finfo(torch.float16).max), - torch.float32: (tl.float32, torch.finfo(torch.float32).max), - torch.bfloat16: (tl.float32, torch.finfo(torch.float32).max), -} +from ..utils.limits import get_dtype_max @libentry() @@ -27,13 +20,14 @@ def argmin_kernel_1( mid_index, M, BLOCK_SIZE: tl.constexpr, - dtype_max_value: tl.constexpr, ): pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=dtype_max_value) + + max_value = get_dtype_max(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) min_val, min_index = tl.min(inp_val, axis=0, return_indices=True) min_index = min_index + pid * BLOCK_SIZE mid_value_ptr = mid_value + pid @@ -50,12 +44,12 @@ def argmin_kernel_2( out, mid_size, BLOCK_MID: tl.constexpr, - dtype_max_value: tl.constexpr, ): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid_value + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=dtype_max_value) + max_value = get_dtype_max(mid_value.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) index_val = tl.argmin(mid_val, axis=0) mid_index_ptrs = mid_index + index_val out_val = tl.load(mid_index_ptrs) @@ -75,8 +69,6 @@ def argmin_kernel( M, N, K, - tl_dtype: tl.constexpr, - dtype_max_value: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -85,18 +77,18 @@ def argmin_kernel( pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - # min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf")) - if tl_dtype is tl.int16: - tl_dtype = tl.int32 - min_values = tl.full([BLOCK_M], dtype=tl_dtype, value=dtype_max_value) + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) for start_n in range(0, N, BLOCK_N): n_offset = start_n + tl.arange(0, BLOCK_N) offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - # inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) - inp_vals = tl.load(inp_ptrs, mask=mask, other=dtype_max_value) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) + # tl.bfloat is promoted to tl.float32 by tl.min local_min, local_argmin = tl.min( inp_vals, 1, return_indices=True, return_indices_tie_break_left=True ) @@ -132,7 +124,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): else: out = torch.empty([], dtype=torch.int64, device=inp.device) - tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype] with torch_device_fn.device(inp.device): argmin_kernel_1[(mid_size, 1, 1)]( inp, @@ -140,7 +131,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): mid_index, M, block_size, - dtype_max_value, ) argmin_kernel_2[(1, 1, 1)]( mid_value, @@ -148,7 +138,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): out, mid_size, block_mid, - dtype_max_value, ) return out else: @@ -167,8 +156,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): if not keepdim: out_index = torch.squeeze(out_index, dim) - tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype] - grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_M"]), K, @@ -180,8 +167,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): M, N, K, - tl_dtype, - dtype_max_value, ) return out_index diff --git a/src/flag_gems/ops/cummin.py b/src/flag_gems/ops/cummin.py index bb6472f5a..c2858b253 100644 --- a/src/flag_gems/ops/cummin.py +++ b/src/flag_gems/ops/cummin.py @@ -8,6 +8,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_max @triton.jit @@ -76,8 +77,9 @@ def scan_part_min_kernel( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offset < n_elements + max_value = get_dtype_max(inp.type.element_ty) inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) if ( tl.constexpr(inp_vals.dtype.is_int64()) or tl.constexpr(inp_vals.dtype.is_uint64()) @@ -169,7 +171,8 @@ def scan_part_min_abc_kernel( mask = b_idx < B inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) + max_value = get_dtype_max(inp.type.element_ty) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) if ( tl.constexpr(inp_vals.dtype.is_int64()) or tl.constexpr(inp_vals.dtype.is_uint64()) diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 695a9f3ae..7c3f3985d 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -10,6 +10,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_min @libentry() @@ -18,14 +19,14 @@ def max_kernel_1( inp, mid, M, - DTYPE_MIN, BLOCK_SIZE: tl.constexpr, ): pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=DTYPE_MIN) + min_value = get_dtype_min(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) max_val = tl.max(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, max_val) @@ -33,11 +34,12 @@ def max_kernel_1( @libentry() @triton.jit -def max_kernel_2(mid, out, mid_size, DTYPE_MIN, BLOCK_MID: tl.constexpr): +def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=DTYPE_MIN) + min_value = get_dtype_min(mid.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) max_val = tl.max(mid_val) tl.store(out, max_val) @@ -62,7 +64,6 @@ def max_kernel( M, N, K, - DTYPE_MIN, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -70,7 +71,11 @@ def max_kernel( pid_m = tle.program_id(0) pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) + + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + min_value = get_dtype_min(dtype) + result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type) result_index = tl.zeros([BLOCK_M], dtype=tl.int64) for i in range(0, N, BLOCK_N): n_offset = i + tl.arange(0, BLOCK_N) @@ -78,7 +83,7 @@ def max_kernel( # set mask mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=DTYPE_MIN) + inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True) update_mask = max_value > result_value result_value = tl.where(update_mask, max_value, result_value) @@ -104,13 +109,9 @@ def max(inp): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) - if torch.is_floating_point(inp): - dtype_min = torch.finfo(inp.dtype).min - else: - dtype_min = torch.iinfo(inp.dtype).min with torch_device_fn.device(inp.device): - max_kernel_1[(mid_size, 1, 1)](inp, mid, M, dtype_min, block_size) - max_kernel_2[(1, 1, 1)](mid, out, mid_size, dtype_min, block_mid) + max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) + max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) return out @@ -134,17 +135,12 @@ def max_dim(inp, dim=None, keepdim=False): out_value = torch.squeeze(out_value, dim) out_index = torch.squeeze(out_index, dim) - if torch.is_floating_point(inp): - dtype_min = torch.finfo(inp.dtype).min - else: - dtype_min = torch.iinfo(inp.dtype).min - grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_M"]), K, ) with torch_device_fn.device(inp.device): - max_kernel[grid](inp, out_value, out_index, M, N, K, dtype_min) + max_kernel[grid](inp, out_value, out_index, M, N, K) Max_out = namedtuple("max", ["values", "indices"]) out = Max_out(values=out_value, indices=out_index) return out diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index 378a60ec9..b80fe795f 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -10,6 +10,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_max @libentry() @@ -18,14 +19,14 @@ def min_kernel_1( inp, mid, M, - DTYPE_MAX, BLOCK_SIZE: tl.constexpr, ): pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=DTYPE_MAX) + max_value = get_dtype_max(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) min_val = tl.min(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, min_val) @@ -33,11 +34,12 @@ def min_kernel_1( @libentry() @triton.jit -def min_kernel_2(mid, out, mid_size, DTYPE_MAX, BLOCK_MID: tl.constexpr): +def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=DTYPE_MAX) + max_value = get_dtype_max(mid.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) min_val = tl.min(mid_val) tl.store(out, min_val) @@ -62,7 +64,6 @@ def min_kernel( M, N, K, - DTYPE_MAX, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -71,14 +72,18 @@ def min_kernel( pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf")) + dtype = inp.type.element_ty + # you just cannot create a function that return a tl.dtype in triton lang + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) for start_n in range(0, N, BLOCK_N): n_offset = start_n + tl.arange(0, BLOCK_N) offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=DTYPE_MAX) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) # if return indices is not supported, call a tl.argmax in addition # local_argmin = tl.argmin(inp_vals, 1) @@ -105,13 +110,9 @@ def min(inp): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) - if torch.is_floating_point(inp): - dtype_max = torch.finfo(inp.dtype).max - else: - dtype_max = torch.iinfo(inp.dtype).max with torch_device_fn.device(inp.device): - min_kernel_1[(mid_size, 1, 1)](inp, mid, M, dtype_max, block_size) - min_kernel_2[(1, 1, 1)](mid, out, mid_size, dtype_max, block_mid) + min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) + min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) return out @@ -139,12 +140,8 @@ def min_dim(inp, dim=None, keepdim=False): triton.cdiv(M, meta["BLOCK_M"]), K, ) - if torch.is_floating_point(inp): - dtype_max = torch.finfo(inp.dtype).max - else: - dtype_max = torch.iinfo(inp.dtype).max with torch_device_fn.device(inp.device): - min_kernel[grid](inp, out_value, out_index, M, N, K, dtype_max) + min_kernel[grid](inp, out_value, out_index, M, N, K) Min_out = namedtuple("min", ["values", "indices"]) out = Min_out(values=out_value, indices=out_index) return out diff --git a/src/flag_gems/utils/limits.py b/src/flag_gems/utils/limits.py new file mode 100644 index 000000000..27a4f9f5c --- /dev/null +++ b/src/flag_gems/utils/limits.py @@ -0,0 +1,36 @@ +import triton +from triton import language as tl + + +@triton.jit +def get_dtype_max(dtype: tl.constexpr): + """get a value which is greater that all other values of that dtype""" + # extract the tl.dtype from tl.constexpr so as to use its methods + dtype_ = dtype.value + if dtype_.is_floating(): + value: tl.constexpr = float("inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2 ** (width - 1) - 1 + return value + if dtype_.is_int_unsigned(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2**width - 1 + return value + + +@triton.jit +def get_dtype_min(dtype): + """get a value which is less that all other values of that dtype""" + dtype_ = dtype.value # tl.dtype + if dtype_.is_floating(): + value: tl.constexpr = float("-inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = -1 * 2 ** (width - 1) + return value + if dtype_.is_int_unsigned(): + value: tl.constexpr = 0 + return value diff --git a/tests/test_general_reduction_ops.py b/tests/test_general_reduction_ops.py index 2a5e2d63e..27b7e40a5 100644 --- a/tests/test_general_reduction_ops.py +++ b/tests/test_general_reduction_ops.py @@ -127,6 +127,22 @@ def test_accuracy_max_without_dim(shape, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.max +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_max_without_dim_all_neg_inf(shape, dtype): + inp = torch.full( + shape, fill_value=float("-inf"), dtype=dtype, device=flag_gems.device + ) + ref_inp = to_reference(inp) + + ref_out = torch.max(ref_inp) + with flag_gems.use_gems(): + res_out = torch.max(inp) + + gems_assert_equal(res_out, ref_out) + + # cambricon add @pytest.mark.max @pytest.mark.skipif(flag_gems.vendor_name != "cambricon", reason="cambricon test only") @@ -249,6 +265,23 @@ def test_accuracy_min_without_dim(shape, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.min +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_min_without_dim_all_inf(shape, dtype): + # ensure that padding value used in min is inf, not max value + inp = torch.full( + shape, fill_value=float("inf"), dtype=dtype, device=flag_gems.device + ) + ref_inp = to_reference(inp) + + ref_out = torch.min(ref_inp) + with flag_gems.use_gems(): + res_out = torch.min(inp) + + gems_assert_equal(res_out, ref_out) + + # TODO: failed at (200, 40999, 3), while successed at this shape in mean_dim @pytest.mark.min @pytest.mark.parametrize("shape", REDUCTION_SMALL_SHAPES) From ac1587bd998f2fc9472fafb7898e67aa579a9afc Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Thu, 20 Mar 2025 21:28:45 +0800 Subject: [PATCH 6/8] fix amax --- src/flag_gems/ops/amax.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/flag_gems/ops/amax.py b/src/flag_gems/ops/amax.py index bfce1c2a3..3d25ce676 100644 --- a/src/flag_gems/ops/amax.py +++ b/src/flag_gems/ops/amax.py @@ -9,6 +9,7 @@ from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_min @libentry() @@ -24,7 +25,8 @@ def amax_kernel_1( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) amax_val = tl.max(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, amax_val) @@ -36,7 +38,8 @@ def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(mid.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) amax_val = tl.max(mid_val) tl.store(out, amax_val) @@ -52,6 +55,9 @@ def amax_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): + dtype = inp.type.element_ty + min_value = get_dtype_min(dtype) + # Map the program id to the row of inp it should compute. pid = tle.program_id(0) rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] @@ -59,13 +65,13 @@ def amax_kernel( out = out + rows row_mask = rows < M - _all = tl.full([BLOCK_M, BLOCK_N], value=-float("inf"), dtype=tl.float32) + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type) for off in range(0, N, BLOCK_N): cols = off + tl.arange(0, BLOCK_N)[None, :] col_mask = cols < N mask = row_mask and col_mask - - a = tl.load(inp + cols, mask, other=-float("inf")).to(tl.float32) + a = tl.load(inp + cols, mask, other=min_value) _all = tl.maximum(_all, a) all = tl.max(_all, axis=1)[:, None] tl.store(out, all, row_mask) From 7ef23414527165d61a504d5ccb2c0031421f97f4 Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Fri, 21 Mar 2025 14:00:16 +0800 Subject: [PATCH 7/8] use int64 index for index_put --- src/flag_gems/ops/index_put.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/flag_gems/ops/index_put.py b/src/flag_gems/ops/index_put.py index 5c4ab7b2f..ecc924589 100644 --- a/src/flag_gems/ops/index_put.py +++ b/src/flag_gems/ops/index_put.py @@ -35,6 +35,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("from flag_gems.utils import libentry") code.writeline("from flag_gems import runtime") code.writeline("from flag_gems.utils.shape_utils import volume") + code.writeline("from flag_gems.utils import triton_lang_extension as tle") code.newline() code.newline() @@ -74,8 +75,8 @@ def generate_index_put_kernel( code.writeline("):") with code.indent(): - code.writeline("pid0 = tl.program_id(axis=0)") - code.writeline("pid1 = tl.program_id(axis=1)") + code.writeline("pid0 = tle.program_id(axis=0)") + code.writeline("pid1 = tle.program_id(axis=1)") code.writeline( "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" ) From c853eb0f77f18e32e3ebf4980fa871da3266049a Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Fri, 21 Mar 2025 14:15:46 +0800 Subject: [PATCH 8/8] upcast inputs for index_put --- tests/test_reduction_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 2a9234b18..0ad9af256 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1062,9 +1062,9 @@ def test_index_put_acc_true(input_shape, indices_shape, values_shape, dtype): values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False ) - ref_inp = to_reference(inp) + ref_inp = to_reference(inp, upcast=True) ref_indices = [to_reference(index) for index in indices] - ref_values = to_reference(values) + ref_values = to_reference(values, upcast=True) ref_out = torch.index_put(ref_inp, ref_indices, ref_values, accumulate) out = flag_gems.index_put(inp, indices, values, accumulate) gems_assert_close(out, ref_out, dtype)