@@ -175,7 +175,7 @@ def get_closest_m(M):
175
175
# Triton
176
176
_GROUP_SIZE_WARNED = False ;
177
177
class GemLiteLinearTriton (torch .nn .Module ):
178
- SUPPORTED_BITS_TRITON = [1 , 2 , 4 , 8 ]
178
+ SUPPORTED_BITS_TRITON = [1 , 2 , 4 , 8 , 16 ]
179
179
SUPPORTED_DTYPES = [DType .FP16 , DType .FP8 , DType .INT8 ]
180
180
181
181
def __init__ (
@@ -196,15 +196,18 @@ def __init__(
196
196
if in_features % 128 != 0 or out_features % 128 != 0 :
197
197
raise NotImplementedError ("Invalid input shapes" )
198
198
199
+ group_size = 1 if (group_size is None ) else group_size
200
+
199
201
if (group_size < 128 and (_GROUP_SIZE_WARNED is False )):
200
202
warnings .warn ("Make sure to enable autotuning for group_size lower than 128: `set_autotune({'GEMV_REVSPLITK':True, 'GEMV':True, 'GEMM_SPLITK':True, 'GEMM':True})`" )
201
203
_GROUP_SIZE_WARNED = True
202
204
205
+
203
206
self .in_features = in_features
204
207
self .out_features = out_features
205
208
self .orig_shape = (out_features , in_features )
206
209
self .W_nbits = W_nbits
207
- self .group_size = group_size if group_size != - 1 else in_features
210
+ self .group_size = group_size
208
211
self .unpack_mask = 2 ** self .W_nbits - 1
209
212
self .elements_per_sample = 32 // self .W_nbits
210
213
self .signature = (in_features , out_features , W_nbits , group_size )
@@ -259,6 +262,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
259
262
col += 1
260
263
261
264
self .W_q = self .W_q .t ().contiguous () #row-major contiguous()
265
+
266
+ #Bias / device
267
+ self .bias = None if (bias is None ) else torch .nn .Parameter (bias .to (device = self .W_q .device , dtype = self .compute_dtype ))
268
+ self .device = self .W_q .device
269
+
270
+ #FP16 x FP16 / FP8 x FP8 / INT8 x INT8 - no meta-data case
271
+ if ((scales is None ) and (zeros is None )):
272
+ self .zeros = torch .tensor ([[0 ,]]).cuda ()
273
+ self .scales = torch .tensor ([[1 ,]]).cuda ()
274
+ self .W_group_mode = 0
275
+ self .channel_scale_mode = 2 if self .scaled_activations else 0
276
+ return
277
+
278
+ #The rest of the use-cases require some kind of meta-data
262
279
263
280
if (scales is not None ):
264
281
assert scales .dtype == self .meta_dtype , "Unsupported scales/zeros dtype. Only FP16 is supported."
@@ -270,8 +287,9 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
270
287
self .W_group_mode = - 1
271
288
272
289
#Symmetric no shift
273
- if (zeros is None ):
290
+ if (zeros is None and self . group_size > 1 ):
274
291
assert self .scales is not None , "Zeros and scales and can't be both None for W_group_mode = 2."
292
+ self .zeros = zeros
275
293
self .W_group_mode = 2
276
294
else :
277
295
#Asymmetric or Symmetric with shift
@@ -283,7 +301,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
283
301
self .zeros = zeros .view ((self .out_features , - 1 )).t ().contiguous ()
284
302
self .W_group_mode = 3
285
303
else : #Integer
286
- self .zeros = int (zeros )
304
+ self .zeros = int (zeros ) if ( zeros is not None ) else None
287
305
if (self .scales is not None ):
288
306
self .W_group_mode = 3 #Symmetric with shift
289
307
else :
@@ -293,7 +311,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
293
311
294
312
#channel-wise scaling
295
313
self .channel_scale_mode = 0
296
- self .meta_is_chanenlwise = self .scales . numel () == self .out_features
314
+ self .meta_is_chanenlwise = False if ( self .scales is None ) else self . scales . numel () == self .out_features
297
315
298
316
#weight-only
299
317
if ((self .scaled_activations == False ) and (self .meta_is_chanenlwise == True )):
@@ -309,14 +327,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
309
327
self .channel_scale_mode = 3
310
328
self .W_group_mode = 1 if (self .zeros is not None ) else 0 #only with fma_mode=False
311
329
330
+ if (isinstance (self .zeros , int )): #Union[Tensor, int] not supported by custom op
331
+ self .zeros = torch .tensor (self .zeros , dtype = torch .int32 )
332
+
312
333
if (self .channel_scale_mode in [1 , 3 ]):
313
334
assert self .W_group_mode not in [3 , 4 ], "Can't use channel_scale_mode with W_group_mode == 3 or 4."
314
335
315
336
if (self .input_dtype == DType .INT8 ):
316
337
assert self .W_group_mode in [1 ], "Only channel-wise symmetric quantization is supported for INT8 inputs."
317
338
318
- self .bias = None if (bias is None ) else torch .nn .Parameter (bias .to (device = self .W_q .device , dtype = self .compute_dtype ))
319
- self .device = self .W_q .device
339
+ #Dummy values
340
+ if (self .zeros is None ):
341
+ self .zeros = torch .tensor ([[0 ,]]).cuda ()
342
+ if (self .scales is None ):
343
+ self .scales = torch .tensor ([[1 ,]]).cuda ()
320
344
321
345
#TODO: Register buffers
322
346
@@ -419,4 +443,4 @@ def forward_manual(self, x: Tensor, matmul_type: str="GEMM") -> Tensor:
419
443
420
444
###################################################################################################################################
421
445
###################################################################################################################################
422
- GemLiteLinear = GemLiteLinearTriton # Triton by default
446
+ GemLiteLinear = GemLiteLinearTriton # Triton by default
0 commit comments