Skip to content

Milestone3.1: Support tanh op in XNNPACK backend #11364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@
op_static_constant_pad,
op_static_resize_bilinear_2d,
op_sub,
op_tanh,
op_to_copy,
)
52 changes: 52 additions & 0 deletions backends/xnnpack/operators/op_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.backends.xnnpack.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNGraph,
XNNTanh,
XNode,
)
from executorch.backends.xnnpack.utils.utils import get_input_node


@register_node_visitor
class TanhVisitor(NodeVisitor):
target = "aten.tanh.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
xnn_graph: XNNGraph,
vals_to_ids: Dict[torch.fx.Node, int],
debug_handle: int,
) -> None:
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)

# input
input_id = vals_to_ids[get_input_node(node, 0)]

# output
output_id = vals_to_ids[node]

ser_node = XNode(
xnode_union=XNNTanh(
input_id=input_id,
output_id=output_id,
flags=0,
),
debug_handle=debug_handle,
)
xnn_graph.xnodes.append(ser_node)
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
SoftmaxConfig,
SquareRootConfig,
SubConfig,
TanhConfig,
UpsampleBilinear2dConfig,
)
from executorch.backends.xnnpack.partition.config.node_configs import (
Expand Down Expand Up @@ -99,6 +100,7 @@
PreluConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
TanhConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
Expand Down
7 changes: 7 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class TanhConfig(GenericNodePartitionerConfig):
target_name = "tanh.default"

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class MeanDimConfig(GenericNodePartitionerConfig):
target_name = "mean.dim"

Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.tanh.default,
]

SUPPORTED_MODULES = [
Expand Down
31 changes: 31 additions & 0 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,36 @@ Error defineGeluNode(
return Error::Ok;
}

/*
Define serialized tanh node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
tensor value
*/
Error defineTanhNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNTanh();

xnn_status status = xnn_define_tanh(
subgraph_ptr,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create tanh node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Define serialized ceiling node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
Expand Down Expand Up @@ -2078,6 +2108,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(Hardswish)
_DEFINE(LeakyReLU)
_DEFINE(Log)
_DEFINE(Tanh)
_DEFINE(Maximum)
_DEFINE(Negate)
_DEFINE(Square)
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/runtime_schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ union XNodeUnion {
XNNReciprocalSquareRoot: _XNNNode1x1,
XNNLog: _XNNNode1x1,
XNNGelu: _XNNNode1x1,
XNNTanh: _XNNNode1x1,
}

union XValueUnion {
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ union XNodeUnion {
XNNReciprocalSquareRoot: _XNNNode1x1,
XNNLog: _XNNNode1x1,
XNNGelu: _XNNNode1x1,
XNNTanh: _XNNNode1x1,
}

union XValueUnion {
Expand Down
6 changes: 6 additions & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ class XNNLog(XNNNode1x1):
pass


@dataclass
class XNNTanh(XNNNode1x1):
pass


@dataclass
class XNNMaximum(XNNNode2x1):
pass
Expand Down Expand Up @@ -391,6 +396,7 @@ class XNNScaledDotProductAttention:
XNNReciprocalSquareRoot,
XNNLog,
XNNGelu,
XNNTanh,
]


Expand Down
43 changes: 43 additions & 0 deletions backends/xnnpack/test/ops/test_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester


class TestTanh(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

class Tanh(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tanh(x)

def run_tanh_test(self, inputs):
(
Tester(self.Tanh(), inputs)
.export()
.check_count({"torch.ops.aten.tanh.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp16_tanh(self):
inputs = (torch.randn(20).to(torch.float16),)
self.run_tanh_test(inputs)

def test_fp32_tanh(self):
inputs = (torch.randn(20),)
self.run_tanh_test(inputs)
Loading