Skip to content

Commit 7da572d

Browse files
committed
fix FP8/INT8 + add detailed tests
1 parent de0759c commit 7da572d

6 files changed

+357
-20
lines changed

gemlite/core.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_closest_m(M):
175175
# Triton
176176
_GROUP_SIZE_WARNED = False;
177177
class GemLiteLinearTriton(torch.nn.Module):
178-
SUPPORTED_BITS_TRITON = [1, 2, 4, 8]
178+
SUPPORTED_BITS_TRITON = [1, 2, 4, 8, 16]
179179
SUPPORTED_DTYPES = [DType.FP16, DType.FP8, DType.INT8]
180180

181181
def __init__(
@@ -196,15 +196,18 @@ def __init__(
196196
if in_features % 128 != 0 or out_features % 128 != 0:
197197
raise NotImplementedError("Invalid input shapes")
198198

199+
group_size = 1 if (group_size is None) else group_size
200+
199201
if(group_size < 128 and (_GROUP_SIZE_WARNED is False)):
200202
warnings.warn("Make sure to enable autotuning for group_size lower than 128: `set_autotune({'GEMV_REVSPLITK':True, 'GEMV':True, 'GEMM_SPLITK':True, 'GEMM':True})`")
201203
_GROUP_SIZE_WARNED = True
202204

205+
203206
self.in_features = in_features
204207
self.out_features = out_features
205208
self.orig_shape = (out_features, in_features)
206209
self.W_nbits = W_nbits
207-
self.group_size = group_size if group_size != -1 else in_features
210+
self.group_size = group_size
208211
self.unpack_mask = 2**self.W_nbits - 1
209212
self.elements_per_sample = 32 // self.W_nbits
210213
self.signature = (in_features, out_features, W_nbits, group_size)
@@ -259,6 +262,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
259262
col += 1
260263

261264
self.W_q = self.W_q.t().contiguous() #row-major contiguous()
265+
266+
#Bias / device
267+
self.bias = None if (bias is None) else torch.nn.Parameter(bias.to(device=self.W_q.device, dtype=self.compute_dtype))
268+
self.device = self.W_q.device
269+
270+
#FP16 x FP16 / FP8 x FP8 / INT8 x INT8 - no meta-data case
271+
if((scales is None) and (zeros is None)):
272+
self.zeros = torch.tensor([[0,]]).cuda()
273+
self.scales = torch.tensor([[1,]]).cuda()
274+
self.W_group_mode = 0
275+
self.channel_scale_mode = 2 if self.scaled_activations else 0
276+
return
277+
278+
#The rest of the use-cases require some kind of meta-data
262279

263280
if(scales is not None):
264281
assert scales.dtype == self.meta_dtype, "Unsupported scales/zeros dtype. Only FP16 is supported."
@@ -270,8 +287,9 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
270287
self.W_group_mode = -1
271288

272289
#Symmetric no shift
273-
if(zeros is None):
290+
if(zeros is None and self.group_size > 1):
274291
assert self.scales is not None, "Zeros and scales and can't be both None for W_group_mode = 2."
292+
self.zeros = zeros
275293
self.W_group_mode = 2
276294
else:
277295
#Asymmetric or Symmetric with shift
@@ -283,7 +301,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
283301
self.zeros = zeros.view((self.out_features, -1)).t().contiguous()
284302
self.W_group_mode = 3
285303
else: #Integer
286-
self.zeros = int(zeros)
304+
self.zeros = int(zeros) if(zeros is not None) else None
287305
if(self.scales is not None):
288306
self.W_group_mode = 3 #Symmetric with shift
289307
else:
@@ -293,7 +311,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
293311

294312
#channel-wise scaling
295313
self.channel_scale_mode = 0
296-
self.meta_is_chanenlwise = self.scales.numel() == self.out_features
314+
self.meta_is_chanenlwise = False if(self.scales is None) else self.scales.numel() == self.out_features
297315

298316
#weight-only
299317
if((self.scaled_activations == False) and (self.meta_is_chanenlwise == True)):
@@ -309,14 +327,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
309327
self.channel_scale_mode = 3
310328
self.W_group_mode = 1 if(self.zeros is not None) else 0 #only with fma_mode=False
311329

330+
if(isinstance(self.zeros, int)): #Union[Tensor, int] not supported by custom op
331+
self.zeros = torch.tensor(self.zeros, dtype=torch.int32)
332+
312333
if(self.channel_scale_mode in [1, 3]):
313334
assert self.W_group_mode not in [3, 4], "Can't use channel_scale_mode with W_group_mode == 3 or 4."
314335

315336
if(self.input_dtype == DType.INT8):
316337
assert self.W_group_mode in [1], "Only channel-wise symmetric quantization is supported for INT8 inputs."
317338

318-
self.bias = None if (bias is None) else torch.nn.Parameter(bias.to(device=self.W_q.device, dtype=self.compute_dtype))
319-
self.device = self.W_q.device
339+
#Dummy values
340+
if(self.zeros is None):
341+
self.zeros = torch.tensor([[0,]]).cuda()
342+
if(self.scales is None):
343+
self.scales = torch.tensor([[1,]]).cuda()
320344

321345
#TODO: Register buffers
322346

@@ -419,4 +443,4 @@ def forward_manual(self, x: Tensor, matmul_type: str="GEMM") -> Tensor:
419443

420444
###################################################################################################################################
421445
###################################################################################################################################
422-
GemLiteLinear = GemLiteLinearTriton # Triton by default
446+
GemLiteLinear = GemLiteLinearTriton # Triton by default

gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py

+1
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def gemm_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor,
233233

234234
#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
235235
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
236+
zeros = zeros.item() if (zeros.numel()==1) else zeros
236237

237238
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
238239

gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py

+1
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def gemm_splitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales:
268268
#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
269269
#assert group_size >= 128, "Only group_size >= 128 is currently supported"
270270
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
271+
zeros = zeros.item() if (zeros.numel()==1) else zeros
271272

272273
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K'])
273274

gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,10 @@ def gemv_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor,
222222

223223
#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
224224
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
225+
zeros = zeros.item() if (zeros.numel()==1) else zeros
225226

226227
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), triton.cdiv(K, meta['BLOCK_SIZE_K']))
227228

