Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12855,8 +12855,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
Value index = rewriter.create<Torch::AtenArangeOp>(
loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);

// Set the current dimension to -1 for broadcasting
viewShapeInts[dim] = -1;
viewShapeInts[dim] = size;
viewShapeListElems[dim] = cstMinusOne;

Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(
Expand Down
40 changes: 40 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,46 @@ def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 12, 3))


class UnflattenIntDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, 12], torch.float32, True),
]
)
def forward(self, inputs):
return torch.ops.aten.unflatten(inputs, 1, [3, 4])


@register_test_case(module_factory=lambda: UnflattenIntDynamicModule())
def UnflattenIntDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 12))


class UnflattenIntDynamicWithInferredSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, 20], torch.float32, True),
]
)
def forward(self, inputs):
return torch.ops.aten.unflatten(inputs, 1, [4, -1])


@register_test_case(module_factory=lambda: UnflattenIntDynamicWithInferredSizeModule())
def UnflattenIntDynamicWithInferredSizeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 20))


# ==============================================================================


Expand Down
74 changes: 74 additions & 0 deletions test/Conversion/TorchToLinalg/unflatten.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.unflatten.int$static
// CHECK: torch_c.to_builtin_tensor
// CHECK: tensor.expand_shape
// CHECK: torch_c.from_builtin_tensor
func.func @torch.aten.unflatten.int$static(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,2,3,4],f32>
return %1 : !torch.vtensor<[2,2,3,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.unflatten.int$negative_dim
// CHECK: torch_c.to_builtin_tensor
// CHECK: tensor.expand_shape
// CHECK: torch_c.from_builtin_tensor
func.func @torch.aten.unflatten.int$negative_dim(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> {
%int-2 = torch.constant.int -2
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int-2, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,2,3,4],f32>
return %1 : !torch.vtensor<[2,2,3,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.unflatten.int$inferred_size
// CHECK: torch_c.to_builtin_tensor
// CHECK: tensor.expand_shape
// CHECK: torch_c.from_builtin_tensor
func.func @torch.aten.unflatten.int$inferred_size(%arg0: !torch.vtensor<[3,12],f32>) -> !torch.vtensor<[3,2,6],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[3,12],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[3,2,6],f32>
return %1 : !torch.vtensor<[3,2,6],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.unflatten.int$dynamic_input
// CHECK: torch_c.to_builtin_tensor
// CHECK: tensor.expand_shape
// CHECK: torch_c.from_builtin_tensor
func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f32>) -> !torch.vtensor<[?,2,3],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,2,3],f32>
return %1 : !torch.vtensor<[?,2,3],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.unflatten.int$two_dynamic_dims
// CHECK: torch_c.to_builtin_tensor
// CHECK: tensor.from_elements
// CHECK: tensor.reshape
// CHECK: torch_c.from_builtin_tensor
func.func @torch.aten.unflatten.int$two_dynamic_dims(%arg0: !torch.vtensor<[?,12],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int1 = torch.constant.int 1
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,12],f32>, !torch.int -> !torch.int
%0 = torch.prim.ListConstruct %2, %2 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,12],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
return %1 : !torch.vtensor<[?,?,?],f32>
}
16 changes: 16 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,19 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
return %0 : !torch.vtensor<[1,8,4,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.as_strided$static_shapes
func.func @torch.aten.as_strided$static_shapes(%arg0: !torch.vtensor<[4,8],f32>) -> !torch.vtensor<[2,3],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
%size = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[2,1],si64>
// CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[1,3],si64>
%0 = torch.aten.as_strided %arg0, %size, %stride, %int0 : !torch.vtensor<[4,8],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[2,3],f32>
return %0 : !torch.vtensor<[2,3],f32>
}