Skip to content

QC SeqMSE Draft #12700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_inf_values import ReplaceInfValues
from .seq_mse import InsertSeqMse, RemoveSeqMse
from .tag_quant_io import TagQuantIO


Expand Down Expand Up @@ -65,13 +66,15 @@
I64toI32,
InsertIOQDQ,
InsertRequantize,
InsertSeqMse,
LayoutTransform,
LiftConstantScalarOperands,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ReduceDynamicRange,
Remove0DTensor,
RemoveRedundancy,
RemoveSeqMse,
ReplaceArangeArgs,
ReplaceInfValues,
TagQuantIO,
Expand Down
203 changes: 203 additions & 0 deletions backends/qualcomm/_passes/seq_mse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import types

import torch
import torchao
from torchao.quantization.pt2e import PerChannelMinMaxObserver
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
PerBlockParamObserver,
)
from executorch.exir.pass_base import ExportPass, PassResult

class SeqMseModule(torch.nn.Module):
"""
Args:
nominal_weight: Tensor
nominal parameters from operator
nominal_bias: Tensor
nominal parameters from operator
operator: fx.Node
operator to be executed
observer: UniformQuantizationObserverBase
parameter observer (specific for weight)
num_candidates: int
grids to search minimal mse loss
"""

def __init__(
self,
nominal_weight,
nominal_bias,
operator,
observer,
num_candidates,
):
super().__init__()
self.nominal_weight = nominal_weight
self.nominal_bias = nominal_bias
self.observer = observer
step = 1 / num_candidates
self.steps = torch.arange(start=step, end=1+step, step=step).tolist()
self.operator = self._make_operator(operator)
self.best_candidate = 1.0

def _make_operator(self, aten_op):
if aten_op.target == torch.ops.aten.conv2d.default:
stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3]
padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4]
dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5]
groups = 1 if len(aten_op.args) < 7 else aten_op.args[6]
has_bias = self.nominal_bias is not None
module = torch.nn.Conv2d(
in_channels=self.nominal_weight.shape[1],
out_channels=self.nominal_weight.shape[0],
kernel_size=self.nominal_weight.shape[-2:],
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=has_bias
)
module.weight.data = self.nominal_weight
if has_bias:
module.bias.data = self.nominal_bias
return module
else:
raise NotImplementedError(f"target of {aten_op.target} is not implemented")

def _per_block_qdq(self, scale, zero_point):
return torchao.quantization.quant_primitives._fake_quantize_affine(
input=self.nominal_weight,
block_size=self.observer.block_size,
scale=scale,
zero_point=zero_point,
quant_dtype=self.observer.dtype,
quant_min=self.observer.quant_min,
quant_max=self.observer.quant_max,
)

def _per_channel_qdq(self, scale, zero_point):
return torch.fake_quantize_per_channel_affine(
input=self.nominal_weight,
scale=scale,
zero_point=zero_point,
axis=0,
quant_min=self.observer.quant_min,
quant_max=self.observer.quant_max,
)

def _fake_quant(self, scale, zero_point):
dispatcher = {
PerChannelMinMaxObserver: self._per_channel_qdq,
PerBlockParamObserver: self._per_block_qdq,
}
return dispatcher[type(self.observer)](scale, zero_point)

def _find_best_candidate(self, nominal_input, nominal_output):
# calculate current baseline
scale, zero_point = self.observer.calculate_qparams()
zero_point = zero_point.to(torch.int32)
self.operator.weight.data = self._fake_quant(scale, zero_point)
candidate, current_loss = 1, torch.nn.functional.mse_loss(
self.operator(nominal_input), nominal_output
).item()
for step in self.steps:
self.operator.weight.data = self._fake_quant(scale * step, zero_point)
loss = torch.nn.functional.mse_loss(
self.operator(nominal_input), nominal_output
).item()
if loss < current_loss:
candidate, current_loss = step, loss
return candidate

def forward(self, nominal_input, nominal_output):
self.best_candidate = self._find_best_candidate(
nominal_input=nominal_input, nominal_output=nominal_output
)


class InsertSeqMse(ExportPass):
"""
Insert Seq Mse Observer to find the best quant config for certain node's weight.
"""

seq_mse_ops = {
torch.ops.aten.conv2d.default
}

def __init__(self, num_candidates=1000):
super(InsertSeqMse, self).__init__()
self.num_candidates = num_candidates

def _insert_seq_mse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
count = 0
for node in graph_module.graph.nodes:
if node.target in self.seq_mse_ops:
# extract observer
weight_node_obs = node.args[1]
observer = getattr(graph_module, weight_node_obs.name)
# extract parameters
weight_node = weight_node_obs.args[0]
weight_tensor = graph_module.get_parameter(
weight_node.target
).detach()
bias_tensor = None
if len(node.args) > 2 and node.args[2] is not None:
bias_tensor = graph_module.get_parameter(
node.args[2].args[0].target
).detach()

with graph_module.graph.inserting_after(node):
seq_mse_mod = SeqMseModule(
nominal_weight=weight_tensor,
nominal_bias=bias_tensor,
operator=node,
observer=observer,
num_candidates=self.num_candidates,
)
module_name = f"seq_mse_{count}"
count += 1
setattr(graph_module, module_name, seq_mse_mod)
input_nodes = (node.args[0], node)
graph_module.graph.create_node("call_module", module_name, input_nodes, {})

def call(self, graph_module: torch.fx.GraphModule):
self._insert_seq_mse(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)


class RemoveSeqMse(ExportPass):
"""
Remove Seq Mse before invoking convert_pt2e and update final quantization encoding.
"""
def __init__(self):
super(RemoveSeqMse, self).__init__()

