Skip to content

Commit b324f8b

Browse files
author
morelos
committed
Update base for Update on "[ET-VK][Ops] choose_qparams op shaders and impl"
Creating the choose_qparams per_tensor and per_token logic shaders and impl which are linked with the testing framework Differential Revision: [D76436933](https://our.internmc.facebook.com/intern/diff/D76436933/) [ghstack-poisoned]
2 parents efa2755 + 3b1c7fd commit b324f8b

File tree

64 files changed

+1552
-437
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1552
-437
lines changed

.github/workflows/android-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ jobs:
342342
git clone https://github.com/huggingface/optimum-executorch
343343
pushd optimum-executorch
344344
# There is no release yet, for CI stability, always test from the same commit on main
345-
git checkout 1c653dc49812fc431a22312c7295d97005d22e12
346-
python install_dev.py
345+
git checkout 4c3b18f6cca68c5ccff809131d570062723d7188
346+
python install_dev.py --skip_override_torch
347347
pip list
348348
349349
ARGS=(

.github/workflows/apple-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ jobs:
347347
git clone https://github.com/huggingface/optimum-executorch
348348
pushd optimum-executorch
349349
# There is no release yet, for CI stability, always test from the same commit on main
350-
git checkout 1c653dc49812fc431a22312c7295d97005d22e12
351-
${CONDA_RUN} python install_dev.py
350+
git checkout 4c3b18f6cca68c5ccff809131d570062723d7188
351+
${CONDA_RUN} python install_dev.py --skip_override_torch
352352
pip list
353353
354354
ARGS=(

.github/workflows/trunk.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,8 @@ jobs:
571571
git clone https://github.com/huggingface/optimum-executorch
572572
pushd optimum-executorch
573573
# There is no release yet, for CI stability, always test from the same commit on main
574-
git checkout 1c653dc49812fc431a22312c7295d97005d22e12
575-
pip install .[tests]
576-
pip install transformers==4.52.4
574+
git checkout 4c3b18f6cca68c5ccff809131d570062723d7188
575+
python install_dev.py --skip_override_torch
577576
popd
578577
pip list
579578
echo "::endgroup::"

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
3030
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
3131
from .decompose_linear_pass import DecomposeLinearPass # noqa
32+
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
3233
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
3334
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
3435
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55

66
# pyre-unsafe
77

8-
from typing import cast
98

109
import torch
1110
from executorch.backends.arm._passes.arm_pass_utils import (
1211
create_node,
1312
get_first_fake_tensor,
14-
insert_q_dq_pair,
1513
)
16-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1714
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1815
from executorch.exir.dialects._ops import ops as exir_ops
1916
from executorch.exir.pass_base import ExportPass, PassResult
@@ -59,20 +56,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
5956

6057
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
6158
"""
62-
returns True for dq and w in the following sequences;
59+
returns True for w in the following sequence;
6360
w -> depthwise_conv2d -> ...
64-
w -> dq -> depthwise_conv2d -> ...
6561
"""
66-
if node.op == "call_function":
67-
if node.target != dq_op:
68-
return False
69-
prev_node = node.args[0]
70-
if cast(torch.fx.Node, prev_node).op != "placeholder":
71-
return False
72-
if is_consumer_node_depthwise_conv2d(node):
73-
consumer_node = list(node.users)[0]
74-
return consumer_node.args[1] == node
75-
elif node.op == "placeholder":
62+
if node.op == "placeholder":
7663
# node is an input, weight or bias node
7764
consumer_node = list(node.users)[0]
7865
if self.is_weight_node_for_depthwise_conv2d(consumer_node):
@@ -129,8 +116,6 @@ def is_channel_reshape(input_shape, output_shape):
129116

130117
@staticmethod
131118
def insert_input_transpose(node, input_node, graph_module):
132-
quantize = input_node.target == dq_op
133-
q_params = input_node.args[1:] if quantize else None
134119
with graph_module.graph.inserting_before(node):
135120
permute_node = create_node(
136121
graph_module.graph,
@@ -143,8 +128,6 @@ def insert_input_transpose(node, input_node, graph_module):
143128
else AnnotateChannelsLastDimOrder.NHWC_inverse_order
144129
),
145130
),
146-
quantize=quantize,
147-
q_params=q_params,
148131
)
149132
node.replace_input_with(input_node, permute_node)
150133

@@ -185,11 +168,6 @@ def insert_output_transpose(node, graph_module):
185168
for user in users:
186169
user.replace_input_with(node, permute_node)
187170

188-
quantize = node.args[0] == q_op
189-
if quantize:
190-
q_params = node.args[0].args[1:]
191-
insert_q_dq_pair(graph_module.graph, node, q_params)
192-
193171
@staticmethod
194172
def _insert_view_transpose(
195173
input_shape, output_shape, node, input_node, graph_module

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
import itertools
99
import operator
10-
from typing import List
10+
from typing import cast, List
1111

1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
1414

15-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
15+
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1718
from executorch.exir.pass_base import ExportPass, PassResult
1819
from torch.fx import GraphModule
1920
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
@@ -61,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6162
}
6263
for partition in matmul_partitions:
6364
quantized_input = all(
64-
input_node.target == dq_op for input_node in partition.input_nodes
65+
input_node.target in dq_ops for input_node in partition.input_nodes
6566
)
6667
matmul_node = [
6768
node for node in partition.nodes if node.target in matmul_targets
@@ -74,17 +75,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7475
input_node = self._match_partition_to_node(
7576
node, partition.input_nodes
7677
)
77-
input_node_qargs = QuantArgs.from_operator(
78-
input_node.target, input_node.args
79-
)
8078
# Insert new dq-node just before the mm/bmm with input_node's qparams
8179
with graph_module.graph.inserting_before(matmul_node):
8280
# Create new dq-node before matmul
8381
dq_node = create_node(
8482
graph=graph_module.graph,
85-
op_target=dq_op,
83+
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
8684
)
87-
dq_node.args = (node, *input_node_qargs)
85+
dq_node.args = (node, *input_node.args[1:])
8886
matmul_node.replace_input_with(node, dq_node)
8987

9088
for partition_input in partition.input_nodes:
@@ -95,19 +93,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
9593
graph_module.graph.erase_node(partition_input)
9694

9795
partition_output = list(partition.output_nodes[0].users)[0]
98-
quantized_output = partition_output.target == q_op
96+
quantized_output = partition_output.target in q_ops
9997
if quantized_output:
100-
output_node_qargs = QuantArgs.from_operator(
101-
partition_output.target, partition_output.args
102-
)
10398
with graph_module.graph.inserting_after(matmul_node):
10499
# Create q-node after matmul
105100
q_node = create_node(
106101
graph=graph_module.graph,
107-
op_target=q_op,
102+
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
108103
)
109104
matmul_node.replace_all_uses_with(q_node)
110-
q_node.args = (matmul_node, *output_node_qargs)
105+
q_node.args = (matmul_node, *partition_output.args[1:])
111106
# Remove partition output q-node
112107
partition_output.replace_all_uses_with(
113108
partition_output.all_input_nodes[0]

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeLeakyReLUPass,
3333
DecomposeLinearPass,
3434
DecomposeLinearVectorNormPass,
35+
DecomposeMaxPool2DPass,
3536
DecomposeMeanDimPass,
3637
DecomposeNotEqualPass,
3738
DecomposeSelectPass,
@@ -92,7 +93,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9293
self.add_pass(RemoveGetItemPass())
9394
self.add_pass(ConvertSplitToSlicePass())
9495
self.add_pass(ConvertMmToBmmPass())
95-
self.add_pass(DecomposeLinearPass())
9696
self.add_pass(DecomposeLinearVectorNormPass())
9797
self.add_pass(
9898
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
@@ -108,12 +108,13 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
108108
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
109109
self.add_pass(AnnotateDecomposedMatmulPass())
110110
self.add_pass(QuantizeOperatorArguments())
111-
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
111+
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
112112
self.add_pass(RetraceFoldedDtypesPass())
113113
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
114114
self.add_pass(MatchArgRanksPass(exported_program))
115115
if self.tosa_spec.is_U55_subset:
116116
self.add_pass(BroadcastArgsPass())
117+
self.add_pass(DecomposeLinearPass())
117118
self.add_pass(ComputeConstantOpsAOT(exported_program))
118119

119120
self.add_pass(RemoveClonePass())
@@ -123,6 +124,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
123124
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
124125
self.add_pass(DecomposeSumPass())
125126
self.add_pass(Conv1dUnsqueezePass())
127+
self.add_pass(DecomposeMaxPool2DPass())
126128
self.add_pass(DecomposeSelectPass())
127129
self.add_pass(ConvertSqueezesToViewPass())
128130

@@ -166,7 +168,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
166168

167169
self.add_pass(AnnotateDecomposedMatmulPass())
168170
self.add_pass(QuantizeOperatorArguments())
169-
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
171+
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
170172
self.add_pass(RetraceFoldedDtypesPass())
171173
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
172174
self.add_pass(MatchArgRanksPass(exported_program))
@@ -179,6 +181,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
179181
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
180182
self.add_pass(DecomposeSumPass())
181183
self.add_pass(Conv1dUnsqueezePass())
184+
self.add_pass(DecomposeMaxPool2DPass())
182185
self.add_pass(DecomposeSelectPass())
183186
self.add_pass(ConvertSqueezesToViewPass())
184187

backends/arm/_passes/cast_int64_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
3535

3636
def _to_int32(self, graph_module: torch.fx.GraphModule):
3737
for node in graph_module.graph.nodes:
38+
if len(node.users) == 0:
39+
continue
3840
fake_tensor = node.meta["val"]
3941
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
4042
continue

backends/arm/_passes/decompose_linalg_vector_norm_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ def call_operator(self, op, args, kwargs, meta):
5151
f"is not supported for linalg_vector_norm operator"
5252
)
5353

54+
# Sum over all dimensions if dim is None
5455
if norm_dim is None:
55-
raise ValueError("The norm_dim for linalg_vector_norm is None.")
56-
57-
dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim)
56+
rank = input_tensor.data.dim()
57+
dims = list(range(rank))
58+
else:
59+
dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim)
5860

5961
# Decomposition based on norm order.
6062
if norm_order == 1:

backends/arm/_passes/decompose_linear_pass.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,28 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import numpy as np
9+
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm._passes.arm_pass_utils import (
1111
create_node,
1212
get_first_fake_tensor,
1313
)
14-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1514
from executorch.exir.dialects._ops import ops as exir_ops
16-
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.pass_base import PassResult
1716

1817

19-
class DecomposeLinearPass(ExportPass):
18+
class DecomposeLinearPass(ArmPass):
2019
"""
2120
This pass decomposes linear into a Conv2D with the required view operations.
2221
linear(x, weights, bias) becomes:
2322
x_reshaped = view(x)
2423
weights_reshaped = view(weights)
2524
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
2625
output = view(conv2d)
27-
It also inserts q/dq pairs if the linear node was quantized.
2826
"""
2927

3028
def call(self, graph_module):
@@ -47,35 +45,22 @@ def call(self, graph_module):
4745
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]
4846

