|
7 | 7 | ####################################################################################################
|
8 | 8 | #16-bit activations / 8-bit weigths
|
9 | 9 | class A16W8:
|
10 |
| - def __init__(self, device='cuda:0'): |
11 |
| - self.device = device |
| 10 | + def __init__(self, device='cuda:0'): |
| 11 | + self.device = device |
12 | 12 |
|
13 |
| - def from_weights(self, weight, bias): |
14 |
| - #GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32 |
| 13 | + def from_weights(self, weight, bias): |
| 14 | + #GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32 |
15 | 15 |
|
16 |
| - scales = torch.abs(weight.float()).amax(axis=1, keepdim=True) / 127.0 |
17 |
| - W_q = torch.round(weight / scales).to(device=self.device, dtype=torch.int8) |
18 |
| - scales = scales.to(device=self.device, dtype=torch.float16) |
| 16 | + scales = torch.abs(weight.float()).amax(axis=1, keepdim=True) / 127.0 |
| 17 | + W_q = torch.round(weight / scales).to(device=self.device, dtype=torch.int8) |
| 18 | + scales = scales.to(device=self.device, dtype=torch.float16) |
19 | 19 |
|
20 |
| - in_features, out_features = weight.shape[::-1] |
| 20 | + in_features, out_features = weight.shape[::-1] |
21 | 21 |
|
22 |
| - gemlite_linear = GemLiteLinearTriton(8, |
23 |
| - group_size=in_features, |
24 |
| - in_features=in_features, |
25 |
| - out_features=out_features, |
26 |
| - input_dtype=DType.FP16, |
27 |
| - output_dtype=DType.FP16, |
28 |
| - ) |
| 22 | + gemlite_linear = GemLiteLinearTriton(8, |
| 23 | + group_size=in_features, |
| 24 | + in_features=in_features, |
| 25 | + out_features=out_features, |
| 26 | + input_dtype=DType.FP16, |
| 27 | + output_dtype=DType.FP16, |
| 28 | + ) |
29 | 29 |
|
30 |
| - gemlite_linear.pack(W_q, scales, |
31 |
| - zeros=None, |
32 |
| - bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
33 |
| - contiguous=False) |
| 30 | + gemlite_linear.pack(W_q, scales, |
| 31 | + zeros=None, |
| 32 | + bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
| 33 | + contiguous=False) |
34 | 34 |
|
35 |
| - gemlite_linear.W_group_mode = 2 |
36 |
| - gemlite_linear.channel_scale_mode = 0 |
37 |
| - gemlite_linear.default_gemv = 'GEMV_SPLITK' |
38 |
| - return gemlite_linear |
| 35 | + gemlite_linear.W_group_mode = 2 |
| 36 | + gemlite_linear.channel_scale_mode = 0 |
| 37 | + gemlite_linear.default_gemv = 'GEMV_SPLITK' |
| 38 | + return gemlite_linear |
39 | 39 |
|
40 |
| - def from_linear(self, linear_layer): |
41 |
| - return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None) |
| 40 | + def from_linear(self, linear_layer): |
| 41 | + return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None) |
42 | 42 |
|
43 | 43 | ####################################################################################################
|
44 | 44 | #8-bit dynamic activations / 8-bit weights
|
45 | 45 | class A8W8_dynamic:
|
46 |
| - def __init__(self, device='cuda:0', fp8=False, weight_scale=1.): |
47 |
| - self.device = device |
48 |
| - self.fp8 = fp8 |
49 |
| - self.weight_scale = weight_scale |
50 |
| - |
51 |
| - def from_weights(self, weight, bias): |
52 |
| - if(self.fp8): #FP8 |
53 |
| - w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448 |
54 |
| - else: #INT8 |
55 |
| - w_dtype, input_dtype, max_val = torch.int8, DType.INT8, 127 |
56 |
| - |
57 |
| - |
58 |
| - weight = weight.float() * self.weight_scale |
59 |
| - scales = torch.abs(weight).amax(axis=1, keepdim=True) / max_val |
60 |
| - W_q = torch.round(weight / scales).to(device=self.device, dtype=w_dtype) |
61 |
| - scales = scales.to(device=self.device, dtype=torch.float16)#.float() |
62 |
| - |
63 |
| - in_features, out_features = weight.shape[::-1] |
64 |
| - |
65 |
| - gemlite_linear = GemLiteLinearTriton(8, |
66 |
| - group_size=in_features, |
67 |
| - in_features=in_features, |
68 |
| - out_features=out_features, |
69 |
| - input_dtype=input_dtype, |
70 |
| - output_dtype=DType.FP16, |
71 |
| - scaled_activations=True, |
72 |
| - ) |
73 |
| - |
74 |
| - def scale_fct(x): |
75 |
| - x_shape = x.shape |
76 |
| - out_x = x.view(-1, x.shape[-1]) |
77 |
| - scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val |
78 |
| - out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype) |
79 |
| - return out_x.view(x_shape), scaled_x |
80 |
| - |
81 |
| - gemlite_linear.scale_activations = scale_fct |
82 |
| - |
83 |
| - gemlite_linear.pack(W_q, scales / self.weight_scale, |
84 |
| - zeros=None, |
85 |
| - bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
86 |
| - contiguous=False) |
87 |
| - |
88 |
| - gemlite_linear.W_group_mode = 0 |
89 |
| - gemlite_linear.channel_scale_mode = 3 #activation[:,None] + weight[None,:] |
90 |
| - gemlite_linear.meta_dtype = DType.FP32 |
91 |
| - gemlite_linear.default_gemv = 'GEMV_SPLITK' |
92 |
| - return gemlite_linear |
93 |
| - |
94 |
| - def from_linear(self, linear_layer): |
95 |
| - return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None) |
| 46 | + def __init__(self, device='cuda:0', fp8=False, weight_scale=1.): |
| 47 | + self.device = device |
| 48 | + self.fp8 = fp8 |
| 49 | + self.weight_scale = weight_scale |
| 50 | + |
| 51 | + def from_weights(self, weight, bias): |
| 52 | + if(self.fp8): #FP8 |
| 53 | + w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448 |
| 54 | + else: #INT8 |
| 55 | + w_dtype, input_dtype, max_val = torch.int8, DType.INT8, 127 |
| 56 | + |
| 57 | + |
| 58 | + weight = weight.float() * self.weight_scale |
| 59 | + scales = torch.abs(weight).amax(axis=1, keepdim=True) / max_val |
| 60 | + W_q = torch.round(weight / scales).to(device=self.device, dtype=w_dtype) |
| 61 | + scales = scales.to(device=self.device, dtype=torch.float16)#.float() |
| 62 | + |
| 63 | + in_features, out_features = weight.shape[::-1] |
| 64 | + |
| 65 | + gemlite_linear = GemLiteLinearTriton(8, |
| 66 | + group_size=in_features, |
| 67 | + in_features=in_features, |
| 68 | + out_features=out_features, |
| 69 | + input_dtype=input_dtype, |
| 70 | + output_dtype=DType.FP16, |
| 71 | + scaled_activations=True, |
| 72 | + ) |
| 73 | + |
| 74 | + def scale_fct(x): |
| 75 | + x_shape = x.shape |
| 76 | + out_x = x.view(-1, x.shape[-1]) |
| 77 | + scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val |
| 78 | + out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype) |
| 79 | + return out_x.view(x_shape), scaled_x |
| 80 | + |
| 81 | + gemlite_linear.scale_activations = scale_fct |
| 82 | + |
| 83 | + gemlite_linear.pack(W_q, scales / self.weight_scale, |
| 84 | + zeros=None, |
| 85 | + bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
| 86 | + contiguous=False) |
| 87 | + |
| 88 | + gemlite_linear.W_group_mode = 0 |
| 89 | + gemlite_linear.channel_scale_mode = 3 #activation[:,None] + weight[None,:] |
| 90 | + gemlite_linear.meta_dtype = DType.FP32 |
| 91 | + gemlite_linear.default_gemv = 'GEMV_SPLITK' |
| 92 | + return gemlite_linear |
| 93 | + |
| 94 | + def from_linear(self, linear_layer): |
| 95 | + return self.from_weights(linear_layer.weight.data, linear_layer.bias.data if linear_layer.bias is not None else None) |
96 | 96 |
|
97 | 97 | class A8W8_int8_dynamic(A8W8_dynamic):
|
98 |
| - def __init__(self, device='cuda:0', weight_scale=1.): |
99 |
| - super().__init__() |
100 |
| - self.device = device |
101 |
| - self.weight_scale = weight_scale |
102 |
| - self.fp8 = False |
| 98 | + def __init__(self, device='cuda:0', weight_scale=1.): |
| 99 | + super().__init__() |
| 100 | + self.device = device |
| 101 | + self.weight_scale = weight_scale |
| 102 | + self.fp8 = False |
103 | 103 |
|
104 | 104 | class A8W8_fp8_dynamic(A8W8_dynamic):
|
105 |
| - def __init__(self, device='cuda:0', weight_scale=1.): |
106 |
| - super().__init__() |
107 |
| - self.device = device |
108 |
| - self.weight_scale = weight_scale |
109 |
| - self.fp8 = True |
| 105 | + def __init__(self, device='cuda:0', weight_scale=1.): |
| 106 | + super().__init__() |
| 107 | + self.device = device |
| 108 | + self.weight_scale = weight_scale |
| 109 | + self.fp8 = True |
110 | 110 |
|
111 | 111 | ####################################################################################################
|
112 | 112 | #FP16 activations / Wn packed weights
|
113 | 113 | class A16Wn:
|
114 |
| - def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False): |
115 |
| - self.packing_bitwidth = 32 |
116 |
| - self.post_scale = post_scale |
117 |
| - self.device = device |
| 114 | + def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False): |
| 115 | + self.packing_bitwidth = 32 |
| 116 | + self.post_scale = post_scale |
| 117 | + self.device = device |
118 | 118 |
|
119 |
| - def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias): |
120 |
| - in_features, out_features = W_q.shape[::-1] |
| 119 | + def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias): |
| 120 | + in_features, out_features = W_q.shape[::-1] |
121 | 121 |
|
122 |
| - gemlite_linear = GemLiteLinearTriton(W_nbits, |
123 |
| - group_size=group_size, |
124 |
| - in_features=in_features, |
125 |
| - out_features=out_features, |
126 |
| - input_dtype=DType.FP16, |
127 |
| - output_dtype=DType.FP16, |
128 |
| - scaled_activations=False, |
129 |
| - ) |
| 122 | + gemlite_linear = GemLiteLinearTriton(W_nbits, |
| 123 | + group_size=group_size, |
| 124 | + in_features=in_features, |
| 125 | + out_features=out_features, |
| 126 | + input_dtype=DType.FP16, |
| 127 | + output_dtype=DType.FP16, |
| 128 | + scaled_activations=False, |
| 129 | + ) |
130 | 130 |
|
131 |
| - gemlite_linear.pack(W_q.to(self.device), |
132 |
| - scales.to(device=self.device, dtype=torch.float16), |
133 |
| - zeros.to(device=self.device, dtype=torch.float16), |
134 |
| - bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
135 |
| - contiguous=True, |
136 |
| - packing_bitwidth=self.packing_bitwidth) |
| 131 | + gemlite_linear.pack(W_q.to(self.device), |
| 132 | + scales.to(device=self.device, dtype=torch.float16), |
| 133 | + zeros.to(device=self.device, dtype=torch.float16), |
| 134 | + bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
| 135 | + contiguous=True, |
| 136 | + packing_bitwidth=self.packing_bitwidth) |
137 | 137 |
|
138 |
| - gemlite_linear.default_gemv = 'GEMV_REVSPLITK' |
| 138 | + gemlite_linear.default_gemv = 'GEMV_REVSPLITK' |
139 | 139 |
|
140 |
| - if(group_size == in_features): |
141 |
| - if(self.post_scale): |
142 |
| - gemlite_linear.W_group_mode = 1 |
143 |
| - gemlite_linear.channel_scale_mode = 1 |
144 |
| - else: |
145 |
| - gemlite_linear.W_group_mode = 3 |
146 |
| - gemlite_linear.channel_scale_mode = 0 |
| 140 | + if(group_size == in_features): |
| 141 | + if(self.post_scale): |
| 142 | + gemlite_linear.W_group_mode = 1 |
| 143 | + gemlite_linear.channel_scale_mode = 1 |
| 144 | + else: |
| 145 | + gemlite_linear.W_group_mode = 3 |
| 146 | + gemlite_linear.channel_scale_mode = 0 |
147 | 147 |
|
148 |
| - return gemlite_linear |
| 148 | + return gemlite_linear |
149 | 149 |
|
150 |
| - def from_hqqlinear(self, hqq_layer): |
151 |
| - assert hqq_layer.meta['axis'] == 1, 'Only axis==1 is supported.' |
| 150 | + def from_hqqlinear(self, hqq_layer): |
| 151 | + assert hqq_layer.meta['axis'] == 1, 'Only axis==1 is supported.' |
152 | 152 |
|
153 |
| - self.device = hqq_layer.W_q.device |
| 153 | + self.device = hqq_layer.W_q.device |
154 | 154 |
|
155 |
| - W_nbits = hqq_layer.meta['nbits'] |
156 |
| - group_size = hqq_layer.meta["group_size"] |
157 |
| - if(group_size is None): |
158 |
| - group_size = hqq_layer.in_features |
| 155 | + W_nbits = hqq_layer.meta['nbits'] |
| 156 | + group_size = hqq_layer.meta["group_size"] |
| 157 | + if(group_size is None): |
| 158 | + group_size = hqq_layer.in_features |
159 | 159 |
|
160 |
| - W_q = hqq_layer.unpack(dtype=torch.uint8).view(hqq_layer.meta['shape']) #Expects uint8 for Wn quantization! |
161 |
| - scales = hqq_layer.meta['scale'].clone() |
162 |
| - zeros = hqq_layer.meta['zero'].clone() |
163 |
| - bias = hqq_layer.bias.clone() if (hqq_layer.bias is not None) else None |
| 160 | + W_q = hqq_layer.unpack(dtype=torch.uint8).view(hqq_layer.meta['shape']) #Expects uint8 for Wn quantization! |
| 161 | + scales = hqq_layer.meta['scale'].clone() |
| 162 | + zeros = hqq_layer.meta['zero'].clone() |
| 163 | + bias = hqq_layer.bias.clone() if (hqq_layer.bias is not None) else None |
164 | 164 |
|
165 |
| - gemlite_linear = self.from_weights(W_q, scales, zeros, W_nbits, group_size, bias) |
| 165 | + gemlite_linear = self.from_weights(W_q, scales, zeros, W_nbits, group_size, bias) |
166 | 166 |
|
167 |
| - del hqq_layer.W_q |
168 |
| - del hqq_layer.meta |
169 |
| - del hqq_layer |
170 |
| - torch.cuda.empty_cache() |
| 167 | + del hqq_layer.W_q |
| 168 | + del hqq_layer.meta |
| 169 | + del hqq_layer |
| 170 | + torch.cuda.empty_cache() |
171 | 171 |
|
172 |
| - return gemlite_linear |
| 172 | + return gemlite_linear |
173 | 173 |
|
174 | 174 |
|
175 | 175 | ####################################################################################################
|
176 | 176 | #FP8 dynamic activations / W4 packed weights
|
177 | 177 | class A8Wn_dynamic(A16Wn):
|
178 |
| - def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False): |
179 |
| - super().__init__() |
180 |
| - self.packing_bitwidth = 32 |
181 |
| - self.post_scale = post_scale |
182 |
| - self.device = device |
183 |
| - |
184 |
| - def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias): |
185 |
| - w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448 |
186 |
| - |
187 |
| - in_features, out_features = W_q.shape[::-1] |
188 |
| - |
189 |
| - gemlite_linear = GemLiteLinearTriton(W_nbits, |
190 |
| - group_size=group_size, |
191 |
| - in_features=in_features, |
192 |
| - out_features=out_features, |
193 |
| - input_dtype=input_dtype, |
194 |
| - output_dtype=DType.FP16, |
195 |
| - scaled_activations=True, |
196 |
| - ) |
197 |
| - |
198 |
| - gemlite_linear.pack(W_q.to(self.device), |
199 |
| - scales.to(device=self.device, dtype=torch.float16), |
200 |
| - zeros.to(device=self.device, dtype=torch.float16), |
201 |
| - bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
202 |
| - contiguous=True, |
203 |
| - packing_bitwidth=self.packing_bitwidth) |
204 |
| - |
205 |
| - def scale_fct(x): |
206 |
| - x_shape = x.shape |
207 |
| - out_x = x.view(-1, x.shape[-1]) |
208 |
| - scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val |
209 |
| - out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype) |
210 |
| - return out_x.view(x_shape), scaled_x |
211 |
| - |
212 |
| - gemlite_linear.scale_activations = scale_fct |
213 |
| - |
214 |
| - gemlite_linear.default_gemv = 'GEMV_REVSPLITK' |
215 |
| - |
216 |
| - if(group_size == in_features): |
217 |
| - if(self.post_scale): |
218 |
| - gemlite_linear.W_group_mode = 1 |
219 |
| - gemlite_linear.channel_scale_mode = 3 |
220 |
| - else: |
221 |
| - gemlite_linear.W_group_mode = 3 |
222 |
| - gemlite_linear.channel_scale_mode = 2 |
223 |
| - |
224 |
| - return gemlite_linear |
| 178 | + def __init__(self, device='cuda:0', packing_bitwidth=32, post_scale=False): |
| 179 | + super().__init__() |
| 180 | + self.packing_bitwidth = 32 |
| 181 | + self.post_scale = post_scale |
| 182 | + self.device = device |
| 183 | + |
| 184 | + def from_weights(self, W_q, scales, zeros, W_nbits, group_size, bias): |
| 185 | + w_dtype, input_dtype, max_val = torch.float8_e4m3fn, DType.FP8, 448 |
| 186 | + |
| 187 | + in_features, out_features = W_q.shape[::-1] |
| 188 | + |
| 189 | + gemlite_linear = GemLiteLinearTriton(W_nbits, |
| 190 | + group_size=group_size, |
| 191 | + in_features=in_features, |
| 192 | + out_features=out_features, |
| 193 | + input_dtype=input_dtype, |
| 194 | + output_dtype=DType.FP16, |
| 195 | + scaled_activations=True, |
| 196 | + ) |
| 197 | + |
| 198 | + gemlite_linear.pack(W_q.to(self.device), |
| 199 | + scales.to(device=self.device, dtype=torch.float16), |
| 200 | + zeros.to(device=self.device, dtype=torch.float16), |
| 201 | + bias=bias.to(device=self.device, dtype=torch.float16) if bias is not None else None, |
| 202 | + contiguous=True, |
| 203 | + packing_bitwidth=self.packing_bitwidth) |
| 204 | + |
| 205 | + def scale_fct(x): |
| 206 | + x_shape = x.shape |
| 207 | + out_x = x.view(-1, x.shape[-1]) |
| 208 | + scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / max_val |
| 209 | + out_x = torch.round(out_x / scaled_x).to(dtype=w_dtype) |
| 210 | + return out_x.view(x_shape), scaled_x |
| 211 | + |
| 212 | + gemlite_linear.scale_activations = scale_fct |
| 213 | + |
| 214 | + gemlite_linear.default_gemv = 'GEMV_REVSPLITK' |
| 215 | + |
| 216 | + if(group_size == in_features): |
| 217 | + if(self.post_scale): |
| 218 | + gemlite_linear.W_group_mode = 1 |
| 219 | + gemlite_linear.channel_scale_mode = 3 |
| 220 | + else: |
| 221 | + gemlite_linear.W_group_mode = 3 |
| 222 | + gemlite_linear.channel_scale_mode = 2 |
| 223 | + |
| 224 | + return gemlite_linear |
225 | 225 |
|
0 commit comments