Skip to content

Commit

Permalink
Fix shape infer of onnx GroupNorm (#23477)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Fix shape infer of onnx GroupNorm.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Unable to run shape inference for onnx `GroupNorm`.


[model.onnx](https://raw.githubusercontent.com/onnx/onnx/refs/heads/main/onnx/backend/test/data/node/test_group_normalization_example/model.onnx)

> python
D:\source\cognition\onnxruntime\onnxruntime\python\tools\symbolic_shape_infer.py
--input model.onnx
Traceback (most recent call last):
File
"D:\source\cognition\onnxruntime\onnxruntime\python\tools\symbolic_shape_infer.py",
line 2999, in <module>
    out_mp = SymbolicShapeInference.infer_shapes(
File
"D:\source\cognition\onnxruntime\onnxruntime\python\tools\symbolic_shape_infer.py",
line 2935, in infer_shapes
    raise Exception("Incomplete symbolic shape inference")
  • Loading branch information
toothache authored Jan 26, 2025
1 parent 1fc9c48 commit 97c2bbe
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"GemmFastGelu": self._infer_GemmFastGelu,
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"GroupNormalization": self._infer_GroupNorm,
"GroupQueryAttention": self._infer_GroupQueryAttention,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
Expand Down Expand Up @@ -474,6 +475,7 @@ def _onnx_infer_single_node(self, node):
"PythonOp",
"MultiHeadAttention",
"GroupNorm",
"GroupNormalization",
"GroupQueryAttention",
"SparseAttention",
"SkipGroupNorm",
Expand Down

0 comments on commit 97c2bbe

Please sign in to comment.