Skip to content
58 changes: 58 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,63 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern<AtenAtleast1dOp> {
};
} // namespace

// Decompose 'aten.outer' into 'aten.unsqueeze', 'aten.matmul'

namespace {
class DecomposeAtenOuterOp : public OpRewritePattern<AtenOuterOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenOuterOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value input = op.getSelf();
Value vec2 = op.getVec2();
Type opType = op.getType();

auto inputType = cast<BaseTensorType>(input.getType());
auto vec2Type = cast<BaseTensorType>(vec2.getType());

// Check if tensors not empty
if (!inputType.hasSizes() || !vec2Type.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "Inputs must be ranked tensors for aten.outer");
}

// Check if both tensors are 1-dimensional
SmallVector<int64_t> inputShape(inputType.getSizes());
SmallVector<int64_t> vec2Shape(vec2Type.getSizes());

if (inputShape.size() != 1 || vec2Shape.size() != 1) {
return rewriter.notifyMatchFailure(
op, "Inputs must be 1-dimensional vectors for aten.outer");
}

Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1)); // Dimension index
inputShape.push_back(1);
Type inputMatrixType = inputType.getWithSizesAndDtype(
inputShape, inputType.getOptionalDtype());

Value inputMatrix =
rewriter.create<AtenUnsqueezeOp>(loc, inputMatrixType, input, one);

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
vec2Shape.insert(vec2Shape.begin(), 1);
Type vec2MatrixType =
vec2Type.getWithSizesAndDtype(vec2Shape, vec2Type.getOptionalDtype());

Value vec2Matrix =
rewriter.create<AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);

rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, opType, inputMatrix,
vec2Matrix);
return success();
}
};
} // namespace

namespace {
// Decompose aten.atleast_2d into: aten.reshape. See
// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604
Expand Down Expand Up @@ -12874,6 +12931,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOuterOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenSoftshrinkOp>();
target.addIllegalOp<AtenEmptyLikeOp>();
target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenOuterOp>();
target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenHstackOp>();
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4021,6 +4021,7 @@
}

ONNX_TOSA_XFAIL_SET = {
"AtenOuter_basic",
"AtenFftRfft2DLastDim_basic",
"AtenFftRfft2DMiddleDim_basic",
"AtenStftCenter1D_basic",
Expand Down
58 changes: 58 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,61 @@ def forward(self, a, b):
@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic())
def AtenLinalgCrossDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))


# ==============================================================================


class AtenOuter(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3], torch.float32, True),
([3], torch.float32, True),
]
)
def forward(self, lhs, rhs):
return torch.outer(lhs, rhs)


@register_test_case(module_factory=lambda: AtenOuter())
def AtenOuter_basic(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(2))


# ==============================================================================


class AtenOuterDynamic(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, lhs, rhs):
return torch.outer(lhs, rhs)


@register_test_case(module_factory=lambda: AtenOuterDynamic())
def AtenOuterDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(5), tu.rand(5))


@register_test_case(module_factory=lambda: AtenOuterDynamic())
def AtenOuterDynamic_lhs_larger(module, tu: TestUtils):
module.forward(tu.rand(7), tu.rand(4))


@register_test_case(module_factory=lambda: AtenOuterDynamic())
def AtenOuterDynamic_rhs_larger(module, tu: TestUtils):
module.forward(tu.rand(2), tu.rand(6))
Loading