Skip to content

Commit b4c8206

Browse files
committed
fix tests
1 parent 616af15 commit b4c8206

5 files changed

+134
-104
lines changed

gemlite/core.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ def __init__(
7171
if in_features % 128 != 0 or out_features % 128 != 0:
7272
raise NotImplementedError("Invalid input shapes")
7373

74-
group_size = 1 if (group_size is None) else group_size
74+
if(group_size is not None):
75+
assert group_size >= 32, "Only group_size >= 32 is supported."
7576

76-
assert group_size >= 32, "Only group_size >= 32 is supported."
77+
group_size = 1 if (group_size is None) else group_size
7778

7879
self.in_features = in_features
7980
self.out_features = out_features
@@ -162,7 +163,7 @@ def pack_weights_over_cols(self, W_q, W_nbits, packing_bitwidth=32, transpose=Tr
162163
return W_q_out, elements_per_sample
163164

164165
#Make sure to feed UINT8 W_q for packing
165-
def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Union[Tensor, None]=None, fma_mode: bool=False, contiguous: bool=True, packing_bitwidth: int=32):
166+
def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Union[Tensor, None]=None, fma_mode: bool=False, contiguous: Union[int,None]=None, packing_bitwidth: int=32):
166167

167168
#Unpacked weights
168169
self.W_q = None
@@ -175,9 +176,12 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
175176
self.W_q = W_q.t() #row-major
176177
self.elements_per_sample = 1
177178

179+
if(contiguous is None): contiguous = False
180+
178181
if(W_q.dtype == torch.uint8): #Packed weigths
179182
self.W_q, self.elements_per_sample = self.pack_weights_over_cols(W_q.view(self.orig_shape), W_nbits=self.W_nbits, packing_bitwidth=packing_bitwidth, transpose=True) #Over-K
180183
#self.W_q, self.elements_per_sample = self.pack_weights_over_rows(W_q.view(self.orig_shape), W_nbits=self.W_nbits, packing_bitwidth=packing_bitwidth, transpose=True) #Over-N
184+
if(contiguous is None): contiguous = True
181185

182186
if(self.W_q is None):
183187
raise Exception('Weights were not packed, please check your W_q.dtype')
@@ -196,8 +200,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
196200
self.scales = None
197201
self.W_group_mode = 0
198202
self.channel_scale_mode = 2 if self.scaled_activations else 0
199-
return
200-
203+
201204
#The rest of the use-cases require some kind of meta-data
202205
if(scales is not None):
203206
self.scales = scales.view((self.out_features, -1)).t()
@@ -207,7 +210,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
207210
#Symmetric no shift
208211
if(zeros is None):
209212
self.zeros = None
210-
self.W_group_mode = 2
213+
self.W_group_mode = 2 if(self.scales is not None) else 0
211214
else:
212215
#Asymmetric or Symmetric with shift
213216
if(isinstance(zeros, torch.Tensor)):
@@ -253,9 +256,9 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
253256
if(isinstance(self.zeros, int)): #Union[Tensor, int] not supported by custom op
254257
self.zeros = torch.tensor(self.zeros, dtype=torch.int32, device=self.device)
255258
if(self.zeros is None):
256-
self.zeros = torch.tensor([], dtype=torch.int32, device=self.device)
259+
self.zeros = torch.tensor([[]], dtype=torch.int32, device=self.device)
257260
if(self.scales is None):
258-
self.scales = torch.tensor([], dtype=torch.int32, device=self.device)
261+
self.scales = torch.tensor([[]], dtype=torch.int32, device=self.device)
259262

260263
if(self.scales is not None):
261264
self.meta_dtype = DType.FP32 if self.scales.dtype == torch.float32 else DType.FP16

gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def get_autotune_config():
115115
def get_default_config():
116116
#small batch, not sure what is the right default cnnfig here.
117117
return [triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':128, 'GROUP_SIZE_M':8, 'A_load_order':2, 'meta_evict_policy':''},
118-
num_warps=4, num_stages=2),]
118+
num_warps=4, num_stages=1),]
119119

120120
ENABLE_AUTOTUNE = AUTOTUNE_ENABLE.GEMM
121121

gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,17 @@ def get_default_config():
137137
#4090: default
138138
config = triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':32, 'BLOCK_SIZE_K':32, 'SPLIT_K':2, 'GROUP_SIZE_M':8,
139139
'A_load_order':2, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
140-
num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))
140+
num_warps=4, num_stages=1, pre_hook=init_to_zero("c_ptr"))
141141

142142
if(compute_capability == (8, 0)): #A100
143143
config = triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':2, 'GROUP_SIZE_M':8,
144144
'A_load_order':0, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
145-
num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))
145+
num_warps=4, num_stages=1, pre_hook=init_to_zero("c_ptr"))
146146

147147
if(compute_capability == (9, 0)): #H100
148148
config = triton.Config({'BLOCK_SIZE_M':16, 'BLOCK_SIZE_N':64, 'BLOCK_SIZE_K':32, 'SPLIT_K':2, 'GROUP_SIZE_M':8,
149149
'A_load_order':0, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
150-
num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))
150+
num_warps=4, num_stages=1, pre_hook=init_to_zero("c_ptr"))
151151

152152
return [config]
153153

gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ def get_autotune_config():
122122

123123
def get_default_config():
124124
# #4090: default
125-
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32, 'A_load_order':1, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
125+
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32, 'A_load_order':1, 'meta_evict_policy':'', 'atomic_mode':'relaxed', 'dot_prod_mode':0},
126126
num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))
127127

128128
if(compute_capability == (8, 0)): #A100
129-
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
129+
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':512, 'BLOCK_SIZE_K':16, 'A_load_order':0, 'meta_evict_policy':'', 'atomic_mode':'relaxed', 'dot_prod_mode':0},
130130
num_warps=2, num_stages=2, pre_hook=init_to_zero("c_ptr"))
131131

132132
if(compute_capability == (9, 0)): #H100
133-
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':16, 'A_load_order':1, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
133+
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':16, 'A_load_order':1, 'meta_evict_policy':'', 'atomic_mode':'relaxed', 'dot_prod_mode':0},
134134
num_warps=2, num_stages=1, pre_hook=init_to_zero("c_ptr"))
135135

136136
return [config]

0 commit comments

Comments
 (0)