1
+ import torch
2
+ import torch .nn as nn
3
+ AWQ_REVERSE_ORDER = [0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ]
4
+ def unpack_awq (qweight : torch .Tensor , qzeros : torch .Tensor , bits : int ):
5
+ shifts = torch .arange (0 , 32 , bits , device = "cpu" )
6
+
7
+ # unpacking columnwise
8
+ iweights = torch .bitwise_right_shift (qweight [:, :, None ], shifts [None , None , :]).to (
9
+ torch .int8 # smallest dtype available
10
+ )
11
+ iweights = iweights .view (iweights .shape [0 ], - 1 )
12
+
13
+ # unpacking columnwise
14
+ if qzeros is not None :
15
+ izeros = torch .bitwise_right_shift (qzeros [:, :, None ], shifts [None , None , :]).to (
16
+ torch .int8 # smallest dtype available
17
+ )
18
+ izeros = izeros .view (izeros .shape [0 ], - 1 )
19
+ else :
20
+ izeros = qzeros
21
+
22
+ return iweights , izeros
23
+
24
+
25
+ def reverse_awq_order (iweights : torch .Tensor , izeros : torch .Tensor , bits : int ):
26
+ reverse_order_tensor = torch .arange (
27
+ iweights .shape [- 1 ],
28
+ dtype = torch .int32 ,
29
+ device = "cpu" ,
30
+ )
31
+ reverse_order_tensor = reverse_order_tensor .view (- 1 , 32 // bits )
32
+ reverse_order_tensor = reverse_order_tensor [:, AWQ_REVERSE_ORDER ]
33
+ reverse_order_tensor = reverse_order_tensor .view (- 1 )
34
+
35
+ if izeros is not None :
36
+ izeros = izeros [:, reverse_order_tensor ]
37
+ iweights = iweights [:, reverse_order_tensor ]
38
+ return iweights , izeros
39
+
40
+
41
+
42
+ try :
43
+ from intel_extension_for_transformers import qbits # with QBits kernels ()
44
+
45
+ QBITS_INSTALLED = True
46
+ except :
47
+ QBITS_INSTALLED = False
48
+
49
+ BITS_DTYPE_MAPPING = {
50
+ 4 : "int4_clip" ,
51
+ 8 : "int8" ,
52
+ }
53
+
54
+
55
+ def convert_dtype_torch2str (dtype ):
56
+ if dtype == torch .int8 :
57
+ return "int8"
58
+ elif dtype == torch .float :
59
+ return "fp32"
60
+ elif dtype == torch .float16 :
61
+ return "fp16"
62
+ elif dtype == torch .bfloat16 :
63
+ return "bf16"
64
+ elif isinstance (dtype , str ) and dtype in ["int8" , "fp32" , "fp16" , "bf16" ]:
65
+ return dtype
66
+ else :
67
+ assert False , "Unsupported pytorch dtype {} to str dtype" .format (dtype )
68
+
69
+
70
+ class QuantLinear (nn .Module ):
71
+
72
+ def __init__ (self , w_bit , group_size , in_features , out_features , bias , zero_point , dev ):
73
+ super ().__init__ ()
74
+ assert QBITS_INSTALLED , \
75
+ "Please install ITREX qbits package with `pip install intel-extension-for-transformers`."
76
+
77
+ self .use_bf16 = qbits .check_isa_supported ("AMX" )
78
+
79
+ if w_bit not in [2 , 3 , 4 , 8 ]:
80
+ raise NotImplementedError ("Only 2, 3, 4, 8 bits are supported for now." )
81
+
82
+ self .in_features = in_features
83
+ self .out_features = out_features
84
+ self .w_bit = w_bit
85
+ self .group_size = group_size if group_size != - 1 else in_features
86
+ self .zero_point = zero_point
87
+ self .scale_dtype = torch .float32
88
+
89
+ # quick sanity check (make sure alignment)
90
+ assert self .in_features % self .group_size == 0
91
+ assert out_features % (32 // self .w_bit ) == 0
92
+ self .pack_num = 32 // self .w_bit
93
+ self .register_buffer (
94
+ "qzeros" ,
95
+ torch .zeros (
96
+ (in_features // self .group_size , out_features // self .pack_num ),
97
+ dtype = torch .int8 ,
98
+ device = dev ,
99
+ )
100
+ )
101
+ self .register_buffer (
102
+ "scales" ,
103
+ torch .zeros (
104
+ (in_features // self .group_size , out_features ),
105
+ dtype = torch .bfloat16 if self .use_bf16 else torch .float32 ,
106
+ device = dev ,
107
+ ))
108
+ if bias :
109
+ self .register_buffer (
110
+ "bias" ,
111
+ torch .zeros ((out_features ), dtype = torch .bfloat16 if self .use_bf16 else torch .float32 , device = dev ),
112
+ )
113
+ else :
114
+ self .register_buffer (
115
+ "bias" ,
116
+ None ,
117
+ )
118
+ qweight = torch .zeros ((in_features , out_features // self .pack_num ), dtype = torch .int32 , device = dev )
119
+ self .register_buffer ("qweight" , qweight )
120
+
121
+ def post_init (self ):
122
+ assert self .qweight .device .type == "cpu"
123
+
124
+ intweight , zeros = unpack_awq (self .qweight , self .qzeros , self .w_bit ) # weight: k x n zeros: k / group_size x n
125
+ intweight , zeros = reverse_awq_order (intweight , zeros , self .w_bit ) # weight: k x n zeros: k / group_size x n
126
+ if self .zero_point : ## asym has accuracy issue, have not root caused yet
127
+ intweight = torch .bitwise_and (intweight , (2 ** self .w_bit ) - 1 ) - (2 ** (self .w_bit - 1 ))
128
+ zeros = torch .bitwise_and (zeros , (2 ** self .w_bit ) - 1 ) - (2 ** (self .w_bit - 1 ))
129
+ else :
130
+ ##symmetric, our default zp is 8
131
+ intweight = torch .bitwise_and (intweight , (2 ** self .w_bit ) - 1 ) - (2 ** (self .w_bit - 1 ))
132
+ g_idx = torch .empty (0 , dtype = torch .int32 )
133
+ self .qweight = qbits .repack_quantized_weight (intweight , self .scales .float (), zeros , g_idx ,
134
+ BITS_DTYPE_MAPPING [self .w_bit ],
135
+ convert_dtype_torch2str (self .scale_dtype ),
136
+ convert_dtype_torch2str (self .scales .dtype ), self .zero_point ,
137
+ self .group_size )
138
+
139
+
140
+
141
+ @classmethod
142
+ def from_linear (cls , linear , w_bit , group_size , init_only = False , scales = None , zeros = None , has_zero_points = False ):
143
+ awq_linear = cls (
144
+ w_bit ,
145
+ group_size ,
146
+ linear .in_features ,
147
+ linear .out_features ,
148
+ linear .bias is not None ,
149
+ has_zero_points ,
150
+ linear .weight .device ,
151
+ )
152
+ if init_only : # just prepare for loading sd
153
+ return awq_linear
154
+
155
+ raise NotImplementedError ("Only inference is supported for Exllama kernels" )
156
+
157
+ @torch .no_grad ()
158
+ def forward (self , x ):
159
+ assert QBITS_INSTALLED , (
160
+ "QBits kernels could not be loaded. "
161
+ "Please install with `pip install intel-extension-for-transformers` and "
162
+ "refer to the detail https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md" )
163
+
164
+ input_dtype = x .dtype
165
+ out_shape = x .shape [:- 1 ] + (self .out_features ,)
166
+ x = x .view (- 1 , x .shape [- 1 ]) # convert xd to 2d
167
+ out_2d_shape = x .shape [:- 1 ] + (self .out_features ,)
168
+
169
+ outputs = torch .zeros (out_2d_shape , dtype = input_dtype )
170
+ bias = self .bias if self .bias is not None else torch .empty (
171
+ 0 , dtype = torch .bfloat16 if self .use_bf16 else torch .float32 )
172
+
173
+ qbits .woq_linear (x , self .qweight , bias , outputs , convert_dtype_torch2str (input_dtype ),
174
+ BITS_DTYPE_MAPPING [self .w_bit ], convert_dtype_torch2str (self .scale_dtype ), True )
175
+
176
+ return outputs .view (out_shape )
177
+
178
+ def extra_repr (self ) -> str :
179
+ return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}" .format (
180
+ self .in_features ,
181
+ self .out_features ,
182
+ self .bias is not None ,
183
+ self .w_bit ,
184
+ self .group_size ,
185
+ ))
186
+
187
+
188
+ def qbits_post_init (model ):
189
+ for _ , submodule in model .named_modules ():
190
+ if isinstance (submodule , QuantLinear ):
191
+ submodule .post_init ()
192
+
193
+ return model
0 commit comments