|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +from typing import List |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from torchao.quantization.quant_primitives import ( |
| 13 | + MappingType, |
| 14 | + choose_qparams_affine, |
| 15 | + quantize_affine, |
| 16 | +) |
| 17 | +from torchao.utils import ( |
| 18 | + TorchAOBaseTensor, |
| 19 | +) |
| 20 | + |
| 21 | +__all__ = [ |
| 22 | + "Int4PlainInt32Tensor", |
| 23 | +] |
| 24 | + |
| 25 | +aten = torch.ops.aten |
| 26 | + |
| 27 | + |
| 28 | +class Int4PlainInt32Tensor(TorchAOBaseTensor): |
| 29 | + """ |
| 30 | + int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) |
| 31 | +
|
| 32 | + Tensor Attributes: |
| 33 | + qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2), the original data type can be half and bfloat16 |
| 34 | + scale: (K/group_size, N), dtype is the same as the original Tensor dtype |
| 35 | + zero_point: (K/group_size, N), dtype is int8 |
| 36 | +
|
| 37 | + Non-Tensor Attributes: |
| 38 | + block_size: the block size for quantization, representing the granularity. |
| 39 | + shape: shape of the original Tensor |
| 40 | +
|
| 41 | + """ |
| 42 | + |
| 43 | + tensor_data_names = ["qdata", "scale", "zero_point"] |
| 44 | + tensor_attribute_names = ["block_size", "shape"] |
| 45 | + |
| 46 | + def __new__( |
| 47 | + cls, |
| 48 | + qdata, |
| 49 | + scale, |
| 50 | + zero_point, |
| 51 | + block_size, |
| 52 | + shape, |
| 53 | + ): |
| 54 | + kwargs = {} |
| 55 | + kwargs["device"] = qdata.device |
| 56 | + kwargs["dtype"] = scale.dtype |
| 57 | + kwargs["requires_grad"] = False |
| 58 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 59 | + |
| 60 | + def __init__(self, qdata, scale, zero_point, block_size, shape): |
| 61 | + self.qdata = qdata |
| 62 | + self.scale = scale |
| 63 | + self.zero_point = zero_point |
| 64 | + self.block_size = block_size |
| 65 | + |
| 66 | + def _quantization_type(self): |
| 67 | + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" |
| 68 | + |
| 69 | + @classmethod |
| 70 | + def from_hp( |
| 71 | + cls, |
| 72 | + w: torch.Tensor, |
| 73 | + block_size: List[int], |
| 74 | + ): |
| 75 | + assert w.ndim == 2 and w.device.type == "xpu", ( |
| 76 | + f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" |
| 77 | + ) |
| 78 | + assert len(block_size) == w.ndim |
| 79 | + assert w.dtype in [torch.float16, torch.bfloat16], ( |
| 80 | + f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}" |
| 81 | + ) |
| 82 | + original_shape = w.shape |
| 83 | + mapping_type = MappingType.ASYMMETRIC |
| 84 | + target_dtype = torch.int32 |
| 85 | + quant_min = 0 |
| 86 | + quant_max = 15 |
| 87 | + eps = 1e-6 |
| 88 | + scale_dtype = None |
| 89 | + zero_point_dtype = torch.int32 |
| 90 | + scale, zero_point = choose_qparams_affine( |
| 91 | + w, |
| 92 | + mapping_type, |
| 93 | + block_size, |
| 94 | + target_dtype, |
| 95 | + quant_min, |
| 96 | + quant_max, |
| 97 | + eps, |
| 98 | + scale_dtype, |
| 99 | + zero_point_dtype, |
| 100 | + ) |
| 101 | + int_data = quantize_affine( |
| 102 | + w, |
| 103 | + block_size, |
| 104 | + scale, |
| 105 | + zero_point, |
| 106 | + target_dtype, |
| 107 | + quant_min, |
| 108 | + quant_max, |
| 109 | + ) |
| 110 | + assert int_data.dtype == torch.int32, ( |
| 111 | + "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" |
| 112 | + ) |
| 113 | + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) |
| 114 | + packed_weight = torch.ops.aten._convert_weight_to_int4pack( |
| 115 | + packed_weight.contiguous(), 8 |
| 116 | + ) |
| 117 | + scale = scale.reshape(int_data.shape[0], -1) |
| 118 | + zero_point = zero_point.reshape(int_data.shape[0], -1) |
| 119 | + return Int4PlainInt32Tensor( |
| 120 | + packed_weight, |
| 121 | + scale.transpose(0, 1).contiguous(), |
| 122 | + zero_point.transpose(0, 1).contiguous().to(torch.int8), |
| 123 | + block_size, |
| 124 | + original_shape, |
| 125 | + ) |
| 126 | + |
| 127 | + |
| 128 | +implements = Int4PlainInt32Tensor.implements |
| 129 | + |
| 130 | + |
| 131 | +@implements([torch.nn.functional.linear, aten.linear.default]) |
| 132 | +def _(func, types, args, kwargs): |
| 133 | + input_tensor, weight_tensor, bias = ( |
| 134 | + args[0], |
| 135 | + args[1], |
| 136 | + args[2] if len(args) > 2 else None, |
| 137 | + ) |
| 138 | + assert input_tensor.device.type == "xpu", ( |
| 139 | + f"For XPU device only but got: {input_tensor.device}" |
| 140 | + ) |
| 141 | + assert isinstance(weight_tensor, Int4PlainInt32Tensor), ( |
| 142 | + f"Expected weight_tensor to be Int4PlainInt32Tensor, got: {type(weight_tensor)}" |
| 143 | + ) |
| 144 | + assert weight_tensor.block_size[0] == 1, ( |
| 145 | + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" |
| 146 | + ) |
| 147 | + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( |
| 148 | + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" |
| 149 | + ) |
| 150 | + |
| 151 | + act_mat = input_tensor |
| 152 | + packed_weight = weight_tensor.qdata |
| 153 | + scale = weight_tensor.scale |
| 154 | + zero_point = weight_tensor.zero_point |
| 155 | + |
| 156 | + orig_act_size = act_mat.size() |
| 157 | + orig_dtype = act_mat.dtype |
| 158 | + |
| 159 | + # reshape to 2D |
| 160 | + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) |
| 161 | + |
| 162 | + # groupwise int4 quantization |
| 163 | + groupsize = weight_tensor.block_size[1] |
| 164 | + y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( |
| 165 | + act_mat, packed_weight, groupsize, scale, zero_point |
| 166 | + ) |
| 167 | + |
| 168 | + # remove out_feature padding |
| 169 | + assert weight_tensor.ndim == 2 |
| 170 | + orig_out_features = weight_tensor.shape[-2] |
| 171 | + y = y[:, :orig_out_features] |
| 172 | + y = y.reshape(*orig_act_size[:-1], orig_out_features) |
| 173 | + |
| 174 | + if bias is not None: |
| 175 | + y += bias |
| 176 | + return y.to(orig_dtype) |
| 177 | + |
| 178 | + |
| 179 | +Int4PlainInt32Tensor.__module__ = "torchao.quantization" |
| 180 | + |
| 181 | +# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True` |
| 182 | +torch.serialization.add_safe_globals([Int4PlainInt32Tensor]) |
0 commit comments