diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index f6e575fd9f3..7439d2c8272 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -40,7 +40,7 @@ ONNXRT1161_VERSION = Version("1.16.1") -def get_blob_size(group_size, has_zp): # pragma: no cover +def get_blob_size(group_size, num_bits, has_zp): # pragma: no cover """Get blob_size. Args: @@ -48,11 +48,11 @@ def get_blob_size(group_size, has_zp): # pragma: no cover has_zp (bool): whether zero_point is None """ if Version(ort.__version__) > ONNXRT1161_VERSION: - blob_size = group_size // 2 + blob_size = group_size * num_bits // 8 elif has_zp: - blob_size = group_size // 2 + 4 + 1 + blob_size = group_size * num_bits // 8 + 4 + 1 else: - blob_size = group_size // 2 + 4 + blob_size = group_size * num_bits // 8 + 4 return blob_size @@ -86,7 +86,7 @@ def make_matmul_weight_only_node( matmul_weight_only_node: MatMulFpQ4 or MatMulNBits node new_inits: initializers of the new node """ - blob_size = get_blob_size(group_size, zero_point is not None) + blob_size = get_blob_size(group_size, num_bits, zero_point is not None) packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8") q_weight_name = node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)) input_names = [node.input[0], q_weight_name] @@ -97,8 +97,14 @@ def make_matmul_weight_only_node( op_type = "MatMulNBits" # pack quantized weight - q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4 - packed[:, :] = q_weight_pairs[:, :blob_size] + if num_bits == 4: + q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4 + packed[:, :] = q_weight_pairs[:, :blob_size] + elif num_bits == 8: + packed = q_weight + else: + logger.error("MatMulNBits does not have kernel support for num_bits = {}.".format(num_bits)) + packed = np.reshape(packed, (-1, k_blocks, blob_size)) # build scale tensor @@ -115,7 +121,9 @@ def make_matmul_weight_only_node( # build zero_point tensor if zero_point is not None: - if num_bits > 4: + if num_bits == 8: + packed_zp = zero_point.astype("uint8") + elif num_bits > 4: packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8") else: packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8") @@ -128,6 +136,7 @@ def make_matmul_weight_only_node( packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel() packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4) + packed_zp = np.reshape(packed_zp, (weight_shape[1], -1)) zp_tensor = onnx.helper.make_tensor( name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True ) @@ -247,6 +256,170 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra return q_weight, scale, zero_point +def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 32. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = np.array(mad) + best_mad_1 = np.array(best_mad) + idx_to_replace = np.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(np.float64) + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight, scale, zero_point + + +def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + try: + import cupy as cp + import torch + + if torch.cuda.is_available(): + data = cp.asarray(data) + data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1) + weights = cp.add(av_x, cp.abs(data)) # (nb, group_size) + rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = cp.subtract(sum_w * sum_l2, sum_l**2) # (nb, 1) + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = cp.array(mad) + best_mad_1 = cp.array(best_mad) + idx_to_replace = cp.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(cp.float64) + q_weight = cp.empty_like(data, dtype=scale.dtype) + cp.divide(data, scale, out=q_weight) + cp.add(q_weight, zero_point, out=q_weight) + cp.round(q_weight, out=q_weight) + cp.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight.get(), scale.get(), zero_point.get() + else: + logger.warning( + "Try to use k-quant quantization on CUDA. However, CUDA is not available." + "Fall back to k-quant quantization on CPU." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) + except ImportError: + logger.info( + "Now we are using k-quant quantization on cpu, which is time consuming." + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" + "Please also install torch to check CUDA availability." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) + + def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): """Quant dequant tensor per group. @@ -299,6 +472,7 @@ def rtn_quantize( ratios={}, accuracy_level=0, providers=["CPUExecutionProvider"], + algorithm="k_quant", ): """Quant the model with round to nearst method. @@ -362,7 +536,10 @@ def rtn_quantize( weight = pad_tensor(weight, group_size, k_blocks) - satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4 + enable_MatMulNBits_8bits = True + satisfy_MatMulNBits_condition = (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or ( + enable_MatMulNBits_8bits and num_bits == 8 + ) satisfy_MatMulFpQ4_condition = ( Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 ) @@ -372,9 +549,13 @@ def rtn_quantize( ): # pragma: no cover # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP - q_weight, scale, zp = quant_tensor( - weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) - ) + if algorithm == "k_quant": + q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size) + else: + q_weight, scale, zp = quant_tensor( + weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) + ) + q_matmul_node, new_inits = make_matmul_weight_only_node( node=node, weight_shape=org_w_shape,