Skip to content

Commit dba6c9a

Browse files
committed
add splitK support
1 parent 7a1c613 commit dba6c9a

File tree

4 files changed

+287
-36
lines changed

4 files changed

+287
-36
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ We implement two versions of the Triton kernels:
6666

6767
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py">GEMM</a></b>: This GEMM kernel is implemented similarly to <a href="https://github.com/fpgaminer/GPTQ-triton">GPTQ-triton</a>. Since it uses tensor cores, activations must be padded with zeros along the batch dimension to fit at least 16 rows. It supports both float32 and float16 accumulation for fp16 inputs, but only float32 accumulation for bfloat16.
6868

69-
Both kernels are flexible, supporting 8, 4, 2, and 1-bit weight precisions.
69+
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py">Split-K</a></b>: This Split-KGEMM kernel is implemented similarly to <a href="https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py">the gptq Split-K version</a>. We build on the gemm version above and add another dimension in the grid which splits the K dimension into multiple jobs that calculate partial sums, which are atomically added and finally stored. Split-K performs very well for batch-sizes between 1 and 32, which is great for LLM decoding.
70+
71+
All kernels are flexible, supporting 8, 4, 2, and 1-bit weight precisions.
7072

7173
To achieve optimal performance, it’s crucial to configure the eviction policy correctly. This is especially important in memory-bound scenarios, where we aim to cache activations by setting `eviction_policy="evict_last"`. Float16 accumulation further improves performance in compute-bound scenarios.
7274

gemlite/core.py

+27-34
Original file line numberDiff line numberDiff line change
@@ -152,37 +152,22 @@ def forward(self, x):
152152
###################################################################################################################################
153153
# Triton backend
154154
###################################################################################################################################
155-
def eval_time(fct, params, warmup=25, rep=200, fast_flush=True, return_mode="min"):
156-
if isinstance(params, dict):
157-
return do_bench(
158-
lambda: fct(**params),
159-
warmup=warmup,
160-
rep=rep,
161-
fast_flush=fast_flush,
162-
return_mode=return_mode,
163-
)
164-
if isinstance(params, list):
165-
return do_bench(
166-
lambda: fct(*params),
167-
warmup=warmup,
168-
rep=rep,
169-
fast_flush=fast_flush,
170-
return_mode=return_mode,
171-
)
155+
def eval_time_for_auto_mode(fct, params):
156+
for _ in range(5):
157+
_ = fct(*params) #Run first to kick-off Triton autotune
158+
return do_bench(lambda: fct(*params), warmup=200, rep=50, fast_flush=True, return_mode='mean')
172159

173160

174161
GEMLITE_TRITON_CACHE = {}
175162

176163
GEMLITE_TRITON_MAPPING = {
177164
("fp16", "GEMV"): gemv_A16fWnO16f_int32packing,
178165
("fp16", "GEMM"): gemm_A16fWnO16f_int32packing,
179-
#("fp16", "GEMM_SPLITK"): gemm_splitK_A16fWnO16f_int32packing,
180-
166+
("fp16", "GEMM_SPLITK"): gemm_splitK_A16fWnO16f_int32packing,
181167
("bf16", "GEMM"): gemm_A16fWnO16f_int32packing,
182168
}
183169

184170
def get_closest_m(M):
185-
#return M if M <= 8 else 2 ** int(math.ceil(math.log2(M)))
186171
return 2 ** int(math.ceil(math.log2(M)))
187172

188173
# Triton
@@ -196,6 +181,7 @@ def __init__(
196181
input_dtype = DType.FP16,
197182
output_dtype = DType.FP16,
198183
acc_dtype = DType.FP32,
184+
exhaustive=False
199185
):
200186
self._SUPPORTED_BITS_TRITON = [1, 2, 4, 8]
201187

@@ -219,7 +205,7 @@ def __init__(
219205

220206
self.compute_dtype = None
221207
if input_dtype == DType.FP16 and output_dtype == DType.FP16:
222-
self.kernels = [gemm_A16fWnO16f_int32packing, gemv_A16fWnO16f_int32packing] #gemm_splitK_A16fWnO16f_int32packing
208+
self.kernels = [gemm_A16fWnO16f_int32packing, gemv_A16fWnO16f_int32packing, gemm_splitK_A16fWnO16f_int32packing]
223209
self.compute_dtype = torch.float16
224210

225211
if input_dtype == DType.BF16 and output_dtype == DType.BF16:
@@ -261,7 +247,10 @@ def __init__(
261247
),
262248
)
263249

264-
self.forward = self.forward_auto
250+
if(exhaustive):
251+
self.forward = self.forward_auto_with_warmup
252+
else:
253+
self.forward = self.forward_auto_no_warmup
265254

