|  | 
| 13 | 13 | from executorch.backends.arm._passes.arm_pass_utils import ( | 
| 14 | 14 |     get_param_tensor, | 
| 15 | 15 |     is_param_node, | 
| 16 |  | -    set_node_arg, | 
| 17 | 16 | ) | 
| 18 | 17 | from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass | 
| 19 | 18 | 
 | 
|  | 
| 23 | 22 | from executorch.exir import ExportedProgram | 
| 24 | 23 | 
 | 
| 25 | 24 | from executorch.exir.dialects._ops import ops as exir_ops | 
|  | 25 | +from executorch.exir.dialects.edge._ops import EdgeOpOverload | 
| 26 | 26 | 
 | 
| 27 | 27 | from executorch.exir.pass_base import ExportPass, PassResult | 
| 28 | 28 | from torch.fx import GraphModule, Node | 
| @@ -66,6 +66,38 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: | 
| 66 | 66 |     return output_qparams | 
| 67 | 67 | 
 | 
| 68 | 68 | 
 | 
|  | 69 | +class RetraceFoldedDtypesPass(ArmPass): | 
|  | 70 | +    """ | 
|  | 71 | +    FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced | 
|  | 72 | +    some operators are retraced to types that cannot be handled by TOSA. One | 
|  | 73 | +    such example is sum.dim_IntList: | 
|  | 74 | +        q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... | 
|  | 75 | +    After folding it becomes: | 
|  | 76 | +        q (int8)              -> sum (int64) ->         ... | 
|  | 77 | +    This pass changes types of ops in self.targeted_ops, such as sum, so that | 
|  | 78 | +    the output type of that matches the type of the output_qparams. | 
|  | 79 | +    """ | 
|  | 80 | + | 
|  | 81 | +    _passes_required_after: Set[Type[ExportPass]] = set() | 
|  | 82 | + | 
|  | 83 | +    targeted_ops: Set[EdgeOpOverload] = { | 
|  | 84 | +        exir_ops.edge.aten.sum.dim_IntList, | 
|  | 85 | +    } | 
|  | 86 | + | 
|  | 87 | +    def call_operator(self, op, args, kwargs, meta): | 
|  | 88 | +        if op not in self.targeted_ops: | 
|  | 89 | +            return super().call_operator(op, args, kwargs, meta, False) | 
|  | 90 | + | 
|  | 91 | +        node_kwargs = kwargs.copy() | 
|  | 92 | +        output_qparams = meta["output_qparams"] | 
|  | 93 | +        if len(output_qparams) == 0: | 
|  | 94 | +            return super().call_operator(op, args, kwargs, meta, False) | 
|  | 95 | + | 
|  | 96 | +        output_dtype = output_qparams[0].dtype | 
|  | 97 | +        node_kwargs["dtype"] = output_dtype | 
|  | 98 | +        return super().call_operator(op, args, node_kwargs, meta, True) | 
|  | 99 | + | 
|  | 100 | + | 
| 69 | 101 | class FoldAndAnnotateQParamsPass(ArmPass): | 
| 70 | 102 |     """ | 
| 71 | 103 |     A pass that walks the graph and removes any DQ and Q nodes before and after the target | 
| @@ -97,6 +129,7 @@ class FoldAndAnnotateQParamsPass(ArmPass): | 
| 97 | 129 |     """ | 
| 98 | 130 | 
 | 
| 99 | 131 |     _passes_required_after: Set[Type[ExportPass]] = { | 
|  | 132 | +        RetraceFoldedDtypesPass, | 
| 100 | 133 |         InsertTableOpsPass, | 
| 101 | 134 |         RemoveNoopPass, | 
| 102 | 135 |     } | 
| @@ -201,16 +234,6 @@ def call(self, graph_module: GraphModule) -> PassResult: | 
| 201 | 234 |                 user.replace_all_uses_with(n) | 
| 202 | 235 |                 graph_module.graph.erase_node(user) | 
| 203 | 236 | 
 | 
| 204 |  | -            # Some op(s) contain a "dtype" key in their node kwargs. Set this | 
| 205 |  | -            # to the type of output qparams. | 
| 206 |  | -            output_qparams = n.meta["output_qparams"] | 
| 207 |  | -            if ( | 
| 208 |  | -                n.target in {exir_ops.edge.aten.sum.dim_IntList} | 
| 209 |  | -                and len(output_qparams) > 0 | 
| 210 |  | -            ): | 
| 211 |  | -                output_dtype = output_qparams[0].dtype | 
| 212 |  | -                set_node_arg(n, "dtype", output_dtype) | 
| 213 |  | - | 
| 214 | 237 |         # retrace the graph to update the fake tensor types | 
| 215 | 238 |         graph_module = super().call(graph_module).graph_module | 
| 216 | 239 | 
 | 
|  | 
0 commit comments