Skip to content

Commit 63ab68e

Browse files
[TOSA] Fix output size calculation for pool ops
TOSA requires (inputDim + padBefore + padAfter - kernel) to be fully divisible by stride. This update adds pad and input size modifications for pooling ops (AvgPool2d and MaxPool2d) to satisfy that requirement by TOSA. Signed-off-by: Justin Ngo <[email protected]> Change-Id: Iab4021f2dda87cb87e54e4e9ca20bd3688dc1c50
1 parent 7190726 commit 63ab68e

File tree

4 files changed

+593
-43
lines changed

4 files changed

+593
-43
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 94 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5766,19 +5766,65 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57665766
op, "Unimplemented pooling input parsing function");
57675767
}
57685768

5769-
static int64_t getOutputDim(int64_t inputDim, int64_t kernelDim,
5770-
int64_t stride, int64_t padBefore,
5771-
int64_t padAfter, int64_t dilation,
5769+
static int64_t getOutputDim(PatternRewriter &rewriter, Value &input,
5770+
Location loc, int64_t inputRank,
5771+
ArrayRef<int64_t> inputShape, Type inputElemTy,
5772+
int64_t dimIndex, int64_t kernelDim,
5773+
int64_t stride, int64_t &padBefore,
5774+
int64_t &padAfter, int64_t dilation,
57725775
bool ceilMode = false) {
5776+
int64_t inputDim = inputShape[dimIndex];
57735777
if (inputDim == kUnknownSize) {
57745778
return kUnknownSize;
57755779
} else {
5780+
// TOSA requires dimSize = inputDim + padBefore + padAfter - kernelDim to
5781+
// be fully divisible by stride. We would have to modify the after pad
5782+
// and/ input in order to achieve that.
5783+
// Note: The dimSize calculation below is the same as TOSA's dimSize
5784+
// calculation when dilation = 1, which is the only dilation value that
5785+
// TOSA supports for MaxPool2d (AvgPool2d doesn't have dilation so the
5786+
// value will be defaulted to 1)
57765787
int64_t dimSize =
57775788
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
5789+
int64_t remainderDim = dimSize % stride;
5790+
5791+
// When PyTorch uses floor mode for output dim calculation, to achieve the
5792+
// TOSA's divisibility requirement, we will remove the unused after pad
5793+
// and slice the unused input rows/columns.
5794+
if (!ceilMode && (remainderDim != 0)) {
5795+
if (remainderDim > padAfter) {
5796+
SmallVector<int64_t> startSlice(inputRank, 0);
5797+
SmallVector<int64_t> sizeSlice(
5798+
dyn_cast<TensorType>(input.getType()).getShape());
5799+
sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter);
5800+
input = rewriter.create<tosa::SliceOp>(
5801+
loc, RankedTensorType::get(sizeSlice, inputElemTy), input,
5802+
tosa::getTosaConstShape(rewriter, loc, startSlice),
5803+
tosa::getTosaConstShape(rewriter, loc, sizeSlice));
5804+
dimSize = dimSize - padAfter;
5805+
padAfter = 0;
5806+
} else {
5807+
dimSize = dimSize - padAfter;
5808+
padAfter = padAfter - remainderDim;
5809+
dimSize = dimSize + padAfter;
5810+
}
5811+
}
5812+
57785813
int64_t outputDim = dimSize / stride + 1;
5779-
if (ceilMode && (dimSize % stride != 0) &&
5780-
(outputDim * stride < inputDim + padBefore))
5781-
outputDim++;
5814+
5815+
// When PyTorch uses ceil mode for output dim calculation, to achieve the
5816+
// TOSA's divisibility requirement, we will remove the unused after pad
5817+
// or add more after pad in case the remainder is more than the after pad
5818+
if (ceilMode && (remainderDim != 0)) {
5819+
if (remainderDim < padAfter) {
5820+
padAfter = padAfter - remainderDim;
5821+
} else {
5822+
padAfter = padAfter + (stride - remainderDim);
5823+
}
5824+
5825+
if (outputDim * stride < inputDim + padBefore)
5826+
outputDim++;
5827+
}
57825828
return outputDim;
57835829
}
57845830
}
@@ -6016,25 +6062,24 @@ class ConvertAtenAdaptivePoolingOp
60166062

