Skip to content

Commit ffb50fb

Browse files
YibinLiu666pytorchmergebot
authored andcommitted
[ONNX] Add onnx::Gelu support for version 20 (pytorch#128773)
Fixes pytorch#128772 Pull Request resolved: pytorch#128773 Approved by: https://github.com/justinchuby
1 parent 3397d5e commit ffb50fb

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

test/onnx/test_utility_funs.py

+4
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,8 @@ def forward(self, input, other):
13581358
iter = graph.nodes()
13591359
self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
13601360

1361+
# gelu is exported as onnx::Gelu for opset >= 20
1362+
@skipIfUnsupportedMaxOpsetVersion(19)
13611363
def test_custom_opsets_gelu(self):
13621364
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
13631365

@@ -1382,6 +1384,8 @@ def gelu(g, self, approximate):
13821384
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
13831385
self.assertEqual(graph.opset_import[1].version, 1)
13841386

1387+
# gelu is exported as onnx::Gelu for opset >= 20
1388+
@skipIfUnsupportedMaxOpsetVersion(19)
13851389
def test_register_aten_custom_op_symbolic(self):
13861390
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)
13871391

torch/onnx/symbolic_opset20.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# EDITING THIS FILE? READ THIS FIRST!
3333
# see Note [Edit Symbolic Files] in symbolic_helper.py
3434

35-
__all__ = ["_grid_sampler", "_affine_grid_generator"]
35+
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
3636

3737

3838
def convert_grid_sample_mode(mode_s):
@@ -84,3 +84,10 @@ def _affine_grid_generator(
8484
size,
8585
align_corners_i=int(align_corners),
8686
)
87+
88+
89+
@_onnx_symbolic("aten::gelu")
90+
@symbolic_helper.parse_args("v", "s")
91+
@_beartype.beartype
92+
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
93+
return g.op("Gelu", self, approximate_s=approximate)

0 commit comments

Comments
 (0)