9
9
from ..runtime import torch_device_fn
10
10
from ..utils import libentry
11
11
from ..utils import triton_lang_extension as tle
12
-
13
- torch_dtype_to_tl_dtype_and_max_value = {
14
- torch .int16 : (tl .int16 , torch .iinfo (torch .int16 ).max ),
15
- torch .int32 : (tl .int32 , torch .iinfo (torch .int32 ).max ),
16
- torch .float16 : (tl .float16 , torch .finfo (torch .float16 ).max ),
17
- torch .float32 : (tl .float32 , torch .finfo (torch .float32 ).max ),
18
- torch .bfloat16 : (tl .float32 , torch .finfo (torch .float32 ).max ),
19
- }
12
+ from ..utils .limits import get_dtype_max
20
13
21
14
22
15
@libentry ()
@@ -27,13 +20,14 @@ def argmin_kernel_1(
27
20
mid_index ,
28
21
M ,
29
22
BLOCK_SIZE : tl .constexpr ,
30
- dtype_max_value : tl .constexpr ,
31
23
):
32
24
pid = tle .program_id (0 )
33
25
offset = pid * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
34
26
inp_ptrs = inp + offset
35
27
mask = offset < M
36
- inp_val = tl .load (inp_ptrs , mask = mask , other = dtype_max_value )
28
+
29
+ max_value = get_dtype_max (inp .type .element_ty )
30
+ inp_val = tl .load (inp_ptrs , mask = mask , other = max_value )
37
31
min_val , min_index = tl .min (inp_val , axis = 0 , return_indices = True )
38
32
min_index = min_index + pid * BLOCK_SIZE
39
33
mid_value_ptr = mid_value + pid
@@ -50,12 +44,12 @@ def argmin_kernel_2(
50
44
out ,
51
45
mid_size ,
52
46
BLOCK_MID : tl .constexpr ,
53
- dtype_max_value : tl .constexpr ,
54
47
):
55
48
offset = tl .arange (0 , BLOCK_MID )
56
49
mid_ptrs = mid_value + offset
57
50
mask = offset < mid_size
58
- mid_val = tl .load (mid_ptrs , mask = mask , other = dtype_max_value )
51
+ max_value = get_dtype_max (mid_value .type .element_ty )
52
+ mid_val = tl .load (mid_ptrs , mask = mask , other = max_value )
59
53
index_val = tl .argmin (mid_val , axis = 0 )
60
54
mid_index_ptrs = mid_index + index_val
61
55
out_val = tl .load (mid_index_ptrs )
@@ -75,8 +69,6 @@ def argmin_kernel(
75
69
M ,
76
70
N ,
77
71
K ,
78
- tl_dtype : tl .constexpr ,
79
- dtype_max_value : tl .constexpr ,
80
72
BLOCK_M : tl .constexpr ,
81
73
BLOCK_N : tl .constexpr ,
82
74
):
@@ -85,18 +77,18 @@ def argmin_kernel(
85
77
pid_k = tle .program_id (1 )
86
78
m_offset = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
87
79
88
- # min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
89
- if tl_dtype is tl .int16 :
90
- tl_dtype = tl . int32
91
- min_values = tl .full ([BLOCK_M ], dtype = tl_dtype , value = dtype_max_value )
80
+ dtype = inp . type . element_ty
81
+ acc_type = tl . float32 if dtype is tl .bfloat16 else dtype
82
+ max_value = get_dtype_max ( dtype )
83
+ min_values = tl .full ([BLOCK_M ], dtype = acc_type , value = max_value )
92
84
argmin_values = tl .full ([BLOCK_M ], dtype = tl .int64 , value = 0 )
93
85
for start_n in range (0 , N , BLOCK_N ):
94
86
n_offset = start_n + tl .arange (0 , BLOCK_N )
95
87
offset = m_offset [:, None ] * N * K + n_offset [None , :] * K + pid_k
96
88
mask = m_offset [:, None ] < M and n_offset [None , :] < N
97
89
inp_ptrs = inp + offset
98
- # inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf") )
99
- inp_vals = tl .load ( inp_ptrs , mask = mask , other = dtype_max_value )
90
+ inp_vals = tl .load (inp_ptrs , mask = mask , other = max_value )
91
+ # tl.bfloat is promoted to tl.float32 by tl.min
100
92
local_min , local_argmin = tl .min (
101
93
inp_vals , 1 , return_indices = True , return_indices_tie_break_left = True
102
94
)
@@ -132,23 +124,20 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
132
124
else :
133
125
out = torch .empty ([], dtype = torch .int64 , device = inp .device )
134
126
135
- tl_dtype , dtype_max_value = torch_dtype_to_tl_dtype_and_max_value [inp .dtype ]
136
127
with torch_device_fn .device (inp .device ):
137
128
argmin_kernel_1 [(mid_size , 1 , 1 )](
138
129
inp ,
139
130
mid_value ,
140
131
mid_index ,
141
132
M ,
142
133
block_size ,
143
- dtype_max_value ,
144
134
)
145
135
argmin_kernel_2 [(1 , 1 , 1 )](
146
136
mid_value ,
147
137
mid_index ,
148
138
out ,
149
139
mid_size ,
150
140
block_mid ,
151
- dtype_max_value ,
152
141
)
153
142
return out
154
143
else :
@@ -167,8 +156,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
167
156
if not keepdim :
168
157
out_index = torch .squeeze (out_index , dim )
169
158
170
- tl_dtype , dtype_max_value = torch_dtype_to_tl_dtype_and_max_value [inp .dtype ]
171
-
172
159
grid = lambda meta : (
173
160
triton .cdiv (M , meta ["BLOCK_M" ]),
174
161
K ,
@@ -180,8 +167,6 @@ def argmin(inp, dim=None, keepdim=False, *, dtype=None):
180
167
M ,
181
168
N ,
182
169
K ,
183
- tl_dtype ,
184
- dtype_max_value ,
185
170
)
186
171
187
172
return out_index
0 commit comments