diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 01710aa8d80..5b10a223866 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -37,6 +37,7 @@ from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs from .replace_inf_values import ReplaceInfValues +from .seq_mse import InsertSeqMse, RemoveSeqMse from .tag_quant_io import TagQuantIO @@ -65,6 +66,7 @@ I64toI32, InsertIOQDQ, InsertRequantize, + InsertSeqMse, LayoutTransform, LiftConstantScalarOperands, RecomposePixelUnshuffle, @@ -72,6 +74,7 @@ ReduceDynamicRange, Remove0DTensor, RemoveRedundancy, + RemoveSeqMse, ReplaceArangeArgs, ReplaceInfValues, TagQuantIO, diff --git a/backends/qualcomm/_passes/seq_mse.py b/backends/qualcomm/_passes/seq_mse.py new file mode 100644 index 00000000000..dc18e3a03e6 --- /dev/null +++ b/backends/qualcomm/_passes/seq_mse.py @@ -0,0 +1,203 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import types + +import torch +import torchao +from torchao.quantization.pt2e import PerChannelMinMaxObserver +from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import ( + PerBlockParamObserver, +) +from executorch.exir.pass_base import ExportPass, PassResult + +class SeqMseModule(torch.nn.Module): + """ + Args: + nominal_weight: Tensor + nominal parameters from operator + nominal_bias: Tensor + nominal parameters from operator + operator: fx.Node + operator to be executed + observer: UniformQuantizationObserverBase + parameter observer (specific for weight) + num_candidates: int + grids to search minimal mse loss + """ + + def __init__( + self, + nominal_weight, + nominal_bias, + operator, + observer, + num_candidates, + ): + super().__init__() + self.nominal_weight = nominal_weight + self.nominal_bias = nominal_bias + self.observer = observer + step = 1 / num_candidates + self.steps = torch.arange(start=step, end=1+step, step=step).tolist() + self.operator = self._make_operator(operator) + self.best_candidate = 1.0 + + def _make_operator(self, aten_op): + if aten_op.target == torch.ops.aten.conv2d.default: + stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3] + padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4] + dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5] + groups = 1 if len(aten_op.args) < 7 else aten_op.args[6] + has_bias = self.nominal_bias is not None + module = torch.nn.Conv2d( + in_channels=self.nominal_weight.shape[1], + out_channels=self.nominal_weight.shape[0], + kernel_size=self.nominal_weight.shape[-2:], + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=has_bias + ) + module.weight.data = self.nominal_weight + if has_bias: + module.bias.data = self.nominal_bias + return module + else: + raise NotImplementedError(f"target of {aten_op.target} is not implemented") + + def _per_block_qdq(self, scale, zero_point): + return torchao.quantization.quant_primitives._fake_quantize_affine( + input=self.nominal_weight, + block_size=self.observer.block_size, + scale=scale, + zero_point=zero_point, + quant_dtype=self.observer.dtype, + quant_min=self.observer.quant_min, + quant_max=self.observer.quant_max, + ) + + def _per_channel_qdq(self, scale, zero_point): + return torch.fake_quantize_per_channel_affine( + input=self.nominal_weight, + scale=scale, + zero_point=zero_point, + axis=0, + quant_min=self.observer.quant_min, + quant_max=self.observer.quant_max, + ) + + def _fake_quant(self, scale, zero_point): + dispatcher = { + PerChannelMinMaxObserver: self._per_channel_qdq, + PerBlockParamObserver: self._per_block_qdq, + } + return dispatcher[type(self.observer)](scale, zero_point) + + def _find_best_candidate(self, nominal_input, nominal_output): + # calculate current baseline + scale, zero_point = self.observer.calculate_qparams() + zero_point = zero_point.to(torch.int32) + self.operator.weight.data = self._fake_quant(scale, zero_point) + candidate, current_loss = 1, torch.nn.functional.mse_loss( + self.operator(nominal_input), nominal_output + ).item() + for step in self.steps: + self.operator.weight.data = self._fake_quant(scale * step, zero_point) + loss = torch.nn.functional.mse_loss( + self.operator(nominal_input), nominal_output + ).item() + if loss < current_loss: + candidate, current_loss = step, loss + return candidate + + def forward(self, nominal_input, nominal_output): + self.best_candidate = self._find_best_candidate( + nominal_input=nominal_input, nominal_output=nominal_output + ) + + +class InsertSeqMse(ExportPass): + """ + Insert Seq Mse Observer to find the best quant config for certain node's weight. + """ + + seq_mse_ops = { + torch.ops.aten.conv2d.default + } + + def __init__(self, num_candidates=1000): + super(InsertSeqMse, self).__init__() + self.num_candidates = num_candidates + + def _insert_seq_mse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + count = 0 + for node in graph_module.graph.nodes: + if node.target in self.seq_mse_ops: + # extract observer + weight_node_obs = node.args[1] + observer = getattr(graph_module, weight_node_obs.name) + # extract parameters + weight_node = weight_node_obs.args[0] + weight_tensor = graph_module.get_parameter( + weight_node.target + ).detach() + bias_tensor = None + if len(node.args) > 2 and node.args[2] is not None: + bias_tensor = graph_module.get_parameter( + node.args[2].args[0].target + ).detach() + + with graph_module.graph.inserting_after(node): + seq_mse_mod = SeqMseModule( + nominal_weight=weight_tensor, + nominal_bias=bias_tensor, + operator=node, + observer=observer, + num_candidates=self.num_candidates, + ) + module_name = f"seq_mse_{count}" + count += 1 + setattr(graph_module, module_name, seq_mse_mod) + input_nodes = (node.args[0], node) + graph_module.graph.create_node("call_module", module_name, input_nodes, {}) + + def call(self, graph_module: torch.fx.GraphModule): + self._insert_seq_mse(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) + + +class RemoveSeqMse(ExportPass): + """ + Remove Seq Mse before invoking convert_pt2e and update final quantization encoding. + """ + def __init__(self): + super(RemoveSeqMse, self).__init__() + + def _remove_seq_mse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + node_to_erase = [] + for node in graph_module.graph.nodes: + if node.op == "call_module" and "seq_mse" in node.name: + # extract SeqMse module + module = getattr(graph_module, node.target) + # rewrite observer method for pre-calculated scale + scale, zero_point = module.observer.calculate_qparams() + module.observer.updated_encoding = ( + scale * module.best_candidate, zero_point + ) + module.observer.calculate_qparams = ( + types.MethodType(lambda s: s.updated_encoding, module.observer) + ) + node_to_erase.append(node) + + for node in node_to_erase: + graph_module.graph.erase_node(node) + + def call(self, graph_module: torch.fx.GraphModule): + self._remove_seq_mse(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 057d3ea93d2..e227375d3bf 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -285,9 +285,9 @@ def annotate_matmul_input1(node: Node): quantization_config_8a8w = get_8a8w_qnn_ptq_config( act_symmetric=True, act_observer=MinMaxObserver ) - quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, + quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( + act_dtype=torch.uint16, + weight_dtype=torch.int8, act_observer=MinMaxObserver, act_symmetric=True, ) @@ -318,7 +318,7 @@ def annotate_matmul_input1(node: Node): node = node.args[0][1] elif node.target == torch.ops.aten.conv2d.default: annotate_conv2d( - node, quantization_config=quantization_config_8a4w_per_channel + node, quantization_config=quantization_config_16a8w_per_channel ) break elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]: diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index 7d605b12cf8..8106cff8ff3 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -34,6 +34,7 @@ def __init__( eps=eps, **kwargs, ) + self.dtype = dtype self.block_size = block_size self.calibrated = False diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index e14d73f521d..8b74aae8903 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -3,12 +3,14 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from contextlib import contextmanager from dataclasses import dataclass from enum import IntEnum, unique from functools import partial from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import torch +from executorch.backends.qualcomm._passes import InsertSeqMse, RemoveSeqMse from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload @@ -427,3 +429,12 @@ def predicate(node): return False return predicate + + +@contextmanager +def qnn_ptq_manager(prepared_gm): + prepared_gm = InsertSeqMse()(prepared_gm).graph_module + try: + yield + finally: + prepared_gm = RemoveSeqMse()(prepared_gm).graph_module diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index b697e81f2d1..f8fbaad00be 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3981,6 +3981,33 @@ def test_qnn_backend_generate_optrace(self): qhas_data = json.load(qhas_file) self.assertIn("data", qhas_data) + def test_qnn_backend_seq_mse(self): + from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager + + o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0 + module = Conv2dSingle( # noqa: F405 + in_channel=i_ch, + out_channel=o_ch, + kernel_size=kernel, + padding=padding, + ) + sample_input = (torch.randn(1, i_ch, 1, o_ch),) + # per-channel / per-block + quantizers = [ + make_quantizer(), + make_quantizer(quant_dtype=QuantDtype.use_16a4w_block), + ] + quantizers[-1].set_block_size_map({"conv2d": (1, 32, 1, 1)}) + + for i, quantizer in enumerate(quantizers): + with self.subTest(i=i): + ep = torch.export.export(module, sample_input).module() + prepared = prepare_pt2e(ep, quantizer) + with qnn_ptq_manager(prepared): + prepared(*sample_input) + converted = convert_pt2e(prepared) + self.lower_module_and_test_output(converted, sample_input) + class TestExampleLLMScript(TestQNN): def test_llama3_2_1b(self): diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index db533986119..05e593f501a 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -142,6 +142,9 @@ def _kv_calibrate( updater=smart_mask_updater, use_i64_token=False, ): + from contextlib import nullcontext + from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager + _, atten_mask, _, k_caches, v_caches = example_inputs # TODO: change criteria & support batch inputs if necessary @@ -191,13 +194,14 @@ def _kv_calibrate( dim=-1, ) - logits, new_k_caches, new_v_caches = module( - tmp_token_list, - tmp_atten_mask, - tmp_pos, - *k_caches, - *v_caches, - ) + with qnn_ptq_manager(module) if pos == max_seq_len-1 else nullcontext(): + logits, new_k_caches, new_v_caches = module( + tmp_token_list, + tmp_atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) atten_mask, pos, k_caches, v_caches = updater( ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches ) @@ -647,10 +651,6 @@ def permute(w, heads): if args.ptq: start_quantize_ts = time.time() custom_annotations = (annotate_matmul_16a8w,) - if args.llama_model == "stories110m": - custom_annotations = custom_annotations + ( - annotate_linear_16a8w_in_affine_layer, - ) kv_quant_attrs = {} for i, llama_instance in enumerate(llama_instance_list): llama_instance.quantize( diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 1a2d9e4f26b..b1c1471b26b 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -374,6 +374,7 @@ def build_executorch_binary( online_prepare=False, optrace=False, op_package_options: QnnExecuTorchOpPackageOptions = None, + use_seq_mse=False, ): """ A function to generate an ExecuTorch binary for Qualcomm platforms. @@ -397,10 +398,14 @@ def build_executorch_binary( optrace (bool, optional): Enable optrace mode for performance analysis if set to True. op_package_options: Optional structure to specify op packages loaded and used by the backend. + use_seq_mse (bool, optional): Optional flag to minimize mse error of activation range Returns: None: The function writes the output to a specified .pte file. """ + from contextlib import nullcontext + from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager + backend_options = generate_htp_compiler_spec( use_fp16=False if quant_dtype else True ) @@ -426,7 +431,8 @@ def build_executorch_binary( else: quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) # ptq calibration - annotated_model = ptq_calibrate(captured_model, quantizer, dataset) + with qnn_ptq_manager(captured_model) if use_seq_mse else nullcontext(): + annotated_model = ptq_calibrate(captured_model, quantizer, dataset) quantized_model = convert_pt2e(annotated_model) edge_prog_mgr = to_edge_transform_and_lower_to_qnn(