Skip to content

Commit 9d01b43

Browse files
authored
Add Int4PlainInt32Tensor (#2845)
* Add Int4XPUTensorIntZP * Add int4_xpu_tensor * Update int4_xpu_tensor.py * Fix typo * Fix code format issue * fix bug * Fix code format * Update int4_xpu_tensor.py * change the pack format to plain * fix typo * Update quant_api.py * merge main branch * Update __init__.py * Update __init__.py * change Int4XPUTensorIntZP to Int4PlainInt32 * Update __init__.py * Refine code * Refine code * Update __init__.py * Update __init__.py * Add more comments about the original weight dtype * fix code format issue * fix code format issue * skip ut if no xpu * Update test_int4_plain_int32_tensor.py * Add assert for the original weight data type
1 parent 8776967 commit 9d01b43

File tree

6 files changed

+287
-1
lines changed

6 files changed

+287
-1
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import (
19+
Int4WeightOnlyConfig,
20+
quantize_,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.utils import (
24+
torch_version_at_least,
25+
)
26+
27+
28+
def get_config(group_size):
29+
return Int4WeightOnlyConfig(
30+
group_size=group_size,
31+
packing_format="plain_int32",
32+
version=2,
33+
)
34+
35+
36+
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
37+
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
38+
class Int4PlainInt32Tensor(TestCase):
39+
@parametrize(
40+
"sizes",
41+
[
42+
((128,), 256, 128),
43+
((32, 128), 512, 128),
44+
((2, 32, 128), 256, 12),
45+
],
46+
)
47+
@parametrize("dtype", [torch.bfloat16, torch.half])
48+
@parametrize("group_size", [32, 64, 128])
49+
def test_linear(self, sizes, dtype, group_size):
50+
device = "xpu"
51+
M, N, K = sizes
52+
input = torch.randn(*M, K, dtype=dtype, device=device)
53+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
54+
original = linear(input)
55+
quantize_(linear, get_config(group_size))
56+
quantized = linear(input)
57+
self.assertTrue(compute_error(original, quantized) > 20)
58+
59+
compiled_linear = torch.compile(linear)
60+
quantized_and_compiled = compiled_linear(input)
61+
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
62+
63+
@parametrize("dtype", [torch.bfloat16, torch.half])
64+
def test_module_path(self, dtype):
65+
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
66+
quantize_(linear, get_config(group_size=128))
67+
self.assertEqual(
68+
str(type(linear.weight)),
69+
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
70+
)
71+
72+
with tempfile.NamedTemporaryFile() as f:
73+
torch.save(linear.state_dict(), f)
74+
f.seek(0)
75+
state_dict = torch.load(f)
76+
self.assertEqual(
77+
str(type(state_dict["weight"])),
78+
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
79+
)
80+
81+
82+
instantiate_parametrized_tests(Int4PlainInt32Tensor)
83+
84+
85+
if __name__ == "__main__":
86+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
Float8Tensor,
9393
Int4MarlinSparseTensor,
9494
Int4OpaqueTensor,
95+
Int4PlainInt32Tensor,
9596
Int4PreshuffledTensor,
9697
Int4Tensor,
9798
Int4TilePackedTo4dTensor,
@@ -163,6 +164,7 @@
163164
"FbgemmConfig",
164165
# tensor subclasses
165166
"Int4Tensor",
167+
"Int4PlainInt32Tensor",
166168
"Int4PreshuffledTensor",
167169
"Int4MarlinSparseTensor",
168170
"IntxOpaqueTensor",

torchao/quantization/quant_api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
Float8Tensor,
7575
Int4MarlinSparseTensor,
7676
Int4OpaqueTensor,
77+
Int4PlainInt32Tensor,
7778
Int4PreshuffledTensor,
7879
Int4Tensor,
7980
Int4TilePackedTo4dTensor,
@@ -522,7 +523,6 @@ def quantize_(
522523
torch._C._log_api_usage_once("torchao.quantization.quantize_")
523524

524525
filter_fn = _is_linear if filter_fn is None else filter_fn
525-
526526
if isinstance(config, ModuleFqnToConfig):
527527
_replace_with_custom_fn_if_matches_filter_with_name(
528528
model,
@@ -1131,6 +1131,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
11311131
block_size,
11321132
)
11331133
return new_weight
1134+
elif packing_format == PackingFormat.PLAIN_INT32:
1135+
new_weight = Int4PlainInt32Tensor.from_hp(
1136+
weight,
1137+
block_size,
1138+
)
1139+
return new_weight
11341140
elif packing_format == PackingFormat.MARLIN_SPARSE:
11351141
new_weight = Int4MarlinSparseTensor.from_hp(
11361142
weight,

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class PackingFormat(str, Enum):
4141
"""
4242
UNPACKED_TO_INT8 = "unpacked_to_int8"
4343

44+
"""
45+
plain_int32 is referring to the format used by int4 weight-only quantization.
46+
which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32.
47+
"""
48+
PLAIN_INT32 = "plain_int32"
49+
4450
"""
4551
tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization
4652
"""

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from .int4.int4_opaque_tensor import (
99
Int4OpaqueTensor,
1010
)
11+
from .int4.int4_plain_int32_tensor import (
12+
Int4PlainInt32Tensor,
13+
)
1114
from .int4.int4_preshuffled_tensor import (
1215
Int4PreshuffledTensor,
1316
)
@@ -26,6 +29,7 @@
2629
"Int4Tensor",
2730
"Int4PreshuffledTensor",
2831
"Int4MarlinSparseTensor",
32+
"Int4PlainInt32Tensor",
2933
"Int4TilePackedTo4dTensor",
3034
"Float8Tensor",
3135
"QuantizeTensorToFloat8Kwargs",
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import List
9+
10+
import torch
11+
12+
from torchao.quantization.quant_primitives import (
13+
MappingType,
14+
choose_qparams_affine,
15+
quantize_affine,
16+
)
17+
from torchao.utils import (
18+
TorchAOBaseTensor,
19+
)
20+
21+
__all__ = [
22+
"Int4PlainInt32Tensor",
23+
]
24+
25+
aten = torch.ops.aten
26+
27+
28+
class Int4PlainInt32Tensor(TorchAOBaseTensor):
29+
"""
30+
int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only)
31+
32+
Tensor Attributes:
33+
qdata: (N, K/8), packed int4 weight, the data type is int32 here with 4*(int4*2), the original data type can be half and bfloat16
34+
scale: (K/group_size, N), dtype is the same as the original Tensor dtype
35+
zero_point: (K/group_size, N), dtype is int8
36+
37+
Non-Tensor Attributes:
38+
block_size: the block size for quantization, representing the granularity.
39+
shape: shape of the original Tensor
40+
41+
"""
42+
43+
tensor_data_names = ["qdata", "scale", "zero_point"]
44+
tensor_attribute_names = ["block_size", "shape"]
45+
46+
def __new__(
47+
cls,
48+
qdata,
49+
scale,
50+
zero_point,
51+
block_size,
52+
shape,
53+
):
54+
kwargs = {}
55+
kwargs["device"] = qdata.device
56+
kwargs["dtype"] = scale.dtype
57+
kwargs["requires_grad"] = False
58+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
59+
60+
def __init__(self, qdata, scale, zero_point, block_size, shape):
61+
self.qdata = qdata
62+
self.scale = scale
63+
self.zero_point = zero_point
64+
self.block_size = block_size
65+
66+
def _quantization_type(self):
67+
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
68+
69+
@classmethod
70+
def from_hp(
71+
cls,
72+
w: torch.Tensor,
73+
block_size: List[int],
74+
):
75+
assert w.ndim == 2 and w.device.type == "xpu", (
76+
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
77+
)
78+
assert len(block_size) == w.ndim
79+
assert w.dtype in [torch.float16, torch.bfloat16], (
80+
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
81+
)
82+
original_shape = w.shape
83+
mapping_type = MappingType.ASYMMETRIC
84+
target_dtype = torch.int32
85+
quant_min = 0
86+
quant_max = 15
87+
eps = 1e-6
88+
scale_dtype = None
89+
zero_point_dtype = torch.int32
90+
scale, zero_point = choose_qparams_affine(
91+
w,
92+
mapping_type,
93+
block_size,
94+
target_dtype,
95+
quant_min,
96+
quant_max,
97+
eps,
98+
scale_dtype,
99+
zero_point_dtype,
100+
)
101+
int_data = quantize_affine(
102+
w,
103+
block_size,
104+
scale,
105+
zero_point,
106+
target_dtype,
107+
quant_min,
108+
quant_max,
109+
)
110+
assert int_data.dtype == torch.int32, (
111+
"torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype"
112+
)
113+
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
114+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
115+
packed_weight.contiguous(), 8
116+
)
117+
scale = scale.reshape(int_data.shape[0], -1)
118+
zero_point = zero_point.reshape(int_data.shape[0], -1)
119+
return Int4PlainInt32Tensor(
120+
packed_weight,
121+
scale.transpose(0, 1).contiguous(),
122+
zero_point.transpose(0, 1).contiguous().to(torch.int8),
123+
block_size,
124+
original_shape,
125+
)
126+
127+
128+
implements = Int4PlainInt32Tensor.implements
129+
130+
131+
@implements([torch.nn.functional.linear, aten.linear.default])
132+
def _(func, types, args, kwargs):
133+
input_tensor, weight_tensor, bias = (
134+
args[0],
135+
args[1],
136+
args[2] if len(args) > 2 else None,
137+
)
138+
assert input_tensor.device.type == "xpu", (
139+
f"For XPU device only but got: {input_tensor.device}"
140+
)
141+
assert isinstance(weight_tensor, Int4PlainInt32Tensor), (
142+
f"Expected weight_tensor to be Int4PlainInt32Tensor, got: {type(weight_tensor)}"
143+
)
144+
assert weight_tensor.block_size[0] == 1, (
145+
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
146+
)
147+
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
148+
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
149+
)
150+
151+
act_mat = input_tensor
152+
packed_weight = weight_tensor.qdata
153+
scale = weight_tensor.scale
154+
zero_point = weight_tensor.zero_point
155+
156+
orig_act_size = act_mat.size()
157+
orig_dtype = act_mat.dtype
158+
159+
# reshape to 2D
160+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
161+
162+
# groupwise int4 quantization
163+
groupsize = weight_tensor.block_size[1]
164+
y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
165+
act_mat, packed_weight, groupsize, scale, zero_point
166+
)
167+
168+
# remove out_feature padding
169+
assert weight_tensor.ndim == 2
170+
orig_out_features = weight_tensor.shape[-2]
171+
y = y[:, :orig_out_features]
172+
y = y.reshape(*orig_act_size[:-1], orig_out_features)
173+
174+
if bias is not None:
175+
y += bias
176+
return y.to(orig_dtype)
177+
178+
179+
Int4PlainInt32Tensor.__module__ = "torchao.quantization"
180+
181+
# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True`
182+
torch.serialization.add_safe_globals([Int4PlainInt32Tensor])

0 commit comments

Comments
 (0)