Skip to content

Commit a4e7475

Browse files
authored
Arm backend: Serialize controlflow submodules. (#15381)
Each conditional submodule in the graph_module gets its own region. The TOSA reference model requires all tensor names in one model to be unique, regardless of region. Pytorch's naming semantics, however don't guarantee this. To fix this, attach a suffix containing the submodule name to tensors in submodules. Signed-off-by: Erik Lundell <[email protected]>
1 parent 5426918 commit a4e7475

File tree

4 files changed

+116
-74
lines changed

4 files changed

+116
-74
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
153153
with TosaLoweringContext(self.tosa_spec):
154154
return self(graph_module).graph_module
155155

156-
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
156+
def _tosa_INT_pipeline(
157+
self, exported_program: ExportedProgram, graph_module: GraphModule
158+
) -> GraphModule:
157159
self.add_pass(AnnotateOutputDimOrderPass())
158160
self.add_pass(FuseQuantizedActivationPass())
159161
self.add_pass(RemoveGetItemPass())
160162
self.add_pass(ConvertSplitToSlicePass())
161163
self.add_pass(ConvertMmToBmmPass())
162-
self.add_pass(
163-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
164-
)
164+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
165165
self.add_pass(ConvertFullLikeToFullPass())
166166
self.add_pass(ConvertToClampPass())
167167
self.add_pass(ConvertMinMaxPass())
@@ -218,9 +218,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
218218
self.add_pass(InsertRescalePass())
219219

220220
self.validate_constraints_mandatory()
221-
return self._transform(exported_program.graph_module)
221+
return self._transform(graph_module)
222222

223-
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
223+
def _tosa_FP_pipeline(
224+
self, exported_program: ExportedProgram, graph_module: GraphModule
225+
) -> GraphModule:
224226
self.add_pass(AnnotateOutputDimOrderPass())
225227
self.add_pass(DecomposeExpm1Pass())
226228
self.add_pass(DecomposeLogitPass())
@@ -255,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
255257
self.add_pass(DecomposeLayerNormPass())
256258
self.add_pass(DecomposeBatchNormNoStatsPass())
257259
self.add_pass(DecomposeVarPass())
258-
self.add_pass(
259-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
260-
)
260+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
261261
self.add_pass(DecomposeNotEqualPass())
262262
self.add_pass(DecomposeDivPass())
263263
self.add_pass(DecomposeAddSubAlphaPass())
@@ -305,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
305305
self.add_pass(InsertRescalePass())
306306

307307
self.validate_constraints_mandatory()
308-
return self._transform(exported_program.graph_module)
308+
return self._transform(graph_module)
309309

310-
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
310+
def transform_to_backend_pipeline(
311+
self, exported_program: ExportedProgram, graph_module: GraphModule
312+
):
311313
"""Apply passes before transforming program to backend"""
312314
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
313-
return self._tosa_FP_pipeline(exported_program)
315+
return self._tosa_FP_pipeline(exported_program, graph_module)
314316
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
315-
return self._tosa_INT_pipeline(exported_program)
317+
return self._tosa_INT_pipeline(exported_program, graph_module)
316318
else:
317319
raise NotImplementedError(
318320
f"No pass pipeline implemented for {self.tosa_spec=}"

backends/arm/process_node.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def process_inputs_to_buffers(
158158
buffer_values = np.transpose(buffer_values, tosa_arg.dim_order)
159159

160160
tosa_graph.addConst(
161-
buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name
161+
buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name
162162
)
163163

164164

@@ -215,11 +215,9 @@ def process_placeholder(
215215
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
216216

217217

218-
def process_output(
219-
node: torch.fx.Node,
220-
tosa_graph: Any,
221-
):
218+
def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification):
222219
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
220+
output_arg = TosaArg(output, tosa_spec)
223221
tosa_graph.addOutputTensor(
224-
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
222+
tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name]
225223
)

backends/arm/tosa/backend.py

Lines changed: 94 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
process_placeholder,
2525
)
2626
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
27+
from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META
2728
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2829
from executorch.exir.backend.compile_spec_schema import CompileSpec
30+
from executorch.exir.graph_module import get_control_flow_submodules
2931
from torch.export.exported_program import ExportedProgram
30-
from torch.fx import Graph, Node
32+
from torch.fx import Graph, GraphModule, Node
33+
3134

3235
# TOSA backend debug functionality
3336
logger = logging.getLogger(__name__)
@@ -52,13 +55,39 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
5255
# Walk backwards so we touch every producer
5356
q.extend(n.all_input_nodes)
5457

55-
out = next(n for n in ep_graph.nodes if n.op == "output")
58+
out = ep_graph.output_node()
59+
# First argument of output node is tuple of outputs
60+
output_list = cast(tuple, out.args[0])
5661
seen: Set[Node] = set()
57-
for idx, val in enumerate(out.args[0]):
62+
for idx, val in enumerate(output_list):
5863
bfs_mark([val], idx, seen)
5964
return node2external_id
6065

6166

67+
def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]):
68+
def _external_id(n: Node, node_2_id, fallback: int) -> int:
69+
return node_2_id.get(n.name, fallback)
70+
71+
out_node = graph_module.graph.output_node()
72+
out_list = cast(tuple, out_node.args[0])
73+
_counter = count()
74+
75+
# sort nodes by the key that is id
76+
def _sort_key(t: Node) -> int:
77+
return _external_id(t, node_to_id_map, next(_counter))
78+
79+
orig_ord = tuple(sorted(out_list, key=_sort_key))
80+
81+
current_order = tuple(out_list)
82+
if orig_ord != current_order:
83+
replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
84+
out_node.args = (replacement,)
85+
graph_module.graph.lint()
86+
graph_module.recompile()
87+
88+
return graph_module
89+
90+
6291
def arm_get_first_delegation_tag(graph_module) -> str:
6392
"""Get the first delegation tag from the graph_module or return empty string."""
6493
for node in graph_module.graph.nodes:
@@ -93,9 +122,9 @@ def _preprocess( # noqa: C901
93122
artifact_path = compile_spec.get_intermediate_path()
94123
tosa_spec = compile_spec.tosa_spec
95124
dump_debug_info = compile_spec.tosa_debug_mode
96-
97-
# Assign to every node external id
98-
node_2_id = _annotate_external_ids(edge_program.graph)
125+
debug_hook = None
126+
if dump_debug_info is not None:
127+
debug_hook = DebugHook(dump_debug_info)
99128

100129
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
101130

@@ -116,45 +145,66 @@ def _preprocess( # noqa: C901
116145
f"doesn't match specification {tosa_spec}"
117146
)
118147

148+
TOSABackend._preprocess_module(
149+
edge_program.graph_module,
150+
edge_program,
151+
compile_spec,
152+
tosa_graph,
153+
debug_hook,
154+
)
155+
# Serialize and return the TOSA flatbuffer.
156+
binary = tosa_graph.serialize()
157+
158+
if artifact_path:
159+
tag = arm_get_first_delegation_tag(edge_program.graph_module)
160+
debug_tosa_dump(
161+
binary,
162+
artifact_path,
163+
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
164+
)
165+
166+
if debug_hook is not None:
167+
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
168+
json_output = debug_hook.serialize()
169+
with open(f"{artifact_path}/debug.json", "w") as f:
170+
f.write(json_output)
171+
172+
return PreprocessResult(processed_bytes=binary)
173+
174+
@staticmethod
175+
def _preprocess_module( # noqa: C901
176+
graph_module: GraphModule,
177+
edge_program: ExportedProgram,
178+
compile_spec: TosaCompileSpec,
179+
tosa_graph: ts.TosaSerializer,
180+
debug_hook: DebugHook | None,
181+
submodule_name: str | None = None,
182+
):
183+
"""Convert 'graph_module' to a tosa_graph"""
184+
tosa_spec = compile_spec.tosa_spec
185+
node_to_id_map = _annotate_external_ids(graph_module.graph)
186+
artifact_path = compile_spec.get_intermediate_path()
187+
119188
# TODO: Fix the need to lazily import this.
120189
from executorch.backends.arm._passes import ArmPassManager
121190

122191
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
123-
exported_program=edge_program
192+
exported_program=edge_program, graph_module=graph_module
124193
)
125194

