Skip to content

Commit e6835ec

Browse files
committed
[cortex-m] initial commit
Pull Request resolved: #10265 Just a placeholder q/dq AoT ops in a new namespace with a test. ghstack-source-id: 279606183 @exported-using-ghexport Differential Revision: [D72987759](https://our.internmc.facebook.com/intern/diff/D72987759/)
1 parent b5e8ee9 commit e6835ec

File tree

7 files changed

+426
-0
lines changed

7 files changed

+426
-0
lines changed

backends/cortex_m/README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Cortex-M Backend
2+
3+
WIP. This is a temporary/placeholder backend for Cortex-M CPUs. It is not intended to be used in production, but rather as a proof of concept. Things will change without notice.

backends/cortex_m/ops/TARGETS

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
8+
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
9+
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")
10+
11+
oncall("executorch")
12+
13+
python_library(
14+
name = "ops",
15+
srcs = [
16+
"operators.py",
17+
],
18+
deps = [
19+
"fbcode//caffe2:torch",
20+
]
21+
)

backends/cortex_m/ops/operators.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import (
9+
ops as exir_ops,
10+
) # To provide the implementation of the operators
11+
from torch.library import impl, Library, register_fake
12+
13+
# New operator library with a custom namespace to allow fusion etc.
14+
lib = Library("cortex_m", "DEF")
15+
16+
###
17+
# dequantize_per_tensor
18+
###
19+
20+
lib.define(
21+
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
22+
)
23+
24+
lib.define(
25+
"quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
26+
)
27+
28+
29+
@register_fake("cortex_m::quantize_per_tensor")
30+
def quantize_per_tensor_meta(
31+
input: torch.Tensor,
32+
scale: float,
33+
zero_point: int,
34+
quant_min: int,
35+
quant_max: int,
36+
dtype: torch.dtype,
37+
) -> torch.Tensor:
38+
return torch.empty_like(input, dtype=dtype)
39+
40+
41+
@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd")
42+
def quantize_per_tensor_impl(
43+
input: torch.Tensor,
44+
scale: float,
45+
zero_point: int,
46+
quant_min: int,
47+
quant_max: int,
48+
dtype: torch.dtype,
49+
) -> torch.Tensor:
50+
"""
51+
The implementation of the quantize_per_tensor operator is the same as the
52+
quantize_per_tensor operator in the edge dialect.
53+
"""
54+
return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
55+
input, scale, zero_point, quant_min, quant_max, dtype
56+
)
57+
58+
59+
###
60+
# dequantize_per_tensor
61+
###
62+
63+
lib.define(
64+
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
65+
)
66+
lib.define(
67+
"dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
68+
)
69+
70+
71+
@register_fake("cortex_m::dequantize_per_tensor")
72+
def dequantize_per_tensor_meta(
73+
input: torch.Tensor,
74+
scale: float,
75+
zero_point: int,
76+
quant_min: int,
77+
quant_max: int,
78+
dtype: torch.dtype,
79+
) -> torch.Tensor:
80+
return torch.empty_like(input, dtype=torch.float)
81+
82+
83+
@impl(lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
84+
def dequantize_per_tensor_impl(
85+
input: torch.Tensor,
86+
scale: float,
87+
zero_point: int,
88+
quant_min: int,
89+
quant_max: int,
90+
dtype: torch.dtype,
91+
) -> torch.Tensor:
92+
"""
93+
The implementation of the dequantize_per_tensor operator is the same as the
94+
dequantize_per_tensor operator in the edge dialect.
95+
"""
96+
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
97+
input, scale, zero_point, quant_min, quant_max, dtype
98+
)

backends/cortex_m/passes/TARGETS

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
8+
9+
oncall("executorch")
10+
11+
python_library(
12+
name = "replace_quant_nodes_pass",
13+
srcs = ["replace_quant_nodes_pass.py"],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/exir:lib",
17+
"//executorch/exir:pass_base",
18+
"//executorch/exir/dialects:lib",
19+
"//executorch/backends/cortex_m/ops:ops",
20+
]
21+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, Dict, Tuple
8+
9+
import executorch.backends.cortex_m.ops.operators # noqa
10+
import torch
11+
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
14+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
15+
16+
17+
class ReplaceQuantNodesPass(ExportPass):
18+
"""
19+
Replace quantize and dequantize nodes with the corresponding
20+
cortex_m.quantize_per_tensor and cortex_m.dequantize_per_tensor nodes.
21+
"""
22+
23+
@staticmethod
24+
def _is_qualified_int8_node(args) -> bool:
25+
return (
26+
args[3] >= torch.iinfo(torch.int8).min # qmin
27+
and args[4] <= torch.iinfo(torch.int8).max # qmax
28+
and args[5] == torch.int8 # dtype
29+
)
30+
31+
def __init__(self):
32+
super().__init__()
33+
self.op_replacements = {
34+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: {
35+
"new_target": exir_ops.edge.cortex_m.quantize_per_tensor.default,
36+
"qualifier": self._is_qualified_int8_node,
37+
},
38+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: {
39+
"new_target": exir_ops.edge.cortex_m.dequantize_per_tensor.default,
40+
"qualifier": self._is_qualified_int8_node,
41+
},
42+
}
43+
44+
def call_operator(
45+
self,
46+
op: Callable[..., object],
47+
args: Tuple[object, ...],
48+
kwargs: Dict[str, object],
49+
meta: NodeMetadata,
50+
) -> ProxyValue:
51+
assert isinstance(
52+
op, EdgeOpOverload
53+
), "Op must be an EdgeOpOverload. Run this pass after to_edge()."
54+
55+
if op in self.op_replacements and self.op_replacements[op]["qualifier"](args):
56+
return super().call_operator(
57+
self.op_replacements[op]["new_target"],
58+
args,
59+
kwargs,
60+
meta,
61+
)
62+
return super().call_operator(op, args, kwargs, meta)

backends/cortex_m/test/TARGETS

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
8+
9+
python_unittest(
10+
name = "test_replace_quant_nodes",
11+
srcs = ["test_replace_quant_nodes.py"],
12+
deps = [
13+
"//pytorch/ao:torchao", # @manual
14+
"//caffe2:torch",
15+
"//executorch/backends/cortex_m/passes:replace_quant_nodes_pass",
16+
"//executorch/backends/cortex_m/ops:ops",
17+
],
18+
)

0 commit comments

Comments
 (0)