Skip to content

Commit 4f572c5

Browse files
authored
[MLIR][TORCH] Support unknown sized input during decomposition of aten.as_strided op (#4303)
#4269 adds the e2e support for the aten.as_strided op by decomposing it into a series of other torch operations. This change extents #4269 to relax a restrictive check and support unknown sized inputs during decomposition
1 parent e8733c7 commit 4f572c5

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12664,8 +12664,8 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
1266412664
Value input = op.getSelf();
1266512665
auto inputType = dyn_cast<BaseTensorType>(input.getType());
1266612666

12667-
if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
12668-
return rewriter.notifyMatchFailure(op, "input must have known sizes");
12667+
if (!inputType || !inputType.hasSizes())
12668+
return rewriter.notifyMatchFailure(op, "input must have sizes");
1266912669

1267012670
SmallVector<int64_t> sizesInts;
1267112671
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts)))
@@ -12695,8 +12695,13 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
1269512695
// If the input is not a 1-d tensor, we need to flatten it
1269612696
// to a 1D tensor before applying the strided indexing.
1269712697
int64_t flattenedInputSize = 1;
12698-
for (int64_t size : inputSizes)
12698+
for (int64_t size : inputSizes) {
12699+
if (size == kUnknownSize) {
12700+
flattenedInputSize = kUnknownSize;
12701+
break;
12702+
}
1269912703
flattenedInputSize *= size;
12704+
}
1270012705

1270112706
auto flattenedInputTy =
1270212707
cast<BaseTensorType>(inputType.getWithSizesAndDtype(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,7 @@
984984
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
985985
"AtenAsStridedModule_basic",
986986
"AtenAsStridedNoStorageOffsetModule_basic",
987+
"AtenAsStridedUnknownSizeModule_basic",
987988
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>'
988989
"ElementwiseClampMaxModule_bfloat16",
989990
"ElementwiseClampMinModule_bfloat16",
@@ -3991,6 +3992,7 @@
39913992
"ReplicationPad3dModuleSingleIntPad_basic",
39923993
"AtenAsStridedModule_basic",
39933994
"AtenAsStridedNoStorageOffsetModule_basic",
3995+
"AtenAsStridedUnknownSizeModule_basic",
39943996
"ChunkListUnpackDynamic_Module_basic",
39953997
"ChunkListUnpackUnevenDynamic_Module_basic",
39963998
"ChunkListUnpackUneven_Module_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6974,3 +6974,23 @@ def forward(self, x):
69746974
@register_test_case(module_factory=lambda: AtenAsStridedNoStorageOffsetModule())
69756975
def AtenAsStridedNoStorageOffsetModule_basic(module, tu: TestUtils):
69766976
module.forward(torch.randn(12, 13))
6977+
6978+
6979+
class AtenAsStridedUnknownSizeModule(torch.nn.Module):
6980+
def __init__(self):
6981+
super().__init__()
6982+
6983+
@export
6984+
@annotate_args(
6985+
[
6986+
None,
6987+
([-1, 13], torch.float32, True),
6988+
]
6989+
)
6990+
def forward(self, x):
6991+
return torch.ops.aten.as_strided(x, size=(3, 4), stride=(2, 2))
6992+
6993+
6994+
@register_test_case(module_factory=lambda: AtenAsStridedUnknownSizeModule())
6995+
def AtenAsStridedUnknownSizeModule_basic(module, tu: TestUtils):
6996+
module.forward(torch.randn(12, 13))

0 commit comments

Comments
 (0)