|  | 
| 13 | 13 | from executorch.backends.arm._passes.arm_pass_utils import ( | 
| 14 | 14 |     get_param_tensor, | 
| 15 | 15 |     is_param_node, | 
|  | 16 | +    set_node_arg, | 
| 16 | 17 | ) | 
| 17 | 18 | from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass | 
| 18 | 19 | 
 | 
|  | 
| 22 | 23 | from executorch.exir import ExportedProgram | 
| 23 | 24 | 
 | 
| 24 | 25 | from executorch.exir.dialects._ops import ops as exir_ops | 
| 25 |  | -from executorch.exir.dialects.edge._ops import EdgeOpOverload | 
| 26 | 26 | 
 | 
| 27 | 27 | from executorch.exir.pass_base import ExportPass, PassResult | 
| 28 | 28 | from torch.fx import GraphModule, Node | 
| @@ -66,38 +66,6 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: | 
| 66 | 66 |     return output_qparams | 
| 67 | 67 | 
 | 
| 68 | 68 | 
 | 
| 69 |  | -class RetraceFoldedDtypesPass(ArmPass): | 
| 70 |  | -    """ | 
| 71 |  | -    FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced | 
| 72 |  | -    some operators are retraced to types that cannot be handled by TOSA. One | 
| 73 |  | -    such example is sum.dim_IntList: | 
| 74 |  | -        q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... | 
| 75 |  | -    After folding it becomes: | 
| 76 |  | -        q (int8)              -> sum (int64) ->         ... | 
| 77 |  | -    This pass changes types of ops in self.targeted_ops, such as sum, so that | 
| 78 |  | -    the output type of that matches the type of the output_qparams. | 
| 79 |  | -    """ | 
| 80 |  | - | 
| 81 |  | -    _passes_required_after: Set[Type[ExportPass]] = set() | 
| 82 |  | - | 
| 83 |  | -    targeted_ops: Set[EdgeOpOverload] = { | 
| 84 |  | -        exir_ops.edge.aten.sum.dim_IntList, | 
| 85 |  | -    } | 
| 86 |  | - | 
| 87 |  | -    def call_operator(self, op, args, kwargs, meta): | 
| 88 |  | -        if op not in self.targeted_ops: | 
| 89 |  | -            return super().call_operator(op, args, kwargs, meta, False) | 
| 90 |  | - | 
| 91 |  | -        node_kwargs = kwargs.copy() | 
| 92 |  | -        output_qparams = meta["output_qparams"] | 
| 93 |  | -        if len(output_qparams) == 0: | 
| 94 |  | -            return super().call_operator(op, args, kwargs, meta, False) | 
| 95 |  | - | 
| 96 |  | -        output_dtype = output_qparams[0].dtype | 
| 97 |  | -        node_kwargs["dtype"] = output_dtype | 
| 98 |  | -        return super().call_operator(op, args, node_kwargs, meta, True) | 
| 99 |  | - | 
| 100 |  | - | 
| 101 | 69 | class FoldAndAnnotateQParamsPass(ArmPass): | 
| 102 | 70 |     """ | 
| 103 | 71 |     A pass that walks the graph and removes any DQ and Q nodes before and after the target | 
| @@ -129,7 +97,6 @@ class FoldAndAnnotateQParamsPass(ArmPass): | 
| 129 | 97 |     """ | 
| 130 | 98 | 
 | 
| 131 | 99 |     _passes_required_after: Set[Type[ExportPass]] = { | 
| 132 |  | -        RetraceFoldedDtypesPass, | 
| 133 | 100 |         InsertTableOpsPass, | 
| 134 | 101 |         RemoveNoopPass, | 
| 135 | 102 |     } | 
| @@ -234,6 +201,16 @@ def call(self, graph_module: GraphModule) -> PassResult: | 
| 234 | 201 |                 user.replace_all_uses_with(n) | 
| 235 | 202 |                 graph_module.graph.erase_node(user) | 
| 236 | 203 | 
 | 
|  | 204 | +            # Some op(s) contain a "dtype" key in their node kwargs. Set this | 
|  | 205 | +            # to the type of output qparams. | 
|  | 206 | +            output_qparams = n.meta["output_qparams"] | 
|  | 207 | +            if ( | 
|  | 208 | +                n.target in {exir_ops.edge.aten.sum.dim_IntList} | 
|  | 209 | +                and len(output_qparams) > 0 | 
|  | 210 | +            ): | 
|  | 211 | +                output_dtype = output_qparams[0].dtype | 
|  | 212 | +                set_node_arg(n, "dtype", output_dtype) | 
|  | 213 | + | 
| 237 | 214 |         # retrace the graph to update the fake tensor types | 
| 238 | 215 |         graph_module = super().call(graph_module).graph_module | 
| 239 | 216 | 
 | 
|  | 
0 commit comments