126-
debug_hook = None
127-
if dump_debug_info is not None:
128-
debug_hook = DebugHook(dump_debug_info)
129-
130195
# TODO: Fix the need to lazily import this.
131196
from executorch.backends.arm.operators.node_visitor import get_node_visitors
132197

133198
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
199+
graph_module = _sort_outputs(graph_module, node_to_id_map)
134200

135-
# Re-shuffle output nodes to preserve author's order
136-
def _external_id(n: Node, node_2_id, fallback: int) -> int:
137-
return node_2_id.get(n.name, fallback)
138-
139-
out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
140-
_counter = count()
141-
142-
# sort nodes by the key that is id
143-
def _sort_key(t: Node) -> int:
144-
return _external_id(t, node_2_id, next(_counter))
201+
if submodule_name is not None:
202+
tosa_graph.startRegion(submodule_name)
203+
tosa_graph.currRegion.addBasicBlock(submodule_name)
204+
suffix = f"_{submodule_name}"
205+
for loop_node in graph_module.graph.nodes:
206+
loop_node.meta[TOSA_TENSOR_NAME_META] = suffix
145207

146-
orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))
147-
148-
current_order = tuple(out_node.args[0])
149-
if orig_ord != current_order:
150-
replacement = (
151-
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
152-
)
153-
out_node.args = (replacement,)
154-
graph_module.graph.lint()
155-
graph_module.recompile()
156-
157-
input_count = 0
158208
for node in graph_module.graph.nodes:
159209
node = cast(Node, node)
160210
try:
@@ -164,37 +214,27 @@ def _sort_key(t: Node) -> int:
164214
if len(node.users) == 0:
165215
continue
166216
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
167-
if node.name in edge_program.graph_signature.user_inputs:
168-
input_count += 1
169217
elif node.op == "output":
170-
process_output(node, tosa_graph)
218+
process_output(node, tosa_graph, tosa_spec)
171219
else:
172220
# This will only happen if an unpartitioned graph is passed without
173221
# any checking of compatibility.
174222
raise RuntimeError(f"{node.name} is unsupported op {node.op}")
175223
except Exception:
176-
debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path)
224+
debug_fail(node, graph_module, tosa_graph, artifact_path)
177225
raise
178226

