|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | | -from dataclasses import dataclass |
7 | | -from typing import Optional |
8 | 6 |
|
9 | | -import torch |
10 | | -from torch.utils._python_dispatch import ( |
11 | | - return_and_correct_aliasing, |
12 | | -) |
| 7 | +# Backward compatibility stub - imports from the new location |
| 8 | +import warnings |
13 | 9 |
|
14 | | -from torchao.dtypes.affine_quantized_tensor import ( |
15 | | - AffineQuantizedTensor, |
16 | | - register_layout, |
17 | | -) |
18 | | -from torchao.dtypes.uintx.plain_layout import ( |
19 | | - _aqt_is_int8, |
| 10 | +warnings.warn( |
| 11 | + "Importing from torchao.dtypes is deprecated. " |
| 12 | + "Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. " |
| 13 | + "This import path will be removed in a future torchao release. " |
| 14 | + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", |
| 15 | + DeprecationWarning, |
| 16 | + stacklevel=2, |
20 | 17 | ) |
21 | | -from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout |
22 | | - |
23 | | -aten = torch.ops.aten |
24 | | - |
25 | | - |
26 | | -def _aqt_is_int4(aqt): |
27 | | - """Check if an AffineQuantizedTensor is int4 quantized Tensor""" |
28 | | - # TODO: use torch.int4 |
29 | | - return ( |
30 | | - aqt.tensor_impl.dtype == torch.int8 |
31 | | - and aqt.quant_min == -8 |
32 | | - and aqt.quant_max == 7 |
33 | | - ) |
34 | | - |
35 | | - |
36 | | -def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool: |
37 | | - return ( |
38 | | - isinstance(self, Int4PackedTensorImpl) |
39 | | - and isinstance(src, Int4PackedTensorImpl) |
40 | | - and self.shape == src.shape |
41 | | - and self.int_data.shape == src.int_data.shape |
42 | | - and self.scale.shape == src.scale.shape |
43 | | - and type(self._layout) == type(src._layout) |
44 | | - ) |
45 | | - |
46 | | - |
47 | | -@dataclass(frozen=True) |
48 | | -class CutlassInt4PackedLayout(Layout): |
49 | | - """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" |
50 | | - |
51 | | - pass |
52 | | - |
53 | | - |
54 | | -@register_layout(CutlassInt4PackedLayout) |
55 | | -class Int4PackedTensorImpl(AQTTensorImpl): |
56 | | - """ |
57 | | - TensorImpl storage class for int4 packed layout for affine quantized tensor. |
58 | | - """ |
59 | | - |
60 | | - @staticmethod |
61 | | - def __new__( |
62 | | - cls, |
63 | | - int_data: torch.Tensor, |
64 | | - scale: torch.Tensor, |
65 | | - _layout: Layout, |
66 | | - ): |
67 | | - kwargs = {} |
68 | | - kwargs["device"] = int_data.device |
69 | | - kwargs["layout"] = ( |
70 | | - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout |
71 | | - ) |
72 | | - kwargs["dtype"] = int_data.dtype |
73 | | - kwargs["requires_grad"] = False |
74 | | - shape = int_data.shape |
75 | | - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
76 | | - |
77 | | - def __init__( |
78 | | - self, |
79 | | - int_data: torch.Tensor, |
80 | | - scale: torch.Tensor, |
81 | | - _layout: Layout, |
82 | | - ): |
83 | | - self.int_data = int_data |
84 | | - self.scale = scale |
85 | | - self._layout = _layout |
86 | | - |
87 | | - @classmethod |
88 | | - def __torch_dispatch__(cls, func, types, args, kwargs): |
89 | | - kwargs = {} if kwargs is None else kwargs |
90 | | - |
91 | | - if func is aten.detach.default: |
92 | | - return return_and_correct_aliasing( |
93 | | - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
94 | | - ) |
95 | | - |
96 | | - elif func is aten.copy_.default: |
97 | | - self = args[0] |
98 | | - src = args[1] |
99 | | - if _same_metadata(self, src): |
100 | | - self_tensors = self.__tensor_flatten__()[0] |
101 | | - for tensor_name in self_tensors: |
102 | | - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) |
103 | | - return |
104 | | - raise ValueError( |
105 | | - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" |
106 | | - ) |
107 | | - |
108 | | - raise NotImplementedError( |
109 | | - f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported" |
110 | | - ) |
111 | | - |
112 | | - def __tensor_flatten__(self): |
113 | | - return ["int_data", "scale"], [self._layout] |
114 | 18 |
|
115 | | - @classmethod |
116 | | - def __tensor_unflatten__( |
117 | | - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride |
118 | | - ): |
119 | | - int_data = tensor_data_dict["int_data"] |
120 | | - scale = tensor_data_dict["scale"] |
121 | | - (_layout,) = tensor_attributes |
122 | | - return cls(int_data, scale, _layout) |
123 | | - |
124 | | - def get_plain(self): |
125 | | - int_data = torch.stack( |
126 | | - ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 |
127 | | - ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) |
128 | | - return int_data, self.scale, None |
129 | | - |
130 | | - @classmethod |
131 | | - def from_plain( |
132 | | - cls, |
133 | | - int_data: torch.Tensor, |
134 | | - scale: torch.Tensor, |
135 | | - zero_point: Optional[torch.Tensor], |
136 | | - _layout: Layout, |
137 | | - ): |
138 | | - assert zero_point is None or torch.all(zero_point == 0) |
139 | | - int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) |
140 | | - return cls( |
141 | | - int_data_s4, |
142 | | - scale, |
143 | | - _layout, |
144 | | - ) |
145 | | - |
146 | | - def get_layout(self) -> Layout: |
147 | | - return self._layout |
148 | | - |
149 | | - def _apply_fn_to_data(self, fn): |
150 | | - self.int_data = fn(self.int_data) |
151 | | - self.scale = fn(self.scale) |
152 | | - return self |
153 | | - |
154 | | - |
155 | | -def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): |
156 | | - return ( |
157 | | - isinstance(input_tensor, AffineQuantizedTensor) |
158 | | - and isinstance(input_tensor._layout, PlainLayout) |
159 | | - and _aqt_is_int8(input_tensor) |
160 | | - and input_tensor.dtype in (torch.float16, torch.bfloat16) |
161 | | - and len(input_tensor.shape) >= 2 |
162 | | - and input_tensor.tensor_impl.scale.dtype == torch.float32 |
163 | | - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 |
164 | | - and isinstance(weight_tensor, AffineQuantizedTensor) |
165 | | - and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) |
166 | | - and _aqt_is_int4(weight_tensor) |
167 | | - and weight_tensor.dtype == input_tensor.dtype |
168 | | - and len(weight_tensor.shape) == 2 |
169 | | - and weight_tensor.tensor_impl.scale.dtype == torch.float32 |
170 | | - and len(weight_tensor.tensor_impl.scale.shape) == 1 |
171 | | - and (bias is None or bias.dtype == input_tensor.dtype) |
172 | | - and (bias is None or len(bias.shape) == 1) |
173 | | - ) |
174 | | - |
175 | | - |
176 | | -def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): |
177 | | - from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 |
178 | | - |
179 | | - weight = weight_tensor.tensor_impl.int_data |
180 | | - weight_scale = weight_tensor.tensor_impl.scale |
181 | | - input = input_tensor.tensor_impl.int_data |
182 | | - input_scale = input_tensor.tensor_impl.scale |
183 | | - out_dtype = input_tensor.dtype |
184 | | - |
185 | | - out = rowwise_scaled_linear_cutlass_s8s4( |
186 | | - input, input_scale, weight, weight_scale, bias, out_dtype |
187 | | - ) |
188 | | - |
189 | | - return out |
190 | | - |
191 | | - |
192 | | -def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): |
193 | | - return ( |
194 | | - isinstance(input_tensor, AffineQuantizedTensor) |
195 | | - and isinstance(input_tensor._layout, CutlassInt4PackedLayout) |
196 | | - and _aqt_is_int4(input_tensor) |
197 | | - and input_tensor.dtype in (torch.float16, torch.bfloat16) |
198 | | - and len(input_tensor.shape) >= 2 |
199 | | - and input_tensor.tensor_impl.scale.dtype == torch.float32 |
200 | | - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 |
201 | | - and isinstance(weight_tensor, AffineQuantizedTensor) |
202 | | - and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) |
203 | | - and _aqt_is_int4(weight_tensor) |
204 | | - and weight_tensor.dtype == input_tensor.dtype |
205 | | - and len(weight_tensor.shape) == 2 |
206 | | - and weight_tensor.tensor_impl.scale.dtype == torch.float32 |
207 | | - and len(weight_tensor.tensor_impl.scale.shape) == 1 |
208 | | - ) |
209 | | - |
210 | | - |
211 | | -def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): |
212 | | - from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 |
213 | | - |
214 | | - weight = weight_tensor.tensor_impl.int_data |
215 | | - weight_scale = weight_tensor.tensor_impl.scale |
216 | | - input = input_tensor.tensor_impl.int_data |
217 | | - input_scale = input_tensor.tensor_impl.scale |
218 | | - out_dtype = input_tensor.dtype |
219 | | - |
220 | | - out = rowwise_scaled_linear_cutlass_s4s4( |
221 | | - input, input_scale, weight, weight_scale, bias, out_dtype |
222 | | - ) |
223 | | - |
224 | | - return out |
| 19 | +from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( # noqa: F401 |
| 20 | + CutlassInt4PackedLayout, # noqa: F401 |
| 21 | + Int4PackedTensorImpl, # noqa: F401 |
| 22 | + _linear_int4_act_int4_weight_cutlass_check, # noqa: F401 |
| 23 | + _linear_int4_act_int4_weight_cutlass_impl, # noqa: F401 |
| 24 | + _linear_int8_act_int4_weight_cutlass_check, # noqa: F401 |
| 25 | + _linear_int8_act_int4_weight_cutlass_impl, # noqa: F401 |
| 26 | +) |
0 commit comments