Skip to content

Commit 616af15

Browse files
committed
fix space
1 parent a02f7b2 commit 616af15

File tree

1 file changed

+178
-178
lines changed

1 file changed

+178
-178
lines changed

gemlite/helper.py

+178-178
Original file line numberDiff line numberDiff line change
@@ -7,219 +7,219 @@
77
####################################################################################################
88
#16-bit activations / 8-bit weigths
99
class A16W8:
10-
def __init__(self, device='cuda:0'):
11-
self.device = device
10+
def __init__(self, device='cuda:0'):
11+
self.device = device
1212

13-
def from_weights(self, weight, bias):
14-
#GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32
13+
def from_weights(self, weight, bias):
14+
#GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32
1515

16-
scales = torch.abs(weight.float()).amax(axis=1, keepdim=True) / 127.0
17-
W_q = torch.round(weight / scales).to(device=self.device, dtype=torch.int8)
18-
scales = scales.to(device=self.device, dtype=torch.float16)
16+
scales = torch.abs(weight.float()).amax(axis=1, keepdim=True) / 127.0
17+
W_q = torch.round(weight / scales).to(device=self.device, dtype=torch.int8)
18+
scales = scales.to(device=self.device, dtype=torch.float16)
1919

20-
in_features, out_features = weight.shape[::-1]
20+
in_features, out_features = weight.shape[::-1]
2121

22-
gemlite_linear = GemLiteLinearTriton(8,
23-
group_size=in_features,
24-
in_features=in_features,
25-
out_features=out_features,
26-
input_dtype=DType.FP16,
27-
output_dtype=DType.FP16,
28-
)
22+
gemlite_linear = GemLiteLinearTriton(8,
23+
group_size=in_features,
24+
in_features=in_features,
25+
out_features=out_features,
26+
input_dtype=DType.FP16,
27+
output_dtype=DType.FP16,
28+
)
2929

30-
gemlite_linear.pack(W_q, scales,
31-
zeros=None,
32-
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
33-
contiguous=False)
30+
gemlite_linear.pack(W_q, scales,
31+
zeros=None,
32+
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
33+
contiguous=False)
3434

35-
gemlite_linear.W_group_mode = 2
36-
gemlite_linear.channel_scale_mode = 0
37-
gemlite_linear.default_gemv = 'GEMV_SPLITK'
38-
return gemlite_linear
35+
gemlite_linear.W_group_mode = 2
36+
gemlite_linear.channel_scale_mode = 0
37+
gemlite_linear.default_gemv = 'GEMV_SPLITK'
38+
return gemlite_linear
3939

40-
def from_linear(self, linear_layer):
41-
return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None)
40+
def from_linear(self, linear_layer):
41+
return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None)
4242