60176063
template <typename AtenOpT, typename tosaOp>
60186064
static Type getOutputTypeForNonAdaptivePoolingOp(
6065+
PatternRewriter &rewriter, Operation *op, Value &input,
60196066
RankedTensorType inputTy, SmallVectorImpl<int64_t> &kernelSize,
60206067
SmallVectorImpl<int64_t> &strideArray, SmallVectorImpl<int64_t> &padArray,
60216068
SmallVectorImpl<int64_t> &dilationArray, bool ceilMode = false) {
60226069
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
60236070
auto inputRank = inputTy.getRank();
60246071
auto inputElemTy = inputTy.getElementType();
60256072

6073+
// PyTorch uses xCHW, so Height dim index is rank-2 and Width dim index is
6074+
// rank-1
60266075
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
6027-
inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0],
6028-
padArray[0], dilationArray[0], ceilMode);
6076+
rewriter, input, op->getLoc(), inputRank, inputShape, inputElemTy,
6077+
/*dimIndex=*/inputRank - 2, kernelSize[0], strideArray[0], padArray[0],
6078+
padArray[1], dilationArray[0], ceilMode);
60296079
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
6030-
inputShape[inputRank - 1], kernelSize[1], strideArray[1], padArray[1],
6031-
padArray[1], dilationArray[1], ceilMode);
6032-
padArray[0] = (outputHDim - 1) * strideArray[0] +
6033-
dilationArray[0] * kernelSize[0] - dilationArray[0] + 1 -
6034-
padArray[0] * 2 - inputShape[inputRank - 2];
6035-
padArray[1] = (outputWDim - 1) * strideArray[1] +
6036-
dilationArray[0] * kernelSize[1] - dilationArray[0] + 1 -
6037-
padArray[1] * 2 - inputShape[inputRank - 1];
6080+
rewriter, input, op->getLoc(), inputRank, inputShape, inputElemTy,
6081+
/*dimIndex=*/inputRank - 1, kernelSize[1], strideArray[1], padArray[2],
6082+
padArray[3], dilationArray[1], ceilMode);
60386083
SmallVector<int64_t> outputShape;
60396084
if (inputRank > 3)
60406085
outputShape.push_back(inputShape[0]);
@@ -6065,7 +6110,7 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
60656110
// vector. Also, gets the output type for the pooling op.
60666111
template <typename AtenOpT, typename tosaOp>
60676112
static LogicalResult getOutputTypeAndPoolingParameters(
6068-
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
6113+
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
60696114
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
60706115
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
60716116
DenseI64ArrayAttr &pad) {
@@ -6138,10 +6183,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61386183

61396184
expandPoolParams(op, dilationArray, 1);
61406185
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6141-
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6142-
ceilMode);
6143-
padArr[1] = padArr[1] + paddingInts[0];
6144-
padArr[3] = padArr[3] + paddingInts[1];
6186+
rewriter, op, inputXchw, inputTy, kernelSizeInts, strideInts, padArr,
6187+
dilationArray, ceilMode);
61456188
pad = rewriter.getDenseI64ArrayAttr(
61466189
{padArr[0], padArr[1], padArr[2], padArr[3]});
61476190
return success();
@@ -6157,6 +6200,7 @@ class ConvertAtenMaxPool2dOp
61576200
DenseI64ArrayAttr &kernel,
61586201
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
61596202
Type &outputTy) const override {
6203+
auto self = adaptor.getSelf();
61606204
SmallVector<int64_t, 2> dilationArray;
61616205
if (!matchPattern(op.getDilation(),
61626206
m_TorchListOfConstantInts(dilationArray)))
@@ -6169,14 +6213,13 @@ class ConvertAtenMaxPool2dOp
61696213

61706214
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
61716215
tosa::MaxPool2dOp>(
6172-
op, rewriter, adaptor.getSelf(), dilationArray, outputTy, kernel,
6173-
stride, pad)))
6216+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
61746217
return rewriter.notifyMatchFailure(
61756218
op, "invalid pooling parameters or input type");
61766219

61776220
// Transpose to xHWC
61786221
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
6179-
transposePoolingInputToHwc(op, rewriter, adaptor.getSelf());
6222+
transposePoolingInputToHwc(op, rewriter, self);
61806223

