Skip to content

Commit 07d8738

Browse files
Martin Lindströmoscarandersson8218
authored andcommitted
Arm backend: Move rescales from SUB visitor to pass
Move the insertion of INT8/INT32 RESCALE ops from the SUB node visitor to the pass InsertRescaleInt32Pass. This is in practice a refactoring patch, but still the output TOSA file becomes different enough to cause an Ethos-U55 test to fail in test_var.py. However, this issue was fixed in https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/commit/642f7517d3a6bd053032e1942822f6e38ccd546f so we temporarily set the failing test to xfail until the version of Ethos-U Vela compiler depended on is bumped to one that includes the fix. Signed-off-by: Martin Lindstroem <[email protected]> Co-authored-by: Oscar Andersson <[email protected]> Change-Id: I38d63015e03e59c267338c84d64731b050854d06
1 parent 25cf3e9 commit 07d8738

File tree

5 files changed

+27
-456
lines changed

5 files changed

+27
-456
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass):
101101
exir_ops.edge.aten.maximum.default,
102102
exir_ops.edge.aten.minimum.default,
103103
exir_ops.edge.aten.mul.Tensor,
104+
exir_ops.edge.aten.sub.Tensor,
104105
exir_ops.edge.aten.sum.dim_IntList,
105106
]
106107

@@ -144,6 +145,7 @@ def _get_inputs_rescaled_qparams(
144145
}
145146
elif target in [
146147
exir_ops.edge.aten.add.Tensor,
148+
exir_ops.edge.aten.sub.Tensor,
147149
]:
148150
if input_qparams[0].dtype != input_qparams[1].dtype:
149151
raise ValueError(
@@ -196,6 +198,7 @@ def _get_output_qparams(
196198
exir_ops.edge.aten.minimum.default,
197199
exir_ops.edge.aten.sum.dim_IntList,
198200
exir_ops.edge.aten.add.Tensor,
201+
exir_ops.edge.aten.sub.Tensor,
199202
]:
200203
# The op has not altered the scale; the output scale is equal to
201204
# the operands' scales.

backends/arm/operators/op_sub.py

Lines changed: 8 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
import executorch.backends.arm.tosa.utils as tutils
1210
import tosa_serializer as ts
1311

1412
from executorch.backends.arm.operators.node_visitor import (
@@ -20,22 +18,20 @@
2018
validate_same_dtype,
2119
validate_valid_dtype,
2220
)
23-
from executorch.backends.arm.tosa import TosaSpecification
2421
from executorch.backends.arm.tosa.mapping import TosaArg
22+
from executorch.backends.arm.tosa.specification import TosaSpecification
2523
from torch.fx import Node
2624

2725

2826
@register_node_visitor
29-
class SubVisitor_INT(NodeVisitor):
27+
class SubVisitor(NodeVisitor):
3028
target = "aten.sub.Tensor"
3129

3230
tosa_specs = [
3331
TosaSpecification.create_from_string("TOSA-1.0+INT"),
32+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3433
]
3534

36-
def __init__(self, *args):
37-
super().__init__(*args)
38-
3935
def define_node(
4036
self,
4137
node: Node,
@@ -48,106 +44,21 @@ def define_node(
4844
validate_valid_dtype(
4945
self.target,
5046
[*inputs, output],
51-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
47+
[ts.DType.INT32, ts.DType.FP32],
5248
output.tosa_spec,
5349
)
5450

55-
scale_back = 1.0
56-
if inputs[0].dtype == ts.DType.INT8:
57-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
58-
tosa_graph, inputs, node, self.tosa_spec
59-
)
60-
elif inputs[0].dtype == ts.DType.INT16:
61-
rescaled_inputs, scale_back = (
62-
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
63-
tosa_graph, inputs, node, self.tosa_spec
64-
)
65-
)
66-
else:
67-
# input[0].dtype == ts.DType.INT32
68-
# Non quantized input, natively support by TOSA.SUB
69-
rescaled_inputs = inputs
70-
71-
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
72-
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
73-
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
74-
else:
75-
# output.dtype == ts.DType.INT32
76-
sub_output = output
77-
78-
# Do the INT32 Sub
7951
attr = ts.TosaSerializerAttribute()
8052
attr.SubAttribute()
53+
8154
self._serialize_operator(
8255
node,
8356
tosa_graph,
8457
ts.Op.SUB,
8558
[
86-
rescaled_inputs[0].name,
87-
rescaled_inputs[1].name,
59+
inputs[0].name,
60+
inputs[1].name,
8861
],
89-
[sub_output.name],
62+
[output.name],
9063
attr,
9164
)
92-
93-
if output.dtype == ts.DType.INT8:
94-
# Scale output back to 8 bit
95-
# pyre-ignore
96-
tqutils.insert_rescale_op_to_int8(
97-
tosa_graph,
98-
sub_output,
99-
scale_back,
100-
node,
101-
compute_rescale=False,
102-
tosa_spec=self.tosa_spec,
103-
) # type: ignore[possibly-undefined]
104-
elif output.dtype == ts.DType.INT16:
105-
tqutils.insert_rescale_op_to_int16(
106-
tosa_graph,
107-
sub_output,
108-
scale_back,
109-
node,
110-
compute_rescale=False,
111-
tosa_spec=self.tosa_spec,
112-
) # type: ignore[possibly-undefined]
113-
114-
115-
@register_node_visitor
116-
class SubVisitor_FP(SubVisitor_INT):
117-
# inheriting 'target' from INT class
118-
119-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
120-
121-
def __init__(self, *args):
122-
super().__init__(*args)
123-
124-
def define_node(
125-
self,
126-
node: Node,
127-
tosa_graph: Any,
128-
inputs: List[TosaArg],
129-
output: TosaArg,
130-
) -> None:
131-
validate_num_inputs(self.target, inputs, 2)
132-
validate_same_dtype(self.target, [*inputs, output], ts)
133-
134-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
135-
# Call the inherited define_node for handling integers
136-
super().define_node(node, tosa_graph, inputs, output)
137-
else:
138-
# FP32 Sub lowering
139-
validate_valid_dtype(
140-
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
141-
)
142-
143-
# MI lowering
144-
attr = ts.TosaSerializerAttribute()
145-
attr.SubAttribute()
146-
self._serialize_operator(
147-
node,
148-
tosa_graph,
149-
ts.Op.SUB,
150-
[inputs[0].name, inputs[1].name],
151-
[output.name],
152-
attr,
153-
)

backends/arm/test/ops/test_var.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,17 @@ def test_var_dim_tosa_INT_correction(test_data: Tuple):
344344
pipeline.run()
345345

346346

347-
@common.parametrize("test_data", VarCorrection.test_parameters)
347+
# TODO: Xfail "var_3d_dims_keep_dim_0_correction" until the Ethos-U Vela compiler ships commit
348+
# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that causes the test to fail.
349+
@common.parametrize(
350+
"test_data",
351+
VarCorrection.test_parameters,
352+
xfails={
353+
"var_3d_dims_keep_dim_0_correction": (
354+
"Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"
355+
),
356+
},
357+
)
348358
@common.XfailIfNoCorstone300
349359
def test_var_dim_u55_INT_correction(test_data: Tuple):
350360
test_data, dim, keepdim, correction = test_data()

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class MultipleOpsModel(torch.nn.Module):
1919
input_t = Tuple[torch.Tensor, torch.Tensor]
2020

2121
def forward(self, x, y):
22-
a = x + y
22+
a = x - y
2323
b = x * a
2424
c = torch.maximum(a, b)
2525
d = torch.abs(b)

0 commit comments

Comments
 (0)