Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bc67c1a
Move block_sparse_layout to prototype
jainapurva Nov 3, 2025
dd2e7d6
test fixes
jainapurva Nov 3, 2025
efa74cd
Remove unused import in __init__.py
jainapurva Nov 3, 2025
3722537
Clean up exports in block_sparse_layout.py
jainapurva Nov 3, 2025
6c1b8ef
test fixes
jainapurva Nov 3, 2025
1702bdf
Fix ruff import sorting and add noqa for re-exported imports
jainapurva Nov 3, 2025
18d682a
Move block_sparse_layout to prototype/dtypes/uintx and update all imp…
jainapurva Nov 3, 2025
de054ab
Apply ruff formatting (remove trailing newlines)
jainapurva Nov 3, 2025
4ef291f
Add Prototype section to dtypes documentation with BlockSparseLayout
jainapurva Nov 3, 2025
b492530
Move cutlass_int4_packed_layout to prototype/dtypes/uintx
jainapurva Nov 3, 2025
d699ee0
Clean up imports in affine_quantized_tensor_ops.py
jainapurva Nov 3, 2025
ab84799
Update internal links
jainapurva Nov 3, 2025
83795a6
<Replace this line with a title. Use 1 line only, 67 chars or less>
jainapurva Nov 3, 2025
40ab188
test fixes
jainapurva Nov 3, 2025
894857e
ruff fixes
jainapurva Nov 3, 2025
5b45869
Fixes
jainapurva Nov 3, 2025
1d51ebf
Remove unused import from uintx init file
jainapurva Nov 3, 2025
8c18d4d
Remove __all__ exports from module
jainapurva Nov 3, 2025
081a4ed
Empty commit to trigger CI
jainapurva Nov 3, 2025
61c1986
Lint fixes
jainapurva Nov 4, 2025
8ebbb9c
Add test cases
jainapurva Nov 4, 2025
b199f7b
Merge remote-tracking branch 'origin/move_block_sparsity' into move_c…
jainapurva Nov 4, 2025
78f5e4c
Add test cases
jainapurva Nov 4, 2025
ffcaca6
Merge remote-tracking branch 'origin/move_block_sparsity' into move_c…
jainapurva Nov 5, 2025
1a75689
Add test cases
jainapurva Nov 4, 2025
8222741
Merge remote-tracking branch 'origin/move_block_sparsity' into move_c…
jainapurva Nov 5, 2025
42acafa
Updates
jainapurva Nov 5, 2025
dc1e447
Merge branch 'main' into move_cutlass_int4_packed_layout
jainapurva Nov 5, 2025
40297ba
Update block_sparse_layout.py
jainapurva Nov 5, 2025
531b07b
Modify deprecation warning for import path
jainapurva Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ Layouts and Tensor Subclasses
FloatxTensor
FloatxTensorCoreLayout
MarlinSparseLayout
BlockSparseLayout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it's both? maybe try to split would be better, or update the summary to say it's two things

Copy link
Contributor Author

@jainapurva jainapurva Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in two separate PRs:
Pr for BlockSparseLayout: #3276

This one is for CutlassInt4PackedLayout

UintxLayout
MarlinQQQTensor
MarlinQQQLayout
Int4CPULayout
CutlassInt4PackedLayout
CutlassSemiSparseLayout

Quantization techniques
Expand All @@ -43,6 +41,18 @@ Quantization techniques
to_affine_quantized_floatx_static
to_marlinqqq_quantized_intx
to_nf4

Prototype
---------
.. currentmodule:: torchao.prototype.dtypes

.. autosummary::
:toctree: generated/
:nosignatures:

BlockSparseLayout
CutlassInt4PackedLayout

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation.
2 changes: 1 addition & 1 deletion test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_sparse(self, compile):
quantize_(model_copy, Int8DynamicActivationInt8WeightConfig())
reference = model_copy(input)

from torchao.dtypes import BlockSparseLayout
from torchao.prototype.dtypes import BlockSparseLayout

