Skip to content

Commit ba2426c

Browse files
support awq with qbits, only support sym (#402)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7714796 commit ba2426c

File tree

4 files changed

+226
-18
lines changed

4 files changed

+226
-18
lines changed

auto_round/auto_quantizer.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,6 @@ def detect_device(self, target_backend, orig_backend):
368368
return device
369369
else:
370370
return "cpu"
371-
372371

373372
def convert_model(self, model: nn.Module):
374373
"""Converts the given model to an AutoRound model by replacing its layers with quantized layers.
@@ -397,7 +396,7 @@ def convert_model(self, model: nn.Module):
397396
quantization_config.target_backend = quantization_config.backend
398397

399398
target_device = self.detect_device(quantization_config.target_backend, quantization_config.backend)
400-
399+
401400
self.target_device = target_device
402401

403402
if hasattr(quantization_config, "backend"): # pragma: no cover
@@ -416,7 +415,7 @@ def convert_model(self, model: nn.Module):
416415

417416
quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config,
418417
"quant_block_list") else None
419-
418+
420419
if quant_block_list is None:
421420
to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config,
422421
"to_quant_block_names") else None
@@ -564,7 +563,16 @@ def remove_device_str(s, device_str):
564563
layer_device = get_device(layer)
565564

566565
bias = layer.bias is not None
567-
if "awq" in layer_backend:
566+
from auto_round_extension.qbits.qbits_awq import QuantLinear as QBitsAWQQuantLinear
567+
if "awq" in layer_backend and isinstance(QuantLinear, QBitsAWQQuantLinear):
568+
new_layer = QuantLinear.from_linear( # pylint: disable=E1123
569+
layer,
570+
bits,
571+
group_size,
572+
init_only=True,
573+
has_zero_points=not sym
574+
)
575+
elif "awq" in layer_backend:
568576
new_layer = QuantLinear.from_linear( # pylint: disable=E1123
569577
layer,
570578
bits,
@@ -596,23 +604,18 @@ def remove_device_str(s, device_str):
596604
set_module(module, layer_name, new_layer)
597605

598606
def cpu_post_init(self, model):
599-
dep_check = True
600607
message = "Repacking to CPU format"
608+
from auto_round_extension.qbits import qbits_qlinear_classes, qbits_awq_classes
609+
from auto_round_extension.ipex import ipex_qlinear_classes
610+
cpu_layers = tuple(list(qbits_qlinear_classes) + list(ipex_qlinear_classes) + list(qbits_awq_classes))
601611
layers = [] ## ipex post_init will add one more layer
602612
for n, m in model.named_modules():
603-
layers.append((n, m))
604-
613+
if isinstance(m, cpu_layers):
614+
layers.append((n, m))
605615
for n, layer in tqdm(layers, desc=message, total=len(layers),
606616
leave=True):
607-
from auto_round_extension.qbits import qbits_qlinear_classes
608-
from auto_round_extension.ipex import ipex_qlinear_classes
609-
if isinstance(layer, qbits_qlinear_classes):
610-
if dep_check:
611-
layer.req_check()
612-
layer.post_init()
613-
dep_check = False
614-
if isinstance(layer, ipex_qlinear_classes):
615-
layer.post_init()
617+
layer.post_init()
618+
616619

617620
return model
618621

@@ -758,5 +761,3 @@ def is_serializable(self):
758761

759762
transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
760763
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer
761-
762-

auto_round/backend.py

+11
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ def check_auto_round_exllamav2_installed():
168168
requirements=["intel-extension-for-transformers"]
169169
)
170170

171+
BackendInfos['auto_round:qbits_awq'] = BackendInfo(device=["cpu"], sym=[True],
172+
packing_format="awq",
173+
bits=[2, 4, 8], group_size=None,
174+
priority=0 if "intel" in get_cpu_manufacturer() else 5,
175+
feature_checks=[],
176+
requirements=["intel-extension-for-transformers"]
177+
)
178+
171179
BackendInfos['auto_round:ipex_gptq'] = BackendInfo(device=["cpu"], sym=[True, False],
172180
packing_format="ipex_gptq",
173181
bits=[4], group_size=None,
@@ -317,6 +325,9 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym):
317325
if "zp" in backend:
318326
import auto_round_extension.qbits.qlinear_qbits_gptq as qlinear_qbits_gptq
319327
return qlinear_qbits_gptq.QuantLinear
328+
elif "awq" in backend:
329+
import auto_round_extension.qbits.qbits_awq as qlinear_qbits_awq
330+
return qlinear_qbits_awq.QuantLinear
320331
else: # auto_round must be at the end
321332
import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits_autoround
322333
return qlinear_qbits_autoround.QuantLinear

auto_round_extension/qbits/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,8 @@
22
from auto_round_extension.qbits.qlinear_qbits_gptq import (
33
QuantLinear as QBitsGPTQQuantLinear,
44
)
5+
from auto_round_extension.qbits.qbits_awq import QuantLinear as QBitsAWQQuantLinear
56

67
qbits_qlinear_classes = (QBitsQuantLinear, QBitsGPTQQuantLinear)
8+
9+
qbits_awq_classes = (QBitsAWQQuantLinear,)
+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import torch
2+
import torch.nn as nn
3+
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
4+
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
5+
shifts = torch.arange(0, 32, bits, device="cpu")
6+
7+
# unpacking columnwise
8+
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
9+
torch.int8 # smallest dtype available
10+
)
11+
iweights = iweights.view(iweights.shape[0], -1)
12+
13+
# unpacking columnwise
14+
if qzeros is not None:
15+
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
16+
torch.int8 # smallest dtype available
17+
)
18+
izeros = izeros.view(izeros.shape[0], -1)
19+
else:
20+
izeros = qzeros
21+
22+
return iweights, izeros
23+
24+
25+
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
26+
reverse_order_tensor = torch.arange(
27+
iweights.shape[-1],
28+
dtype=torch.int32,
29+
device="cpu",
30+
)
31+
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
32+
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
33+
reverse_order_tensor = reverse_order_tensor.view(-1)
34+
35+
if izeros is not None:
36+
izeros = izeros[:, reverse_order_tensor]
37+
iweights = iweights[:, reverse_order_tensor]
38+
return iweights, izeros
39+
40+
41+
42+
try:
43+
from intel_extension_for_transformers import qbits # with QBits kernels ()
44+
45+
QBITS_INSTALLED = True
46+
except:
47+
QBITS_INSTALLED = False
48+
49+
BITS_DTYPE_MAPPING = {
50+
4: "int4_clip",
51+
8: "int8",
52+
}
53+
54+
55+
def convert_dtype_torch2str(dtype):
56+
if dtype == torch.int8:
57+
return "int8"
58+
elif dtype == torch.float:
59+
return "fp32"
60+
elif dtype == torch.float16:
61+
return "fp16"
62+
elif dtype == torch.bfloat16:
63+
return "bf16"
64+
elif isinstance(dtype, str) and dtype in ["int8", "fp32", "fp16", "bf16"]:
65+
return dtype
66+
else:
67+
assert False, "Unsupported pytorch dtype {} to str dtype".format(dtype)
68+
69+
70+
class QuantLinear(nn.Module):
71+
72+
def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
73+
super().__init__()
74+
assert QBITS_INSTALLED, \
75+
"Please install ITREX qbits package with `pip install intel-extension-for-transformers`."
76+
77+
self.use_bf16 = qbits.check_isa_supported("AMX")
78+
79+
if w_bit not in [2, 3, 4, 8]:
80+
raise NotImplementedError("Only 2, 3, 4, 8 bits are supported for now.")
81+
82+
self.in_features = in_features
83+
self.out_features = out_features
84+
self.w_bit = w_bit
85+
self.group_size = group_size if group_size != -1 else in_features
86+
self.zero_point = zero_point
87+
self.scale_dtype = torch.float32
88+
89+
# quick sanity check (make sure alignment)
90+
assert self.in_features % self.group_size == 0
91+
assert out_features % (32 // self.w_bit) == 0
92+
self.pack_num = 32 // self.w_bit
93+
self.register_buffer(
94+
"qzeros",
95+
torch.zeros(
96+
(in_features // self.group_size, out_features // self.pack_num),
97+
dtype=torch.int8,
98+
device=dev,
99+
)
100+
)
101+
self.register_buffer(
102+
"scales",
103+
torch.zeros(
104+
(in_features // self.group_size, out_features),
105+
dtype=torch.bfloat16 if self.use_bf16 else torch.float32,
106+
device=dev,
107+
))
108+
if bias:
109+
self.register_buffer(
110+
"bias",
111+
torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev),
112+
)
113+
else:
114+
self.register_buffer(
115+
"bias",
116+
None,
117+
)
118+
qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev)
119+
self.register_buffer("qweight", qweight)
120+
121+
def post_init(self):
122+
assert self.qweight.device.type == "cpu"
123+
124+
intweight, zeros = unpack_awq(self.qweight, self.qzeros, self.w_bit) # weight: k x n zeros: k / group_size x n
125+
intweight, zeros = reverse_awq_order(intweight, zeros, self.w_bit) # weight: k x n zeros: k / group_size x n
126+
if self.zero_point: ## asym has accuracy issue, have not root caused yet
127+
intweight = torch.bitwise_and(intweight, (2 ** self.w_bit) - 1) - (2 ** (self.w_bit - 1))
128+
zeros = torch.bitwise_and(zeros, (2 ** self.w_bit) - 1) - (2 ** (self.w_bit - 1))
129+
else:
130+
##symmetric, our default zp is 8
131+
intweight = torch.bitwise_and(intweight, (2 ** self.w_bit) - 1) - (2 ** (self.w_bit - 1))
132+
g_idx = torch.empty(0, dtype=torch.int32)
133+
self.qweight = qbits.repack_quantized_weight(intweight, self.scales.float(), zeros, g_idx,
134+
BITS_DTYPE_MAPPING[self.w_bit],
135+
convert_dtype_torch2str(self.scale_dtype),
136+
convert_dtype_torch2str(self.scales.dtype), self.zero_point,
137+
self.group_size)
138+
139+
140+
141+
@classmethod
142+
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False):
143+
awq_linear = cls(
144+
w_bit,
145+
group_size,
146+
linear.in_features,
147+
linear.out_features,
148+
linear.bias is not None,
149+
has_zero_points,
150+
linear.weight.device,
151+
)
152+
if init_only: # just prepare for loading sd
153+
return awq_linear
154+
155+
raise NotImplementedError("Only inference is supported for Exllama kernels")
156+
157+
@torch.no_grad()
158+
def forward(self, x):
159+
assert QBITS_INSTALLED, (
160+
"QBits kernels could not be loaded. "
161+
"Please install with `pip install intel-extension-for-transformers` and "
162+
"refer to the detail https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md")
163+
164+
input_dtype = x.dtype
165+
out_shape = x.shape[:-1] + (self.out_features,)
166+
x = x.view(-1, x.shape[-1]) # convert xd to 2d
167+
out_2d_shape = x.shape[:-1] + (self.out_features,)
168+
169+
outputs = torch.zeros(out_2d_shape, dtype=input_dtype)
170+
bias = self.bias if self.bias is not None else torch.empty(
171+
0, dtype=torch.bfloat16 if self.use_bf16 else torch.float32)
172+
173+
qbits.woq_linear(x, self.qweight, bias, outputs, convert_dtype_torch2str(input_dtype),
174+
BITS_DTYPE_MAPPING[self.w_bit], convert_dtype_torch2str(self.scale_dtype), True)
175+
176+
return outputs.view(out_shape)
177+
178+
def extra_repr(self) -> str:
179+
return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
180+
self.in_features,
181+
self.out_features,
182+
self.bias is not None,
183+
self.w_bit,
184+
self.group_size,
185+
))
186+
187+
188+
def qbits_post_init(model):
189+
for _, submodule in model.named_modules():
190+
if isinstance(submodule, QuantLinear):
191+
submodule.post_init()
192+
193+
return model

0 commit comments

Comments
 (0)