Skip to content

Commit

Permalink
Merge branch 'main' into hengguo/update0114
Browse files Browse the repository at this point in the history
  • Loading branch information
n1ck-guo committed Jan 16, 2025
2 parents 3075914 + 937d019 commit 2e2b1ac
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 85 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ AutoRound

[![python](https://img.shields.io/badge/python-3.9%2B-blue)](https://github.com/intel/auto-round)
[![version](https://img.shields.io/badge/release-0.4.4-green)](https://github.com/intel/auto-round)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/intel/auto-round/blob/main/LICENSE)
[![license](https://img.shields.io/badge/license-Apache%202-9C27B0)](https://github.com/intel/auto-round/blob/main/LICENSE)
<a href="https://huggingface.co/OPEA">
<img alt="Model Checkpoints" src="https://img.shields.io/badge/%F0%9F%A4%97%20HF-Models-F57C00">
</a>
---
<div align="left">

Expand Down
1 change: 1 addition & 0 deletions auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"mx_fp8": (4, 5, 8, 448.0, 0.015625),
"mx_fp8e4m3": (4, 5, 8, 448.0, 0.015625),
"mx_fp6e3m2": (3, 4, 4, 28.0, 0.25),
"mx_fp6": (2, 5, 2, 7.5, 1.0),
"mx_fp6e2m3": (2, 5, 2, 7.5, 1.0),
"mx_fp4": (2, 3, 2, 6.0, 1.0),
"mx_fp4e2m1": (2, 3, 2, 6.0, 1.0),
Expand Down
11 changes: 6 additions & 5 deletions auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# SOFTWARE.
import torch

import auto_round.export.export_to_autogptq.qlinear_triton
from auto_round.utils import check_to_quantized, get_block_names, \
get_module, logger, set_module
import copy
Expand Down Expand Up @@ -102,7 +103,11 @@ def pack_layer(name, model, layer_config, backend, pbar):
# so far can only pack layer on CPU
qlayer.to("cpu")
##force to float32 to be compatible with torch 2.0
layer, scale, zero = layer.to("cpu"), scale.to("cpu"), zero.to("cpu").to(torch.float32)
if sym and isinstance(new_layer, auto_round.export.export_to_autogptq.qlinear_triton.QuantLinear):
layer, scale = layer.to("cpu"), scale.to("cpu")
zero = 2 ** (bits - 1)
else:
layer, scale, zero = layer.to("cpu"), scale.to("cpu"), zero.to("cpu").to(torch.float32)
sig = inspect.signature(qlayer.pack)
param_count = len(sig.parameters)
if param_count == 2:
Expand Down Expand Up @@ -138,7 +143,6 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
logger.error(f"auto-gptq format may not support loading this quantized model")
quantization_config['block_name_to_quantize'] = common_prefix
quantization_config.pop("to_quant_block_names", None)


all_to_quantized = True
modules_in_block_to_quantize = []
Expand Down Expand Up @@ -215,6 +219,3 @@ def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", saf
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(model.config.quantization_config, f, indent=2)



135 changes: 115 additions & 20 deletions auto_round/export/export_to_autogptq/qlinear_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,29 @@
import torch
import torch.nn as nn
import transformers
import numba


##TODO different bits
# @numba.jit(nopython=True, parallel=True)
# def pack_array_with_numba_b4_c32(
# raw_array: np.ndarray, packed_array: np.ndarray
# ) -> np.ndarray:
# """Pack the array with numba when bits=4 and compress_bits=32."""
# bits = 4
# n_pack = 32 // bits
#
# for row in range(packed_array.shape[0]):
# packed_array[row] = ((((raw_array[row * n_pack + 7]) << 28)
# | ((raw_array[row * n_pack + 6]) << 24)
# | ((raw_array[row * n_pack + 5]) << 20)
# | ((raw_array[row * n_pack + 4]) << 16)
# | ((raw_array[row * n_pack + 3]) << 12)
# | (raw_array[row * n_pack + 2]) << 8)
# | ((raw_array[row * n_pack + 1]) << 4)
# | ((raw_array[row * n_pack]) << 0))
#
# return packed_array


class TritonModuleMixin:
Expand Down Expand Up @@ -76,7 +99,7 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa
def post_init(self):
pass

def pack(self, linear, scales, zeros, g_idx=None):
def pack_cpu(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
Expand All @@ -86,8 +109,11 @@ def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
if isinstance(zeros, torch.Tensor):
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
else:
scale_zeros = scales * zeros
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
Expand All @@ -107,34 +133,103 @@ def pack(self, linear, scales, zeros, g_idx=None):
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

if isinstance(zeros, torch.Tensor):
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
col += 1

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
else:
zeros -= 1
shape = scales.shape
value = 0
for j in range(0, (32 // self.bits)):
value |= zeros << (self.bits * j)
qzeros = np.ones((shape[0], shape[1] // 32 * self.bits), dtype=np.uint32) * value
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

def pack(self, linear, scales, zeros, g_idx):
if torch.cuda.is_available():
return self.pack_cuda(linear, scales, zeros, g_idx)
else:
return self.pack_cpu(linear, scales, zeros, g_idx)

def pack_cuda(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
scales_t = scales.t().contiguous()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.scales = scales_t.clone().half()

repeat_scales = scales.to("cuda:0").repeat_interleave(self.group_size, 1)
if isinstance(zeros, torch.Tensor):
repeat_zeros = zeros.to("cuda:0").repeat_interleave(self.group_size, 1)
else:
repeat_zeros = zeros

intweight = torch.round(W.to("cuda:0") / repeat_scales + repeat_zeros).to(torch.int).t().contiguous().to("cpu")
intweight = intweight.numpy().astype(np.uint32)
del repeat_scales

i = 0
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
# pack_array_with_numba_b4_c32(intweight, qweight)
while row < qweight.shape[0]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
if isinstance(zeros, torch.Tensor):
zeros = zeros.t().contiguous()
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
else:
zeros -= 1
shape = scales_t.shape
value = 0
for j in range(0, (32 // self.bits)):
value |= zeros << (self.bits * j)
qzeros = np.ones((shape[0], shape[1] // 32 * self.bits), dtype=np.uint32) * value
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)


__all__ = ["QuantLinear"]
5 changes: 2 additions & 3 deletions auto_round/export/export_to_awq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def pack_layer(name, model, layer_config, backend, pbar):
bits = config["bits"]
group_size = config["group_size"]
linear_layer = get_module(model, name)
if config["sym"] :
zp = 2** (config["bits"]-1)
q_linear = WQLinear_GEMM.from_linear(
linear=linear_layer,
w_bit=bits,
Expand All @@ -62,10 +64,7 @@ def pack_layer(name, model, layer_config, backend, pbar):
scales=scale,
zeros=zp,
)
linear_layer.cpu()
q_linear.to("cpu")
set_module(model, name, q_linear)
clear_memory()
pbar.update(1)


Expand Down
79 changes: 43 additions & 36 deletions auto_round/export/export_to_awq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,23 +221,29 @@ def from_linear(

pack_num = 32 // awq_linear.w_bit

intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)

best_device = get_best_device()
if torch.cuda.is_available():

# Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device
if "mps" in best_device:
intweight = intweight.to("cpu")
repeat_scales = scales.to("cuda:0").t().repeat_interleave(group_size, 1)
if isinstance(zeros, torch.Tensor):
repeat_zeros = zeros.to("cuda:0").t().repeat_interleave(group_size, 1)
else:
repeat_zeros = zeros
intweight = torch.round(linear.weight.to("cuda:0") / repeat_scales + repeat_zeros).to(
torch.int).t().contiguous().to("cpu")
intweight = intweight.to(dtype=torch.int32)
del repeat_scales
else:
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)

qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
Expand All @@ -246,34 +252,35 @@ def from_linear(
)

for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight

zeros = zeros.to(dtype=torch.int32, device=best_device)

if "mps" in best_device:
zeros = zeros.to("cpu")

qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
if isinstance(zeros, torch.Tensor):
qzeros = torch.zeros(
(scales.shape[0], scales.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
zeros = zeros.to(dtype=torch.int32, device="cpu")

for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
for col in range(zeros.shape[1] // pack_num):
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
else:
value = 0
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
value |= zeros << (i * awq_linear.w_bit)
qzeros = torch.ones(
(scales.shape[0], scales.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=qweight.device,
) * value

awq_linear.qzeros = qzeros

return awq_linear
Expand Down
Loading

0 comments on commit 2e2b1ac

Please sign in to comment.