@@ -71,9 +71,10 @@ def __init__(
71
71
if in_features % 128 != 0 or out_features % 128 != 0 :
72
72
raise NotImplementedError ("Invalid input shapes" )
73
73
74
- group_size = 1 if (group_size is None ) else group_size
74
+ if (group_size is not None ):
75
+ assert group_size >= 32 , "Only group_size >= 32 is supported."
75
76
76
- assert group_size >= 32 , "Only group_size >= 32 is supported."
77
+ group_size = 1 if ( group_size is None ) else group_size
77
78
78
79
self .in_features = in_features
79
80
self .out_features = out_features
@@ -162,7 +163,7 @@ def pack_weights_over_cols(self, W_q, W_nbits, packing_bitwidth=32, transpose=Tr
162
163
return W_q_out , elements_per_sample
163
164
164
165
#Make sure to feed UINT8 W_q for packing
165
- def pack (self , W_q : Tensor , scales : Tensor , zeros : Union [Tensor , int ], bias : Union [Tensor , None ]= None , fma_mode : bool = False , contiguous : bool = True , packing_bitwidth : int = 32 ):
166
+ def pack (self , W_q : Tensor , scales : Tensor , zeros : Union [Tensor , int ], bias : Union [Tensor , None ]= None , fma_mode : bool = False , contiguous : Union [ int , None ] = None , packing_bitwidth : int = 32 ):
166
167
167
168
#Unpacked weights
168
169
self .W_q = None
@@ -175,9 +176,12 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
175
176
self .W_q = W_q .t () #row-major
176
177
self .elements_per_sample = 1
177
178
179
+ if (contiguous is None ): contiguous = False
180
+
178
181
if (W_q .dtype == torch .uint8 ): #Packed weigths
179
182
self .W_q , self .elements_per_sample = self .pack_weights_over_cols (W_q .view (self .orig_shape ), W_nbits = self .W_nbits , packing_bitwidth = packing_bitwidth , transpose = True ) #Over-K
180
183
#self.W_q, self.elements_per_sample = self.pack_weights_over_rows(W_q.view(self.orig_shape), W_nbits=self.W_nbits, packing_bitwidth=packing_bitwidth, transpose=True) #Over-N
184
+ if (contiguous is None ): contiguous = True
181
185
182
186
if (self .W_q is None ):
183
187
raise Exception ('Weights were not packed, please check your W_q.dtype' )
@@ -196,8 +200,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
196
200
self .scales = None
197
201
self .W_group_mode = 0
198
202
self .channel_scale_mode = 2 if self .scaled_activations else 0
199
- return
200
-
203
+
201
204
#The rest of the use-cases require some kind of meta-data
202
205
if (scales is not None ):
203
206
self .scales = scales .view ((self .out_features , - 1 )).t ()
@@ -207,7 +210,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
207
210
#Symmetric no shift
208
211
if (zeros is None ):
209
212
self .zeros = None
210
- self .W_group_mode = 2
213
+ self .W_group_mode = 2 if ( self . scales is not None ) else 0
211
214
else :
212
215
#Asymmetric or Symmetric with shift
213
216
if (isinstance (zeros , torch .Tensor )):
@@ -253,9 +256,9 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
253
256
if (isinstance (self .zeros , int )): #Union[Tensor, int] not supported by custom op
254
257
self .zeros = torch .tensor (self .zeros , dtype = torch .int32 , device = self .device )
255
258
if (self .zeros is None ):
256
- self .zeros = torch .tensor ([], dtype = torch .int32 , device = self .device )
259
+ self .zeros = torch .tensor ([[] ], dtype = torch .int32 , device = self .device )
257
260
if (self .scales is None ):
258
- self .scales = torch .tensor ([], dtype = torch .int32 , device = self .device )
261
+ self .scales = torch .tensor ([[] ], dtype = torch .int32 , device = self .device )
259
262
260
263
if (self .scales is not None ):
261
264
self .meta_dtype = DType .FP32 if self .scales .dtype == torch .float32 else DType .FP16
0 commit comments