266255
# Pack data, adapted from: following the same logic as: https://github.com/LeiWang1999/AutoGPTQ.tvm/blob/dcd135b9784b9f98235fc91467fe3c3c8afa34fc/auto_gptq/nn_modules/qlinear_triton.py#L413-L419
267256
def pack(self, W_q, scales, zeros, bias=None):
@@ -290,20 +279,24 @@ def warmup(self, signature, args):
290279
global GEMLITE_TRITON_CACHE
291280
t = [np.inf] * len(self.kernels)
292281
for i, _kernel in enumerate(self.kernels):
293-
if signature[0] >= 8 and _kernel.matmul_type == "GEMV": #skip gemvs for larger batch-sizes
282+
if signature[0] > 1 and _kernel.matmul_type == "GEMV": #skip gemvs for larger batch-sizes
294283
pass
284+
if signature[0] > 32 and _kernel.matmul_type == "GEMM_SPLITK": #skip SPLIT_K for larger batch-
285+
pass
286+
if signature[0] < 16 and _kernel.matmul_type == "GEMM": #skip GEMM for smaller matrices
287+
pass
295288
else:
296-
t[i] = eval_time(_kernel.forward, args)
289+
t[i] = eval_time_for_auto_mode(_kernel.forward, args)
297290

298291
indx = np.argmin(t)
299292
GEMLITE_TRITON_CACHE[signature] = {
300293
"forward": self.kernels[indx].forward,
301294
"time": t[indx],
295+
"time_all": list(zip([k.matmul_type for k in self.kernels] , t))
302296
}
303297

304-
################################################################################
305-
#Main forward pass
306-
def forward_auto(self, x):
298+
#Exhaustive search
299+
def forward_auto_with_warmup(self, x):
307300
global GEMLITE_TRITON_CACHE
308301
out_shape = x.shape[:-1] + (self.out_features,)
309302
x_input = x.view(-1, x.shape[-1])
@@ -329,13 +322,13 @@ def forward_auto(self, x):
329322
out += self.bias
330323
return out
331324

332-
# def forward_auto(self, x):
333-
# if(x.view(-1, x.shape[-1]).shape[0] == 1):
334-
# return self.forward_manual(x, matmul_type='GEMV') #GEMV / GEMM_SPLITK
335-
# else:
336-
# return self.forward_manual(x, matmul_type='GEMM')
337-
#############################################################
338-
325+
def forward_auto_no_warmup(self, x):
326+
if(x.view(-1, x.shape[-1]).shape[0] <= 16):
327+
out = self.forward_manual(x, matmul_type='GEMM_SPLITK') #GEMV / GEMM_SPLITK
328+
else:
329+
out = self.forward_manual(x, matmul_type='GEMM')
330+
return out
331+
339332
def forward_manual(self, x, matmul_type="GEMM"):
340333
out_shape = x.shape[:-1] + (self.out_features,)
341334

gemlite/triton_kernels/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .gemm_A16fWnO16f_int32packing import gemm_A16fWnO16f_int32packing
22
from .gemv_A16fWnO16f_int32packing import gemv_A16fWnO16f_int32packing
3+
from .gemm_splitK_A16fWnO16f_int32packing import gemm_splitK_A16fWnO16f_int32packing
34

