@@ -27,12 +27,13 @@ def argmin_kernel_1(
27
27
mid_index ,
28
28
M ,
29
29
BLOCK_SIZE : tl .constexpr ,
30
+ dtype_max_value : tl .constexpr ,
30
31
):
31
32
pid = tle .program_id (0 )
32
33
offset = pid * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
33
34
inp_ptrs = inp + offset
34
35
mask = offset < M
35
- inp_val = tl .load (inp_ptrs , mask = mask , other = float ( "inf" ) )
36
+ inp_val = tl .load (inp_ptrs , mask = mask , other = dtype_max_value )
36
37
min_val , min_index = tl .min (inp_val , axis = 0 , return_indices = True )
37
38
min_index = min_index + pid * BLOCK_SIZE
38
39
mid_value_ptr = mid_value + pid
@@ -43,11 +44,18 @@ def argmin_kernel_1(
43
44
44
45
@libentry ()
45
46
@triton .jit
46
- def argmin_kernel_2 (mid_value , mid_index , out , mid_size , BLOCK_MID : tl .constexpr ):
47
+ def argmin_kernel_2 (
48
+ mid_value ,
49
+ mid_index ,
50
+ out ,
51
+ mid_size ,
52
+ BLOCK_MID : tl .constexpr ,
53
+ dtype_max_value : tl .constexpr ,
54
+ ):
47
55
offset = tl .arange (0 , BLOCK_MID )
48
56
mid_ptrs = mid_value + offset
49
57
mask = offset < mid_size
50
- mid_val = tl .load (mid_ptrs , mask = mask , other = float ( "inf" ) )
58
+ mid_val = tl .load (mid_ptrs , mask = mask , other = dtype_max_value )
51
59
index_val = tl .argmin (mid_val , axis = 0 )
52
60
mid_index_ptrs = mid_index + index_val
53
61
out_val = tl .load (mid_index_ptrs )
@@ -122,15 +130,24 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
122
130
else :
123
131
out = torch .empty ([], dtype = torch .int64 , device = inp .device )
124
132
133
+ tl_dtype , dtype_max_value = torch_dtype_to_tl_dtype_and_max_value [inp .dtype ]
125
134
with torch_device_fn .device (inp .device ):
126
135
argmin_kernel_1 [(mid_size , 1 , 1 )](
127
136
inp ,
128
137
mid_value ,
129
138
mid_index ,
130
139
M ,
131
140
block_size ,
141
+ dtype_max_value ,
142
+ )
143
+ argmin_kernel_2 [(1 , 1 , 1 )](
144
+ mid_value ,
145
+ mid_index ,
146
+ out ,
147
+ mid_size ,
148
+ block_mid ,
149
+ dtype_max_value ,
132
150
)
133
- argmin_kernel_2 [(1 , 1 , 1 )](mid_value , mid_index , out , mid_size , block_mid )
134
151
return out
135
152
else :
136
153
assert dim >= - inp .ndim and dim < inp .ndim , "Invalid dim"
0 commit comments