11import torch
22import torch .nn as nn
3+ import numpy as np
34import math
45import ctypes
56from ctypes import POINTER , Structure , c_int32 , c_size_t , c_uint64 , c_void_p , c_float
@@ -354,9 +355,170 @@ def pack(weight, scale, zero, minq, maxq):
354355 return qweight
355356
356357
358+ def _get_perms ():
359+ perm = []
360+ for i in range (32 ):
361+ perm1 = []
362+ col = i // 4
363+ for block in [0 , 1 ]:
364+ for row in [
365+ 2 * (i % 4 ),
366+ 2 * (i % 4 ) + 1 ,
367+ 2 * (i % 4 + 4 ),
368+ 2 * (i % 4 + 4 ) + 1 ,
369+ ]:
370+ perm1 .append (16 * row + col + 8 * block )
371+ for j in range (4 ):
372+ perm .extend ([p + 256 * j for p in perm1 ])
373+
374+ perm = np .array (perm )
375+ interleave = np .array ([0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ])
376+ perm = perm .reshape ((- 1 , 8 ))[:, interleave ].ravel ()
377+ perm = torch .from_numpy (perm )
378+ scale_perm = []
379+ for i in range (8 ):
380+ scale_perm .extend ([i + 8 * j for j in range (8 )])
381+ scale_perm_single = []
382+ for i in range (4 ):
383+ scale_perm_single .extend ([2 * i + j for j in [0 , 1 , 8 , 9 , 16 , 17 , 24 , 25 ]])
384+ return perm , scale_perm , scale_perm_single
385+
386+
387+ _perm , _scale_perm , _scale_perm_single = _get_perms ()
388+
389+
390+ class MarlinLayer (nn .Module ):
391+ """PyTorch compatible Marlin layer; 4-bit (symmetric grouped) linear layer without bias."""
392+
393+ def __init__ (self , infeatures , outfeatures , groupsize = - 1 ):
394+ """Create an empty Marlin layer.
395+ @infeatures: number of input features (must be divisible by 128)
396+ @outfeatures: number of output features (must be divisible by 256)
397+ @groupsize: quantization groupsize (must be -1 or 128)
398+ """
399+ super ().__init__ ()
400+ if groupsize not in [- 1 , 128 ]:
401+ raise ValueError ("Only groupsize -1 and 128 are supported." )
402+ if infeatures % 128 != 0 or outfeatures % 256 != 0 :
403+ raise ValueError (
404+ "`infeatures` must be divisible by 128 and `outfeatures` by 256."
405+ )
406+ if groupsize == - 1 :
407+ groupsize = infeatures
408+ if infeatures % groupsize != 0 :
409+ raise ValueError ("`infeatures` must be divisible by `groupsize`." )
410+ self .k = infeatures
411+ self .n = outfeatures
412+ self .groupsize = groupsize
413+ self .register_buffer (
414+ "B" , torch .empty ((self .k // 16 , self .n * 16 // 8 ), dtype = torch .int )
415+ )
416+ self .register_buffer (
417+ "s" , torch .empty ((self .k // groupsize , self .n ), dtype = torch .half )
418+ )
419+
420+ def forward (self , A ):
421+ C = torch .empty (
422+ A .shape [:- 1 ] + (self .s .shape [1 ],), dtype = A .dtype , device = A .device
423+ )
424+ marlin_matmul (
425+ A .view ((- 1 , A .shape [- 1 ])),
426+ self .B ,
427+ C .view ((- 1 , C .shape [- 1 ])),
428+ self .s ,
429+ )
430+ return C
431+
432+ def pack (self , linear , scales ):
433+ """Pack a fake-quantized linear layer into this actual Marlin representation.
434+ @linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
435+ @scales: corresponding quantization scales of shape `(infeatures, groups)`
436+ """
437+ if linear .weight .dtype != torch .half :
438+ raise ValueError ("Only `torch.half` weights are supported." )
439+ tile = 16
440+ maxq = 2 ** 4 - 1
441+ s = scales .t ()
442+ w = linear .weight .data .t ()
443+ if self .groupsize != self .k :
444+ w = w .reshape ((- 1 , self .groupsize , self .n ))
445+ w = w .permute (1 , 0 , 2 )
446+ w = w .reshape ((self .groupsize , - 1 ))
447+ s = s .reshape ((1 , - 1 ))
448+ w = torch .round (w / s ).int ()
449+ w += (maxq + 1 ) // 2
450+ w = torch .clamp (w , 0 , maxq )
451+ if self .groupsize != self .k :
452+ w = w .reshape ((self .groupsize , - 1 , self .n ))
453+ w = w .permute (1 , 0 , 2 )
454+ w = w .reshape ((self .k , self .n )).contiguous ()
455+ s = s .reshape ((- 1 , len (_scale_perm )))[:, _scale_perm ]
456+ else :
457+ s = s .reshape ((- 1 , len (_scale_perm_single )))[:, _scale_perm_single ]
458+ s = s .reshape ((- 1 , self .n )).contiguous ()
459+ w = w .reshape ((self .k // tile , tile , self .n // tile , tile ))
460+ w = w .permute ((0 , 2 , 1 , 3 ))
461+ w = w .reshape ((self .k // tile , self .n * tile ))
462+ res = w
463+ res = res .reshape ((- 1 , _perm .numel ()))[:, _perm ].reshape (res .shape )
464+ q = np .zeros ((res .shape [0 ], res .shape [1 ] // 8 ), dtype = np .uint32 )
465+ res = res .cpu ().numpy ().astype (np .uint32 )
466+ for i in range (8 ):
467+ q |= res [:, i ::8 ] << 4 * i
468+ q = torch .from_numpy (q .astype (np .int32 )).to (w .device )
469+ self .B [:, :] = q .to (self .B .device )
470+ self .s [:, :] = s .to (self .s .device )
471+
472+
473+ def gen_quant4 (m , n , groupsize = - 1 ):
474+ DEV = torch .device ("cuda:0" )
475+ tile = 16
476+ maxq = 2 ** 4 - 1
477+ w = torch .randn ((m , n ), dtype = torch .half , device = DEV )
478+ if groupsize != - 1 :
479+ w = w .reshape ((- 1 , groupsize , n ))
480+ w = w .permute (1 , 0 , 2 )
481+ w = w .reshape ((groupsize , - 1 ))
482+ s = torch .max (torch .abs (w ), 0 , keepdim = True )[0 ]
483+ s *= 2 / maxq
484+ w = torch .round (w / s ).int ()
485+ w += (maxq + 1 ) // 2
486+ w = torch .clamp (w , 0 , maxq )
487+ ref = (w - (maxq + 1 ) // 2 ).half () * s
488+ if groupsize != - 1 :
489+
490+ def reshape (w ):
491+ w = w .reshape ((groupsize , - 1 , n ))
492+ w = w .permute (1 , 0 , 2 )
493+ w = w .reshape ((m , n )).contiguous ()
494+ return w
495+
496+ ref = reshape (ref )
497+ w = reshape (w )
498+ s = s .reshape ((- 1 , n )).contiguous ()
499+ linear = nn .Linear (m , n )
500+ linear .weight .data = ref .t ()
501+ # Workaround to test some special cases that are forbidden by the API
502+ layer = MarlinLayer (256 , 256 , groupsize = groupsize )
503+ if groupsize == - 1 :
504+ groupsize = m
505+ layer .k = m
506+ layer .n = n
507+ layer .groupsize = groupsize
508+ layer .B = torch .empty ((m // 16 , n * 16 // 8 ), dtype = torch .int , device = DEV )
509+ layer .s = torch .empty ((m // groupsize , n ), dtype = torch .half , device = DEV )
510+ layer .pack (linear , s .t ())
511+ q = layer .B .reshape (m // 8 , n )
512+ s = layer .s
513+ return ref , q , s
514+
515+
357516# PyTorch implementation for matrix multiplication
358- def quantize_gptq (a , b ): # 昇腾芯片的CPU不支持转置计算
359- ans = torch .matmul (a .to (torch .float32 ), b .to (torch .float32 )).to (b .dtype )
517+ def quantize_gptq (a , b , is_weight_transposed ): # 昇腾芯片的CPU不支持转置计算
518+ if is_weight_transposed :
519+ ans = torch .matmul (a .to (torch .float32 ), b .to (torch .float32 )).to (b .dtype )
520+ else :
521+ ans = torch .matmul (b .to (torch .float32 ), a .to (torch .float32 )).to (b .dtype )
360522 return ans
361523
362524
@@ -379,7 +541,7 @@ def test(
379541 # Initialize tensors
380542 a = 1e0 * torch .randn ([K , M ], dtype = dtype ).to (torch_device )
381543 layer = nn .Linear (K , N )
382- b = 1e-3 * layer .weight .data .to (dtype ).to (torch_device )
544+ b = 1e0 * layer .weight .data .to (dtype ).to (torch_device )
383545 c = torch .zeros ([N , M ], dtype = dtype ).to (torch_device )
384546 is_weight_transposed = False
385547 sign_ed = False
@@ -393,10 +555,6 @@ def test(
393555 num_groups = 1
394556 else :
395557 num_groups = K // group_size
396- if is_weight_transposed :
397- ans = quantize_gptq (a .t (), b .t ())
398- else :
399- ans = quantize_gptq (b , a )
400558 packed_weights = torch .zeros ([N , K // 8 ], dtype = torch .int32 ).to (torch_device )
401559 s = torch .zeros ([N , num_groups ], dtype = dtype ).to (torch_device )
402560 z = torch .zeros ([N , num_groups ], dtype = dtype ).to (torch_device )
@@ -409,36 +567,28 @@ def test(
409567 minq = - (2 ** (bits - 1 ))
410568
411569 if torch_device == "cuda" :
412- b_ref , s , z = get_scale_zero (
413- b , a .t (), c , group_size , bits , sym , sign_ed = sign_ed
414- ) # 无符号量化
415-
416- packed_weights = pack (b_ref , s , z , minq , maxq )
570+ b , packed_weights , s = gen_quant4 (K , N , groupsize = group_size )
571+ a = 1e0 * torch .randn ([M , K ], dtype = dtype ).to (
572+ torch_device
573+ ) # 不知道为什么,不能使用a = a.t(), c = c.t()
574+ c = torch .zeros ([M , N ], dtype = dtype ).to (torch_device )
575+ z = torch .zeros_like (s ).to (torch_device )
417576
418577 # if torch_device == "cpu":
419578 # b_ref, s, z = get_scale_zero(
420579 # b, a.t(), c, group_size, bits, sym, sign_ed=sign_ed
421580 # ) # 无符号量化
422581
423582 # packed_weights = pack(b_ref, s, z, minq, maxq)
424- if is_weight_transposed :
425- a_tensor , b_tensor , c_tensor , s_tensor , z_tensor , packed_weights_tensor = (
426- to_tensor (a .t (), lib ),
427- to_tensor (b .t (), lib ),
428- to_tensor (c .t (), lib ),
429- to_tensor (s .t (), lib ),
430- to_tensor (z .t (), lib ),
431- to_tensor (packed_weights .t (), lib ),
432- )
433- else :
434- a_tensor , b_tensor , c_tensor , s_tensor , z_tensor , packed_weights_tensor = (
435- to_tensor (a , lib ),
436- to_tensor (b , lib ),
437- to_tensor (c , lib ),
438- to_tensor (s , lib ),
439- to_tensor (z , lib ),
440- to_tensor (packed_weights , lib ),
441- )
583+ ans = quantize_gptq (a , b , is_weight_transposed )
584+ a_tensor , b_tensor , c_tensor , s_tensor , z_tensor , packed_weights_tensor = (
585+ to_tensor (a , lib ),
586+ to_tensor (b , lib ),
587+ to_tensor (c , lib ),
588+ to_tensor (s , lib ),
589+ to_tensor (z , lib ),
590+ to_tensor (packed_weights , lib ),
591+ )
442592
443593 descriptor = infiniopQuantizeGPTQDescriptor_t ()
444594 check_error (
@@ -522,10 +672,7 @@ def lib_quantize_gptq():
522672 # Profiling workflow
523673 if PROFILE :
524674 # fmt: off
525- if (is_weight_transposed ):
526- profile_operation ("PyTorch" , lambda : quantize_gptq (a .t (), b .t ()), torch_device , NUM_PRERUN , NUM_ITERATIONS )
527- else :
528- profile_operation ("PyTorch" , lambda : quantize_gptq (b , a ), torch_device , NUM_PRERUN , NUM_ITERATIONS )
675+ profile_operation ("PyTorch" , lambda : quantize_gptq (a , b , is_weight_transposed ), torch_device , NUM_PRERUN , NUM_ITERATIONS )
529676 profile_operation (" lib" , lambda : lib_quantize_gptq (), torch_device , NUM_PRERUN , NUM_ITERATIONS )
530677 # fmt: on
531678 check_error (lib .infiniopDestroyQuantizeGPTQDescriptor (descriptor ))
0 commit comments