Skip to content

Commit aadfded

Browse files
authored
Move cutlass_int4_packed_layout to prototype (#3277)
1 parent 6c78c4d commit aadfded

File tree

9 files changed

+280
-228
lines changed

9 files changed

+280
-228
lines changed

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Layouts and Tensor Subclasses
2626
MarlinQQQTensor
2727
MarlinQQQLayout
2828
Int4CPULayout
29-
CutlassInt4PackedLayout
3029
CutlassSemiSparseLayout
3130

3231
Quantization techniques
@@ -52,6 +51,7 @@ Prototype
5251
:nosignatures:
5352

5453
BlockSparseLayout
54+
CutlassInt4PackedLayout
5555

5656
..
5757
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring

test/integration/test_integration.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,5 +1946,32 @@ def test_benchmark_model_cpu(self):
19461946
assert self.run_benchmark_model("cpu") is not None
19471947

19481948

1949+
# TODO: Remove this test once the deprecated API has been removed
1950+
def test_cutlass_int4_packed_layout_deprecated():
1951+
import sys
1952+
import warnings
1953+
1954+
# We need to clear the cache to force re-importing and trigger the warning again.
1955+
modules_to_clear = [
1956+
"torchao.dtypes.uintx.cutlass_int4_packed_layout",
1957+
"torchao.dtypes",
1958+
]
1959+
for mod in modules_to_clear:
1960+
if mod in sys.modules:
1961+
del sys.modules[mod]
1962+
1963+
with warnings.catch_warnings(record=True) as w:
1964+
from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401
1965+
1966+
warnings.simplefilter("always") # Ensure all warnings are captured
1967+
assert any(
1968+
issubclass(warning.category, DeprecationWarning)
1969+
and "CutlassInt4PackedLayout" in str(warning.message)
1970+
for warning in w
1971+
), (
1972+
f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}"
1973+
)
1974+
1975+
19491976
if __name__ == "__main__":
19501977
unittest.main()

