Skip to content

Commit 4559a61

Browse files
authored
Qualcomm AI Engine Direct - Mimi Enablement Stage 2 (#10098)
### Summary - Support Mimi Encoder - Support Mimi Decoder - Support OP: CDist - Workaround for 0D tensor scenario since QNN does not support 0D tensor Commands to Execute the model 10 sec as a batch (1 inference for 10 sec audio): `python examples/qualcomm/oss_scripts/moshi/mimi.py -b build-android -s $DEVICE -m SM8650 --chunks_per_batch 125` 80s as a batch (125 inferences for 10 sec audio) `python examples/qualcomm/oss_scripts/moshi/mimi.py -b build-android -s $DEVICE -m SM8650` ### Test plan - UT for CDist - UT for 0D tensor
1 parent 6d1caca commit 4559a61

17 files changed

+701
-45
lines changed

backends/qualcomm/_passes/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
1212
from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear
1313
from .decompose_any import DecomposeAny
14+
from .decompose_cdist import DecomposeCDist
1415
from .decompose_einsum import DecomposeEinsum
1516
from .decompose_expm1 import DecomposeExpM1
1617
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
@@ -27,6 +28,7 @@
2728
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
2829
from .recompose_rms_norm import RecomposeRmsNorm
2930
from .reduce_dynamic_range import ReduceDynamicRange
31+
from .remove_0d_tensor import Remove0DTensor
3032
from .remove_redundancy import RemoveRedundancy
3133
from .replace_arange_args import ReplaceArangeArgs
3234
from .replace_index_put_input import ReplaceIndexPutInput
@@ -40,8 +42,9 @@
4042
AnnotateUnbind,
4143
ConvertBmmToMatmul,
4244
ConvertConv1dToConv2d,
43-
DecomposeAny,
4445
ConvertUpsampleBicubicWithBilinear,
46+
DecomposeAny,
47+
DecomposeCDist,
4548
DecomposeEinsum,
4649
DecomposeExpM1,
4750
DecomposeLinalgVectorNorm,
@@ -58,6 +61,7 @@
5861
RecomposePixelUnshuffle,
5962
RecomposeRmsNorm,
6063
ReduceDynamicRange,
64+
Remove0DTensor,
6165
RemoveRedundancy,
6266
ReplaceArangeArgs,
6367
ReplaceIndexPutInput,

backends/qualcomm/_passes/annotate_stack.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@ class AnnotateStack(ExportPass):
1717
generated after quantization process.
1818
"""
1919

20-
decomp_ops = [torch.ops.aten.unbind.int]
20+
decomp_ops = [torch.ops.aten.stack.default]
2121

2222
def __init__(self, edge_program: torch.export.ExportedProgram):
2323
super(AnnotateStack, self).__init__()
2424
self.edge_program = edge_program
2525

2626
def _annotate_stack(self, graph_module: torch.fx.GraphModule):
27-
partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"])
27+
partitions = get_source_partitions(
28+
graph_module.graph, [torch.stack, torch.ops.aten.stack.default, "stack"]
29+
)
2830
for _, src_partitions in partitions.items():
2931
for src_partition in src_partitions:
3032
output = src_partition.output_nodes[0]

backends/qualcomm/_passes/annotate_unbind.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
2424
self.edge_program = edge_program
2525

2626
def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
27-
partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"])
27+
partitions = get_source_partitions(
28+
graph_module.graph, [torch.unbind, torch.ops.aten.unbind.int, "unbind"]
29+
)
2830
for _, src_partitions in partitions.items():
2931
for src_partition in src_partitions:
3032
if src_partition.input_nodes[0].target in dq_ops:

backends/qualcomm/_passes/convert_conv1d_to_conv2d.py

+10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn as nn
99
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
10+
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass, PassResult
1213

@@ -43,6 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4344
unsqueeze_node.meta = copy_meta(
4445
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
4546
)
47+
4648
with graph_module.graph.inserting_after(unsqueeze_node):
4749

4850
filter_node = node.args[1]
@@ -92,6 +94,14 @@ def call(self, graph_module: torch.fx.GraphModule):
9294
),
9395
)
9496
squeeze_node.meta = copy_meta(node.meta)
97+
98+
if QCOM_REQUANTIZE in input_node.meta:
99+
input_node.meta.pop(QCOM_REQUANTIZE)
100+
if QCOM_REQUANTIZE in node.meta:
101+
squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[
102+
QCOM_REQUANTIZE
103+
]
104+
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
95105
for user in node.users.copy():
96106
user.replace_input_with(node, squeeze_node)
97107
graph.eliminate_dead_code()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
11+
class CDist(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
15+
def forward(self, x, y):
16+
# Step 1: Compute differences
17+
diff = x.unsqueeze(-2) - y.unsqueeze(-3)
18+
19+
# Step 2: Square differences
20+
sq_diff = diff**2
21+
22+
# Step 3: Sum of squares
23+
sum_sq_diff = sq_diff.sum(dim=-1)
24+
25+
# Step 4: Square root
26+
distances = torch.sqrt(sum_sq_diff)
27+
28+
return distances
29+
30+
31+
class DecomposeCDist(ExportPass):
32+
"""
33+
Decompose for math equivalent op.
34+
"""
35+
36+
def __init__(self) -> None:
37+
super().__init__()
38+
39+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
40+
graph = graph_module.graph
41+
for node in graph.nodes:
42+
model = CDist()
43+
if torch.ops.aten.cdist.default == node.target:
44+
if len(node.args) > 2:
45+
assert (
46+
node.args[2] == 2
47+
), "Currently only p=2 is supported for CDist Decomposition"
48+
decomposed_module = torch.export.export(
49+
model,
50+
(node.args[0].meta["val"], node.args[1].meta["val"]),
51+
strict=True,
52+
).module()
53+
with graph.inserting_before(node):
54+
# remap is used to map original node values to new node values,
55+
# which ensures that reference to nodes are correctly updated in the new graph
56+
remap = {"x": node.args[0], "y": node.args[1]}
57+
58+
for decomposed_node in decomposed_module.graph.nodes:
59+
# no need to copy existent 'output'
60+
if decomposed_node.op == "output":
61+
for user in node.users.copy():
62+
# remap
63+
user.replace_input_with(
64+
node,
65+
remap[decomposed_node.args[0][0]],
66+
)
67+
# no need to copy existent placeholders
68+
elif decomposed_node.op == "placeholder":
69+
# replace node map from string to graph node
70+
remap[decomposed_node] = remap.pop(decomposed_node.name)
71+
else:
72+
remap[decomposed_node] = graph.node_copy(
73+
decomposed_node,
74+
arg_transform=lambda x, remap=remap: remap[x],
75+
)
76+
77+
graph.erase_node(node)
78+
79+
graph.eliminate_dead_code()
80+
graph_module.recompile()
81+
return PassResult(graph_module, True)

backends/qualcomm/_passes/lift_constant_scalar_operands.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ class TensorOpInfo:
5353
}
5454

5555

56-
SKIP_LIFT_OPS = {aten.full_like.default, aten.arange.start_step}
56+
SKIP_LIFT_OPS = {
57+
aten.full_like.default,
58+
aten.arange.start_step,
59+
aten.arange.default,
60+
aten.scalar_tensor.default,
61+
aten.elu.default,
62+
}
5763

5864

5965
class LiftConstantScalarOperands(ExportPass):

backends/qualcomm/_passes/qnn_pass_manager.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ConvertConv1dToConv2d,
1717
ConvertUpsampleBicubicWithBilinear,
1818
DecomposeAny,
19+
DecomposeCDist,
1920
DecomposeEinsum,
2021
DecomposeExpM1,
2122
DecomposeLinalgVectorNorm,
@@ -32,6 +33,7 @@
3233
RecomposePixelUnshuffle,
3334
RecomposeRmsNorm,
3435
ReduceDynamicRange,
36+
Remove0DTensor,
3537
RemoveRedundancy,
3638
ReplaceArangeArgs,
3739
ReplaceIndexPutInput,
@@ -71,7 +73,7 @@ def get_capture_program_passes():
7173
# If a pass is activated, it will be executed by default.
7274
default_passes_and_setting = [
7375
(AnnotateQuantAttrs, True),
74-
(AnnotateStack, False),
76+
(AnnotateStack, True),
7577
(AnnotateUnbind, True),
7678
(ConvertBmmToMatmul, True),
7779
(ConvertConv1dToConv2d, True),
@@ -84,6 +86,7 @@ def get_capture_program_passes():
8486
(LayoutTransform, True),
8587
(RecomposePixelUnshuffle, True),
8688
(RecomposeRmsNorm, False),
89+
(Remove0DTensor, True),
8790
(RemoveRedundancy, True),
8891
(ReplaceIndexPutInput, True),
8992
(TagQuantIO, False),
@@ -176,7 +179,23 @@ def transform_for_to_edge_pipeline(
176179

177180
return exported_program
178181

182+
# Before quantizer
183+
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
184+
self.add_pass(ReduceDynamicRange())
185+
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
186+
self.add_pass(ReplaceArangeArgs())
187+
self.add_pass(DecomposeCDist())
188+
self.add_pass(DecomposeScaledDotProductAttention())
189+
self.add_pass(DecomposeSilu())
190+
self.add_pass(DecomposeEinsum())
191+
self.add_pass(DecomposeExpM1())
192+
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
193+
self.add_pass(ReplaceInfValues())
194+
self.add_pass(LiftConstantScalarOperands())
195+
return self._transform(graph_module)
196+
179197
def transform_for_export_pipeline(self, exported_program: ExportedProgram):
198+
self.add_pass(DecomposeCDist())
180199
self.add_pass(DecomposeScaledDotProductAttention())
181200
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
182201
self.add_pass(DecomposeExpM1())
@@ -191,16 +210,3 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
191210
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
192211
self.add_pass(FuseConsecutiveTranspose())
193212
return self._transform(exported_program.graph_module)
194-
195-
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
196-
self.add_pass(ReduceDynamicRange())
197-
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
198-
self.add_pass(ReplaceArangeArgs())
199-
self.add_pass(DecomposeScaledDotProductAttention())
200-
self.add_pass(DecomposeSilu())
201-
self.add_pass(DecomposeEinsum())
202-
self.add_pass(DecomposeExpM1())
203-
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
204-
self.add_pass(ReplaceInfValues())
205-
self.add_pass(LiftConstantScalarOperands())
206-
return self._transform(graph_module)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class Remove0DTensor(ExportPass):
13+
"""
14+
QNN does not allow 0D tensor, we remove the node that will output an 0D tensor.
15+
Before adding operations to the list of nodes to be removed, please ensure that it will not change the logic.
16+
"""
17+
18+
remove_ops = {
19+
exir_ops.edge.aten.select.int,
20+
exir_ops.edge.aten.select_copy.int,
21+
}
22+
23+
def __init__(self, quantization_capture=False) -> None:
24+
super().__init__()
25+
26+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
27+
graph = graph_module.graph
28+
for node in graph.nodes:
29+
if node.target in self.remove_ops and len(node.meta["val"].shape) == 0:
30+
for user_n in list(node.users.keys()):
31+
user_n.replace_input_with(node, node.args[0])
32+
graph.erase_node(node)
33+
34+
graph.eliminate_dead_code()
35+
graph_module.recompile()
36+
return PassResult(graph_module, True)

backends/qualcomm/partition/qnn_partitioner.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
not_supported_operator,
3535
to_be_implemented_operator,
3636
)
37-
from .utils import generate_qnn_executorch_option, get_skip_decomp_table
37+
from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table
3838

3939

4040
class QnnOperatorSupport(OperatorSupportBase):
@@ -181,5 +181,4 @@ def ops_to_not_decompose(
181181
self, ep: ExportedProgram
182182
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
183183
do_not_decompose = get_skip_decomp_table()
184-
185-
return do_not_decompose, None
184+
return (do_not_decompose, filter_fn)

backends/qualcomm/partition/utils.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,21 @@ def generate_qnn_executorch_option(
2424
return qnn_compile_spec_buffer
2525

2626

27+
# Logic to determine whether to skip decompose and has higher priority than get_skip_decomp_table()
28+
def filter_fn(node: torch.fx.Node) -> bool:
29+
# QNN does not support int32/int64 IO for the following OPs.
30+
potential_i32_i64_io_ops = [
31+
torch.ops.aten.stack.default,
32+
torch.ops.aten.unbind.int,
33+
]
34+
if node.target in potential_i32_i64_io_ops and node.meta["val"].dtype in [
35+
torch.int32,
36+
torch.int64,
37+
]:
38+
return False
39+
return True
40+
41+
2742
def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
2843
do_not_decompose = [
2944
torch.ops.aten.adaptive_avg_pool2d.default,
@@ -41,7 +56,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
4156
torch.ops.aten.stack.default,
4257
torch.ops.aten.upsample_bicubic2d.vec,
4358
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
44-
# torch.ops.aten.unbind.int,
59+
torch.ops.aten.unbind.int,
4560
torch.ops.pt2e_quant.quantize_affine.default,
4661
torch.ops.pt2e_quant.dequantize_affine.default,
4762
]

0 commit comments

Comments
 (0)