228-
#faster to do channel-wise like this for this kernel
229-
if(channel_scale_mode == 1 and W_group_mode == 1):
230-
channel_scale_mode, W_group_mode = 0, 3
231-
232229
gemv_A16fWnO16f_int32packing_kernel[grid](
233230
x, W_q, output,
234231
scales, zeros, scales_x,

gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ def get_autotune_config():
8585
compute_capability = torch.cuda.get_device_capability(0)
8686

8787
def get_default_config():
88-
#4090: default
88+
# #4090: default
8989
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32, 'A_load_order':2, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
9090
num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))
9191

92+
#4090: default
93+
#config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32}, num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))
94+
9295
if(compute_capability == (8, 0)): #A100
9396
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':16, 'A_load_order':2, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
9497
num_warps=2, num_stages=1, pre_hook=init_to_zero("c_ptr"))
@@ -157,21 +160,20 @@ def gemv_revsplitK_A16fWnO16f_int32packing_kernel(
157160
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
158161
b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
159162
q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None]
160-
161163
####################################################################
162164
#Load meta data first, for two passes
163165
k_m = (pid_k * (BLOCK_SIZE_K / group_size)).to(tl.int32)
164166

165167
if(W_group_mode >= 2): #[2, 3, 4]
166-
scales = tl.load(scales_ptr + offs_bn[None, :] + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
168+
scales = tl.load(scales_ptr + offs_bn[None, :] * stride_meta_n + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
167169
else:
168170
scales = None
169171

170172
if(W_group_mode == 1 or W_group_mode >= 3): #[1, 3, 4]
171173
if(zero_is_scalar):
172174
zeros = zeros_ptr
173175
else:
174-
zeros = tl.load(zeros_ptr + offs_bn[None, :] + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
176+
zeros = tl.load(zeros_ptr + offs_bn[None, :] * stride_meta_n + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
175177
else:
176178
zeros = None
177179

@@ -236,13 +238,10 @@ def gemv_revsplitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scale
236238

237239
#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
238240
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
241+
zeros = zeros.item() if (zeros.numel()==1) else zeros
239242

240243
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), triton.cdiv(K, meta['BLOCK_SIZE_K'] * 2))
241244

242-
#faster to do channel-wise like this for this kernel
243-
if(channel_scale_mode == 1 and W_group_mode == 1):
244-
channel_scale_mode, W_group_mode = 0, 3
245-
246245
gemv_revsplitK_A16fWnO16f_int32packing_kernel[grid](
247246
x, W_q, output,
248247
scales, zeros, scales_x,

0 commit comments

Comments
 (0)