Skip to content

Commit 328800a

Browse files
mx-flaggems-usermx-flaggems-user
authored andcommitted
[Operator]Fix metax backend bugs (#432)
1. update multi-backend code 2. fix argmin op might test failed under int types Co-authored-by: mx-flaggems-user <[email protected]>
1 parent 89ed8e7 commit 328800a

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

src/flag_gems/ops/argmin.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ def argmin_kernel_1(
2727
mid_index,
2828
M,
2929
BLOCK_SIZE: tl.constexpr,
30+
dtype_max_value: tl.constexpr,
3031
):
3132
pid = tle.program_id(0)
3233
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
3334
inp_ptrs = inp + offset
3435
mask = offset < M
35-
inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf"))
36+
inp_val = tl.load(inp_ptrs, mask=mask, other=dtype_max_value)
3637
min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
3738
min_index = min_index + pid * BLOCK_SIZE
3839
mid_value_ptr = mid_value + pid
@@ -43,11 +44,18 @@ def argmin_kernel_1(
4344

4445
@libentry()
4546
@triton.jit
46-
def argmin_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
47+
def argmin_kernel_2(
48+
mid_value,
49+
mid_index,
50+
out,
51+
mid_size,
52+
BLOCK_MID: tl.constexpr,
53+
dtype_max_value: tl.constexpr,
54+
):
4755
offset = tl.arange(0, BLOCK_MID)
4856
mid_ptrs = mid_value + offset
4957
mask = offset < mid_size
50-
mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf"))
58+
mid_val = tl.load(mid_ptrs, mask=mask, other=dtype_max_value)
5159
index_val = tl.argmin(mid_val, axis=0)
5260
mid_index_ptrs = mid_index + index_val
5361
out_val = tl.load(mid_index_ptrs)
@@ -122,15 +130,24 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
122130
else:
123131
out = torch.empty([], dtype=torch.int64, device=inp.device)
124132

133+
tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]
125134
with torch_device_fn.device(inp.device):
126135
argmin_kernel_1[(mid_size, 1, 1)](
127136
inp,
128137
mid_value,
129138
mid_index,
130139
M,
131140
block_size,
141+
dtype_max_value,
142+
)
143+
argmin_kernel_2[(1, 1, 1)](
144+
mid_value,
145+
mid_index,
146+
out,
147+
mid_size,
148+
block_mid,
149+
dtype_max_value,
132150
)
133-
argmin_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
134151
return out
135152
else:
136153
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"

src/flag_gems/runtime/backend/_metax/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
vendor_info = VendorInfoBase(
44
vendor_name="metax", device_name="cuda", device_query_cmd="mx-smi"
55
)
6-
6+
CUSTOMIZED_UNUSED_OPS = ()
77
__all__ = ["vendor_info"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__all__ = []

src/flag_gems/runtime/backend/_metax/ops/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .prod import prod, prod_dim
1313
from .sigmoid import sigmoid
1414
from .tanh import tanh
15-
from .unique import unique
15+
from .unique import _unique2
1616
from .zeros import zeros
1717

1818
__all__ = [
@@ -34,6 +34,6 @@
3434
"prod_dim",
3535
"sigmoid",
3636
"tanh",
37-
"unique",
37+
"_unique2",
3838
"zeros",
3939
]

src/flag_gems/runtime/backend/_metax/tune_configs.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -982,4 +982,3 @@ batch_norm:
982982
- 4
983983
- 8
984984
- 16
985-
- 32

0 commit comments

Comments
 (0)