def _remove_seq_mse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
node_to_erase = []
for node in graph_module.graph.nodes:
if node.op == "call_module" and "seq_mse" in node.name:
# extract SeqMse module
module = getattr(graph_module, node.target)
# rewrite observer method for pre-calculated scale
scale, zero_point = module.observer.calculate_qparams()
module.observer.updated_encoding = (
scale * module.best_candidate, zero_point
)
module.observer.calculate_qparams = (
types.MethodType(lambda s: s.updated_encoding, module.observer)
)
node_to_erase.append(node)

for node in node_to_erase:
graph_module.graph.erase_node(node)

def call(self, graph_module: torch.fx.GraphModule):
self._remove_seq_mse(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)
8 changes: 4 additions & 4 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def annotate_matmul_input1(node: Node):
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
act_symmetric=True, act_observer=MinMaxObserver
)
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
act_dtype=torch.uint8,
weight_dtype=torch.int4,
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
act_dtype=torch.uint16,
weight_dtype=torch.int8,
act_observer=MinMaxObserver,
act_symmetric=True,
)
Expand Down Expand Up @@ -318,7 +318,7 @@ def annotate_matmul_input1(node: Node):
node = node.args[0][1]
elif node.target == torch.ops.aten.conv2d.default:
annotate_conv2d(
node, quantization_config=quantization_config_8a4w_per_channel
node, quantization_config=quantization_config_16a8w_per_channel
)
break
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
eps=eps,
**kwargs,
)
self.dtype = dtype
self.block_size = block_size
self.calibrated = False

Expand Down
11 changes: 11 additions & 0 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from dataclasses import dataclass
from enum import IntEnum, unique
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple

import torch
from executorch.backends.qualcomm._passes import InsertSeqMse, RemoveSeqMse
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager

from torch._ops import OpOverload
Expand Down Expand Up @@ -427,3 +429,12 @@ def predicate(node):
return False

return predicate


@contextmanager
def qnn_ptq_manager(prepared_gm):
prepared_gm = InsertSeqMse()(prepared_gm).graph_module
try:
yield
finally:
prepared_gm = RemoveSeqMse()(prepared_gm).graph_module
27 changes: 27 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3981,6 +3981,33 @@ def test_qnn_backend_generate_optrace(self):
qhas_data = json.load(qhas_file)
self.assertIn("data", qhas_data)

def test_qnn_backend_seq_mse(self):
from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager

o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
module = Conv2dSingle( # noqa: F405
in_channel=i_ch,
out_channel=o_ch,
kernel_size=kernel,
padding=padding,
)
sample_input = (torch.randn(1, i_ch, 1, o_ch),)
# per-channel / per-block
quantizers = [
make_quantizer(),
make_quantizer(quant_dtype=QuantDtype.use_16a4w_block),
]
quantizers[-1].set_block_size_map({"conv2d": (1, 32, 1, 1)})

for i, quantizer in enumerate(quantizers):
with self.subTest(i=i):
ep = torch.export.export(module, sample_input).module()
prepared = prepare_pt2e(ep, quantizer)
with qnn_ptq_manager(prepared):
prepared(*sample_input)
converted = convert_pt2e(prepared)
self.lower_module_and_test_output(converted, sample_input)


class TestExampleLLMScript(TestQNN):
def test_llama3_2_1b(self):
Expand Down
22 changes: 11 additions & 11 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def _kv_calibrate(
updater=smart_mask_updater,
use_i64_token=False,
):
from contextlib import nullcontext
from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager

_, atten_mask, _, k_caches, v_caches = example_inputs

# TODO: change criteria & support batch inputs if necessary
Expand Down Expand Up @@ -191,13 +194,14 @@ def _kv_calibrate(
dim=-1,
)

logits, new_k_caches, new_v_caches = module(
tmp_token_list,
tmp_atten_mask,
tmp_pos,
*k_caches,
*v_caches,
)
with qnn_ptq_manager(module) if pos == max_seq_len-1 else nullcontext():
logits, new_k_caches, new_v_caches = module(
tmp_token_list,
tmp_atten_mask,
tmp_pos,
*k_caches,
*v_caches,
)
atten_mask, pos, k_caches, v_caches = updater(
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
)
Expand Down Expand Up @@ -647,10 +651,6 @@ def permute(w, heads):
if args.ptq:
start_quantize_ts = time.time()
custom_annotations = (annotate_matmul_16a8w,)
if args.llama_model == "stories110m":
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
)
kv_quant_attrs = {}
for i, llama_instance in enumerate(llama_instance_list):
llama_instance.quantize(
Expand Down
8 changes: 7 additions & 1 deletion examples/qualcomm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def build_executorch_binary(
online_prepare=False,
optrace=False,
op_package_options: QnnExecuTorchOpPackageOptions = None,
use_seq_mse=False,
):
"""
A function to generate an ExecuTorch binary for Qualcomm platforms.
Expand All @@ -397,10 +398,14 @@ def build_executorch_binary(
optrace (bool, optional): Enable optrace mode for performance analysis if set to True.
op_package_options: Optional structure to specify op packages
loaded and used by the backend.
use_seq_mse (bool, optional): Optional flag to minimize mse error of activation range

Returns:
None: The function writes the output to a specified .pte file.
"""
from contextlib import nullcontext
from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager

backend_options = generate_htp_compiler_spec(
use_fp16=False if quant_dtype else True
)
Expand All @@ -426,7 +431,8 @@ def build_executorch_binary(
else:
quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype)
# ptq calibration
annotated_model = ptq_calibrate(captured_model, quantizer, dataset)
with qnn_ptq_manager(captured_model) if use_seq_mse else nullcontext():
annotated_model = ptq_calibrate(captured_model, quantizer, dataset)

quantized_model = convert_pt2e(annotated_model)
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
Expand Down
Loading