quantize_(
model,
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
)
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
CutlassInt4PackedLayout,
Int4CPULayout,
Int4XPULayout,
Int8DynamicActInt4WeightCPULayout,
Expand All @@ -29,6 +27,8 @@
UintxLayout,
to_marlinqqq_quantized_intx,
)
from .uintx.block_sparse_layout import BlockSparseLayout
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
from .utils import (
Layout,
PlainLayout,
Expand Down
10 changes: 5 additions & 5 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
_linear_f16_bf16_act_floatx_weight_check,
_linear_f16_bf16_act_floatx_weight_impl,
)
from torchao.dtypes.uintx.block_sparse_layout import (
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.dtypes.uintx.cutlass_int4_packed_layout import (
from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import (
_linear_int4_act_int4_weight_cutlass_check,
_linear_int4_act_int4_weight_cutlass_impl,
_linear_int8_act_int4_weight_cutlass_check,
Expand Down Expand Up @@ -94,6 +90,10 @@
_linear_bf16_act_uint4_weight_check,
_linear_bf16_act_uint4_weight_impl,
)
from torchao.prototype.dtypes.uintx.block_sparse_layout import (
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
_dequantize_affine_no_zero_point,
Expand Down
8 changes: 0 additions & 8 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
from .block_sparse_layout import (
BlockSparseLayout,
)
from .cutlass_int4_packed_layout import (
CutlassInt4PackedLayout,
)
from .dyn_int8_act_int4_wei_cpu_layout import (
Int8DynamicActInt4WeightCPULayout,
)
Expand Down Expand Up @@ -39,15 +33,13 @@

__all__ = [
"UintxLayout",
"BlockSparseLayout",
"MarlinSparseLayout",
"SemiSparseLayout",
"TensorCoreTiledLayout",
"Int4CPULayout",
"MarlinQQQLayout",
"MarlinQQQTensor",
"to_marlinqqq_quantized_intx",
"CutlassInt4PackedLayout",
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"QDQLayout",
"Int4XPULayout",
Expand Down
238 changes: 14 additions & 224 deletions torchao/dtypes/uintx/block_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,231 +3,21 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)
# Backward compatibility stub - imports from the new location
import warnings

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.uintx.plain_layout import (
PlainAQTTensorImpl,
_aqt_is_int8_reduced_range,
)
from torchao.dtypes.utils import (
Layout,
PlainLayout,
warnings.warn(
"Importing BlockSparseLayout from torchao.dtypes is deprecated. "
"Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. "
"This import path will be removed in torchao v0.16.0.",
DeprecationWarning,
stacklevel=2,
)

logger = logging.getLogger(__name__)

aten = torch.ops.aten


@dataclass(frozen=True)
class BlockSparseLayout(Layout):
"""BlockSparseLayout is a data class that represents the layout of a block sparse matrix.
Attributes:
blocksize (int): The size of the blocks in the sparse matrix. Default is 64.
"""

blocksize: int = 64


@register_layout(BlockSparseLayout)
class BlockSparseAQTTensorImpl(PlainAQTTensorImpl):
bsr_crow_indices: Optional[torch.Tensor]
bsr_col_indices: Optional[torch.Tensor]
bsr_values: Optional[torch.Tensor]
scale: Optional[torch.Tensor]
zero_point: Optional[torch.Tensor]

__slots__ = [
"bsr_crow_indices",
"bsr_col_indices",
"bsr_values",
"scale",
"zero_point",
]

@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
_layout: Layout,
requires_grad: bool = False,
):
if bsr_values is None:
raise ValueError("bsr values must be provided!")
else:
previous_tensor = bsr_values

kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__( # noqa: PYI034
self,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
_layout: Layout,
requires_grad: bool = False,
):
self.bsr_crow_indices = bsr_crow_indices
self.bsr_col_indices = bsr_col_indices
self.bsr_values = bsr_values
self.scale = scale
self.zero_point = zero_point
self._layout = _layout

def __tensor_flatten__(self):
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (self.shape, self._layout, self.requires_grad)
return inner_tensors, tensor_meta

@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, _layout, requires_grad = tensor_meta
return cls(
shape=shape,
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
bsr_values=inner_tensors.get("bsr_values", None),
scale=inner_tensors.get("scale", None),
zero_point=inner_tensors.get("zero_point", None),
_layout=_layout,
requires_grad=requires_grad,
)

@classmethod
def from_plain(cls, int_data, scale, zero_point, _layout):
bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize)
return cls(
shape=int_data.shape,
bsr_crow_indices=bsr_tensor.crow_indices(),
bsr_col_indices=bsr_tensor.col_indices(),
bsr_values=bsr_tensor.values(),
scale=scale,
zero_point=zero_point,
_layout=_layout,
requires_grad=False,
)

def get_plain(self):
int_data_expanded = torch.ops.blocksparse.bsr_to_dense(
self.crow_indices(),
self.col_indices(),
self.values(),
self.shape[0],
self.shape[1],
)
return int_data_expanded, self.scale, self.zero_point

def _apply_fn_to_data(self, func):
return self.__class__(
shape=self.shape,
bsr_crow_indices=func(self.bsr_crow_indices),
bsr_col_indices=func(self.bsr_col_indices),
bsr_values=func(self.bsr_values),
scale=self.scale,
zero_point=self.zero_point,
_layout=self._layout,
requires_grad=self.requires_grad,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

# Need the following for bsr specific functions
if func is aten.crow_indices.default:
return args[0].bsr_crow_indices.detach()

if func is aten.col_indices.default:
return args[0].bsr_col_indices.detach()

if func is aten.values.default:
return args[0].bsr_values.detach()

if func is aten._nnz.default:
return args[0].bsr_values.shape[0]

raise NotImplementedError(
f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)


def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
and isinstance(weight_tensor, AffineQuantizedTensor)
and weight_tensor.is_cuda
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, BlockSparseLayout)
)


def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias):
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals = weight_tensor.tensor_impl
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
tmp_t = tmp.t()

y = torch.ops.blocksparse.int_addmm(
w_vals.crow_indices(),
w_vals.col_indices(),
w_vals.values(),
tmp_t,
w_scales,
x_scales.reshape(-1),
)
y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
y = y.reshape(*y_shape)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
from torchao.prototype.dtypes.uintx.block_sparse_layout import (
BlockSparseAQTTensorImpl, # noqa: F401
BlockSparseLayout, # noqa: F401
_linear_int8_act_int8_weight_block_sparse_check, # noqa: F401
_linear_int8_act_int8_weight_block_sparse_impl, # noqa: F401
)
Loading
Loading