4343
####################################################################################################
4444
#8-bit dynamic activations / 8-bit weights
4545
class A8W8_dynamic:
46-
def __init__(self, device='cuda:0', fp8=False, weight_scale=1.):
47-
self.device = device
48-
self.fp8 = fp8
49-
self.weight_scale = weight_scale
50-
51-
def from_weights(self, weight, bias):
52-
if(self.fp8): #FP8
53-
w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448
54-
else: #INT8
55-
w_dtype, input_dtype, max_val = torch.int8, DType.INT8, 127
56-
57-
58-
weight = weight.float() * self.weight_scale
59-
scales = torch.abs(weight).amax(axis=1, keepdim=True) / max_val
60-
W_q = torch.round(weight / scales).to(device=self.device, dtype=w_dtype)
61-
scales = scales.to(device=self.device, dtype=torch.float16)#.float()
62-
63-
in_features, out_features = weight.shape[::-1]
64-
65-
gemlite_linear = GemLiteLinearTriton(8,
66-
group_size=in_features,
67-
in_features=in_features,
68-
out_features=out_features,
69-
input_dtype=input_dtype,
70-
output_dtype=DType.FP16,
71-
scaled_activations=True,
72-
)
73-
74-
def scale_fct(x):
75-
x_shape = x.shape
76-
out_x = x.view(-1, x.shape[-1])
77-
scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val
78-
out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype)
79-
return out_x.view(x_shape), scaled_x
80-
81-
gemlite_linear.scale_activations = scale_fct
82-
83-
gemlite_linear.pack(W_q, scales / self.weight_scale,
84-
zeros=None,
85-
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
86-
contiguous=False)
87-
88-
gemlite_linear.W_group_mode = 0
89-
gemlite_linear.channel_scale_mode = 3 #activation[:,None] + weight[None,:]
90-
gemlite_linear.meta_dtype = DType.FP32
91-
gemlite_linear.default_gemv = 'GEMV_SPLITK'
92-
return gemlite_linear
93-
94-
def from_linear(self, linear_layer):
95-
return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None)
46+
def __init__(self, device='cuda:0', fp8=False, weight_scale=1.):
47+
self.device = device
48+
self.fp8 = fp8
49+
self.weight_scale = weight_scale
50+
51+
def from_weights(self, weight, bias):
52+
if(self.fp8): #FP8
53+
w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448
54+
else: #INT8
55+
w_dtype, input_dtype, max_val = torch.int8, DType.INT8, 127
56+
57+
58+
weight = weight.float() * self.weight_scale
59+
scales = torch.abs(weight).amax(axis=1, keepdim=True) / max_val
60+
W_q = torch.round(weight / scales).to(device=self.device, dtype=w_dtype)
61+
scales = scales.to(device=self.device, dtype=torch.float16)#.float()
62+
63+
in_features, out_features = weight.shape[::-1]
64+
65+
gemlite_linear = GemLiteLinearTriton(8,
66+
group_size=in_features,
67+
in_features=in_features,
68+
out_features=out_features,
69+
input_dtype=input_dtype,
70+
output_dtype=DType.FP16,
71+
scaled_activations=True,
72+
)
73+
74+
def scale_fct(x):
75+
x_shape = x.shape
76+
out_x = x.view(-1, x.shape[-1])
77+
scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val
78+
out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype)
79+
return out_x.view(x_shape), scaled_x
80+
81+
gemlite_linear.scale_activations = scale_fct
82+
83+
gemlite_linear.pack(W_q, scales / self.weight_scale,
84+
zeros=None,
85+
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
86+
contiguous=False)
87+
88+
gemlite_linear.W_group_mode = 0
89+
gemlite_linear.channel_scale_mode = 3 #activation[:,None] + weight[None,:]
90+
gemlite_linear.meta_dtype = DType.FP32
91+
gemlite_linear.default_gemv = 'GEMV_SPLITK'
92+
return gemlite_linear
93+
94+
def from_linear(self, linear_layer):
95+
return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None)
9696

9797
class A8W8_int8_dynamic(A8W8_dynamic):
98-
def __init__(self, device='cuda:0', weight_scale=1.):
99-
super().__init__()
100-
self.device = device
101-
self.weight_scale = weight_scale
102-
self.fp8 = False
98+
def __init__(self, device='cuda:0', weight_scale=1.):
99+
super().__init__()
100+
self.device = device
101+
self.weight_scale = weight_scale
102+
self.fp8 = False
103103

104104
class A8W8_fp8_dynamic(A8W8_dynamic):
105-
def __init__(self, device='cuda:0', weight_scale=1.):
106-
super().__init__()
107-
self.device = device
108-
self.weight_scale = weight_scale
109-
self.fp8 = True
105+
def __init__(self, device='cuda:0', weight_scale=1.):
106+
super().__init__()
107+
self.device = device
108+
self.weight_scale = weight_scale
109+
self.fp8 = True
110110

111111
####################################################################################################
112112
#FP16 activations / Wn packed weights
113113
class A16Wn:
114-
def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False):
115-
self.packing_bitwidth = 32
116-
self.post_scale = post_scale
117-
self.device = device
114+
def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False):
115+
self.packing_bitwidth = 32
116+
self.post_scale = post_scale
117+
self.device = device
118118

