Skip to content

Commit 096233a

Browse files
committed
issue/170: add signed quant
1 parent a5d1924 commit 096233a

File tree

2 files changed

+71
-33
lines changed

2 files changed

+71
-33
lines changed

src/infiniop/ops/quantize_gptq/cpu/quantize_gptq_cpu.cc

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,25 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, Descriptor **desc_pt
4242
return INFINI_STATUS_SUCCESS;
4343
}
4444

45-
float quantize(float x, float s, float z, float maxq) {
45+
float quantize(float x, float s, float z, float minq, float maxq) {
4646
float q = std::roundf(x / s + z);
47-
q = std::max(0.0f, std::min(maxq, q));
47+
q = std::max(minq, std::min(maxq, q));
4848
return s * (q - z);
4949
}
5050

5151
template <typename T>
5252
void find_params(T *x, T *b_scale, T *zero, int N, int K,
5353
int bits = 4, bool sym = false, bool mse = false,
54-
float norm = 2.4f, int grid = 100, float maxshrink = 0.8f) {
55-
float maxq = static_cast<float>(std::pow(2, bits) - 1);
54+
float norm = 2.4f, int grid = 100, float maxshrink = 0.8f, bool sign_ed = false) {
55+
float maxq;
56+
float minq;
57+
if (sign_ed) { // 如果有符号量化
58+
maxq = static_cast<float>(std::pow(2, bits - 1) - 1);
59+
minq = -static_cast<float>(std::pow(2, bits - 1));
60+
} else {
61+
maxq = static_cast<float>(std::pow(2, bits) - 1);
62+
minq = 0.0f;
63+
}
5664
#pragma omp parallel for
5765
for (int n = 0; n < N; n++) {
5866
float x_min = FLT_MAX;
@@ -76,16 +84,16 @@ void find_params(T *x, T *b_scale, T *zero, int N, int K,
7684
x_max = 1;
7785
}
7886
if constexpr (std::is_same<T, fp16_t>::value) {
79-
b_scale[n] = utils::cast<fp16_t>((x_max - x_min) / maxq);
87+
b_scale[n] = utils::cast<fp16_t>((x_max - x_min) / (maxq - minq));
8088
if (sym) {
81-
zero[n] = utils::cast<fp16_t>((maxq + 1.0f) * 0.5f);
89+
zero[n] = utils::cast<fp16_t>((maxq + minq + 1.0f) * 0.5f);
8290
} else {
83-
zero[n] = utils::cast<fp16_t>(-x_min * maxq / (x_max - x_min));
91+
zero[n] = utils::cast<fp16_t>(-x_min * (maxq - minq) / (x_max - x_min));
8492
}
8593
} else if constexpr (std::is_same<T, float>::value) {
86-
b_scale[n] = (x_max - x_min) / maxq;
94+
b_scale[n] = (x_max - x_min) / (maxq - minq);
8795
if (sym) {
88-
zero[n] = (maxq + 1.0f) * 0.5f;
96+
zero[n] = (maxq + minq + 1.0f) * 0.5f;
8997
} else {
9098
zero[n] = -x_min / b_scale[n];
9199
}
@@ -96,11 +104,11 @@ void find_params(T *x, T *b_scale, T *zero, int N, int K,
96104
float p = 1 - static_cast<float>(i) / static_cast<float>(grid);
97105
float x_min_1 = p * x_min;
98106
float x_max_1 = p * x_max;
99-
float scale_1 = (x_max_1 - x_min_1) / maxq;
107+
float scale_1 = (x_max_1 - x_min_1) / (maxq - minq);
100108
float zero_1 = (sym ? utils::cast<float>(zero[n]) : std::roundf(-x_min_1 / scale_1));
101109
float err = 0.0f;
102110
for (int k = 0; k < K; k++) {
103-
float q = quantize(utils::cast<float>(x[n * K + k]), scale_1, zero_1, maxq);
111+
float q = quantize(utils::cast<float>(x[n * K + k]), scale_1, zero_1, minq, maxq);
104112
q -= utils::cast<float>(x[n * K + k]);
105113
q = std::abs(q);
106114
q = static_cast<float>(std::pow(q, norm));
@@ -344,12 +352,20 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess,
344352
int M, int K, int N,
345353
int block_size = 128, float percdamp = 0.01, int group_size = -1,
346354
int bits = 4, bool sym = false, bool mse = false,
347-
float norm = 2.4, int grid = 100, float maxshrink = 0.8) {
348-
float maxq = static_cast<float>(std::pow(2, bits) - 1);
355+
float norm = 2.4, int grid = 100, float maxshrink = 0.8, bool sign_ed = false) {
356+
float maxq;
357+
float minq;
358+
if (sign_ed) { // 如果有符号量化
359+
maxq = static_cast<float>(std::pow(2, bits - 1) - 1);
360+
minq = -static_cast<float>(std::pow(2, bits - 1));
361+
} else {
362+
maxq = static_cast<float>(std::pow(2, bits) - 1);
363+
minq = 0.0f;
364+
}
349365
int num_groups = (group_size == -1 ? 1 : K / group_size);
350366

351367
if (group_size == -1) {
352-
find_params(weight, b_scale, zero, N, K, bits, sym, mse, norm, grid, maxshrink);
368+
find_params(weight, b_scale, zero, N, K, bits, sym, mse, norm, grid, maxshrink, sign_ed);
353369
}
354370
float damp = 0.0f;
355371

@@ -388,13 +404,13 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess,
388404
if ((index * block_size + i) % group_size == 0) {
389405
int ind = (index * block_size + i) / group_size;
390406
for (int n = 0; n < N; n++) {
391-
find_params(&weight[n * K + index * block_size + i], &b_scale[n * num_groups + ind], &zero[n * num_groups + ind], 1, group_size, bits, sym, mse, norm, grid, maxshrink);
407+
find_params(&weight[n * K + index * block_size + i], &b_scale[n * num_groups + ind], &zero[n * num_groups + ind], 1, group_size, bits, sym, mse, norm, grid, maxshrink, sign_ed);
392408
}
393409
}
394410
}
395411
int ind = (group_size != -1 ? (index * block_size + i) / group_size : 0);
396412
for (int n = 0; n < N; n++) {
397-
float q = quantize(utils::cast<float>(weight[n * K + index * block_size + i]), utils::cast<float>(b_scale[n * num_groups + ind]), utils::cast<float>(zero[n * num_groups + ind]), maxq);
413+
float q = quantize(utils::cast<float>(weight[n * K + index * block_size + i]), utils::cast<float>(b_scale[n * num_groups + ind]), utils::cast<float>(zero[n * num_groups + ind]), minq, maxq);
398414
if constexpr (std::is_same<T, fp16_t>::value) {
399415
Q[n * K + index * block_size + i] = utils::cast<fp16_t>(q);
400416
} else if constexpr (std::is_same<T, float>::value) {
@@ -435,8 +451,16 @@ void fasterquant(T *weight, T *Q, float *Err, T *b_scale, T *zero, float *Hess,
435451
}
436452

437453
void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
438-
int32_t *packed_weight, int K, int N, int group_size, int bits = 4) {
439-
int maxq = int(std::pow(2, bits) - 1);
454+
int32_t *packed_weight, int K, int N, int group_size, int bits = 4, bool sign_ed = false) {
455+
int maxq;
456+
int minq;
457+
if (sign_ed) { // 如果有符号量化
458+
maxq = int(std::pow(2, bits - 1) - 1);
459+
minq = -int(std::pow(2, bits - 1));
460+
} else {
461+
maxq = int(std::pow(2, bits) - 1);
462+
minq = 0;
463+
}
440464
int num_groups = (group_size == -1) ? 1 : K / group_size;
441465
int blocks_per_group = (group_size == -1) ? K / 8 : group_size / 8;
442466

@@ -458,7 +482,7 @@ void PackQuantizedWeight(fp16_t *Q, fp16_t *b_scale, fp16_t *zero,
458482
int k = row_base + i;
459483
float val = utils::cast<float>(Q[n * K + k]); // Q: [N, K]
460484
int q = static_cast<int>(std::roundf(val / scale + zero_f));
461-
q = std::max(0, std::min(maxq, q)); // clamp to [0, maxq]
485+
q = std::max(minq, std::min(maxq, q)); // clamp to [minq, maxq]
462486
packed |= (q & 0xF) << (i * 4);
463487
}
464488

@@ -518,6 +542,7 @@ void quantWeights(void *workspace, int32_t *packed_weights,
518542
int grid = 100;
519543
float maxshrink = 0.8f;
520544
float nsamples = 0.0f;
545+
bool sign_ed = false;
521546

522547
char *tmp = (char *)workspace + (K * K + N * block_size) * sizeof(float);
523548
float *Hess = (float *)workspace; //[K, K]
@@ -535,9 +560,9 @@ void quantWeights(void *workspace, int32_t *packed_weights,
535560
M, K, N,
536561
block_size, percdamp, group_size,
537562
bits, sym, mse,
538-
norm, grid, maxshrink);
563+
norm, grid, maxshrink, sign_ed);
539564

540-
PackQuantizedWeight(Q, b_scale, zero, packed_weights, K, N, group_size, bits);
565+
PackQuantizedWeight(Q, b_scale, zero, packed_weights, K, N, group_size, bits, sign_ed);
541566
}
542567

543568
void caculate(void *workspace, fp16_t *C, const fp16_t *A,

test/infiniop/quantize_gptq.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ class QuantizeGPTQDescriptor(Structure):
6464
infiniopQuantizeGPTQDescriptor_t = POINTER(QuantizeGPTQDescriptor)
6565

6666

67-
def quantize(x, scale, zero, maxq):
67+
def quantize(x, scale, zero, minq, maxq):
6868
if scale.shape[1] == 1:
69-
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
69+
q = torch.clamp(torch.round(x / scale) + zero, minq, maxq)
7070
return scale * (q - zero)
7171
else:
7272
group_size = x.shape[1] // scale.shape[1]
@@ -77,7 +77,7 @@ def quantize(x, scale, zero, maxq):
7777
x[:, j * group_size : (j + 1) * group_size] / scale[:, j : j + 1]
7878
)
7979
+ zero[:, j : j + 1],
80-
0,
80+
minq,
8181
maxq,
8282
)
8383
y[:, j * group_size : (j + 1) * group_size] = scale[:, j : j + 1] * (
@@ -91,6 +91,7 @@ class Quantizer(nn.Module):
9191
def __init__(self, shape=1):
9292
super(Quantizer, self).__init__()
9393
self.register_buffer("maxq", torch.tensor(0))
94+
self.register_buffer("minq", torch.tensor(0))
9495
self.register_buffer("scale", torch.zeros(shape))
9596
self.register_buffer("zero", torch.zeros(shape))
9697

@@ -103,8 +104,14 @@ def configure(
103104
norm=2.4,
104105
grid=100,
105106
maxshrink=0.8,
107+
sign_ed=False,
106108
):
107-
self.maxq = torch.tensor(2**bits - 1)
109+
if sign_ed: # 有符号量化,范围是[-8,7]
110+
self.maxq = torch.tensor(2 ** (bits - 1) - 1)
111+
self.minq = -torch.tensor(2 ** (bits - 1))
112+
else: # 无符号量化,范围是[0,15]
113+
self.maxq = torch.tensor(2**bits - 1)
114+
self.minq = -torch.tensor(0)
108115
self.perchannel = perchannel
109116
self.sym = sym
110117
self.mse = mse
@@ -115,6 +122,7 @@ def configure(
115122
def find_params(self, x, weight=False):
116123
dev = x.device
117124
self.maxq = self.maxq.to(dev)
125+
self.minq = self.minq.to(dev)
118126

119127
shape = x.shape
120128
if self.perchannel:
@@ -139,9 +147,9 @@ def find_params(self, x, weight=False):
139147
xmin[tmp] = -1
140148
xmax[tmp] = +1
141149

142-
self.scale = (xmax - xmin) / self.maxq
150+
self.scale = (xmax - xmin) / (self.maxq - self.minq)
143151
if self.sym:
144-
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
152+
self.zero = torch.full_like(self.scale, (self.maxq + self.minq + 1) / 2)
145153
else:
146154
self.zero = torch.round(-xmin / self.scale)
147155

@@ -151,9 +159,11 @@ def find_params(self, x, weight=False):
151159
p = 1 - i / self.grid
152160
xmin1 = p * xmin
153161
xmax1 = p * xmax
154-
scale1 = (xmax1 - xmin1) / self.maxq
162+
scale1 = (xmax1 - xmin1) / (self.maxq - self.minq)
155163
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
156-
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
164+
q = quantize(
165+
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.minq, self.maxq
166+
)
157167
q -= x
158168
q.abs_()
159169
q.pow_(self.norm)
@@ -190,7 +200,7 @@ def find_params(self, x, weight=False):
190200

191201
def quantize(self, x):
192202
if self.ready():
193-
return quantize(x, self.scale, self.zero, self.maxq)
203+
return quantize(x, self.scale, self.zero, self.minq, self.maxq)
194204
return x
195205

196206
def enabled(self):
@@ -292,6 +302,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1):
292302
w.unsqueeze(1),
293303
self.quantizer.scale,
294304
self.quantizer.zero,
305+
self.quantizer.minq,
295306
self.quantizer.maxq,
296307
).flatten()
297308
Q1[:, i] = q
@@ -313,13 +324,13 @@ def fasterquant(self, blocksize=128, percdamp=0.01, group_size=-1):
313324
self.zero = zero.to(self.weight.dtype)
314325

315326

316-
def get_scale_zero(b, a, c, group_size):
327+
def get_scale_zero(b, a, c, group_size, sign_ed):
317328
weight = b.clone()
318329
inp = a.clone()
319330
out = c.clone()
320331
gptq = GPTQ(weight)
321332
gptq.quantizer = Quantizer()
322-
gptq.quantizer.configure(perchannel=True, sym=False, mse=False)
333+
gptq.quantizer.configure(perchannel=True, sym=False, mse=False, signed=sign_ed)
323334
gptq.add_batch(inp, out)
324335
gptq.fasterquant(group_size=group_size)
325336

@@ -383,7 +394,9 @@ def test(
383394
s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
384395
z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
385396
if torch_device == "cuda":
386-
b_ref, s, z = get_scale_zero(b, a.t(), c, group_size)
397+
b_ref, s, z = get_scale_zero(
398+
b, a.t(), c, group_size, signed=False
399+
) # 无符号量化
387400
z = torch.zeros_like(s)
388401
packed_weights = pack(b_ref, s, z)
389402
# print(s)

0 commit comments

Comments
 (0)