From bc67c1a2b27d2760838e7c37c54c5996b977943c Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 00:49:49 +0000 Subject: [PATCH 01/24] Move block_sparse_layout to prototype --- docs/source/api_ref_dtypes.rst | 1 - docs/source/api_ref_sparsity.rst | 10 + torchao/dtypes/__init__.py | 5 +- torchao/dtypes/affine_quantized_tensor_ops.py | 2 +- torchao/dtypes/uintx/__init__.py | 4 - torchao/dtypes/uintx/block_sparse_layout.py | 243 ++---------------- torchao/prototype/sparsity/__init__.py | 6 + .../prototype/sparsity/block_sparse_layout.py | 233 +++++++++++++++++ 8 files changed, 274 insertions(+), 230 deletions(-) create mode 100644 torchao/prototype/sparsity/block_sparse_layout.py diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 6cbec7465e..37e1407435 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -22,7 +22,6 @@ Layouts and Tensor Subclasses FloatxTensor FloatxTensorCoreLayout MarlinSparseLayout - BlockSparseLayout UintxLayout MarlinQQQTensor MarlinQQQLayout diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index 9fc6644683..acb4cd3fb6 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -15,3 +15,13 @@ torchao.sparsity apply_fake_sparsity WandaSparsifier PerChannelNormObserver + +Prototype +--------- +.. currentmodule:: torchao.prototype.sparsity + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + BlockSparseLayout diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 07f03c7ed9..8033a4de66 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -13,8 +13,11 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 + +# Import BlockSparseLayout from prototype for backward compatibility +from torchao.prototype.sparsity import BlockSparseLayout + from .uintx import ( - BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ffadece729..47f285a121 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,7 +25,7 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.block_sparse_layout import ( +from torchao.prototype.sparsity.block_sparse_layout import ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 6d1bc95653..1d269fc4c4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .block_sparse_layout import ( - BlockSparseLayout, -) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -39,7 +36,6 @@ __all__ = [ "UintxLayout", - "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 0c6046c313..7d6b76ac95 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -3,231 +3,28 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import logging -from dataclasses import dataclass -from typing import Optional, Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.uintx.plain_layout import ( - PlainAQTTensorImpl, - _aqt_is_int8_reduced_range, +warnings.warn( + "Importing from torchao.dtypes.uintx.block_sparse_layout is deprecated. " + "Please use 'from torchao.prototype.sparsity import BlockSparseLayout' instead. " + "This import path will be removed in torchao v0.16.0.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import ( - Layout, - PlainLayout, -) - -logger = logging.getLogger(__name__) - -aten = torch.ops.aten - - -@dataclass(frozen=True) -class BlockSparseLayout(Layout): - """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. - - Attributes: - blocksize (int): The size of the blocks in the sparse matrix. Default is 64. - """ - - blocksize: int = 64 - - -@register_layout(BlockSparseLayout) -class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): - bsr_crow_indices: Optional[torch.Tensor] - bsr_col_indices: Optional[torch.Tensor] - bsr_values: Optional[torch.Tensor] - scale: Optional[torch.Tensor] - zero_point: Optional[torch.Tensor] - - __slots__ = [ - "bsr_crow_indices", - "bsr_col_indices", - "bsr_values", - "scale", - "zero_point", - ] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - if bsr_values is None: - raise ValueError("bsr values must be provided!") - else: - previous_tensor = bsr_values - - kwargs = { - "device": previous_tensor.device, - "dtype": previous_tensor.dtype, - "layout": previous_tensor.layout, - "requires_grad": requires_grad, - } - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( # noqa: PYI034 - self, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - self.bsr_crow_indices = bsr_crow_indices - self.bsr_col_indices = bsr_col_indices - self.bsr_values = bsr_values - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self._layout, self.requires_grad) - return inner_tensors, tensor_meta - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta: Tuple[torch.Size, bool], - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, _layout, requires_grad = tensor_meta - return cls( - shape=shape, - bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), - bsr_col_indices=inner_tensors.get("bsr_col_indices", None), - bsr_values=inner_tensors.get("bsr_values", None), - scale=inner_tensors.get("scale", None), - zero_point=inner_tensors.get("zero_point", None), - _layout=_layout, - requires_grad=requires_grad, - ) - - @classmethod - def from_plain(cls, int_data, scale, zero_point, _layout): - bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) - return cls( - shape=int_data.shape, - bsr_crow_indices=bsr_tensor.crow_indices(), - bsr_col_indices=bsr_tensor.col_indices(), - bsr_values=bsr_tensor.values(), - scale=scale, - zero_point=zero_point, - _layout=_layout, - requires_grad=False, - ) - - def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense( - self.crow_indices(), - self.col_indices(), - self.values(), - self.shape[0], - self.shape[1], - ) - return int_data_expanded, self.scale, self.zero_point - - def _apply_fn_to_data(self, func): - return self.__class__( - shape=self.shape, - bsr_crow_indices=func(self.bsr_crow_indices), - bsr_col_indices=func(self.bsr_col_indices), - bsr_values=func(self.bsr_values), - scale=self.scale, - zero_point=self.zero_point, - _layout=self._layout, - requires_grad=self.requires_grad, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - # Need the following for bsr specific functions - if func is aten.crow_indices.default: - return args[0].bsr_crow_indices.detach() - - if func is aten.col_indices.default: - return args[0].bsr_col_indices.detach() - - if func is aten.values.default: - return args[0].bsr_values.detach() - - if func is aten._nnz.default: - return args[0].bsr_values.shape[0] - - raise NotImplementedError( - f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - -def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.is_cuda - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor._layout, BlockSparseLayout) - ) - - -def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals = weight_tensor.tensor_impl - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - tmp_t = tmp.t() - - y = torch.ops.blocksparse.int_addmm( - w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1), - ) - y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) - y = y.reshape(*y_shape) +from torchao.prototype.sparsity.block_sparse_layout import ( + BlockSparseLayout, + BlockSparseAQTTensorImpl, + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y +__all__ = [ + "BlockSparseLayout", + "BlockSparseAQTTensorImpl", + "_linear_int8_act_int8_weight_block_sparse_check", + "_linear_int8_act_int8_weight_block_sparse_impl", +] diff --git a/torchao/prototype/sparsity/__init__.py b/torchao/prototype/sparsity/__init__.py index 821e5049e0..403458ceae 100644 --- a/torchao/prototype/sparsity/__init__.py +++ b/torchao/prototype/sparsity/__init__.py @@ -19,6 +19,11 @@ WeightNormSparsifier, ) +# Block Sparse Layout +from torchao.prototype.sparsity.block_sparse_layout import ( + BlockSparseLayout, +) + __all__ = [ "BaseScheduler", "CubicSL", @@ -30,4 +35,5 @@ "get_arg_info_from_tensor_fqn", "module_to_fqn", "WeightNormSparsifier", + "BlockSparseLayout", ] diff --git a/torchao/prototype/sparsity/block_sparse_layout.py b/torchao/prototype/sparsity/block_sparse_layout.py new file mode 100644 index 0000000000..0c6046c313 --- /dev/null +++ b/torchao/prototype/sparsity/block_sparse_layout.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import ( + Layout, + PlainLayout, +) + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class BlockSparseLayout(Layout): + """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. + + Attributes: + blocksize (int): The size of the blocks in the sparse matrix. Default is 64. + """ + + blocksize: int = 64 + + +@register_layout(BlockSparseLayout) +class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + __slots__ = [ + "bsr_crow_indices", + "bsr_col_indices", + "bsr_values", + "scale", + "zero_point", + ] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( # noqa: PYI034 + self, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + self.bsr_crow_indices = bsr_crow_indices + self.bsr_col_indices = bsr_col_indices + self.bsr_values = bsr_values + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self._layout, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, _layout, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + scale=inner_tensors.get("scale", None), + zero_point=inner_tensors.get("zero_point", None), + _layout=_layout, + requires_grad=requires_grad, + ) + + @classmethod + def from_plain(cls, int_data, scale, zero_point, _layout): + bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) + return cls( + shape=int_data.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + scale=scale, + zero_point=zero_point, + _layout=_layout, + requires_grad=False, + ) + + def get_plain(self): + int_data_expanded = torch.ops.blocksparse.bsr_to_dense( + self.crow_indices(), + self.col_indices(), + self.values(), + self.shape[0], + self.shape[1], + ) + return int_data_expanded, self.scale, self.zero_point + + def _apply_fn_to_data(self, func): + return self.__class__( + shape=self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + scale=self.scale, + zero_point=self.zero_point, + _layout=self._layout, + requires_grad=self.requires_grad, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + # Need the following for bsr specific functions + if func is aten.crow_indices.default: + return args[0].bsr_crow_indices.detach() + + if func is aten.col_indices.default: + return args[0].bsr_col_indices.detach() + + if func is aten.values.default: + return args[0].bsr_values.detach() + + if func is aten._nnz.default: + return args[0].bsr_values.shape[0] + + raise NotImplementedError( + f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, BlockSparseLayout) + ) + + +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals = weight_tensor.tensor_impl + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp_t = tmp.t() + + y = torch.ops.blocksparse.int_addmm( + w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1), + ) + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) + y = y.reshape(*y_shape) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y From dd2e7d613c8b3af0b9eeb05a0b15fd1867bcc795 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 04:03:24 +0000 Subject: [PATCH 02/24] test fixes --- test/sparsity/test_sparse_api.py | 2 +- torchao/dtypes/__init__.py | 4 +--- torchao/dtypes/uintx/__init__.py | 4 ++++ torchao/dtypes/uintx/block_sparse_layout.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 003a50c4d1..4bfcb5bb0a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -253,7 +253,7 @@ def test_sparse(self, compile): quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) reference = model_copy(input) - from torchao.dtypes import BlockSparseLayout + from torchao.prototype.sparsity import BlockSparseLayout quantize_( model, diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8033a4de66..d2042d1fb9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,10 +14,8 @@ ) from .nf4tensor import NF4Tensor, to_nf4 -# Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.sparsity import BlockSparseLayout - from .uintx import ( + BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 1d269fc4c4..6d1bc95653 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,3 +1,6 @@ +from .block_sparse_layout import ( + BlockSparseLayout, +) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -36,6 +39,7 @@ __all__ = [ "UintxLayout", + "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 7d6b76ac95..edff0b0493 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -8,7 +8,7 @@ import warnings warnings.warn( - "Importing from torchao.dtypes.uintx.block_sparse_layout is deprecated. " + "Importing BlockSparseLayout from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.sparsity import BlockSparseLayout' instead. " "This import path will be removed in torchao v0.16.0.", DeprecationWarning, From efa74cdf1b5aa84384f0ddc983441fca32fcdd0c Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Sun, 2 Nov 2025 20:07:02 -0800 Subject: [PATCH 03/24] Remove unused import in __init__.py --- torchao/dtypes/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d2042d1fb9..07f03c7ed9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -13,7 +13,6 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 - from .uintx import ( BlockSparseLayout, CutlassInt4PackedLayout, From 37225378308bdf3aedb4aca5309520bf2de34fa4 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Sun, 2 Nov 2025 20:08:10 -0800 Subject: [PATCH 04/24] Clean up exports in block_sparse_layout.py Removed unused exports from block_sparse_layout.py --- torchao/dtypes/uintx/block_sparse_layout.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index edff0b0493..137f13d56f 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -21,10 +21,3 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) - -__all__ = [ - "BlockSparseLayout", - "BlockSparseAQTTensorImpl", - "_linear_int8_act_int8_weight_block_sparse_check", - "_linear_int8_act_int8_weight_block_sparse_impl", -] From 6c1b8ef35018107d692231858577e9f40a9e7ea1 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:06:41 +0000 Subject: [PATCH 05/24] test fixes --- torchao/dtypes/__init__.py | 5 ++++- torchao/dtypes/uintx/__init__.py | 4 ---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 07f03c7ed9..8033a4de66 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -13,8 +13,11 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 + +# Import BlockSparseLayout from prototype for backward compatibility +from torchao.prototype.sparsity import BlockSparseLayout + from .uintx import ( - BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 6d1bc95653..1d269fc4c4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .block_sparse_layout import ( - BlockSparseLayout, -) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -39,7 +36,6 @@ __all__ = [ "UintxLayout", - "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", From 1702bdf932fb76b9d38f2ec1457bbee5236bf8e4 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:13:05 +0000 Subject: [PATCH 06/24] Fix ruff import sorting and add noqa for re-exported imports --- torchao/dtypes/__init__.py | 7 +++---- torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++++---- torchao/dtypes/uintx/block_sparse_layout.py | 8 ++++---- torchao/prototype/sparsity/__init__.py | 9 ++++----- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8033a4de66..33a22b4925 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,3 +1,6 @@ +# Import BlockSparseLayout from prototype for backward compatibility +from torchao.prototype.sparsity import BlockSparseLayout + from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, @@ -13,10 +16,6 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 - -# Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.sparsity import BlockSparseLayout - from .uintx import ( CutlassInt4PackedLayout, Int4CPULayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 47f285a121..8f6305d0c0 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.prototype.sparsity.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, @@ -94,6 +90,10 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.sparsity.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 137f13d56f..ef162d3945 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -16,8 +16,8 @@ ) from torchao.prototype.sparsity.block_sparse_layout import ( - BlockSparseLayout, - BlockSparseAQTTensorImpl, - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, + BlockSparseAQTTensorImpl, # noqa: F401 + BlockSparseLayout, # noqa: F401 + _linear_int8_act_int8_weight_block_sparse_check, # noqa: F401 + _linear_int8_act_int8_weight_block_sparse_impl, # noqa: F401 ) diff --git a/torchao/prototype/sparsity/__init__.py b/torchao/prototype/sparsity/__init__.py index 403458ceae..20ea714577 100644 --- a/torchao/prototype/sparsity/__init__.py +++ b/torchao/prototype/sparsity/__init__.py @@ -1,5 +1,9 @@ # Sparsifier # Scheduler +# Block Sparse Layout +from torchao.prototype.sparsity.block_sparse_layout import ( + BlockSparseLayout, +) from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL @@ -19,11 +23,6 @@ WeightNormSparsifier, ) -# Block Sparse Layout -from torchao.prototype.sparsity.block_sparse_layout import ( - BlockSparseLayout, -) - __all__ = [ "BaseScheduler", "CubicSL", From 18d682a61dd3f0d16ec7e4c1cba7d9de4cb1419a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:29:19 +0000 Subject: [PATCH 07/24] Move block_sparse_layout to prototype/dtypes/uintx and update all imports - Move torchao/prototype/sparsity/block_sparse_layout.py to torchao/prototype/dtypes/uintx/block_sparse_layout.py - Update all imports to use 'from torchao.prototype.dtypes import BlockSparseLayout' - Update deprecation warning message - Create torchao/prototype/dtypes/__init__.py and torchao/prototype/dtypes/uintx/__init__.py - Remove BlockSparseLayout from torchao/prototype/sparsity/__init__.py - Update documentation --- docs/source/api_ref_sparsity.rst | 10 ---------- test/sparsity/test_sparse_api.py | 2 +- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 2 +- torchao/dtypes/uintx/block_sparse_layout.py | 4 ++-- torchao/prototype/dtypes/__init__.py | 12 ++++++++++++ torchao/prototype/dtypes/uintx/__init__.py | 12 ++++++++++++ .../uintx}/block_sparse_layout.py | 0 torchao/prototype/sparsity/__init__.py | 5 ----- 9 files changed, 29 insertions(+), 20 deletions(-) create mode 100644 torchao/prototype/dtypes/__init__.py create mode 100644 torchao/prototype/dtypes/uintx/__init__.py rename torchao/prototype/{sparsity => dtypes/uintx}/block_sparse_layout.py (100%) diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index acb4cd3fb6..9fc6644683 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -15,13 +15,3 @@ torchao.sparsity apply_fake_sparsity WandaSparsifier PerChannelNormObserver - -Prototype ---------- -.. currentmodule:: torchao.prototype.sparsity - -.. autosummary:: - :toctree: generated/ - :nosignatures: - - BlockSparseLayout diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 4bfcb5bb0a..66cd032a9a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -253,7 +253,7 @@ def test_sparse(self, compile): quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) reference = model_copy(input) - from torchao.prototype.sparsity import BlockSparseLayout + from torchao.prototype.dtypes import BlockSparseLayout quantize_( model, diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 33a22b4925..ebf6c98553 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,5 +1,5 @@ # Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.sparsity import BlockSparseLayout +from torchao.prototype.dtypes import BlockSparseLayout from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 8f6305d0c0..2b6f47e692 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -90,7 +90,7 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) -from torchao.prototype.sparsity.block_sparse_layout import ( +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index ef162d3945..2840a9b7fa 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -9,13 +9,13 @@ warnings.warn( "Importing BlockSparseLayout from torchao.dtypes is deprecated. " - "Please use 'from torchao.prototype.sparsity import BlockSparseLayout' instead. " + "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " "This import path will be removed in torchao v0.16.0.", DeprecationWarning, stacklevel=2, ) -from torchao.prototype.sparsity.block_sparse_layout import ( +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( BlockSparseAQTTensorImpl, # noqa: F401 BlockSparseLayout, # noqa: F401 _linear_int8_act_int8_weight_block_sparse_check, # noqa: F401 diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py new file mode 100644 index 0000000000..c8dfc67f70 --- /dev/null +++ b/torchao/prototype/dtypes/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .uintx import BlockSparseLayout + +__all__ = [ + "BlockSparseLayout", +] + diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py new file mode 100644 index 0000000000..2b8e2e66ba --- /dev/null +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .block_sparse_layout import BlockSparseLayout + +__all__ = [ + "BlockSparseLayout", +] + diff --git a/torchao/prototype/sparsity/block_sparse_layout.py b/torchao/prototype/dtypes/uintx/block_sparse_layout.py similarity index 100% rename from torchao/prototype/sparsity/block_sparse_layout.py rename to torchao/prototype/dtypes/uintx/block_sparse_layout.py diff --git a/torchao/prototype/sparsity/__init__.py b/torchao/prototype/sparsity/__init__.py index 20ea714577..821e5049e0 100644 --- a/torchao/prototype/sparsity/__init__.py +++ b/torchao/prototype/sparsity/__init__.py @@ -1,9 +1,5 @@ # Sparsifier # Scheduler -# Block Sparse Layout -from torchao.prototype.sparsity.block_sparse_layout import ( - BlockSparseLayout, -) from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL @@ -34,5 +30,4 @@ "get_arg_info_from_tensor_fqn", "module_to_fqn", "WeightNormSparsifier", - "BlockSparseLayout", ] From de054ab08cc45e244877a7f09bf91fe1c8d1e2b1 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:37:19 +0000 Subject: [PATCH 08/24] Apply ruff formatting (remove trailing newlines) --- torchao/prototype/dtypes/__init__.py | 1 - torchao/prototype/dtypes/uintx/__init__.py | 1 - 2 files changed, 2 deletions(-) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index c8dfc67f70..54d395e673 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -9,4 +9,3 @@ __all__ = [ "BlockSparseLayout", ] - diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 2b8e2e66ba..107e6a344b 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -9,4 +9,3 @@ __all__ = [ "BlockSparseLayout", ] - From 4ef291fa13fdf89d636d84aecf3368a6584a3736 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 05:46:15 +0000 Subject: [PATCH 09/24] Add Prototype section to dtypes documentation with BlockSparseLayout --- docs/source/api_ref_dtypes.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 37e1407435..abf938c322 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -42,6 +42,17 @@ Quantization techniques to_affine_quantized_floatx_static to_marlinqqq_quantized_intx to_nf4 + +Prototype +--------- +.. currentmodule:: torchao.prototype.dtypes + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + BlockSparseLayout + .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. From b492530d774e43554ae349b6e6cb0b30815857c9 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 06:02:57 +0000 Subject: [PATCH 10/24] Move cutlass_int4_packed_layout to prototype/dtypes/uintx - Moved cutlass_int4_packed_layout.py from torchao/dtypes/uintx/ to torchao/prototype/dtypes/uintx/ - Created torchao/prototype/dtypes/__init__.py and uintx/__init__.py to export CutlassInt4PackedLayout - Replaced old file with backward compatibility stub that imports from new location - Added deprecation warning for old import path (to be removed in v0.16.0) - Updated torchao/dtypes/__init__.py to re-export from prototype for backward compatibility - Updated internal imports in affine_quantized_tensor_ops.py to use new prototype location - Removed CutlassInt4PackedLayout from torchao/dtypes/uintx/__init__.py to avoid circular imports - Updated documentation to move CutlassInt4PackedLayout to Prototype section All import paths work: - New: from torchao.prototype.dtypes import CutlassInt4PackedLayout - Backward compat: from torchao.dtypes import CutlassInt4PackedLayout - Deprecated: from torchao.dtypes.uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout --- docs/source/api_ref_dtypes.rst | 12 +- torchao/dtypes/__init__.py | 8 +- torchao/dtypes/affine_quantized_tensor_ops.py | 20 +- torchao/dtypes/uintx/__init__.py | 4 - .../uintx/cutlass_int4_packed_layout.py | 238 ++---------------- torchao/prototype/dtypes/__init__.py | 12 + torchao/prototype/dtypes/uintx/__init__.py | 12 + .../uintx/cutlass_int4_packed_layout.py | 224 +++++++++++++++++ 8 files changed, 296 insertions(+), 234 deletions(-) create mode 100644 torchao/prototype/dtypes/__init__.py create mode 100644 torchao/prototype/dtypes/uintx/__init__.py create mode 100644 torchao/prototype/dtypes/uintx/cutlass_int4_packed_layout.py diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 37e1407435..e77ecbd899 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -26,7 +26,6 @@ Layouts and Tensor Subclasses MarlinQQQTensor MarlinQQQLayout Int4CPULayout - CutlassInt4PackedLayout CutlassSemiSparseLayout Quantization techniques @@ -42,6 +41,17 @@ Quantization techniques to_affine_quantized_floatx_static to_marlinqqq_quantized_intx to_nf4 + +Prototype +--------- +.. currentmodule:: torchao.prototype.dtypes + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + CutlassInt4PackedLayout + .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8033a4de66..2bd070af6b 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,3 +1,6 @@ +# Import layouts from prototype for backward compatibility +from torchao.prototype.dtypes import CutlassInt4PackedLayout + from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, @@ -13,12 +16,7 @@ Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 - -# Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.sparsity import BlockSparseLayout - from .uintx import ( - CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, Int8DynamicActInt4WeightCPULayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 47f285a121..a657db4031 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,16 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.prototype.sparsity.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) -from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( - _linear_int4_act_int4_weight_cutlass_check, - _linear_int4_act_int4_weight_cutlass_impl, - _linear_int8_act_int4_weight_cutlass_check, - _linear_int8_act_int4_weight_cutlass_impl, -) from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, @@ -94,6 +84,16 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + _linear_int8_act_int4_weight_cutlass_check, + _linear_int8_act_int4_weight_cutlass_impl, +) +from torchao.prototype.sparsity.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 1d269fc4c4..b76e80e0fc 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .cutlass_int4_packed_layout import ( - CutlassInt4PackedLayout, -) from .dyn_int8_act_int4_wei_cpu_layout import ( Int8DynamicActInt4WeightCPULayout, ) @@ -43,7 +40,6 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", - "CutlassInt4PackedLayout", "PackedLinearInt8DynamicActivationIntxWeightLayout", "QDQLayout", "Int4XPULayout", diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index d680f4cf77..e0a0c8edf0 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -3,222 +3,32 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Optional -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.uintx.plain_layout import ( - _aqt_is_int8, +warnings.warn( + "Importing from torchao.dtypes.uintx.cutlass_int4_packed_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. " + "This import path will be removed in torchao v0.16.0.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout - -aten = torch.ops.aten - - -def _aqt_is_int4(aqt): - """Check if an AffineQuantizedTensor is int4 quantized Tensor""" - # TODO: use torch.int4 - return ( - aqt.tensor_impl.dtype == torch.int8 - and aqt.quant_min == -8 - and aqt.quant_max == 7 - ) - - -def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool: - return ( - isinstance(self, Int4PackedTensorImpl) - and isinstance(src, Int4PackedTensorImpl) - and self.shape == src.shape - and self.int_data.shape == src.int_data.shape - and self.scale.shape == src.scale.shape - and type(self._layout) == type(src._layout) - ) - - -@dataclass(frozen=True) -class CutlassInt4PackedLayout(Layout): - """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" - - pass - - -@register_layout(CutlassInt4PackedLayout) -class Int4PackedTensorImpl(AQTTensorImpl): - """ - TensorImpl storage class for int4 packed layout for affine quantized tensor. - """ - - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - self.int_data = int_data - self.scale = scale - self._layout = _layout - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - elif func is aten.copy_.default: - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - raise NotImplementedError( - f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - return ["int_data", "scale"], [self._layout] - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data = tensor_data_dict["int_data"] - scale = tensor_data_dict["scale"] - (_layout,) = tensor_attributes - return cls(int_data, scale, _layout) - - def get_plain(self): - int_data = torch.stack( - ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 - ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) - return int_data, self.scale, None - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert zero_point is None or torch.all(zero_point == 0) - int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) - return cls( - int_data_s4, - scale, - _layout, - ) - - def get_layout(self) -> Layout: - return self._layout - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.scale = fn(self.scale) - return self - - -def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and isinstance(input_tensor._layout, PlainLayout) - and _aqt_is_int8(input_tensor) - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and len(input_tensor.shape) >= 2 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) - and _aqt_is_int4(weight_tensor) - and weight_tensor.dtype == input_tensor.dtype - and len(weight_tensor.shape) == 2 - and weight_tensor.tensor_impl.scale.dtype == torch.float32 - and len(weight_tensor.tensor_impl.scale.shape) == 1 - and (bias is None or bias.dtype == input_tensor.dtype) - and (bias is None or len(bias.shape) == 1) - ) - - -def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 - - weight = weight_tensor.tensor_impl.int_data - weight_scale = weight_tensor.tensor_impl.scale - input = input_tensor.tensor_impl.int_data - input_scale = input_tensor.tensor_impl.scale - out_dtype = input_tensor.dtype - - out = rowwise_scaled_linear_cutlass_s8s4( - input, input_scale, weight, weight_scale, bias, out_dtype - ) - - return out - - -def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and isinstance(input_tensor._layout, CutlassInt4PackedLayout) - and _aqt_is_int4(input_tensor) - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and len(input_tensor.shape) >= 2 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) - and _aqt_is_int4(weight_tensor) - and weight_tensor.dtype == input_tensor.dtype - and len(weight_tensor.shape) == 2 - and weight_tensor.tensor_impl.scale.dtype == torch.float32 - and len(weight_tensor.tensor_impl.scale.shape) == 1 - ) - - -def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 - - weight = weight_tensor.tensor_impl.int_data - weight_scale = weight_tensor.tensor_impl.scale - input = input_tensor.tensor_impl.int_data - input_scale = input_tensor.tensor_impl.scale - out_dtype = input_tensor.dtype - - out = rowwise_scaled_linear_cutlass_s4s4( - input, input_scale, weight, weight_scale, bias, out_dtype - ) +from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( # noqa: F401 + CutlassInt4PackedLayout, + Int4PackedTensorImpl, + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + _linear_int8_act_int4_weight_cutlass_check, + _linear_int8_act_int4_weight_cutlass_impl, +) - return out +__all__ = [ + "CutlassInt4PackedLayout", + "Int4PackedTensorImpl", + "_linear_int4_act_int4_weight_cutlass_check", + "_linear_int4_act_int4_weight_cutlass_impl", + "_linear_int8_act_int4_weight_cutlass_check", + "_linear_int8_act_int4_weight_cutlass_impl", +] diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py new file mode 100644 index 0000000000..abd9c2c476 --- /dev/null +++ b/torchao/prototype/dtypes/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .uintx import CutlassInt4PackedLayout + +__all__ = [ + "CutlassInt4PackedLayout", +] + diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py new file mode 100644 index 0000000000..135e643781 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .cutlass_int4_packed_layout import CutlassInt4PackedLayout + +__all__ = [ + "CutlassInt4PackedLayout", +] + diff --git a/torchao/prototype/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/prototype/dtypes/uintx/cutlass_int4_packed_layout.py new file mode 100644 index 0000000000..d680f4cf77 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/cutlass_int4_packed_layout.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + _aqt_is_int8, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout + +aten = torch.ops.aten + + +def _aqt_is_int4(aqt): + """Check if an AffineQuantizedTensor is int4 quantized Tensor""" + # TODO: use torch.int4 + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -8 + and aqt.quant_max == 7 + ) + + +def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool: + return ( + isinstance(self, Int4PackedTensorImpl) + and isinstance(src, Int4PackedTensorImpl) + and self.shape == src.shape + and self.int_data.shape == src.int_data.shape + and self.scale.shape == src.scale.shape + and type(self._layout) == type(src._layout) + ) + + +@dataclass(frozen=True) +class CutlassInt4PackedLayout(Layout): + """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" + + pass + + +@register_layout(CutlassInt4PackedLayout) +class Int4PackedTensorImpl(AQTTensorImpl): + """ + TensorImpl storage class for int4 packed layout for affine quantized tensor. + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + self.int_data = int_data + self.scale = scale + self._layout = _layout + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + elif func is aten.copy_.default: + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + raise NotImplementedError( + f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "scale"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + (_layout,) = tensor_attributes + return cls(int_data, scale, _layout) + + def get_plain(self): + int_data = torch.stack( + ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 + ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) + return int_data, self.scale, None + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert zero_point is None or torch.all(zero_point == 0) + int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) + return cls( + int_data_s4, + scale, + _layout, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + return self + + +def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, PlainLayout) + and _aqt_is_int8(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == torch.float32 + and len(weight_tensor.tensor_impl.scale.shape) == 1 + and (bias is None or bias.dtype == input_tensor.dtype) + and (bias is None or len(bias.shape) == 1) + ) + + +def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + out_dtype = input_tensor.dtype + + out = rowwise_scaled_linear_cutlass_s8s4( + input, input_scale, weight, weight_scale, bias, out_dtype + ) + + return out + + +def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, CutlassInt4PackedLayout) + and _aqt_is_int4(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == torch.float32 + and len(weight_tensor.tensor_impl.scale.shape) == 1 + ) + + +def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + out_dtype = input_tensor.dtype + + out = rowwise_scaled_linear_cutlass_s4s4( + input, input_scale, weight, weight_scale, bias, out_dtype + ) + + return out From d699ee0a418c913017be653a9b2c584eb0bf312a Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 3 Nov 2025 09:49:54 -0800 Subject: [PATCH 11/24] Clean up imports in affine_quantized_tensor_ops.py Removed redundant imports of block sparse layout functions. --- torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 2b6f47e692..45e350068e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,6 +25,10 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, @@ -90,10 +94,6 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) -from torchao.prototype.dtypes.uintx.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, From ab8479957f0bbc74530bc0882d9c26881fe28185 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 11:11:58 -0800 Subject: [PATCH 12/24] Update internal links --- torchao/dtypes/__init__.py | 4 +--- torchao/dtypes/uintx/__init__.py | 4 ++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index ebf6c98553..07f03c7ed9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,6 +1,3 @@ -# Import BlockSparseLayout from prototype for backward compatibility -from torchao.prototype.dtypes import BlockSparseLayout - from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, @@ -17,6 +14,7 @@ ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( + BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 1d269fc4c4..6d1bc95653 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,3 +1,6 @@ +from .block_sparse_layout import ( + BlockSparseLayout, +) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -36,6 +39,7 @@ __all__ = [ "UintxLayout", + "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", From 40ab18861fb676d462558276212d386bd9211177 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 21:56:08 +0000 Subject: [PATCH 13/24] test fixes --- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/uintx/__init__.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 07f03c7ed9..6cfcc741fa 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,6 @@ ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( - BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, @@ -33,6 +32,7 @@ Layout, PlainLayout, ) +from .uintx.block_sparse_layout import BlockSparseLayout __all__ = [ "NF4Tensor", diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 6d1bc95653..1d269fc4c4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .block_sparse_layout import ( - BlockSparseLayout, -) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -39,7 +36,6 @@ __all__ = [ "UintxLayout", - "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", From 894857e3a6582842fbfb06090b057f1f32ce8b1a Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 22:08:46 +0000 Subject: [PATCH 14/24] ruff fixes --- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 6cfcc741fa..b1e7fc9875 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -28,11 +28,11 @@ UintxLayout, to_marlinqqq_quantized_intx, ) +from .uintx.block_sparse_layout import BlockSparseLayout from .utils import ( Layout, PlainLayout, ) -from .uintx.block_sparse_layout import BlockSparseLayout __all__ = [ "NF4Tensor", diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 45e350068e..2b6f47e692 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.prototype.dtypes.uintx.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, @@ -94,6 +90,10 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, From 1d51ebf104cb54edc24d83e3faf52cd507718c55 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 3 Nov 2025 14:20:48 -0800 Subject: [PATCH 15/24] Remove unused import from uintx init file Removed unused import for CutlassInt4PackedLayout. --- torchao/dtypes/uintx/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 7c5f16494e..b76e80e0fc 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .cutlass_int4_packed_layout import ( - CutlassInt4PackedLayout, -) from .dyn_int8_act_int4_wei_cpu_layout import ( Int8DynamicActInt4WeightCPULayout, ) From 8c18d4d9fe3ad97bc40b815d2d341d74fb0c687e Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 3 Nov 2025 14:21:21 -0800 Subject: [PATCH 16/24] Remove __all__ exports from module Removed __all__ exports from cutlass_int4_packed_layout.py --- torchao/dtypes/__init__.py | 2 +- .../uintx/cutlass_int4_packed_layout.py | 23 ++++++------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index b1e7fc9875..252498bc97 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,6 @@ ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( - CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, Int8DynamicActInt4WeightCPULayout, @@ -29,6 +28,7 @@ to_marlinqqq_quantized_intx, ) from .uintx.block_sparse_layout import BlockSparseLayout +from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index e0a0c8edf0..48a6ffe96d 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -8,7 +8,7 @@ import warnings warnings.warn( - "Importing from torchao.dtypes.uintx.cutlass_int4_packed_layout is deprecated. " + "Importing from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. " "This import path will be removed in torchao v0.16.0.", DeprecationWarning, @@ -16,19 +16,10 @@ ) from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( # noqa: F401 - CutlassInt4PackedLayout, - Int4PackedTensorImpl, - _linear_int4_act_int4_weight_cutlass_check, - _linear_int4_act_int4_weight_cutlass_impl, - _linear_int8_act_int4_weight_cutlass_check, - _linear_int8_act_int4_weight_cutlass_impl, + CutlassInt4PackedLayout, # noqa: F401 + Int4PackedTensorImpl, # noqa: F401 + _linear_int4_act_int4_weight_cutlass_check, # noqa: F401 + _linear_int4_act_int4_weight_cutlass_impl, # noqa: F401 + _linear_int8_act_int4_weight_cutlass_check, # noqa: F401 + _linear_int8_act_int4_weight_cutlass_impl, # noqa: F401 ) - -__all__ = [ - "CutlassInt4PackedLayout", - "Int4PackedTensorImpl", - "_linear_int4_act_int4_weight_cutlass_check", - "_linear_int4_act_int4_weight_cutlass_impl", - "_linear_int8_act_int4_weight_cutlass_check", - "_linear_int8_act_int4_weight_cutlass_impl", -] From 081a4ed2a6716f18ab793827aa4629cddabea254 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 3 Nov 2025 22:41:27 +0000 Subject: [PATCH 17/24] Empty commit to trigger CI From 61c198683785282294c655726483d3dddf7cb652 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 4 Nov 2025 10:23:52 -0800 Subject: [PATCH 18/24] Lint fixes --- torchao/dtypes/affine_quantized_tensor_ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 379a3bb9d2..e46809059e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,12 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( - _linear_int4_act_int4_weight_cutlass_check, - _linear_int4_act_int4_weight_cutlass_impl, - _linear_int8_act_int4_weight_cutlass_check, - _linear_int8_act_int4_weight_cutlass_impl, -) from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, @@ -94,6 +88,12 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) +from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + _linear_int8_act_int4_weight_cutlass_check, + _linear_int8_act_int4_weight_cutlass_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, From 8ebbb9cc416833dda7da01b664e5855b0e8eb633 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 4 Nov 2025 11:22:07 -0800 Subject: [PATCH 19/24] Add test cases --- test/sparsity/test_sparse_api.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 66cd032a9a..5946df0a50 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -267,6 +267,31 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) + # TODO: Remove this test once the deprecated API has been removed + def test_sparse_deprecated(self): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.block_sparse_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Ensure all warnings are captured + self.assertTrue( + any( + issubclass(warning.category, DeprecationWarning) + and "BlockSparseLayout" in str(warning.message) + for warning in w + ), + f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", + ) + common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) From 78f5e4c4e92d0eb4971d9446d7918a3d2a813f37 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 4 Nov 2025 11:22:07 -0800 Subject: [PATCH 20/24] Add test cases --- test/sparsity/test_sparse_api.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 66cd032a9a..5946df0a50 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -267,6 +267,31 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) + # TODO: Remove this test once the deprecated API has been removed + def test_sparse_deprecated(self): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.block_sparse_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Ensure all warnings are captured + self.assertTrue( + any( + issubclass(warning.category, DeprecationWarning) + and "BlockSparseLayout" in str(warning.message) + for warning in w + ), + f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", + ) + common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) From 1a756898a4950f0c5c505bafb1d3a30d6c0bbc0e Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 4 Nov 2025 11:22:07 -0800 Subject: [PATCH 21/24] Add test cases --- test/sparsity/test_sparse_api.py | 26 +++++++++++++++++++++ torchao/dtypes/uintx/block_sparse_layout.py | 3 ++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 66cd032a9a..b50026dacc 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -267,6 +267,32 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) + # TODO: Remove this test once the deprecated API has been removed + def test_sparse_deprecated(self): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.block_sparse_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import BlockSparseLayout + warnings.simplefilter("always") # Ensure all warnings are captured + self.assertTrue( + any( + issubclass(warning.category, DeprecationWarning) + and "BlockSparseLayout" in str(warning.message) + for warning in w + ), + f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", + ) + common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 2840a9b7fa..c1d6436268 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -10,7 +10,8 @@ warnings.warn( "Importing BlockSparseLayout from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " - "This import path will be removed in torchao v0.16.0.", + "This import path will be removed in torchao v0.16.0. " + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", DeprecationWarning, stacklevel=2, ) From 42acafa6da749e8d37168e2c0a204153b2cb3653 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 4 Nov 2025 21:50:35 -0800 Subject: [PATCH 22/24] Updates --- test/integration/test_integration.py | 27 +++++++++++++++++++ test/sparsity/test_sparse_api.py | 3 ++- .../uintx/cutlass_int4_packed_layout.py | 3 ++- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index dc58470526..2d05426d73 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1946,5 +1946,32 @@ def test_benchmark_model_cpu(self): assert self.run_benchmark_model("cpu") is not None +# TODO: Remove this test once the deprecated API has been removed +def test_cutlass_int4_packed_layout_deprecated(): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.cutlass_int4_packed_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "CutlassInt4PackedLayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index b50026dacc..c9d41a98a9 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -282,7 +282,8 @@ def test_sparse_deprecated(self): del sys.modules[mod] with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import BlockSparseLayout + from torchao.dtypes import BlockSparseLayout # noqa: F401 + warnings.simplefilter("always") # Ensure all warnings are captured self.assertTrue( any( diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index 48a6ffe96d..19c55a8993 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -10,7 +10,8 @@ warnings.warn( "Importing from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. " - "This import path will be removed in torchao v0.16.0.", + "This import path will be removed in torchao v0.16.0. " + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", DeprecationWarning, stacklevel=2, ) From 40297baf40d18bc61cbba10e3ebd3670b2156397 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 5 Nov 2025 13:55:29 -0800 Subject: [PATCH 23/24] Update block_sparse_layout.py --- torchao/dtypes/uintx/block_sparse_layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index c1d6436268..6ca4e8745a 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -10,7 +10,7 @@ warnings.warn( "Importing BlockSparseLayout from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " - "This import path will be removed in torchao v0.16.0. " + "This import path will be removed in a future torchao release. " "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", DeprecationWarning, stacklevel=2, From 531b07baf7d07089970d3dea4dd16d401bd3cacf Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 5 Nov 2025 13:56:25 -0800 Subject: [PATCH 24/24] Modify deprecation warning for import path Updated deprecation warning message to indicate future removal. --- torchao/dtypes/uintx/cutlass_int4_packed_layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index 19c55a8993..582dff6d50 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -10,7 +10,7 @@ warnings.warn( "Importing from torchao.dtypes is deprecated. " "Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. " - "This import path will be removed in torchao v0.16.0. " + "This import path will be removed in a future torchao release. " "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", DeprecationWarning, stacklevel=2,