@@ -5766,19 +5766,65 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
5766
5766
op, " Unimplemented pooling input parsing function" );
5767
5767
}
5768
5768
5769
- static int64_t getOutputDim (int64_t inputDim, int64_t kernelDim,
5770
- int64_t stride, int64_t padBefore,
5771
- int64_t padAfter, int64_t dilation,
5769
+ static int64_t getOutputDim (PatternRewriter &rewriter, Value &input,
5770
+ Location loc, int64_t inputRank,
5771
+ ArrayRef<int64_t > inputShape, Type inputElemTy,
5772
+ int64_t dimIndex, int64_t kernelDim,
5773
+ int64_t stride, int64_t &padBefore,
5774
+ int64_t &padAfter, int64_t dilation,
5772
5775
bool ceilMode = false ) {
5776
+ int64_t inputDim = inputShape[dimIndex];
5773
5777
if (inputDim == kUnknownSize ) {
5774
5778
return kUnknownSize ;
5775
5779
} else {
5780
+ // TOSA requires dimSize = inputDim + padBefore + padAfter - kernelDim to
5781
+ // be fully divisible by stride. We would have to modify the after pad
5782
+ // and/ input in order to achieve that.
5783
+ // Note: The dimSize calculation below is the same as TOSA's dimSize
5784
+ // calculation when dilation = 1, which is the only dilation value that
5785
+ // TOSA supports for MaxPool2d (AvgPool2d doesn't have dilation so the
5786
+ // value will be defaulted to 1)
5776
5787
int64_t dimSize =
5777
5788
inputDim + padBefore + padAfter - dilation * (kernelDim - 1 ) - 1 ;
5789
+ int64_t remainderDim = dimSize % stride;
5790
+
5791
+ // When PyTorch uses floor mode for output dim calculation, to achieve the
5792
+ // TOSA's divisibility requirement, we will remove the unused after pad
5793
+ // and slice the unused input rows/columns.
5794
+ if (!ceilMode && (remainderDim != 0 )) {
5795
+ if (remainderDim > padAfter) {
5796
+ SmallVector<int64_t > startSlice (inputRank, 0 );
5797
+ SmallVector<int64_t > sizeSlice (
5798
+ dyn_cast<TensorType>(input.getType ()).getShape ());
5799
+ sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter);
5800
+ input = rewriter.create <tosa::SliceOp>(
5801
+ loc, RankedTensorType::get (sizeSlice, inputElemTy), input,
5802
+ tosa::getTosaConstShape (rewriter, loc, startSlice),
5803
+ tosa::getTosaConstShape (rewriter, loc, sizeSlice));
5804
+ dimSize = dimSize - padAfter;
5805
+ padAfter = 0 ;
5806
+ } else {
5807
+ dimSize = dimSize - padAfter;
5808
+ padAfter = padAfter - remainderDim;
5809
+ dimSize = dimSize + padAfter;
5810
+ }
5811
+ }
5812
+
5778
5813
int64_t outputDim = dimSize / stride + 1 ;
5779
- if (ceilMode && (dimSize % stride != 0 ) &&
5780
- (outputDim * stride < inputDim + padBefore))
5781
- outputDim++;
5814
+
5815
+ // When PyTorch uses ceil mode for output dim calculation, to achieve the
5816
+ // TOSA's divisibility requirement, we will remove the unused after pad
5817
+ // or add more after pad in case the remainder is more than the after pad
5818
+ if (ceilMode && (remainderDim != 0 )) {
5819
+ if (remainderDim < padAfter) {
5820
+ padAfter = padAfter - remainderDim;
5821
+ } else {
5822
+ padAfter = padAfter + (stride - remainderDim);
5823
+ }
5824
+
5825
+ if (outputDim * stride < inputDim + padBefore)
5826
+ outputDim++;
5827
+ }
5782
5828
return outputDim;
5783
5829
}
5784
5830
}
@@ -6016,25 +6062,24 @@ class ConvertAtenAdaptivePoolingOp
6016
6062
6017
6063
template <typename AtenOpT, typename tosaOp>
6018
6064
static Type getOutputTypeForNonAdaptivePoolingOp (
6065
+ PatternRewriter &rewriter, Operation *op, Value &input,
6019
6066
RankedTensorType inputTy, SmallVectorImpl<int64_t > &kernelSize,
6020
6067
SmallVectorImpl<int64_t > &strideArray, SmallVectorImpl<int64_t > &padArray,
6021
6068
SmallVectorImpl<int64_t > &dilationArray, bool ceilMode = false ) {
6022
6069
auto inputShape = makeShapeTorchCompatible (inputTy.getShape ());
6023
6070
auto inputRank = inputTy.getRank ();
6024
6071
auto inputElemTy = inputTy.getElementType ();
6025
6072
6073
+ // PyTorch uses xCHW, so Height dim index is rank-2 and Width dim index is
6074
+ // rank-1
6026
6075
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6027
- inputShape[inputRank - 2 ], kernelSize[0 ], strideArray[0 ], padArray[0 ],
6028
- padArray[0 ], dilationArray[0 ], ceilMode);
6076
+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6077
+ /* dimIndex=*/ inputRank - 2 , kernelSize[0 ], strideArray[0 ], padArray[0 ],
6078
+ padArray[1 ], dilationArray[0 ], ceilMode);
6029
6079
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim (
6030
- inputShape[inputRank - 1 ], kernelSize[1 ], strideArray[1 ], padArray[1 ],
6031
- padArray[1 ], dilationArray[1 ], ceilMode);
6032
- padArray[0 ] = (outputHDim - 1 ) * strideArray[0 ] +
6033
- dilationArray[0 ] * kernelSize[0 ] - dilationArray[0 ] + 1 -
6034
- padArray[0 ] * 2 - inputShape[inputRank - 2 ];
6035
- padArray[1 ] = (outputWDim - 1 ) * strideArray[1 ] +
6036
- dilationArray[0 ] * kernelSize[1 ] - dilationArray[0 ] + 1 -
6037
- padArray[1 ] * 2 - inputShape[inputRank - 1 ];
6080
+ rewriter, input, op->getLoc (), inputRank, inputShape, inputElemTy,
6081
+ /* dimIndex=*/ inputRank - 1 , kernelSize[1 ], strideArray[1 ], padArray[2 ],
6082
+ padArray[3 ], dilationArray[1 ], ceilMode);
6038
6083
SmallVector<int64_t > outputShape;
6039
6084
if (inputRank > 3 )
6040
6085
outputShape.push_back (inputShape[0 ]);
@@ -6065,7 +6110,7 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> ¶ms,
6065
6110
// vector. Also, gets the output type for the pooling op.
6066
6111
template <typename AtenOpT, typename tosaOp>
6067
6112
static LogicalResult getOutputTypeAndPoolingParameters (
6068
- AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
6113
+ AtenOpT op, ConversionPatternRewriter &rewriter, Value & inputXchw,
6069
6114
SmallVectorImpl<int64_t > &dilationArray, Type &outputTy,
6070
6115
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
6071
6116
DenseI64ArrayAttr &pad) {
@@ -6138,10 +6183,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
6138
6183
6139
6184
expandPoolParams (op, dilationArray, 1 );
6140
6185
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
6141
- inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray,
6142
- ceilMode);
6143
- padArr[1 ] = padArr[1 ] + paddingInts[0 ];
6144
- padArr[3 ] = padArr[3 ] + paddingInts[1 ];
6186
+ rewriter, op, inputXchw, inputTy, kernelSizeInts, strideInts, padArr,
6187
+ dilationArray, ceilMode);
6145
6188
pad = rewriter.getDenseI64ArrayAttr (
6146
6189
{padArr[0 ], padArr[1 ], padArr[2 ], padArr[3 ]});
6147
6190
return success ();
@@ -6157,6 +6200,7 @@ class ConvertAtenMaxPool2dOp
6157
6200
DenseI64ArrayAttr &kernel,
6158
6201
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
6159
6202
Type &outputTy) const override {
6203
+ auto self = adaptor.getSelf ();
6160
6204
SmallVector<int64_t , 2 > dilationArray;
6161
6205
if (!matchPattern (op.getDilation (),
6162
6206
m_TorchListOfConstantInts (dilationArray)))
@@ -6169,14 +6213,13 @@ class ConvertAtenMaxPool2dOp
6169
6213
6170
6214
if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
6171
6215
tosa::MaxPool2dOp>(
6172
- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6173
- stride, pad)))
6216
+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6174
6217
return rewriter.notifyMatchFailure (
6175
6218
op, " invalid pooling parameters or input type" );
6176
6219
6177
6220
// Transpose to xHWC
6178
6221
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
6179
- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6222
+ transposePoolingInputToHwc (op, rewriter, self );
6180
6223
6181
6224
return success ();
6182
6225
}
@@ -6210,11 +6253,15 @@ class ConvertAtenMaxPool1dOp
6210
6253
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
6211
6254
SmallVector<int64_t > rank4Shape (selfShape);
6212
6255
rank4Shape.push_back (1 );
6213
- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6214
- op->getLoc (),
6215
- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6216
- selfTy.getElementType ()),
6217
- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6256
+ auto reshapedSelf =
6257
+ rewriter
6258
+ .create <tosa::ReshapeOp>(
6259
+ op->getLoc (),
6260
+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6261
+ selfTy.getElementType ()),
6262
+ self,
6263
+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6264
+ .getResult ();
6218
6265
6219
6266
SmallVector<int64_t > dilationArray;
6220
6267
if (!matchPattern (op.getDilation (),
@@ -6231,14 +6278,14 @@ class ConvertAtenMaxPool1dOp
6231
6278
6232
6279
if (failed (getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
6233
6280
tosa::MaxPool2dOp>(
6234
- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6235
- kernel, stride, pad)))
6281
+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6282
+ pad)))
6236
6283
return rewriter.notifyMatchFailure (
6237
6284
op, " invalid pooling parameters or input type" );
6238
6285
6239
6286
// Transpose to xHWC
6240
6287
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
6241
- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6288
+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6242
6289
6243
6290
return success ();
6244
6291
}
@@ -6254,6 +6301,7 @@ class ConvertAtenAvgPool2dOp
6254
6301
DenseI64ArrayAttr &kernel,
6255
6302
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
6256
6303
Type &outputTy) const override {
6304
+ auto self = adaptor.getSelf ();
6257
6305
6258
6306
// Currently, we can not represent `divisor_override` with the existing TOSA
6259
6307
// AvgPool2d specification. Without the below check, we produce silent wrong
@@ -6267,14 +6315,13 @@ class ConvertAtenAvgPool2dOp
6267
6315
SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6268
6316
if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
6269
6317
tosa::AvgPool2dOp>(
6270
- op, rewriter, adaptor.getSelf (), dilationArray, outputTy, kernel,
6271
- stride, pad)))
6318
+ op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
6272
6319
return rewriter.notifyMatchFailure (
6273
6320
op, " invalid pooling parameters or input type" );
6274
6321
6275
6322
// Transpose to xHWC
6276
6323
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
6277
- transposePoolingInputToHwc (op, rewriter, adaptor. getSelf () );
6324
+ transposePoolingInputToHwc (op, rewriter, self );
6278
6325
6279
6326
return success ();
6280
6327
}
@@ -6308,23 +6355,27 @@ class ConvertAtenAvgPool1dOp
6308
6355
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
6309
6356
SmallVector<int64_t > rank4Shape (selfShape);
6310
6357
rank4Shape.push_back (1 );
6311
- auto reshapedSelf = rewriter.create <tosa::ReshapeOp>(
6312
- op->getLoc (),
6313
- RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6314
- selfTy.getElementType ()),
6315
- self, tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape));
6358
+ auto reshapedSelf =
6359
+ rewriter
6360
+ .create <tosa::ReshapeOp>(
6361
+ op->getLoc (),
6362
+ RankedTensorType::get (makeShapeTorchCompatible (rank4Shape),
6363
+ selfTy.getElementType ()),
6364
+ self,
6365
+ tosa::getTosaConstShape (rewriter, op->getLoc (), rank4Shape))
6366
+ .getResult ();
6316
6367
6317
6368
SmallVector<int64_t , 2 > dilationArray{1 , 1 };
6318
6369
if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
6319
6370
tosa::AvgPool2dOp>(
6320
- op, rewriter, reshapedSelf. getResult () , dilationArray, outputTy,
6321
- kernel, stride, pad)))
6371
+ op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride ,
6372
+ pad)))
6322
6373
return rewriter.notifyMatchFailure (
6323
6374
op, " invalid pooling parameters or input type" );
6324
6375
6325
6376
// Transpose to xHWC
6326
6377
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
6327
- transposePoolingInputToHwc (op, rewriter, reshapedSelf. getResult () );
6378
+ transposePoolingInputToHwc (op, rewriter, reshapedSelf);
6328
6379
6329
6380
return success ();
6330
6381
}
0 commit comments