2424 process_placeholder ,
2525)
2626from executorch .backends .arm .tosa .compile_spec import TosaCompileSpec
27+ from executorch .backends .arm .tosa .mapping import TOSA_TENSOR_NAME_META
2728from executorch .exir .backend .backend_details import BackendDetails , PreprocessResult
2829from executorch .exir .backend .compile_spec_schema import CompileSpec
30+ from executorch .exir .graph_module import get_control_flow_submodules
2931from torch .export .exported_program import ExportedProgram
30- from torch .fx import Graph , Node
32+ from torch .fx import Graph , GraphModule , Node
33+
3134
3235# TOSA backend debug functionality
3336logger = logging .getLogger (__name__ )
@@ -52,13 +55,39 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
5255 # Walk backwards so we touch every producer
5356 q .extend (n .all_input_nodes )
5457
55- out = next (n for n in ep_graph .nodes if n .op == "output" )
58+ out = ep_graph .output_node ()
59+ # First argument of output node is tuple of outputs
60+ output_list = cast (tuple , out .args [0 ])
5661 seen : Set [Node ] = set ()
57- for idx , val in enumerate (out . args [ 0 ] ):
62+ for idx , val in enumerate (output_list ):
5863 bfs_mark ([val ], idx , seen )
5964 return node2external_id
6065
6166
67+ def _sort_outputs (graph_module : GraphModule , node_to_id_map : dict [str , int ]):
68+ def _external_id (n : Node , node_2_id , fallback : int ) -> int :
69+ return node_2_id .get (n .name , fallback )
70+
71+ out_node = graph_module .graph .output_node ()
72+ out_list = cast (tuple , out_node .args [0 ])
73+ _counter = count ()
74+
75+ # sort nodes by the key that is id
76+ def _sort_key (t : Node ) -> int :
77+ return _external_id (t , node_to_id_map , next (_counter ))
78+
79+ orig_ord = tuple (sorted (out_list , key = _sort_key ))
80+
81+ current_order = tuple (out_list )
82+ if orig_ord != current_order :
83+ replacement = list (orig_ord ) if isinstance (out_node .args [0 ], list ) else orig_ord
84+ out_node .args = (replacement ,)
85+ graph_module .graph .lint ()
86+ graph_module .recompile ()
87+
88+ return graph_module
89+
90+
6291def arm_get_first_delegation_tag (graph_module ) -> str :
6392 """Get the first delegation tag from the graph_module or return empty string."""
6493 for node in graph_module .graph .nodes :
@@ -93,9 +122,9 @@ def _preprocess( # noqa: C901
93122 artifact_path = compile_spec .get_intermediate_path ()
94123 tosa_spec = compile_spec .tosa_spec
95124 dump_debug_info = compile_spec .tosa_debug_mode
96-
97- # Assign to every node external id
98- node_2_id = _annotate_external_ids ( edge_program . graph )
125+ debug_hook = None
126+ if dump_debug_info is not None :
127+ debug_hook = DebugHook ( dump_debug_info )
99128
100129 logger .info (f"Converting ExportedProgram to TOSA: { tosa_spec } " )
101130
@@ -116,45 +145,66 @@ def _preprocess( # noqa: C901
116145 f"doesn't match specification { tosa_spec } "
117146 )
118147
148+ TOSABackend ._preprocess_module (
149+ edge_program .graph_module ,
150+ edge_program ,
151+ compile_spec ,
152+ tosa_graph ,
153+ debug_hook ,
154+ )
155+ # Serialize and return the TOSA flatbuffer.
156+ binary = tosa_graph .serialize ()
157+
158+ if artifact_path :
159+ tag = arm_get_first_delegation_tag (edge_program .graph_module )
160+ debug_tosa_dump (
161+ binary ,
162+ artifact_path ,
163+ suffix = "{}" .format (f"_{ tag } " if tag else "" ) + (f"_{ tosa_spec } " ),
164+ )
165+
166+ if debug_hook is not None :
167+ if debug_hook .mode == ArmCompileSpec .DebugMode .JSON :
168+ json_output = debug_hook .serialize ()
169+ with open (f"{ artifact_path } /debug.json" , "w" ) as f :
170+ f .write (json_output )
171+
172+ return PreprocessResult (processed_bytes = binary )
173+
174+ @staticmethod
175+ def _preprocess_module ( # noqa: C901
176+ graph_module : GraphModule ,
177+ edge_program : ExportedProgram ,
178+ compile_spec : TosaCompileSpec ,
179+ tosa_graph : ts .TosaSerializer ,
180+ debug_hook : DebugHook | None ,
181+ submodule_name : str | None = None ,
182+ ):
183+ """Convert 'graph_module' to a tosa_graph"""
184+ tosa_spec = compile_spec .tosa_spec
185+ node_to_id_map = _annotate_external_ids (graph_module .graph )
186+ artifact_path = compile_spec .get_intermediate_path ()
187+
119188 # TODO: Fix the need to lazily import this.
120189 from executorch .backends .arm ._passes import ArmPassManager
121190
122191 graph_module = ArmPassManager (tosa_spec ).transform_to_backend_pipeline ( # type: ignore
123- exported_program = edge_program
192+ exported_program = edge_program , graph_module = graph_module
124193 )
125194
126- debug_hook = None
127- if dump_debug_info is not None :
128- debug_hook = DebugHook (dump_debug_info )
129-
130195 # TODO: Fix the need to lazily import this.
131196 from executorch .backends .arm .operators .node_visitor import get_node_visitors
132197
133198 node_visitors = get_node_visitors (edge_program , tosa_spec , debug_hook )
199+ graph_module = _sort_outputs (graph_module , node_to_id_map )
134200
135- # Re-shuffle output nodes to preserve author's order
136- def _external_id (n : Node , node_2_id , fallback : int ) -> int :
137- return node_2_id .get (n .name , fallback )
138-
139- out_node = next (n for n in graph_module .graph .nodes if n .op == "output" )
140- _counter = count ()
141-
142- # sort nodes by the key that is id
143- def _sort_key (t : Node ) -> int :
144- return _external_id (t , node_2_id , next (_counter ))
201+ if submodule_name is not None :
202+ tosa_graph .startRegion (submodule_name )
203+ tosa_graph .currRegion .addBasicBlock (submodule_name )
204+ suffix = f"_{ submodule_name } "
205+ for loop_node in graph_module .graph .nodes :
206+ loop_node .meta [TOSA_TENSOR_NAME_META ] = suffix
145207
146- orig_ord = tuple (sorted (out_node .args [0 ], key = _sort_key ))
147-
148- current_order = tuple (out_node .args [0 ])
149- if orig_ord != current_order :
150- replacement = (
151- list (orig_ord ) if isinstance (out_node .args [0 ], list ) else orig_ord
152- )
153- out_node .args = (replacement ,)
154- graph_module .graph .lint ()
155- graph_module .recompile ()
156-
157- input_count = 0
158208 for node in graph_module .graph .nodes :
159209 node = cast (Node , node )
160210 try :
@@ -164,37 +214,27 @@ def _sort_key(t: Node) -> int:
164214 if len (node .users ) == 0 :
165215 continue
166216 process_placeholder (node , tosa_graph , edge_program , tosa_spec )
167- if node .name in edge_program .graph_signature .user_inputs :
168- input_count += 1
169217 elif node .op == "output" :
170- process_output (node , tosa_graph )
218+ process_output (node , tosa_graph , tosa_spec )
171219 else :
172220 # This will only happen if an unpartitioned graph is passed without
173221 # any checking of compatibility.
174222 raise RuntimeError (f"{ node .name } is unsupported op { node .op } " )
175223 except Exception :
176- debug_fail (node , graph_module , tosa_graph . serialize () , artifact_path )
224+ debug_fail (node , graph_module , tosa_graph , artifact_path )
177225 raise
178226
179- # Serialize and return the TOSA flatbuffer .
180- binary = tosa_graph . serialize ()
181-
182- if artifact_path :
183- tag = arm_get_first_delegation_tag ( graph_module )
184- debug_tosa_dump (
185- binary ,
186- artifact_path ,
187- suffix = "{}" . format ( f"_ { tag } " if tag else "" ) + ( f"_ { tosa_spec } " ) ,
227+ # Recursively preprocess controlflow submodules .
228+ for name , submodule , _ in get_control_flow_submodules ( graph_module ):
229+ TOSABackend . _preprocess_module (
230+ submodule ,
231+ edge_program ,
232+ compile_spec ,
233+ tosa_graph ,
234+ debug_hook ,
235+ submodule_name = name ,
188236 )
189237
190- if debug_hook is not None :
191- if debug_hook .mode == ArmCompileSpec .DebugMode .JSON :
192- json_output = debug_hook .serialize ()
193- with open (f"{ artifact_path } /debug.json" , "w" ) as f :
194- f .write (json_output )
195-
196- return PreprocessResult (processed_bytes = binary )
197-
198238 @staticmethod
199239 def filter_tosa_compile_specs (
200240 compile_spec : ArmCompileSpec ,
0 commit comments