Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion backends/test/suite/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
UNSUPPORTED_PORTABLE_OPS = {
"aten::_embedding_bag",
"aten::_adaptive_avg_pool2d",
"aten::adaptive_max_pool2d",
"aten::median",
"aten::median.dim",
"aten::round.decimals",
Expand All @@ -34,6 +35,7 @@
TestResult,
)
from executorch.exir import EdgeProgramManager
from executorch.exir.dialects._ops import ops as exir_ops


# A list of all runnable test suites and the corresponding python package.
Expand All @@ -43,6 +45,24 @@
}


def _graph_has_unsupported_patterns(program: torch.export.ExportedProgram) -> bool:
# Returns true if the model contains patterns that will fail when running on the ET
# portable kernel library.

# Check for 3d convolutions. All convs (1d, 2d, 3d) use the same op, so we need to look at
# the input meta to determine the rank.
for node in program.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.convolution.default
):
in_rank = node.args[0].meta["val"].dim()
if in_rank != 4:
return True

return False


def _get_test_seed(test_base_name: str) -> int:
# Set the seed based on the test base name to give consistent inputs between backends. Add the
# run seed to allow for reproducible results, but still allow for run-to-run variation.
Expand Down Expand Up @@ -162,7 +182,7 @@ def build_result(
# Check if any undelegated ops are in the unsupported ops set.
has_unsupported_ops = any(
op in UNSUPPORTED_PORTABLE_OPS for op in undelegated_op_counts.keys()
)
) or _graph_has_unsupported_patterns(edge_manager._etrecord.edge_dialect_program)

# Skip the test if there are unsupported portable ops remaining.
if has_unsupported_ops:
Expand Down
Loading