119-
def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias):
120-
in_features, out_features = W_q.shape[::-1]
119+
def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias):
120+
in_features, out_features = W_q.shape[::-1]
121121

122-
gemlite_linear = GemLiteLinearTriton(W_nbits,
123-
group_size=group_size,
124-
in_features=in_features,
125-
out_features=out_features,
126-
input_dtype=DType.FP16,
127-
output_dtype=DType.FP16,
128-
scaled_activations=False,
129-
)
122+
gemlite_linear = GemLiteLinearTriton(W_nbits,
123+
group_size=group_size,
124+
in_features=in_features,
125+
out_features=out_features,
126+
input_dtype=DType.FP16,
127+
output_dtype=DType.FP16,
128+
scaled_activations=False,
129+
)
130130

131-
gemlite_linear.pack(W_q.to(self.device),
132-
scales.to(device=self.device, dtype=torch.float16),
133-
zeros.to(device=self.device, dtype=torch.float16),
134-
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
135-
contiguous=True,
136-
packing_bitwidth=self.packing_bitwidth)
131+
gemlite_linear.pack(W_q.to(self.device),
132+
scales.to(device=self.device, dtype=torch.float16),
133+
zeros.to(device=self.device, dtype=torch.float16),
134+
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
135+
contiguous=True,
136+
packing_bitwidth=self.packing_bitwidth)
137137

138-
gemlite_linear.default_gemv = 'GEMV_REVSPLITK'
138+
gemlite_linear.default_gemv = 'GEMV_REVSPLITK'
139139

140-
if(group_size == in_features):
141-
if(self.post_scale):
142-
gemlite_linear.W_group_mode = 1
143-
gemlite_linear.channel_scale_mode = 1
144-
else:
145-
gemlite_linear.W_group_mode = 3
146-
gemlite_linear.channel_scale_mode = 0
140+
if(group_size == in_features):
141+
if(self.post_scale):
142+
gemlite_linear.W_group_mode = 1
143+
gemlite_linear.channel_scale_mode = 1
144+
else:
145+
gemlite_linear.W_group_mode = 3
146+
gemlite_linear.channel_scale_mode = 0
147147

148-
return gemlite_linear
148+
return gemlite_linear
149149

150-
def from_hqqlinear(self, hqq_layer):
151-
assert hqq_layer.meta['axis'] == 1, 'Only axis==1 is supported.'
150+
def from_hqqlinear(self, hqq_layer):
151+
assert hqq_layer.meta['axis'] == 1, 'Only axis==1 is supported.'
152152

153-
self.device = hqq_layer.W_q.device
153+
self.device = hqq_layer.W_q.device
154154

155-
W_nbits = hqq_layer.meta['nbits']
156-
group_size = hqq_layer.meta["group_size"]
157-
if(group_size is None):
158-
group_size = hqq_layer.in_features
155+
W_nbits = hqq_layer.meta['nbits']
156+
group_size = hqq_layer.meta["group_size"]
157+
if(group_size is None):
158+
group_size = hqq_layer.in_features
159159

160-
W_q = hqq_layer.unpack(dtype=torch.uint8).view(hqq_layer.meta['shape']) #Expects uint8 for Wn quantization!
161-
scales = hqq_layer.meta['scale'].clone()
162-
zeros = hqq_layer.meta['zero'].clone()
163-
bias = hqq_layer.bias.clone() if (hqq_layer.bias is not None) else None
160+
W_q = hqq_layer.unpack(dtype=torch.uint8).view(hqq_layer.meta['shape']) #Expects uint8 for Wn quantization!
161+
scales = hqq_layer.meta['scale'].clone()
162+
zeros = hqq_layer.meta['zero'].clone()
163+
bias = hqq_layer.bias.clone() if (hqq_layer.bias is not None) else None
164164