61816224
return success();
61826225
}
@@ -6210,11 +6253,15 @@ class ConvertAtenMaxPool1dOp
62106253
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
62116254
SmallVector<int64_t> rank4Shape(selfShape);
62126255
rank4Shape.push_back(1);
6213-
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
6214-
op->getLoc(),
6215-
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6216-
selfTy.getElementType()),
6217-
self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape));
6256+
auto reshapedSelf =
6257+
rewriter
6258+
.create<tosa::ReshapeOp>(
6259+
op->getLoc(),
6260+
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6261+
selfTy.getElementType()),
6262+
self,
6263+
tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape))
6264+
.getResult();
62186265

62196266
SmallVector<int64_t> dilationArray;
62206267
if (!matchPattern(op.getDilation(),
@@ -6231,14 +6278,14 @@ class ConvertAtenMaxPool1dOp
62316278

62326279
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
62336280
tosa::MaxPool2dOp>(
6234-
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
6235-
kernel, stride, pad)))
6281+
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6282+
pad)))
62366283
return rewriter.notifyMatchFailure(
62376284
op, "invalid pooling parameters or input type");
62386285

62396286
// Transpose to xHWC
62406287
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
6241-
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
6288+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
62426289

62436290
return success();
62446291
}
@@ -6254,6 +6301,7 @@ class ConvertAtenAvgPool2dOp
62546301
DenseI64ArrayAttr &kernel,
62556302
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
62566303
Type &outputTy) const override {
6304+
auto self = adaptor.getSelf();
62576305

62586306
// Currently, we can not represent `divisor_override` with the existing TOSA
62596307
// AvgPool2d specification. Without the below check, we produce silent wrong
@@ -6267,14 +6315,13 @@ class ConvertAtenAvgPool2dOp
62676315
SmallVector<int64_t, 2> dilationArray{1, 1};
62686316
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
62696317
tosa::AvgPool2dOp>(
6270-
op, rewriter, adaptor.getSelf(), dilationArray, outputTy, kernel,
6271-
stride, pad)))
6318+
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
62726319
return rewriter.notifyMatchFailure(
62736320
op, "invalid pooling parameters or input type");
62746321

62756322
// Transpose to xHWC
62766323
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6277-
transposePoolingInputToHwc(op, rewriter, adaptor.getSelf());
6324+
transposePoolingInputToHwc(op, rewriter, self);
62786325

62796326
return success();
62806327
}
@@ -6308,23 +6355,27 @@ class ConvertAtenAvgPool1dOp
63086355
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
63096356
SmallVector<int64_t> rank4Shape(selfShape);
63106357
rank4Shape.push_back(1);
6311-
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
6312-
op->getLoc(),
6313-
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6314-
selfTy.getElementType()),
6315-
self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape));
6358+
auto reshapedSelf =
6359+
rewriter
6360+
.create<tosa::ReshapeOp>(
6361+
op->getLoc(),
6362+
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
6363+
selfTy.getElementType()),
6364+
self,
6365+
tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape))
6366+
.getResult();
63166367

63176368
SmallVector<int64_t, 2> dilationArray{1, 1};
63186369
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
63196370
tosa::AvgPool2dOp>(
6320-
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
6321-
kernel, stride, pad)))
6371+
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
6372+
pad)))
63226373
return rewriter.notifyMatchFailure(
63236374
op, "invalid pooling parameters or input type");
63246375

63256376
// Transpose to xHWC
63266377
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6327-
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
6378+
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
63286379

63296380
return success();
63306381
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,8 @@
964964
"AtenSymConstrainRangeForSize_basic",
965965
"Aten_AssertScalar_basic",
966966
"NativeGroupNormModule_basic",
967+
"AvgPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
968+
"MaxPool2dCeilModeFullDimIndivisibleByStrideModule_basic",
967969
}
968970

969971
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -3300,6 +3302,9 @@
33003302
"Aten_AssertScalar_basic",
33013303
# JIT session error: Symbols not found: [ memrefCopy ]
33023304
"SplitWithSizes_Module_basic",
3305+
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
3306+
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
3307+
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
33033308
}
33043309

33053310
if torch_version_for_comparison() < version.parse("2.3.0.dev"):

0 commit comments

Comments
 (0)