4-
__all__ = ["gemm_A16fWnO16f_int32packing", "gemv_A16fWnO16f_int32packing"]
5+
__all__ = ["gemm_A16fWnO16f_int32packing", "gemv_A16fWnO16f_int32packing", "gemm_splitK_A16fWnO16f_int32packing"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024
2+
#********************************************************
3+
import torch, math
4+
from torch import Tensor
5+
import triton
6+
import triton.language as tl
7+
8+
def init_to_zero(name):
9+
return lambda nargs: nargs[name].zero_()
10+
11+
def is_divisible(dividend, divisor):
12+
return dividend % divisor == 0
13+
14+
def kernel_config_pruner(configs, nargs, **kwargs):
15+
m = nargs['M']
16+
n = nargs['N']
17+
k = nargs['K']
18+
g = nargs['group_size']
19+
20+
used = set()
21+
for config in configs:
22+
group_size_m = config.kwargs['GROUP_SIZE_M']
23+
block_size_m = config.kwargs['BLOCK_SIZE_M'] #min(m, config.kwargs['BLOCK_SIZE_M'])
24+
block_size_n = config.kwargs['BLOCK_SIZE_N'] #min(n, config.kwargs['BLOCK_SIZE_N'])
25+
block_size_k = config.kwargs['BLOCK_SIZE_K'] #min(k, config.kwargs['BLOCK_SIZE_K'])
26+
split_k = config.kwargs['SPLIT_K']
27+
28+
#Constraints
29+
#BLOCK_SIZE_K >= group_size
30+
block_size_k = min(block_size_k, g)
31+
#K needs to be devisible by BLOCK_SIZE_K * SPLIT_K
32+
if(not is_divisible(k, block_size_k * split_k)):
33+
continue
34+
35+
A_load_order = config.kwargs['A_load_order']
36+
meta_evict_policy = config.kwargs['meta_evict_policy']
37+
atomic_mode = config.kwargs['atomic_mode']
38+
39+
_key = (block_size_m, block_size_n, block_size_k, group_size_m, split_k,
40+
A_load_order, meta_evict_policy, atomic_mode,
41+
config.num_stages, config.num_warps,
42+
)
43+
44+
if _key in used:
45+
continue
46+
47+
used.add(_key)
48+
yield triton.Config(
49+
{
50+
'BLOCK_SIZE_M': block_size_m,
51+
'BLOCK_SIZE_N': block_size_n,
52+
'BLOCK_SIZE_K': block_size_k,
53+
'GROUP_SIZE_M': group_size_m,
54+
'SPLIT_K' : split_k,
55+
56+
'A_load_order' : A_load_order,
57+
'meta_evict_policy' : meta_evict_policy,
58+
'atomic_mode' : atomic_mode,
59+
},
60+
num_stages=config.num_stages,
61+
num_warps=config.num_warps,
62+
pre_hook=config.pre_hook,
63+
)
64+
65+
66+
def get_gemm_config():
67+
#Tuned on 4090 RTX
68+
_configs = []
69+
for _M in [16]: #This is fixed to 16 for skinny matrices
70+
for _N in [32, 64]:
71+
for _K in [32, 64, 128]: #[128], group_size >= 128
72+
for _w in [4]: #[4]
73+
for _s in [2, 3]: #[2, 3] #
74+
for _sK in [2, 4, 8]: #[2, 4, 8]
75+
for _a_load_order in [1, 2, 3]: #[1, 2, 3] - [1]: default 4090
76+
for _meta_evict_policy in ['']: #[', 'evict_last'] - ['']: default 4090
77+
for _atomic_mode in ['release', 'relaxed']: #['release', 'relaxed']:
78+
_configs.append(
79+
triton.Config(
80+
{'BLOCK_SIZE_M': _M, 'BLOCK_SIZE_N': _N, 'BLOCK_SIZE_K': _K,
81+
'GROUP_SIZE_M': 8, 'SPLIT_K': _sK,
82+
'A_load_order': _a_load_order, 'meta_evict_policy': _meta_evict_policy, 'atomic_mode': _atomic_mode,
83+
},
84+
num_stages=_s, num_warps=_w,
85+
pre_hook=init_to_zero("c_ptr"),
86+
)
87+
)
88+
return _configs
89+
90+
91+
92+
#@triton.heuristics(values={'CLOSEST_M': lambda args: 2 ** int(math.ceil(math.log2(args['M'])))})
93+
@triton.autotune(
94+
configs = get_gemm_config(),
95+
key=['M', 'N', 'K', 'group_size', 'elements_per_sample'],
96+
prune_configs_by={
97+
'early_config_prune': kernel_config_pruner,
98+
},
99+
warmup=200,
100+
rep=50, #20 for faster tuning
101+
)
102+
103+
@triton.jit
104+
def gemm_splitK_A16fWnO16f_int32packing_kernel(
105+
a_ptr, b_ptr, c_ptr,
106+
scales_ptr, zeros_ptr,
107+
M, N, K,
108+
W_nbits: tl.constexpr, group_size: tl.constexpr, unpack_mask: tl.constexpr, elements_per_sample: tl.constexpr,
109+
stride_am, stride_ak,
110+
stride_bk, stride_bn,
111+
stride_cm, stride_cn,
112+
stride_meta_g, stride_meta_n,
113+
acc_dtype: tl.constexpr,
114+
######### tuning params #########
115+
#CLOSEST_M: tl.constexpr,
116+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
117+
GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr,
118+
A_load_order: tl.constexpr, meta_evict_policy: tl.constexpr, atomic_mode: tl.constexpr,
119+
):
120+
"""
121+
Based on https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py
122+
GEMM for C = matmul(A, dequantize(B, scales, zeros))
123+
A is of shape (M, K): float16 or bfloat16
124+
B is of shape (K//elements_per_sample, N): int32 as a packed matrix
125+
C is of shape (M, N): float16 or bfloat16 depending on the input A
126+
scales and zeros is of shape (group_size, N): float16 or bfloat16
127+
128+
BLOCK_SIZE_M >=16
129+
BLOCK_SIZE_K * SPLIT_K <= group_size for imp1
130+
BLOCK_SIZE_K == SPLIT_K for imp2 (similar to original)
131+
"""
132+
133+
pid = tl.program_id(axis=0)
134+
pid_k = tl.program_id(axis=1)
135+
136+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
137+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
138+
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
139+
140+
#Swizzle
141+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
142+
group_id = pid // num_pid_in_group
143+
first_pid_m = group_id * GROUP_SIZE_M
144+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
145+
pid_m = first_pid_m + (pid % group_size_m)
146+
pid_n = (pid % num_pid_in_group) // group_size_m
147+
148+
#Offsets
149+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
150+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
151+
offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
152+
153+
#Vectorized coalesced load
154+
offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M)
155+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N)
156+
157+
#Inputs
158+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
159+
a_mask = offs_am[:, None] < M
160+
b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
161+
162+
#Meta data stuff
163+
q_shifts = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None]
164+
165+
scales_ptrs = scales_ptr + offs_bn[None, :] * stride_meta_n
166+
zeros_ptrs = zeros_ptr + offs_bn[None, :] * stride_meta_n
167+
168+
stride_mul: tl.constexpr = BLOCK_SIZE_K / group_size
169+
BLOCK_SIZE_K_P: tl.constexpr = BLOCK_SIZE_K // elements_per_sample
170+
####################################################################################
171+
172+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
173+
174+
for k in tl.range(0, num_pid_k, 1, num_stages=1):
175+
176+
b = tl.load(b_ptrs, eviction_policy='evict_first')
177+
178+
if(A_load_order == 1): #Early load
179+
a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
180+
181+
#Meta-data loading policy
182+
k_m = ((k * SPLIT_K + pid_k) * stride_mul).to(tl.int32)
183+
scales = tl.load(scales_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
184+
zeros = tl.load(zeros_ptrs + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
185+
186+
if(A_load_order == 2): #Mid load
187+
a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
188+
189+
# Unpack and dequantize
190+
b = (b >> q_shifts) & unpack_mask
191+
b = (b.to(scales.dtype) - zeros) * scales
192+
193+
if(A_load_order == 3): #Late load
194+
a = tl.load(a_ptrs, mask=a_mask, other=0., eviction_policy='evict_last')
195+
196+
#Dot
197+
acc = tl.dot(a, b.to(a.dtype), acc=acc, out_dtype=acc_dtype, input_precision="ieee")
198+
199+
#Advance
200+
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
201+
b_ptrs += BLOCK_SIZE_K_P * SPLIT_K * stride_bk
202+
203+
#Output
204+
#acc = acc.to(tl.float16)
205+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
206+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
207+
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
208+
tl.atomic_add(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), sem=atomic_mode) #release / relaxed
209+
210+
211+
@torch.library.custom_op("gemlite::gemm_splitK_A16fWnO16f_int32packing_forward", mutates_args=())
212+
def gemm_splitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor,
213+
W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int,
214+
acc_dtype: int,
215+
) -> Tensor:
216+
217+
M, K, N = x.shape[0], x.shape[1], W_q.shape[1]
218+
219+
#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
220+
#assert group_size >= 128, "Only group_size >= 128 is currently supported"
221+
222+
output = torch.empty((M, N), device=W_q.device, dtype=scales.dtype)
223+
224+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K'])
225+
226+
gemm_splitK_A16fWnO16f_int32packing_kernel[grid](
227+
x, W_q, output,
228+
scales, zeros,
229+
M, N, K,
230+
W_nbits, group_size, unpack_mask, elements_per_sample,
231+
x.stride(0), x.stride(1),
232+
W_q.stride(0), W_q.stride(1),
233+
output.stride(0), output.stride(1),
234+
scales.stride(0), scales.stride(1),
235+
tl.float16 if (acc_dtype == 1) else tl.float32,
236+
)
237+
238+
return output
239+
240+
@torch.library.register_fake("gemlite::gemm_splitK_A16fWnO16f_int32packing_forward")
241+
def gemm_splitK_A16fWnO16f_int32packing_forward_fake(x: Tensor, W_q: Tensor, scales: Tensor, zeros: Tensor,
242+
W_nbits: int, group_size: int, unpack_mask: int, elements_per_sample: int,
243+
acc_dtype: int,
244+
) -> Tensor:
245+
246+
M, K, N = x.shape[0], x.shape[1], W_q.shape[1]
247+
return torch.empty((M, N), device=W_q.device, dtype=scales.dtype)
248+
249+
250+
class gemm_splitK_A16fWnO16f_int32packing:
251+
kernel = gemm_splitK_A16fWnO16f_int32packing_kernel
252+
forward = gemm_splitK_A16fWnO16f_int32packing_forward
253+
matmul_type = "GEMM_SPLITK"
254+
255+
__all__ = ["gemm_splitK_A16fWnO16f_int32packing"]

0 commit comments

Comments
 (0)