4947
with graph_module.graph.inserting_before(node):
50-
quantize = input.op == "call_function" and input.target == dq_op
51-
q_params = input.args[1:] if quantize else None
5248
# Reshape input to 4D with shape (N, Ci, 1, 1)
5349
input_reshaped = create_node(
5450
graph=graph_module.graph,
5551
op_target=exir_ops.edge.aten.view_copy.default,
5652
args=(input, input_reshaped_shape),
5753
kwargs={},
58-
quantize=quantize,
59-
q_params=q_params,
6054
)
6155

62-
quantize = weights.op == "call_function" and weights.target == dq_op
63-
q_params = weights.args[1:] if quantize else None
6456
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
6557
weights_reshaped = create_node(
6658
graph=graph_module.graph,
6759
op_target=exir_ops.edge.aten.view_copy.default,
6860
args=(weights, weights_reshaped_shape),
6961
kwargs={},
70-
quantize=quantize,
71-
q_params=q_params,
7262
)
7363

74-
consumer_node = list(node.users)[0]
75-
quantize = (
76-
consumer_node.op == "call_function" and consumer_node.target == q_op
77-
)
78-
q_params = consumer_node.args[1:] if quantize else None
7964
conv = create_node(
8065
graph=graph_module.graph,
8166
op_target=exir_ops.edge.aten.convolution.default,
@@ -91,8 +76,7 @@ def call(self, graph_module):
9176
1, # groups
9277
),
9378
kwargs={},
94-
quantize=quantize,
95-
q_params=q_params,
79+
from_node=node,
9680
)
9781

9882
with graph_module.graph.inserting_after(conv):

0 commit comments

Comments
 (0)