Skip to content

Commit 4675292

Browse files
authored
Arm backend: Merge RetraceFoldedDtypesPass into FoldAndAnnotateQParam… (#15377)
The pass RetraceFoldedDtypesPass carries out extra processing after the output of FoldAndAnnotateQParamsPass, meaning that they are tightly coupled and always run in sequence. Merge these two passes together into FoldAndAnnotateQParamsPass. ### Test plan No behavior is modified. Tests already in place are sufficient. Signed-off-by: Martin Lindström <[email protected]>
1 parent a4e7475 commit 4675292

File tree

3 files changed

+11
-38
lines changed

3 files changed

+11
-38
lines changed

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
7373
FoldAndAnnotateQParamsPass,
7474
QuantizeOperatorArguments,
75-
RetraceFoldedDtypesPass,
7675
)
7776
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
7877
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888
RemoveNoopPass,
8989
ReplaceInfValues,
9090
ReplaceScalarWithTensorByProfilePass,
91-
RetraceFoldedDtypesPass,
9291
RewriteConv2dPass,
9392
RewriteMatmulPass,
9493
RewriteUpsamplePass,
@@ -176,7 +175,6 @@ def _tosa_INT_pipeline(
176175
self.add_pass(QuantizeOperatorArguments())
177176
self.add_pass(ConvertELUParamsPass())
178177
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
179-
self.add_pass(RetraceFoldedDtypesPass())
180178
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
181179
self.add_pass(MatchArgRanksPass(exported_program))
182180
if self.tosa_spec.is_U55_subset:
@@ -271,7 +269,6 @@ def _tosa_FP_pipeline(
271269
self.add_pass(AnnotateDecomposedMatmulPass())
272270
self.add_pass(QuantizeOperatorArguments())
273271
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
274-
self.add_pass(RetraceFoldedDtypesPass())
275272
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
276273
self.add_pass(MatchArgRanksPass(exported_program))
277274
self.add_pass(DecomposeAdaptiveAvgPool2dPass())

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.arm_pass_utils import (
1414
get_param_tensor,
1515
is_param_node,
16+
set_node_arg,
1617
)
1718
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
1819

@@ -22,7 +23,6 @@
2223
from executorch.exir import ExportedProgram
2324

2425
from executorch.exir.dialects._ops import ops as exir_ops
25-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2626

2727
from executorch.exir.pass_base import ExportPass, PassResult
2828
from torch.fx import GraphModule, Node
@@ -66,38 +66,6 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
6666
return output_qparams
6767

6868

69-
class RetraceFoldedDtypesPass(ArmPass):
70-
"""
71-
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
72-
some operators are retraced to types that cannot be handled by TOSA. One
73-
such example is sum.dim_IntList:
74-
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
75-
After folding it becomes:
76-
q (int8) -> sum (int64) -> ...
77-
This pass changes types of ops in self.targeted_ops, such as sum, so that
78-
the output type of that matches the type of the output_qparams.
79-
"""
80-
81-
_passes_required_after: Set[Type[ExportPass]] = set()
82-
83-
targeted_ops: Set[EdgeOpOverload] = {
84-
exir_ops.edge.aten.sum.dim_IntList,
85-
}
86-
87-
def call_operator(self, op, args, kwargs, meta):
88-
if op not in self.targeted_ops:
89-
return super().call_operator(op, args, kwargs, meta, False)
90-
91-
node_kwargs = kwargs.copy()
92-
output_qparams = meta["output_qparams"]
93-
if len(output_qparams) == 0:
94-
return super().call_operator(op, args, kwargs, meta, False)
95-
96-
output_dtype = output_qparams[0].dtype
97-
node_kwargs["dtype"] = output_dtype
98-
return super().call_operator(op, args, node_kwargs, meta, True)
99-
100-
10169
class FoldAndAnnotateQParamsPass(ArmPass):
10270
"""
10371
A pass that walks the graph and removes any DQ and Q nodes before and after the target
@@ -129,7 +97,6 @@ class FoldAndAnnotateQParamsPass(ArmPass):
12997
"""
13098

13199
_passes_required_after: Set[Type[ExportPass]] = {
132-
RetraceFoldedDtypesPass,
133100
InsertTableOpsPass,
134101
RemoveNoopPass,
135102
}
@@ -234,6 +201,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
234201
user.replace_all_uses_with(n)
235202
graph_module.graph.erase_node(user)
236203

204+
# Some op(s) contain a "dtype" key in their node kwargs. Set this
205+
# to the type of output qparams.
206+
output_qparams = n.meta["output_qparams"]
207+
if (
208+
n.target in {exir_ops.edge.aten.sum.dim_IntList}
209+
and len(output_qparams) > 0
210+
):
211+
output_dtype = output_qparams[0].dtype
212+
set_node_arg(n, "dtype", output_dtype)
213+
237214
# retrace the graph to update the fake tensor types
238215
graph_module = super().call(graph_module).graph_module
239216

0 commit comments

Comments
 (0)