torchao/dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
)
1515
from .nf4tensor import NF4Tensor, to_nf4
1616
from .uintx import (
17-
CutlassInt4PackedLayout,
1817
Int4CPULayout,
1918
Int4XPULayout,
2019
Int8DynamicActInt4WeightCPULayout,
@@ -29,6 +28,7 @@
2928
to_marlinqqq_quantized_intx,
3029
)
3130
from .uintx.block_sparse_layout import BlockSparseLayout
31+
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
3232
from .utils import (
3333
Layout,
3434
PlainLayout,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,6 @@
2525
_linear_f16_bf16_act_floatx_weight_check,
2626
_linear_f16_bf16_act_floatx_weight_impl,
2727
)
28-
from torchao.dtypes.uintx.cutlass_int4_packed_layout import (
29-
_linear_int4_act_int4_weight_cutlass_check,
30-
_linear_int4_act_int4_weight_cutlass_impl,
31-
_linear_int8_act_int4_weight_cutlass_check,
32-
_linear_int8_act_int4_weight_cutlass_impl,
33-
)
3428
from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
3529
_linear_int8_act_int4_weight_cpu_check,
3630
_linear_int8_act_int4_weight_cpu_impl,
@@ -94,6 +88,12 @@
9488
_linear_int8_act_int8_weight_block_sparse_check,
9589
_linear_int8_act_int8_weight_block_sparse_impl,
9690
)
91+
from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import (
92+
_linear_int4_act_int4_weight_cutlass_check,
93+
_linear_int4_act_int4_weight_cutlass_impl,
94+
_linear_int8_act_int4_weight_cutlass_check,
95+
_linear_int8_act_int4_weight_cutlass_impl,
96+
)
9797
from torchao.quantization.quant_primitives import (
9898
ZeroPointDomain,
9999
_dequantize_affine_no_zero_point,

torchao/dtypes/uintx/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from .cutlass_int4_packed_layout import (
2-
CutlassInt4PackedLayout,
3-
)
41
from .dyn_int8_act_int4_wei_cpu_layout import (
52
Int8DynamicActInt4WeightCPULayout,
63
)
@@ -43,7 +40,6 @@
4340
"MarlinQQQLayout",
4441
"MarlinQQQTensor",
4542
"to_marlinqqq_quantized_intx",
46-
"CutlassInt4PackedLayout",
4743
"PackedLinearInt8DynamicActivationIntxWeightLayout",
4844
"QDQLayout",
4945
"Int4XPULayout",

torchao/dtypes/uintx/cutlass_int4_packed_layout.py

Lines changed: 17 additions & 215 deletions
Original file line numberDiff line numberDiff line change
@@ -3,222 +3,24 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from dataclasses import dataclass
7-
from typing import Optional
86

9-
import torch
10-
from torch.utils._python_dispatch import (
11-
return_and_correct_aliasing,
12-
)
7+
# Backward compatibility stub - imports from the new location
8+
import warnings
139

14-
from torchao.dtypes.affine_quantized_tensor import (
15-
AffineQuantizedTensor,
16-
register_layout,
17-
)
18-
from torchao.dtypes.uintx.plain_layout import (
19-
_aqt_is_int8,
10+
warnings.warn(
11+
"Importing from torchao.dtypes is deprecated. "
12+
"Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. "
13+
"This import path will be removed in a future torchao release. "
14+
"Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ",
15+
DeprecationWarning,
16+
stacklevel=2,
2017
)
21-
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
22-
23-
aten = torch.ops.aten
24-
25-
26-
def _aqt_is_int4(aqt):
27-
"""Check if an AffineQuantizedTensor is int4 quantized Tensor"""
28-
# TODO: use torch.int4
29-
return (
30-
aqt.tensor_impl.dtype == torch.int8
31-
and aqt.quant_min == -8
32-
and aqt.quant_max == 7
33-
)
34-
35-
36-
def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
37-
return (
38-
isinstance(self, Int4PackedTensorImpl)
39-
and isinstance(src, Int4PackedTensorImpl)
40-
and self.shape == src.shape
41-
and self.int_data.shape == src.int_data.shape
42-
and self.scale.shape == src.scale.shape
43-
and type(self._layout) == type(src._layout)
44-
)
45-
46-
47-
@dataclass(frozen=True)
48-
class CutlassInt4PackedLayout(Layout):
49-
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
50-
51-
pass
52-
53-
54-
@register_layout(CutlassInt4PackedLayout)
55-
class Int4PackedTensorImpl(AQTTensorImpl):
56-
"""
57-
TensorImpl storage class for int4 packed layout for affine quantized tensor.
58-
"""
59-
60-
@staticmethod
61-
def __new__(
62-
cls,
63-
int_data: torch.Tensor,
64-
scale: torch.Tensor,
65-
_layout: Layout,
66-
):
67-
kwargs = {}
68-
kwargs["device"] = int_data.device
69-
kwargs["layout"] = (
70-
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
71-
)
72-
kwargs["dtype"] = int_data.dtype
73-
kwargs["requires_grad"] = False
74-
shape = int_data.shape
75-
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
76-
77-
def __init__(
78-
self,
79-
int_data: torch.Tensor,
80-
scale: torch.Tensor,
81-
_layout: Layout,
82-
):
83-
self.int_data = int_data
84-
self.scale = scale
85-
self._layout = _layout
86-
87-
@classmethod
88-
def __torch_dispatch__(cls, func, types, args, kwargs):
89-
kwargs = {} if kwargs is None else kwargs
90-
91-
if func is aten.detach.default:
92-
return return_and_correct_aliasing(
93-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
94-
)
95-
96-
elif func is aten.copy_.default:
97-
self = args[0]
98-
src = args[1]
99-
if _same_metadata(self, src):
100-
self_tensors = self.__tensor_flatten__()[0]
101-
for tensor_name in self_tensors:
102-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
103-
return
104-
raise ValueError(
105-
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
106-
)
107-
108-
raise NotImplementedError(
109-
f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported"
110-
)
111-
112-
def __tensor_flatten__(self):
113-
return ["int_data", "scale"], [self._layout]
11418

115-
@classmethod
116-
def __tensor_unflatten__(
117-
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
118-
):
119-
int_data = tensor_data_dict["int_data"]
120-
scale = tensor_data_dict["scale"]
121-
(_layout,) = tensor_attributes
122-
return cls(int_data, scale, _layout)
123-
124-
def get_plain(self):
125-
int_data = torch.stack(
126-
((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1
127-
).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],))
128-
return int_data, self.scale, None
129-
130-
@classmethod
131-
def from_plain(
132-
cls,
133-
int_data: torch.Tensor,
134-
scale: torch.Tensor,
135-
zero_point: Optional[torch.Tensor],
136-
_layout: Layout,
137-
):
138-
assert zero_point is None or torch.all(zero_point == 0)
139-
int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF)
140-
return cls(
141-
int_data_s4,
142-
scale,
143-
_layout,
144-
)
145-
146-
def get_layout(self) -> Layout:
147-
return self._layout
148-
149-
def _apply_fn_to_data(self, fn):
150-
self.int_data = fn(self.int_data)
151-
self.scale = fn(self.scale)
152-
return self
153-
154-
155-
def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias):
156-
return (
157-
isinstance(input_tensor, AffineQuantizedTensor)
158-
and isinstance(input_tensor._layout, PlainLayout)
159-
and _aqt_is_int8(input_tensor)
160-
and input_tensor.dtype in (torch.float16, torch.bfloat16)
161-
and len(input_tensor.shape) >= 2
162-
and input_tensor.tensor_impl.scale.dtype == torch.float32
163-
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
164-
and isinstance(weight_tensor, AffineQuantizedTensor)
165-
and isinstance(weight_tensor._layout, CutlassInt4PackedLayout)
166-
and _aqt_is_int4(weight_tensor)
167-
and weight_tensor.dtype == input_tensor.dtype
168-
and len(weight_tensor.shape) == 2
169-
and weight_tensor.tensor_impl.scale.dtype == torch.float32
170-
and len(weight_tensor.tensor_impl.scale.shape) == 1
171-
and (bias is None or bias.dtype == input_tensor.dtype)
172-
and (bias is None or len(bias.shape) == 1)
173-
)
174-
175-
176-
def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
177-
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
178-
179-
weight = weight_tensor.tensor_impl.int_data
180-
weight_scale = weight_tensor.tensor_impl.scale
181-
input = input_tensor.tensor_impl.int_data
182-
input_scale = input_tensor.tensor_impl.scale
183-
out_dtype = input_tensor.dtype
184-
185-
out = rowwise_scaled_linear_cutlass_s8s4(
186-
input, input_scale, weight, weight_scale, bias, out_dtype
187-
)
188-
189-
return out
190-
191-
192-
def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias):
193-
return (
194-
isinstance(input_tensor, AffineQuantizedTensor)
195-
and isinstance(input_tensor._layout, CutlassInt4PackedLayout)
196-
and _aqt_is_int4(input_tensor)
197-
and input_tensor.dtype in (torch.float16, torch.bfloat16)
198-
and len(input_tensor.shape) >= 2
199-
and input_tensor.tensor_impl.scale.dtype == torch.float32
200-
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
201-
and isinstance(weight_tensor, AffineQuantizedTensor)
202-
and isinstance(weight_tensor._layout, CutlassInt4PackedLayout)
203-
and _aqt_is_int4(weight_tensor)
204-
and weight_tensor.dtype == input_tensor.dtype
205-
and len(weight_tensor.shape) == 2
206-
and weight_tensor.tensor_impl.scale.dtype == torch.float32
207-
and len(weight_tensor.tensor_impl.scale.shape) == 1
208-
)
209-
210-
211-
def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
212-
from torchao.ops import rowwise_scaled_linear_cutlass_s4s4
213-
214-
weight = weight_tensor.tensor_impl.int_data
215-
weight_scale = weight_tensor.tensor_impl.scale
216-
input = input_tensor.tensor_impl.int_data
217-
input_scale = input_tensor.tensor_impl.scale
218-
out_dtype = input_tensor.dtype
219-
220-
out = rowwise_scaled_linear_cutlass_s4s4(
221-
input, input_scale, weight, weight_scale, bias, out_dtype
222-
)
223-
224-
return out
19+
from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( # noqa: F401
20+
CutlassInt4PackedLayout, # noqa: F401
21+
Int4PackedTensorImpl, # noqa: F401
22+
_linear_int4_act_int4_weight_cutlass_check, # noqa: F401
23+
_linear_int4_act_int4_weight_cutlass_impl, # noqa: F401
24+
_linear_int8_act_int4_weight_cutlass_check, # noqa: F401
25+
_linear_int8_act_int4_weight_cutlass_impl, # noqa: F401
26+
)

torchao/prototype/dtypes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .uintx import BlockSparseLayout
7+
from .uintx import BlockSparseLayout, CutlassInt4PackedLayout
88

99
__all__ = [
1010
"BlockSparseLayout",
11+
"CutlassInt4PackedLayout",
1112
]

torchao/prototype/dtypes/uintx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from .block_sparse_layout import BlockSparseLayout
8+
from .cutlass_int4_packed_layout import CutlassInt4PackedLayout
89

910
__all__ = [
1011
"BlockSparseLayout",
12+
"CutlassInt4PackedLayout",
1113
]

0 commit comments

Comments
 (0)