Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
QuantizeOperatorArguments,
RetraceFoldedDtypesPass,
)
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
Expand Down
3 changes: 0 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
RemoveNoopPass,
ReplaceInfValues,
ReplaceScalarWithTensorByProfilePass,
RetraceFoldedDtypesPass,
RewriteConv2dPass,
RewriteMatmulPass,
RewriteUpsamplePass,
Expand Down Expand Up @@ -176,7 +175,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(QuantizeOperatorArguments())
self.add_pass(ConvertELUParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
if self.tosa_spec.is_U55_subset:
Expand Down Expand Up @@ -271,7 +269,6 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
Expand Down
45 changes: 11 additions & 34 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from executorch.backends.arm._passes.arm_pass_utils import (
get_param_tensor,
is_param_node,
set_node_arg,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass

Expand All @@ -22,7 +23,6 @@
from executorch.exir import ExportedProgram

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node
Expand Down Expand Up @@ -66,38 +66,6 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
return output_qparams


class RetraceFoldedDtypesPass(ArmPass):
"""
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
some operators are retraced to types that cannot be handled by TOSA. One
such example is sum.dim_IntList:
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
After folding it becomes:
q (int8) -> sum (int64) -> ...
This pass changes types of ops in self.targeted_ops, such as sum, so that
the output type of that matches the type of the output_qparams.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.sum.dim_IntList,
}

def call_operator(self, op, args, kwargs, meta):
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta, False)

node_kwargs = kwargs.copy()
output_qparams = meta["output_qparams"]
if len(output_qparams) == 0:
return super().call_operator(op, args, kwargs, meta, False)

output_dtype = output_qparams[0].dtype
node_kwargs["dtype"] = output_dtype
return super().call_operator(op, args, node_kwargs, meta, True)


class FoldAndAnnotateQParamsPass(ArmPass):
"""
A pass that walks the graph and removes any DQ and Q nodes before and after the target
Expand Down Expand Up @@ -129,7 +97,6 @@ class FoldAndAnnotateQParamsPass(ArmPass):
"""

_passes_required_after: Set[Type[ExportPass]] = {
RetraceFoldedDtypesPass,
InsertTableOpsPass,
RemoveNoopPass,
}
Expand Down Expand Up @@ -234,6 +201,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
user.replace_all_uses_with(n)
graph_module.graph.erase_node(user)

# Some op(s) contain a "dtype" key in their node kwargs. Set this
# to the type of output qparams.
output_qparams = n.meta["output_qparams"]
if (
n.target in {exir_ops.edge.aten.sum.dim_IntList}
and len(output_qparams) > 0
):
output_dtype = output_qparams[0].dtype
set_node_arg(n, "dtype", output_dtype)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

Expand Down
Loading