Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 101 additions & 17 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1391,9 +1391,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
return success();
}

if (numSpatialDims != 2)
if (numSpatialDims != 2 && numSpatialDims != 3)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D and 2D grouped convolution supported");
op, "unimplemented: only 2D and 3D grouped convolution supported");
if (numSpatialDims == 3 && inputZp) {
return rewriter.notifyMatchFailure(
op, "unimplemented: quantized 3D grouped convolution not supported");
}

// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
Expand Down Expand Up @@ -1435,21 +1439,101 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
weight = transposed ? weight : expandWeight(weight);
auto expandOutputTensor = expandGroups(outputTensor, 1);

// TODO: add 1D and 3D case
if (!inputZp) {
conv = rewriter
.create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weight},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);
} else {
conv = rewriter
.create<linalg::Conv2DNgchwGfchwQOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);
if (numSpatialDims == 2) {
// 2D grouped convolution
if (!inputZp) {
conv =
rewriter
.create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weight},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);
} else {
conv =
rewriter
.create<linalg::Conv2DNgchwGfchwQOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);
}
} else if (numSpatialDims == 3) {
// MLIR does not have a named 3D grouped convolution op, so we use
// linalg.generic instead.
AffineExpr d0, d1, d2, d3, d4, d5, d6, d7, d8, d9;
bindDims(context, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9);

SmallVector<AffineExpr> inputExprs = {
d0, // N
d1, // G
d6, // C/G
d3 * strideInts[0] + d7 * dilationInts[0], // D
d4 * strideInts[1] + d8 * dilationInts[1], // H
d5 * strideInts[2] + d9 * dilationInts[2] // W
};

SmallVector<AffineExpr> weightExprs = {
d1, // G
d2, // F/G
d6, // C/G
d7, // KD
d8, // KH
d9 // KW
};

SmallVector<AffineExpr> outputExprs = {
d0, // N
d1, // G
d2, // F/G
d3, // OD
d4, // OH
d5, // OW
};

SmallVector<AffineMap> indexingMaps = {
AffineMap::get(10, 0, inputExprs, rewriter.getContext()),
AffineMap::get(10, 0, weightExprs, rewriter.getContext()),
AffineMap::get(10, 0, outputExprs, rewriter.getContext())};

SmallVector<utils::IteratorType> iteratorTypes = {
utils::IteratorType::parallel, // N
utils::IteratorType::parallel, // G
utils::IteratorType::parallel, // F/G
utils::IteratorType::parallel, // OD
utils::IteratorType::parallel, // OH
utils::IteratorType::parallel, // OW
utils::IteratorType::reduction, // C/G
utils::IteratorType::reduction, // KD
utils::IteratorType::reduction, // KH
utils::IteratorType::reduction // KW
};

conv =
rewriter
.create<linalg::GenericOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weight},
expandOutputTensor.getResult(), indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0];
Value weight = args[1];
Value output = args[2];

// Convert input and weight to accumulator type if needed
Type accType = output.getType();
if (input.getType() != accType) {
input = b.create<arith::ExtFOp>(loc, accType, input);
}
if (weight.getType() != accType) {
weight = b.create<arith::ExtFOp>(loc, accType, weight);
}

Value mul = b.create<arith::MulFOp>(loc, input, weight);
Value add = b.create<arith::AddFOp>(loc, mul, output);
b.create<linalg::YieldOp>(loc, add);
})
.getResult(0);
}
conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv,
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,6 +2914,9 @@
"Conv3dModule_basic",
"Conv3dWithSamePaddingModule_basic",
"Conv3dWithValidPaddingModule_basic",
"ConvolutionModule3DGroups_basic",
"ConvolutionModule3DGroupsStrided_basic",
"ConvolutionModule3DGroupsDilated_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
Expand Down Expand Up @@ -3721,6 +3724,9 @@
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"ConvolutionModule2DGroupedTranspose_basic",
"ConvolutionModule3DGroups_basic",
"ConvolutionModule3DGroupsStrided_basic",
"ConvolutionModule3DGroupsDilated_basic",
"CumsumInputDtypeInt32Module_basic",
"CumsumWithDtypeModule_basic",
"CumsumModule_basic",
Expand Down Expand Up @@ -4369,6 +4375,9 @@
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"ConvolutionModule2DGroupedTranspose_basic",
"ConvolutionModule3DGroups_basic",
"ConvolutionModule3DGroupsStrided_basic",
"ConvolutionModule3DGroupsDilated_basic",
"CopyModule_basic",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",
Expand Down
93 changes: 93 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,99 @@ def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 8, 3, 3))


class ConvolutionModule3DGroups(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(
inputVec,
weight,
bias=None,
stride=[1, 1, 1],
padding=[0, 0, 0],
dilation=[1, 1, 1],
transposed=False,
output_padding=[0, 0, 0],
groups=2,
)


@register_test_case(module_factory=lambda: ConvolutionModule3DGroups())
def ConvolutionModule3DGroups_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6, 6, 6), tu.rand(8, 2, 3, 3, 3))


class ConvolutionModule3DGroupsStrided(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(
inputVec,
weight,
bias=None,
stride=[2, 2, 2],
padding=[1, 1, 1],
dilation=[1, 1, 1],
transposed=False,
output_padding=[0, 0, 0],
groups=4,
)


@register_test_case(module_factory=lambda: ConvolutionModule3DGroupsStrided())
def ConvolutionModule3DGroupsStrided_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 8, 8, 8, 8), tu.rand(16, 2, 3, 3, 3))


class ConvolutionModule3DGroupsDilated(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(
inputVec,
weight,
bias=None,
stride=[1, 1, 1],
padding=[2, 2, 2],
dilation=[2, 2, 2],
transposed=False,
output_padding=[0, 0, 0],
groups=2,
)


@register_test_case(module_factory=lambda: ConvolutionModule3DGroupsDilated())
def ConvolutionModule3DGroupsDilated_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 8, 8, 8), tu.rand(8, 2, 3, 3, 3))


# ==============================================================================


Expand Down
Loading