179-
# Serialize and return the TOSA flatbuffer.
180-
binary = tosa_graph.serialize()
181-
182-
if artifact_path:
183-
tag = arm_get_first_delegation_tag(graph_module)
184-
debug_tosa_dump(
185-
binary,
186-
artifact_path,
187-
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
227+
# Recursively preprocess controlflow submodules.
228+
for name, submodule, _ in get_control_flow_submodules(graph_module):
229+
TOSABackend._preprocess_module(
230+
submodule,
231+
edge_program,
232+
compile_spec,
233+
tosa_graph,
234+
debug_hook,
235+
submodule_name=name,
188236
)
189237

190-
if debug_hook is not None:
191-
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
192-
json_output = debug_hook.serialize()
193-
with open(f"{artifact_path}/debug.json", "w") as f:
194-
f.write(json_output)
195-
196-
return PreprocessResult(processed_bytes=binary)
197-
198238
@staticmethod
199239
def filter_tosa_compile_specs(
200240
compile_spec: ArmCompileSpec,

backends/arm/tosa/mapping.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import tosa_serializer as ts
1818
from executorch.backends.arm.tosa.specification import TosaSpecification
1919

20+
TOSA_TENSOR_NAME_META = "tosa_tensor_name"
21+
2022
UNSUPPORTED_DTYPES = (
2123
torch.float64,
2224
torch.double,
@@ -144,7 +146,7 @@ def __process_node(self, argument: torch.fx.Node):
144146
argument (torch.fx.Node): FX node to inspect.
145147
146148
"""
147-
self.name: str = argument.name
149+
self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "")
148150
output_dtype, self.shape, self.dim_order = extract_tensor_meta(
149151
argument.meta, self.tosa_spec
150152
)

0 commit comments

Comments
 (0)