Skip to content

Commit 4da1550

Browse files
committed
Revert "Arm backend: Merge RetraceFoldedDtypesPass into FoldAndAnnotateQParam… (pytorch#15377)"
This reverts commit 4675292.
1 parent cd6f2e2 commit 4da1550

File tree

3 files changed

+38
-11
lines changed

3 files changed

+38
-11
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
7373
FoldAndAnnotateQParamsPass,
7474
QuantizeOperatorArguments,
75+
RetraceFoldedDtypesPass,
7576
)
7677
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
7778
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
RemoveNoopPass,
8989
ReplaceInfValues,
9090
ReplaceScalarWithTensorByProfilePass,
91+
RetraceFoldedDtypesPass,
9192
RewriteConv2dPass,
9293
RewriteMatmulPass,
9394
RewriteUpsamplePass,
@@ -175,6 +176,7 @@ def _tosa_INT_pipeline(
175176
self.add_pass(QuantizeOperatorArguments())
176177
self.add_pass(ConvertELUParamsPass())
177178
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
179+
self.add_pass(RetraceFoldedDtypesPass())
178180
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
179181
self.add_pass(MatchArgRanksPass(exported_program))
180182
if self.tosa_spec.is_U55_subset:
@@ -269,6 +271,7 @@ def _tosa_FP_pipeline(
269271
self.add_pass(AnnotateDecomposedMatmulPass())
270272
self.add_pass(QuantizeOperatorArguments())
271273
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
274+
self.add_pass(RetraceFoldedDtypesPass())
272275
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
273276
self.add_pass(MatchArgRanksPass(exported_program))
274277
self.add_pass(DecomposeAdaptiveAvgPool2dPass())

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from executorch.backends.arm._passes.arm_pass_utils import (
1414
get_param_tensor,
1515
is_param_node,
16-
set_node_arg,
1716
)
1817
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
1918

@@ -23,6 +22,7 @@
2322
from executorch.exir import ExportedProgram
2423

2524
from executorch.exir.dialects._ops import ops as exir_ops
25+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2626

2727
from executorch.exir.pass_base import ExportPass, PassResult
2828
from torch.fx import GraphModule, Node
@@ -66,6 +66,38 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
6666
return output_qparams
6767

6868

69+
class RetraceFoldedDtypesPass(ArmPass):
70+
"""
71+
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
72+
some operators are retraced to types that cannot be handled by TOSA. One
73+
such example is sum.dim_IntList:
74+
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
75+
After folding it becomes:
76+
q (int8) -> sum (int64) -> ...
77+
This pass changes types of ops in self.targeted_ops, such as sum, so that
78+
the output type of that matches the type of the output_qparams.
79+
"""
80+
81+
_passes_required_after: Set[Type[ExportPass]] = set()
82+
83+
targeted_ops: Set[EdgeOpOverload] = {
84+
exir_ops.edge.aten.sum.dim_IntList,
85+
}
86+
87+
def call_operator(self, op, args, kwargs, meta):
88+
if op not in self.targeted_ops:
89+
return super().call_operator(op, args, kwargs, meta, False)
90+
91+
node_kwargs = kwargs.copy()
92+
output_qparams = meta["output_qparams"]
93+
if len(output_qparams) == 0:
94+
return super().call_operator(op, args, kwargs, meta, False)
95+
96+
output_dtype = output_qparams[0].dtype
97+
node_kwargs["dtype"] = output_dtype
98+
return super().call_operator(op, args, node_kwargs, meta, True)
99+
100+
69101
class FoldAndAnnotateQParamsPass(ArmPass):
70102
"""
71103
A pass that walks the graph and removes any DQ and Q nodes before and after the target
@@ -97,6 +129,7 @@ class FoldAndAnnotateQParamsPass(ArmPass):
97129
"""
98130

99131
_passes_required_after: Set[Type[ExportPass]] = {
132+
RetraceFoldedDtypesPass,
100133
InsertTableOpsPass,
101134
RemoveNoopPass,
102135
}
@@ -201,16 +234,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
201234
user.replace_all_uses_with(n)
202235
graph_module.graph.erase_node(user)
203236

204-
# Some op(s) contain a "dtype" key in their node kwargs. Set this
205-
# to the type of output qparams.
206-
output_qparams = n.meta["output_qparams"]
207-
if (
208-
n.target in {exir_ops.edge.aten.sum.dim_IntList}
209-
and len(output_qparams) > 0
210-
):
211-
output_dtype = output_qparams[0].dtype
212-
set_node_arg(n, "dtype", output_dtype)
213-
214237
# retrace the graph to update the fake tensor types
215238
graph_module = super().call(graph_module).graph_module
216239

0 commit comments

Comments
 (0)