Skip to content

Commit 830daeb

Browse files
committed
issue/170: success marlin
1 parent c909d0e commit 830daeb

File tree

1 file changed

+181
-34
lines changed

1 file changed

+181
-34
lines changed

test/infiniop/quantize_gptq.py

Lines changed: 181 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
import numpy as np
34
import math
45
import ctypes
56
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
@@ -354,9 +355,170 @@ def pack(weight, scale, zero, minq, maxq):
354355
return qweight
355356

356357

358+
def _get_perms():
359+
perm = []
360+
for i in range(32):
361+
perm1 = []
362+
col = i // 4
363+
for block in [0, 1]:
364+
for row in [
365+
2 * (i % 4),
366+
2 * (i % 4) + 1,
367+
2 * (i % 4 + 4),
368+
2 * (i % 4 + 4) + 1,
369+
]:
370+
perm1.append(16 * row + col + 8 * block)
371+
for j in range(4):
372+
perm.extend([p + 256 * j for p in perm1])
373+
374+
perm = np.array(perm)
375+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
376+
perm = perm.reshape((-1, 8))[:, interleave].ravel()
377+
perm = torch.from_numpy(perm)
378+
scale_perm = []
379+
for i in range(8):
380+
scale_perm.extend([i + 8 * j for j in range(8)])
381+
scale_perm_single = []
382+
for i in range(4):
383+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
384+
return perm, scale_perm, scale_perm_single
385+
386+
387+
_perm, _scale_perm, _scale_perm_single = _get_perms()
388+
389+
390+
class MarlinLayer(nn.Module):
391+
"""PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias."""
392+
393+
def __init__(self, infeatures, outfeatures, groupsize=-1):
394+
"""Create an empty Marlin layer.
395+
@infeatures: number of input features (must be divisible by 128)
396+
@outfeatures: number of output features (must be divisible by 256)
397+
@groupsize: quantization groupsize (must be -1 or 128)
398+
"""
399+
super().__init__()
400+
if groupsize not in [-1, 128]:
401+
raise ValueError("Only groupsize -1 and 128 are supported.")
402+
if infeatures % 128 != 0 or outfeatures % 256 != 0:
403+
raise ValueError(
404+
"`infeatures` must be divisible by 128 and `outfeatures` by 256."
405+
)
406+
if groupsize == -1:
407+
groupsize = infeatures
408+
if infeatures % groupsize != 0:
409+
raise ValueError("`infeatures` must be divisible by `groupsize`.")
410+
self.k = infeatures
411+
self.n = outfeatures
412+
self.groupsize = groupsize
413+
self.register_buffer(
414+
"B", torch.empty((self.k // 16, self.n * 16 // 8), dtype=torch.int)
415+
)
416+
self.register_buffer(
417+
"s", torch.empty((self.k // groupsize, self.n), dtype=torch.half)
418+
)
419+
420+
def forward(self, A):
421+
C = torch.empty(
422+
A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device
423+
)
424+
marlin_matmul(
425+
A.view((-1, A.shape[-1])),
426+
self.B,
427+
C.view((-1, C.shape[-1])),
428+
self.s,
429+
)
430+
return C
431+
432+
def pack(self, linear, scales):
433+
"""Pack a fake-quantized linear layer into this actual Marlin representation.
434+
@linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
435+
@scales: corresponding quantization scales of shape `(infeatures, groups)`
436+
"""
437+
if linear.weight.dtype != torch.half:
438+
raise ValueError("Only `torch.half` weights are supported.")
439+
tile = 16
440+
maxq = 2**4 - 1
441+
s = scales.t()
442+
w = linear.weight.data.t()
443+
if self.groupsize != self.k:
444+
w = w.reshape((-1, self.groupsize, self.n))
445+
w = w.permute(1, 0, 2)
446+
w = w.reshape((self.groupsize, -1))
447+
s = s.reshape((1, -1))
448+
w = torch.round(w / s).int()
449+
w += (maxq + 1) // 2
450+
w = torch.clamp(w, 0, maxq)
451+
if self.groupsize != self.k:
452+
w = w.reshape((self.groupsize, -1, self.n))
453+
w = w.permute(1, 0, 2)
454+
w = w.reshape((self.k, self.n)).contiguous()
455+
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
456+
else:
457+
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
458+
s = s.reshape((-1, self.n)).contiguous()
459+
w = w.reshape((self.k // tile, tile, self.n // tile, tile))
460+
w = w.permute((0, 2, 1, 3))
461+
w = w.reshape((self.k // tile, self.n * tile))
462+
res = w
463+
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
464+
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
465+
res = res.cpu().numpy().astype(np.uint32)
466+
for i in range(8):
467+
q |= res[:, i::8] << 4 * i
468+
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
469+
self.B[:, :] = q.to(self.B.device)
470+
self.s[:, :] = s.to(self.s.device)
471+
472+
473+
def gen_quant4(m, n, groupsize=-1):
474+
DEV = torch.device("cuda:0")
475+
tile = 16
476+
maxq = 2**4 - 1
477+
w = torch.randn((m, n), dtype=torch.half, device=DEV)
478+
if groupsize != -1:
479+
w = w.reshape((-1, groupsize, n))
480+
w = w.permute(1, 0, 2)
481+
w = w.reshape((groupsize, -1))
482+
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
483+
s *= 2 / maxq
484+
w = torch.round(w / s).int()
485+
w += (maxq + 1) // 2
486+
w = torch.clamp(w, 0, maxq)
487+
ref = (w - (maxq + 1) // 2).half() * s
488+
if groupsize != -1:
489+
490+
def reshape(w):
491+
w = w.reshape((groupsize, -1, n))
492+
w = w.permute(1, 0, 2)
493+
w = w.reshape((m, n)).contiguous()
494+
return w
495+
496+
ref = reshape(ref)
497+
w = reshape(w)
498+
s = s.reshape((-1, n)).contiguous()
499+
linear = nn.Linear(m, n)
500+
linear.weight.data = ref.t()
501+
# Workaround to test some special cases that are forbidden by the API
502+
layer = MarlinLayer(256, 256, groupsize=groupsize)
503+
if groupsize == -1:
504+
groupsize = m
505+
layer.k = m
506+
layer.n = n
507+
layer.groupsize = groupsize
508+
layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device=DEV)
509+
layer.s = torch.empty((m // groupsize, n), dtype=torch.half, device=DEV)
510+
layer.pack(linear, s.t())
511+
q = layer.B.reshape(m // 8, n)
512+
s = layer.s
513+
return ref, q, s
514+
515+
357516
# PyTorch implementation for matrix multiplication
358-
def quantize_gptq(a, b): # 昇腾芯片的CPU不支持转置计算
359-
ans = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(b.dtype)
517+
def quantize_gptq(a, b, is_weight_transposed): # 昇腾芯片的CPU不支持转置计算
518+
if is_weight_transposed:
519+
ans = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(b.dtype)
520+
else:
521+
ans = torch.matmul(b.to(torch.float32), a.to(torch.float32)).to(b.dtype)
360522
return ans
361523

362524

@@ -379,7 +541,7 @@ def test(
379541
# Initialize tensors
380542
a = 1e0 * torch.randn([K, M], dtype=dtype).to(torch_device)
381543
layer = nn.Linear(K, N)
382-
b = 1e-3 * layer.weight.data.to(dtype).to(torch_device)
544+
b = 1e0 * layer.weight.data.to(dtype).to(torch_device)
383545
c = torch.zeros([N, M], dtype=dtype).to(torch_device)
384546
is_weight_transposed = False
385547
sign_ed = False
@@ -393,10 +555,6 @@ def test(
393555
num_groups = 1
394556
else:
395557
num_groups = K // group_size
396-
if is_weight_transposed:
397-
ans = quantize_gptq(a.t(), b.t())
398-
else:
399-
ans = quantize_gptq(b, a)
400558
packed_weights = torch.zeros([N, K // 8], dtype=torch.int32).to(torch_device)
401559
s = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
402560
z = torch.zeros([N, num_groups], dtype=dtype).to(torch_device)
@@ -409,36 +567,28 @@ def test(
409567
minq = -(2 ** (bits - 1))
410568

411569
if torch_device == "cuda":
412-
b_ref, s, z = get_scale_zero(
413-
b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed
414-
) # 无符号量化
415-
416-
packed_weights = pack(b_ref, s, z, minq, maxq)
570+
b, packed_weights, s = gen_quant4(K, N, groupsize=group_size)
571+
a = 1e0 * torch.randn([M, K], dtype=dtype).to(
572+
torch_device
573+
) # 不知道为什么,不能使用a = a.t(), c = c.t()
574+
c = torch.zeros([M, N], dtype=dtype).to(torch_device)
575+
z = torch.zeros_like(s).to(torch_device)
417576

418577
# if torch_device == "cpu":
419578
# b_ref, s, z = get_scale_zero(
420579
# b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed
421580
# ) # 无符号量化
422581

423582
# packed_weights = pack(b_ref, s, z, minq, maxq)
424-
if is_weight_transposed:
425-
a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = (
426-
to_tensor(a.t(), lib),
427-
to_tensor(b.t(), lib),
428-
to_tensor(c.t(), lib),
429-
to_tensor(s.t(), lib),
430-
to_tensor(z.t(), lib),
431-
to_tensor(packed_weights.t(), lib),
432-
)
433-
else:
434-
a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = (
435-
to_tensor(a, lib),
436-
to_tensor(b, lib),
437-
to_tensor(c, lib),
438-
to_tensor(s, lib),
439-
to_tensor(z, lib),
440-
to_tensor(packed_weights, lib),
441-
)
583+
ans = quantize_gptq(a, b, is_weight_transposed)
584+
a_tensor, b_tensor, c_tensor, s_tensor, z_tensor, packed_weights_tensor = (
585+
to_tensor(a, lib),
586+
to_tensor(b, lib),
587+
to_tensor(c, lib),
588+
to_tensor(s, lib),
589+
to_tensor(z, lib),
590+
to_tensor(packed_weights, lib),
591+
)
442592

443593
descriptor = infiniopQuantizeGPTQDescriptor_t()
444594
check_error(
@@ -522,10 +672,7 @@ def lib_quantize_gptq():
522672
# Profiling workflow
523673
if PROFILE:
524674
# fmt: off
525-
if(is_weight_transposed):
526-
profile_operation("PyTorch", lambda: quantize_gptq(a.t(), b.t()), torch_device, NUM_PRERUN, NUM_ITERATIONS)
527-
else:
528-
profile_operation("PyTorch", lambda: quantize_gptq(b, a), torch_device, NUM_PRERUN, NUM_ITERATIONS)
675+
profile_operation("PyTorch", lambda: quantize_gptq(a, b, is_weight_transposed), torch_device, NUM_PRERUN, NUM_ITERATIONS)
529676
profile_operation(" lib", lambda: lib_quantize_gptq(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
530677
# fmt: on
531678
check_error(lib.infiniopDestroyQuantizeGPTQDescriptor(descriptor))

0 commit comments

Comments
 (0)