Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/regression_test_aarch64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py
pytest -s test/prototype/test_embedding.py
pytest -s test/prototype/test_int8_lut_tensor.py
pytest -s test/prototype/test_tensor_conversion.py
pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py
pytest -s test/prototype/test_parq.py
- name: torchao/csrc/cpu - build and run C++ tests
Expand Down
180 changes: 180 additions & 0 deletions test/prototype/test_tensor_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


import pytest
import torch

from torchao.prototype.parq.quant import (
StretchedIntxWeightConfig,
StretchedUnifTorchaoQuantizer,
)
from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
from torchao.quantization import MappingType
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
)
from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import (
IntxOpaqueTensor,
_is_kernel_library_loaded,
)
from torchao.quantization.utils import compute_error


class ToyLinearModelWithTiedEmbedding(torch.nn.Module):
def __init__(self, d0=512, d1=512, d2=256, d3=128, d4=32):
super().__init__()
self.embedding1 = torch.nn.Embedding(d0, d1)
self.embedding2 = torch.nn.Embedding(d0, d1)
self.embedding3 = torch.nn.Embedding(d0, d1)

self.linear1 = torch.nn.Linear(d1, d2, bias=False)
self.linear2 = torch.nn.Linear(d2, d3, bias=True)
self.linear3 = torch.nn.Linear(d3, d4, bias=False)
self.linear4 = torch.nn.Linear(d4, d1, bias=False)

self.lm_head1 = torch.nn.Linear(d1, d0, bias=False)
self.lm_head2 = torch.nn.Linear(d1, d0, bias=False)
self.lm_head3 = torch.nn.Linear(d1, d0, bias=False)

# Tie weights
# lm_head1 / lm_head2 form one tied weight group
self.embedding2.weight = self.embedding1.weight
self.lm_head1.weight = self.embedding1.weight
self.lm_head2.weight = self.embedding1.weight

# lm_head3 forms a separate tied weight group
self.lm_head3.weight = self.embedding3.weight

def example_inputs(
self,
lead_dim=(1,),
dtype=torch.bfloat16,
):
return (
torch.randint(
0,
self.embedding1.num_embeddings,
size=lead_dim,
dtype=torch.int64,
device="cpu",
),
)

def forward(self, x):
x = self.embedding1(x) + self.embedding2(x) + self.embedding3(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.lm_head1(x) + self.lm_head2(x) + self.lm_head3(x)
return x


@pytest.fixture(autouse=True)
def run_before_and_after_tests():
yield
torch._dynamo.reset() # reset cache between tests


@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)])
@pytest.mark.parametrize("bit_width", [1, 2, 3, 4])
@pytest.mark.parametrize(
"lead_dim",
[
(1,),
(5,),
(7, 2),
],
)
@pytest.mark.skipif(
not _is_kernel_library_loaded(), reason="Kernel library is not loaded"
)
def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim):
torch.manual_seed(0)

model = ToyLinearModelWithTiedEmbedding()
model = model.to(dtype)
example_inputs = model.example_inputs(lead_dim, dtype)

# Quantize linear 2 and 3 with PARQ
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
config = StretchedIntxWeightConfig(
b=bit_width,
quant_min=quantizer.quant_min,
quant_max=quantizer.quant_max,
granularity=granularity,
activation_quantization="int8_asym_per_token",
)
quantize_(model, config, filter_fn=lambda m, fqn: fqn in ["linear2", "linear3"])

# Quantize linear 1 and 4 with int8 dynamic activation
config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=granularity,
weight_mapping_type=MappingType.SYMMETRIC,
)
quantize_(
model,
config,
filter_fn=lambda m, fqn: fqn
in ["linear1", "linear4", "lm_head1", "lm_head2", "lm_head3"],
)

# Quantize embedding 1, 2, and 3 with weight only
config = IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=granularity,
mapping_type=MappingType.SYMMETRIC,
)
quantize_(
model,
config,
filter_fn=lambda m, fqn: fqn in ["embedding1", "embedding2", "embedding3"],
)
model_out = model(*example_inputs)

# Convert to optimized model
_convert_model_for_aarch64(model)

# Check expected tensor subclass
assert isinstance(model.linear2.weight, Int8LutTensor)
assert isinstance(model.linear3.weight, Int8LutTensor)
assert isinstance(model.linear1.weight, IntxOpaqueTensor)
assert isinstance(model.linear4.weight, IntxOpaqueTensor)

# Assert tied params
tied_group1_id = id(model.embedding1.weight)
assert id(model.embedding2.weight) == tied_group1_id
assert id(model.lm_head1.weight) == tied_group1_id
assert id(model.lm_head2.weight) == tied_group1_id

assert id(model.lm_head3.weight) == id(model.embedding3.weight)
assert id(model.lm_head3.weight) != tied_group1_id

# Compare converted out with original out
converted_out = model(*example_inputs)
sqnr = compute_error(model_out, converted_out)
sqnr_threshold = 30
assert sqnr > sqnr_threshold, f"sqnr: {sqnr}"

# Check exported graph for correct ops
ep = torch.export.export(model, example_inputs)
expected_counts = {
"torch.ops.torchao._shared_embedding_": 3,
"torch.ops.torchao._linear_8bit_act_": 7,
"torch.ops.aten.linear.default": 0,
"torch.ops.aten.embedding.default": 0,
}
for line, cnt in expected_counts.items():
assert ep.graph_module.code.count(line) == cnt, (
f"expected {cnt} {line} in {ep.graph_module.code}"
)
120 changes: 115 additions & 5 deletions torchao/prototype/tensor_conversion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
import torch.nn as nn

