diff --git a/src/flag_gems/ops/amax.py b/src/flag_gems/ops/amax.py index bfce1c2a39..3d25ce6767 100644 --- a/src/flag_gems/ops/amax.py +++ b/src/flag_gems/ops/amax.py @@ -9,6 +9,7 @@ from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_min @libentry() @@ -24,7 +25,8 @@ def amax_kernel_1( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) amax_val = tl.max(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, amax_val) @@ -36,7 +38,8 @@ def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(mid.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) amax_val = tl.max(mid_val) tl.store(out, amax_val) @@ -52,6 +55,9 @@ def amax_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): + dtype = inp.type.element_ty + min_value = get_dtype_min(dtype) + # Map the program id to the row of inp it should compute. pid = tle.program_id(0) rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] @@ -59,13 +65,13 @@ def amax_kernel( out = out + rows row_mask = rows < M - _all = tl.full([BLOCK_M, BLOCK_N], value=-float("inf"), dtype=tl.float32) + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + _all = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type) for off in range(0, N, BLOCK_N): cols = off + tl.arange(0, BLOCK_N)[None, :] col_mask = cols < N mask = row_mask and col_mask - - a = tl.load(inp + cols, mask, other=-float("inf")).to(tl.float32) + a = tl.load(inp + cols, mask, other=min_value) _all = tl.maximum(_all, a) all = tl.max(_all, axis=1)[:, None] tl.store(out, all, row_mask) diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index 86e2b6bd80..01bb00b4ed 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -9,6 +9,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_min @libentry() @@ -24,7 +25,8 @@ def argmax_kernel_1( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) max_index = max_index + pid * BLOCK_SIZE mid_value_ptr = mid_value + pid @@ -39,7 +41,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid_value + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(mid_value.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) index_val = tl.argmax(mid_val, axis=0) mid_index_ptrs = mid_index + index_val out_val = tl.load(mid_index_ptrs) @@ -63,14 +66,17 @@ def argmax_kernel( pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf")) + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + min_value = get_dtype_min(dtype) + max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value) argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) for start_n in range(0, N, BLOCK_N): n_offset = start_n + tl.arange(0, BLOCK_N) offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")).to(tl.float32) + inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) local_max, local_argmax = tl.max( inp_vals, 1, return_indices=True, return_indices_tie_break_left=True ) diff --git a/src/flag_gems/ops/argmin.py b/src/flag_gems/ops/argmin.py index 2136c6cf24..566d0fbe5c 100644 --- a/src/flag_gems/ops/argmin.py +++ b/src/flag_gems/ops/argmin.py @@ -9,14 +9,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle - -torch_dtype_to_tl_dtype_and_max_value = { - torch.int16: (tl.int16, torch.iinfo(torch.int16).max), - torch.int32: (tl.int32, torch.iinfo(torch.int32).max), - torch.float16: (tl.float16, torch.finfo(torch.float16).max), - torch.float32: (tl.float32, torch.finfo(torch.float32).max), - torch.bfloat16: (tl.float32, torch.finfo(torch.float32).max), -} +from ..utils.limits import get_dtype_max @libentry() @@ -27,13 +20,14 @@ def argmin_kernel_1( mid_index, M, BLOCK_SIZE: tl.constexpr, - dtype_max_value: tl.constexpr, ): pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=dtype_max_value) + + max_value = get_dtype_max(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) min_val, min_index = tl.min(inp_val, axis=0, return_indices=True) min_index = min_index + pid * BLOCK_SIZE mid_value_ptr = mid_value + pid @@ -50,12 +44,12 @@ def argmin_kernel_2( out, mid_size, BLOCK_MID: tl.constexpr, - dtype_max_value: tl.constexpr, ): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid_value + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=dtype_max_value) + max_value = get_dtype_max(mid_value.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) index_val = tl.argmin(mid_val, axis=0) mid_index_ptrs = mid_index + index_val out_val = tl.load(mid_index_ptrs) @@ -75,8 +69,6 @@ def argmin_kernel( M, N, K, - tl_dtype: tl.constexpr, - dtype_max_value: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -85,18 +77,18 @@ def argmin_kernel( pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - # min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf")) - if tl_dtype is tl.int16: - tl_dtype = tl.int32 - min_values = tl.full([BLOCK_M], dtype=tl_dtype, value=dtype_max_value) + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) for start_n in range(0, N, BLOCK_N): n_offset = start_n + tl.arange(0, BLOCK_N) offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - # inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) - inp_vals = tl.load(inp_ptrs, mask=mask, other=dtype_max_value) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) + # tl.bfloat is promoted to tl.float32 by tl.min local_min, local_argmin = tl.min( inp_vals, 1, return_indices=True, return_indices_tie_break_left=True ) @@ -132,7 +124,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): else: out = torch.empty([], dtype=torch.int64, device=inp.device) - tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype] with torch_device_fn.device(inp.device): argmin_kernel_1[(mid_size, 1, 1)]( inp, @@ -140,7 +131,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): mid_index, M, block_size, - dtype_max_value, ) argmin_kernel_2[(1, 1, 1)]( mid_value, @@ -148,7 +138,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): out, mid_size, block_mid, - dtype_max_value, ) return out else: @@ -167,8 +156,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): if not keepdim: out_index = torch.squeeze(out_index, dim) - tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype] - grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_M"]), K, @@ -180,8 +167,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None): M, N, K, - tl_dtype, - dtype_max_value, ) return out_index diff --git a/src/flag_gems/ops/cummin.py b/src/flag_gems/ops/cummin.py index bb6472f5a3..c2858b253d 100644 --- a/src/flag_gems/ops/cummin.py +++ b/src/flag_gems/ops/cummin.py @@ -8,6 +8,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_max @triton.jit @@ -76,8 +77,9 @@ def scan_part_min_kernel( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offset < n_elements + max_value = get_dtype_max(inp.type.element_ty) inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) if ( tl.constexpr(inp_vals.dtype.is_int64()) or tl.constexpr(inp_vals.dtype.is_uint64()) @@ -169,7 +171,8 @@ def scan_part_min_abc_kernel( mask = b_idx < B inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) + max_value = get_dtype_max(inp.type.element_ty) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) if ( tl.constexpr(inp_vals.dtype.is_int64()) or tl.constexpr(inp_vals.dtype.is_uint64()) diff --git a/src/flag_gems/ops/index_put.py b/src/flag_gems/ops/index_put.py index 5c4ab7b2f1..ecc9245898 100644 --- a/src/flag_gems/ops/index_put.py +++ b/src/flag_gems/ops/index_put.py @@ -35,6 +35,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("from flag_gems.utils import libentry") code.writeline("from flag_gems import runtime") code.writeline("from flag_gems.utils.shape_utils import volume") + code.writeline("from flag_gems.utils import triton_lang_extension as tle") code.newline() code.newline() @@ -74,8 +75,8 @@ def generate_index_put_kernel( code.writeline("):") with code.indent(): - code.writeline("pid0 = tl.program_id(axis=0)") - code.writeline("pid1 = tl.program_id(axis=1)") + code.writeline("pid0 = tle.program_id(axis=0)") + code.writeline("pid1 = tle.program_id(axis=1)") code.writeline( "offset0 = pid0 * BLOCK_SIZE0 + tl.arange(0, BLOCK_SIZE0)[:, None]" ) diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index ba0b25d5f2..7c3f3985d2 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -10,6 +10,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_min @libentry() @@ -24,7 +25,8 @@ def max_kernel_1( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) max_val = tl.max(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, max_val) @@ -36,7 +38,8 @@ def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) + min_value = get_dtype_min(mid.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) max_val = tl.max(mid_val) tl.store(out, max_val) @@ -68,7 +71,11 @@ def max_kernel( pid_m = tle.program_id(0) pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - result_value = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) + + dtype = inp.type.element_ty + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + min_value = get_dtype_min(dtype) + result_value = tl.full([BLOCK_M], value=min_value, dtype=acc_type) result_index = tl.zeros([BLOCK_M], dtype=tl.int64) for i in range(0, N, BLOCK_N): n_offset = i + tl.arange(0, BLOCK_N) @@ -76,7 +83,7 @@ def max_kernel( # set mask mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")) + inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) max_value, max_index = tl.max(inp_vals, axis=1, return_indices=True) update_mask = max_value > result_value result_value = tl.where(update_mask, max_value, result_value) diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index 9426853b7f..b80fe795f7 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -10,6 +10,7 @@ from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle +from ..utils.limits import get_dtype_max @libentry() @@ -24,7 +25,8 @@ def min_kernel_1( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf")) + max_value = get_dtype_max(inp.type.element_ty) + inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) min_val = tl.min(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, min_val) @@ -36,7 +38,8 @@ def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf")) + max_value = get_dtype_max(mid.type.element_ty) + mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) min_val = tl.min(mid_val) tl.store(out, min_val) @@ -69,14 +72,18 @@ def min_kernel( pid_k = tle.program_id(1) m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf")) + dtype = inp.type.element_ty + # you just cannot create a function that return a tl.dtype in triton lang + acc_type = tl.float32 if dtype is tl.bfloat16 else dtype + max_value = get_dtype_max(dtype) + min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) for start_n in range(0, N, BLOCK_N): n_offset = start_n + tl.arange(0, BLOCK_N) offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf")) + inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True) # if return indices is not supported, call a tl.argmax in addition # local_argmin = tl.argmin(inp_vals, 1) diff --git a/src/flag_gems/utils/limits.py b/src/flag_gems/utils/limits.py new file mode 100644 index 0000000000..27a4f9f5cd --- /dev/null +++ b/src/flag_gems/utils/limits.py @@ -0,0 +1,36 @@ +import triton +from triton import language as tl + + +@triton.jit +def get_dtype_max(dtype: tl.constexpr): + """get a value which is greater that all other values of that dtype""" + # extract the tl.dtype from tl.constexpr so as to use its methods + dtype_ = dtype.value + if dtype_.is_floating(): + value: tl.constexpr = float("inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2 ** (width - 1) - 1 + return value + if dtype_.is_int_unsigned(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = 2**width - 1 + return value + + +@triton.jit +def get_dtype_min(dtype): + """get a value which is less that all other values of that dtype""" + dtype_ = dtype.value # tl.dtype + if dtype_.is_floating(): + value: tl.constexpr = float("-inf") + return value + if dtype_.is_int_signed(): + width: tl.constexpr = dtype_.int_bitwidth + value: tl.constexpr = -1 * 2 ** (width - 1) + return value + if dtype_.is_int_unsigned(): + value: tl.constexpr = 0 + return value diff --git a/tests/test_general_reduction_ops.py b/tests/test_general_reduction_ops.py index 9fa5c9070c..27b7e40a5e 100644 --- a/tests/test_general_reduction_ops.py +++ b/tests/test_general_reduction_ops.py @@ -112,9 +112,28 @@ def test_accuracy_any_dims(shape, dim, keepdim, dtype, kind): @pytest.mark.max @pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) + ref_inp = to_reference(inp) + + ref_out = torch.max(ref_inp) + with flag_gems.use_gems(): + res_out = torch.max(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.max +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_max_without_dim_all_neg_inf(shape, dtype): + inp = torch.full( + shape, fill_value=float("-inf"), dtype=dtype, device=flag_gems.device + ) ref_inp = to_reference(inp) ref_out = torch.max(ref_inp) @@ -144,9 +163,14 @@ def test_accuracy_max_int(shape, dtype): @pytest.mark.max @pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_without_dim_uncontiguous(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)[::2, ::2] + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)[::2, ::2] + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device)[ + ::2, ::2 + ] ref_inp = to_reference(inp) ref_out = torch.max(ref_inp) @@ -160,9 +184,12 @@ def test_accuracy_max_without_dim_uncontiguous(shape, dtype): @pytest.mark.max @pytest.mark.parametrize("shape", REDUCTION_SMALL_SHAPES) @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.max(ref_inp, dim=dim, keepdim=keepdim) @@ -176,9 +203,12 @@ def test_accuracy_max_dim(shape, dim, keepdim, dtype): @pytest.mark.max @pytest.mark.parametrize("shape", [(4, 1048577, 4)]) @pytest.mark.parametrize("keepdim, dim", [(True, 1), (False, 1)]) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_max_dim_big_shape(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.max(ref_inp, dim=dim, keepdim=keepdim) @@ -220,9 +250,29 @@ def test_accuracy_mean_dim(shape, dim, keepdim, dtype): @pytest.mark.min @pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_min_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) + ref_inp = to_reference(inp) + + ref_out = torch.min(ref_inp) + with flag_gems.use_gems(): + res_out = torch.min(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.min +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_min_without_dim_all_inf(shape, dtype): + # ensure that padding value used in min is inf, not max value + inp = torch.full( + shape, fill_value=float("inf"), dtype=dtype, device=flag_gems.device + ) ref_inp = to_reference(inp) ref_out = torch.min(ref_inp) @@ -236,9 +286,12 @@ def test_accuracy_min_without_dim(shape, dtype): @pytest.mark.min @pytest.mark.parametrize("shape", REDUCTION_SMALL_SHAPES) @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES) def test_accuracy_min_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + if dtype in FLOAT_DTYPES: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + else: + inp = torch.randint(-10000, 10000, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.min(ref_inp, dim=dim, keepdim=keepdim) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 2a9234b185..0ad9af256f 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1062,9 +1062,9 @@ def test_index_put_acc_true(input_shape, indices_shape, values_shape, dtype): values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False ) - ref_inp = to_reference(inp) + ref_inp = to_reference(inp, upcast=True) ref_indices = [to_reference(index) for index in indices] - ref_values = to_reference(values) + ref_values = to_reference(values, upcast=True) ref_out = torch.index_put(ref_inp, ref_indices, ref_values, accumulate) out = flag_gems.index_put(inp, indices, values, accumulate) gems_assert_close(out, ref_out, dtype)