165-
gemlite_linear = self.from_weights(W_q, scales, zeros, W_nbits, group_size, bias)
165+
gemlite_linear = self.from_weights(W_q, scales, zeros, W_nbits, group_size, bias)
166166

167-
del hqq_layer.W_q
168-
del hqq_layer.meta
169-
del hqq_layer
170-
torch.cuda.empty_cache()
167+
del hqq_layer.W_q
168+
del hqq_layer.meta
169+
del hqq_layer
170+
torch.cuda.empty_cache()
171171

172-
return gemlite_linear
172+
return gemlite_linear
173173

174174

175175
####################################################################################################
176176
#FP8 dynamic activations / W4 packed weights
177177
class A8Wn_dynamic(A16Wn):
178-
def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False):
179-
super().__init__()
180-
self.packing_bitwidth = 32
181-
self.post_scale = post_scale
182-
self.device = device
183-
184-
def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias):
185-
w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448
186-
187-
in_features, out_features = W_q.shape[::-1]
188-
189-
gemlite_linear = GemLiteLinearTriton(W_nbits,
190-
group_size=group_size,
191-
in_features=in_features,
192-
out_features=out_features,
193-
input_dtype=input_dtype,
194-
output_dtype=DType.FP16,
195-
scaled_activations=True,
196-
)
197-
198-
gemlite_linear.pack(W_q.to(self.device),
199-
scales.to(device=self.device, dtype=torch.float16),
200-
zeros.to(device=self.device, dtype=torch.float16),
201-
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
202-
contiguous=True,
203-
packing_bitwidth=self.packing_bitwidth)
204-
205-
def scale_fct(x):
206-
x_shape = x.shape
207-
out_x = x.view(-1, x.shape[-1])
208-
scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val
209-
out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype)
210-
return out_x.view(x_shape), scaled_x
211-
212-
gemlite_linear.scale_activations = scale_fct
213-
214-
gemlite_linear.default_gemv = 'GEMV_REVSPLITK'
215-
216-
if(group_size == in_features):
217-
if(self.post_scale):
218-
gemlite_linear.W_group_mode = 1
219-
gemlite_linear.channel_scale_mode = 3
220-
else:
221-
gemlite_linear.W_group_mode = 3
222-
gemlite_linear.channel_scale_mode = 2
223-
224-
return gemlite_linear
178+
def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False):
179+
super().__init__()
180+
self.packing_bitwidth = 32
181+
self.post_scale = post_scale
182+
self.device = device
183+
184+
def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias):
185+
w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448
186+
187+
in_features, out_features = W_q.shape[::-1]
188+
189+
gemlite_linear = GemLiteLinearTriton(W_nbits,
190+
group_size=group_size,
191+
in_features=in_features,
192+
out_features=out_features,
193+
input_dtype=input_dtype,
194+
output_dtype=DType.FP16,
195+
scaled_activations=True,
196+
)
197+
198+
gemlite_linear.pack(W_q.to(self.device),
199+
scales.to(device=self.device, dtype=torch.float16),
200+
zeros.to(device=self.device, dtype=torch.float16),
201+
bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None,
202+
contiguous=True,
203+
packing_bitwidth=self.packing_bitwidth)
204+
205+
def scale_fct(x):
206+
x_shape = x.shape
207+
out_x = x.view(-1, x.shape[-1])
208+
scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val
209+
out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype)
210+
return out_x.view(x_shape), scaled_x
211+
212+
gemlite_linear.scale_activations = scale_fct
213+
214+
gemlite_linear.default_gemv = 'GEMV_REVSPLITK'
215+
216+
if(group_size == in_features):
217+
if(self.post_scale):
218+
gemlite_linear.W_group_mode = 1
219+
gemlite_linear.channel_scale_mode = 3
220+
else:
221+
gemlite_linear.W_group_mode = 3
222+
gemlite_linear.channel_scale_mode = 2
223+
224+
return gemlite_linear
225225

0 commit comments

Comments
 (0)