from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor


def _convert_linear_weight_to_int8_lut_tensor(module):
from torchao.prototype.quantization.int8_lut_tensor import Int8LutTensor
Expand All @@ -20,17 +22,116 @@ def _convert_linear_weight_to_int8_lut_tensor(module):
module.bias = None


def _convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format):
from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import (
IntxOpaqueTensor,
)

assert isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)
weight = module.weight
new_weight = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor(
weight,
bias=module.bias if hasattr(module, "bias") else None,
intx_packing_format=intx_packing_format,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
if hasattr(module, "bias"):
module.bias = None


def _find_tied_module_names_for_embedding(embedding_weight, model):
assert isinstance(embedding_weight, IntxUnpackedToInt8Tensor)
tied_names = []
for name, module in model.named_modules():
is_linear = isinstance(module, nn.Linear)
is_embedding = isinstance(module, nn.Embedding)
if not (is_linear or is_embedding):
continue

weight = module.weight
if not isinstance(weight, IntxUnpackedToInt8Tensor):
continue

# We only have tied kernels for dynamically quantized linears
if is_linear and weight.activation_quantization != "int8_asym_per_token":
continue

# We only have tied kernels for linear layers with no bias
if is_linear and module.bias is not None:
continue

are_tied = (
(embedding_weight.shape == weight.shape)
and (embedding_weight.block_size == weight.block_size)
and (embedding_weight.dtype == weight.dtype)
and (embedding_weight.qdata == weight.qdata).all()
and (embedding_weight.scale == weight.scale).all()
and (embedding_weight.zero_point == weight.zero_point).all()
)

if are_tied:
tied_names.append(name)

return tied_names


def _find_tied_params(model):
from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import (
IntxOpaqueTensor,
)

module_name_to_tied_param = {}
for name, module in model.named_modules():
if not isinstance(module, nn.Embedding):
continue

weight = module.weight
if not isinstance(weight, IntxUnpackedToInt8Tensor):
continue

tied_module_names = _find_tied_module_names_for_embedding(weight, model)
if not tied_module_names:
continue

if name in module_name_to_tied_param:
tied_param = module_name_to_tied_param[name]
else:
# Construct a new tied param
# IntxOpaqueTensor requires activation_quantization = int8_asym_per_token
prev = weight.activation_quantization
weight.activation_quantization = "int8_asym_per_token"
tied_param = IntxOpaqueTensor.from_intx_unpacked_to_int8_tensor(
weight,
bias=None,
intx_packing_format="opaque_torchao_lowbit",
)
weight.activation_quantization = prev
tied_param = nn.Parameter(tied_param, requires_grad=False)
module_name_to_tied_param[name] = tied_param

for t in tied_module_names:
if t not in module_name_to_tied_param:
module_name_to_tied_param[t] = tied_param

return module_name_to_tied_param


def _convert_model_for_aarch64(
model,
*,
tensor_type="int8_lut_tensor",
model, *, tensor_type="auto", intx_packing_format="opaque_torchao_auto"
):
from torchao.quantization.quantize_.workflows import IntxUnpackedToInt8Tensor
module_name_to_tied_param = _find_tied_params(model)

# Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor
for name, module in model.named_modules():
if name in module_name_to_tied_param:
module.weight = module_name_to_tied_param[name]
continue

if isinstance(module, nn.Embedding):
print("Skipping converting nn.Embedding {name} because it is not tied")
continue

if not isinstance(module, nn.Linear):
print(f"Skipping converting {name} because it is not a linear layer")
continue

weight = module.weight
Expand All @@ -42,6 +143,15 @@ def _convert_model_for_aarch64(

if tensor_type == "int8_lut_tensor":
_convert_linear_weight_to_int8_lut_tensor(module)
elif tensor_type == "intx_opaque_tensor":
_convert_module_weight_to_intx_opaque_tensor(module, intx_packing_format)
elif tensor_type == "auto":
if weight._has_float_zero_point() and isinstance(module, nn.Linear):
_convert_linear_weight_to_int8_lut_tensor(module)
else:
_convert_module_weight_to_intx_opaque_tensor(
module, intx_packing_format
)
else:
raise ValueError(f"Unexpected tensor_type={tensor_type}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,35 @@ def _(func, types, args, kwargs):
return res


@implements([torch.nn.functional.embedding, aten.embedding.default])
def _(func, types, args, kwargs):
assert len(args) == 2
indices, weight_tensor = (
args[0],
args[1],
)
assert isinstance(weight_tensor, IntxOpaqueTensor)
assert weight_tensor.intx_packing_format == IntxPackingFormat.OPAQUE_TORCHAO_LOWBIT
packed_weights = weight_tensor.packed_weights

assert len(weight_tensor.block_size) == 2
assert weight_tensor.block_size[0] == 1
group_size = weight_tensor.block_size[1]

n, k = weight_tensor.shape
bit_width = weight_tensor.bit_width

shape = indices.shape
out = getattr(torch.ops.torchao, f"_shared_embedding_{bit_width}bit")(
packed_weights,
group_size,
n,
k,
indices.reshape(-1),
).reshape(*shape, -1)
return out


IntxOpaqueTensor.__module__ = "torchao.quantization"

torch.serialization.add_safe_globals([IntxOpaqueTensor])
Loading