From 25a29cef311efea03d50cc715f5b20e685c16d1a Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 4 Mar 2025 11:34:23 -0600 Subject: [PATCH] [mlir][tosa] Switch zero point of avgpool2d to input variable type (#128983) This commit changes the TOSA operator AvgPool2d's zero point attributes to inputs to align with TOSA 1.0 spec. Signed-off-by: Luke Hutton Co-authored-by: Luke Hutton --- .../Dialect/Tosa/IR/TosaComplianceData.h.inc | 14 +- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 46 +++-- .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 1 + .../TosaToLinalg/TosaToLinalgNamed.cpp | 72 ++++--- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 180 ++++++++++-------- .../Transforms/TosaDecomposeDepthwise.cpp | 22 ++- .../Transforms/TosaDecomposeTransposeConv.cpp | 24 ++- .../Tosa/Transforms/TosaProfileCompliance.cpp | 2 + .../TosaToLinalg/tosa-to-linalg-invalid.mlir | 4 +- .../TosaToLinalg/tosa-to-linalg-named.mlir | 16 +- .../TosaToLinalg/tosa-to-linalg-pipeline.mlir | 10 +- mlir/test/Dialect/Tosa/availability.mlir | 4 +- mlir/test/Dialect/Tosa/invalid.mlir | 60 +++++- mlir/test/Dialect/Tosa/level_check.mlir | 52 ++--- mlir/test/Dialect/Tosa/ops.mlir | 24 ++- .../Dialect/Tosa/profile_all_unsupported.mlir | 4 +- .../Tosa/profile_pro_fp_unsupported.mlir | 4 +- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 20 +- 18 files changed, 355 insertions(+), 204 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index 2617a902c3a0d..a9b458acd87f2 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -5,9 +5,11 @@ profileComplianceMap = { {{{Profile::pro_int}, {{i8T, i32T}}}, {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}}, {"tosa.avg_pool2d", - {{{Profile::pro_int}, {{i8T, i32T, i8T}}}, + {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{fp16T, fp16T, fp16T, fp16T, fp16T}, + {fp16T, fp16T, fp16T, fp32T, fp16T}, + {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, {"tosa.conv2d", {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}}, {{Profile::pro_fp}, @@ -243,10 +245,10 @@ extensionComplianceMap = { {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}}, {{Extension::bf16}, {{bf16T, i32T}}}}}, {"tosa.avg_pool2d", - {{{Extension::int16}, {{i16T, i32T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}}, + {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, {"tosa.conv2d", {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}}, {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 85bd3fb1bb1cc..e0f2fd411bbe4 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { let arguments = (ins Tosa_Tensor4D:$input, + Tosa_ScalarIntOrFloatTensor:$input_zp, + Tosa_ScalarIntOrFloatTensor:$output_zp, Tosa_IntArrayAttr2:$kernel, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$pad, - TypeAttrOf:$acc_type, - OptionalAttr:$input_zp, - OptionalAttr:$output_zp + TypeAttrOf:$acc_type ); let results = (outs @@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { ]; let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; + + let extraClassDeclaration = [{ + FailureOr getInputZeroPoint(); + FailureOr getOutputZeroPoint(); + LogicalResult verifyInputZeroPoint(int64_t zp); + LogicalResult verifyOutputZeroPoint(int64_t zp); + }]; + let hasVerifier = 1; } @@ -116,8 +124,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Tosa_ScalarTensor:$input_zp, - Tosa_ScalarTensor:$weight_zp, + Tosa_ScalarIntOrFloatTensor:$input_zp, + Tosa_ScalarIntOrFloatTensor:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, @@ -136,8 +144,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { ]; let extraClassDeclaration = [{ - LogicalResult getInputZeroPoint(int64_t &zp); - LogicalResult getWeightZeroPoint(int64_t &zp); + FailureOr getInputZeroPoint(); + FailureOr getWeightZeroPoint(); LogicalResult verifyInputZeroPoint(int64_t zp); LogicalResult verifyWeightZeroPoint(int64_t zp); }]; @@ -161,8 +169,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { Tosa_Tensor5D:$input, TosaTensorRankOf<[Tosa_Weight], [5]>:$weight, Tosa_Tensor1D:$bias, - Tosa_ScalarTensor:$input_zp, - Tosa_ScalarTensor:$weight_zp, + Tosa_ScalarIntOrFloatTensor:$input_zp, + Tosa_ScalarIntOrFloatTensor:$weight_zp, Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, @@ -181,8 +189,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { ]; let extraClassDeclaration = [{ - LogicalResult getInputZeroPoint(int64_t &zp); - LogicalResult getWeightZeroPoint(int64_t &zp); + FailureOr getInputZeroPoint(); + FailureOr getWeightZeroPoint(); LogicalResult verifyInputZeroPoint(int64_t zp); LogicalResult verifyWeightZeroPoint(int64_t zp); }]; @@ -207,8 +215,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Tosa_ScalarTensor:$input_zp, - Tosa_ScalarTensor:$weight_zp, + Tosa_ScalarIntOrFloatTensor:$input_zp, + Tosa_ScalarIntOrFloatTensor:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, @@ -227,8 +235,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { ]; let extraClassDeclaration = [{ - LogicalResult getInputZeroPoint(int64_t &zp); - LogicalResult getWeightZeroPoint(int64_t &zp); + FailureOr getInputZeroPoint(); + FailureOr getWeightZeroPoint(); LogicalResult verifyInputZeroPoint(int64_t zp); LogicalResult verifyWeightZeroPoint(int64_t zp); }]; @@ -412,8 +420,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Tosa_ScalarTensor:$input_zp, - Tosa_ScalarTensor:$weight_zp, + Tosa_ScalarIntOrFloatTensor:$input_zp, + Tosa_ScalarIntOrFloatTensor:$weight_zp, Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, @@ -431,8 +439,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { ]; let extraClassDeclaration = [{ - LogicalResult getInputZeroPoint(int64_t &zp); - LogicalResult getWeightZeroPoint(int64_t &zp); + FailureOr getInputZeroPoint(); + FailureOr getWeightZeroPoint(); LogicalResult verifyInputZeroPoint(int64_t zp); LogicalResult verifyWeightZeroPoint(int64_t zp); }]; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index cf6ddc66f4ada..7a8357ecfa430 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -149,6 +149,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>; def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; +def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>; // We include unranked tensors as a supported type for all possible tosa // Tensors as unranked does not guarantee invalid. If unranked tensors exist diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index e3400b9ba4358..2a2589e19d0ac 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -260,18 +260,26 @@ class ConvConverter : public OpConversionPattern { DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr(); // Get and verify zero points. - int64_t inputZpVal; - int64_t weightZpVal; + FailureOr maybeIZp = op.getInputZeroPoint(); + if (failed(maybeIZp)) + return rewriter.notifyMatchFailure( + op, "input zero point cannot be statically determined"); + + FailureOr maybeWZp = op.getWeightZeroPoint(); + if (failed(maybeWZp)) + return rewriter.notifyMatchFailure( + op, "weight zero point cannot be statically determined"); - if (op.getInputZeroPoint(inputZpVal).failed() || - op.getWeightZeroPoint(weightZpVal).failed()) + int64_t inputZpVal = *maybeIZp; + int64_t weightZpVal = *maybeWZp; + + if (op.verifyInputZeroPoint(inputZpVal).failed()) return rewriter.notifyMatchFailure( - op, "bail out if zero points cannot statically be determined"); + op, "input zero point must be zero for non-int8 integer types"); - if (op.verifyInputZeroPoint(inputZpVal).failed() || - op.verifyWeightZeroPoint(weightZpVal).failed()) + if (op.verifyWeightZeroPoint(weightZpVal).failed()) return rewriter.notifyMatchFailure( - op, "zero point must be zero for non-int8 integer types"); + op, "weight zero point must be zero for non-int8 integer types"); bool hasZp = (inputZpVal != 0) || (weightZpVal != 0); @@ -448,18 +456,26 @@ class DepthwiseConvConverter /*kernelSizeDims=*/{0, 1}, rewriter); // Get and verify zero points. - int64_t inputZpVal; - int64_t weightZpVal; - if (op.getInputZeroPoint(inputZpVal).failed() || - op.getWeightZeroPoint(weightZpVal).failed()) + FailureOr maybeIZp = op.getInputZeroPoint(); + FailureOr maybeWZp = op.getWeightZeroPoint(); + if (failed(maybeIZp)) + return rewriter.notifyMatchFailure( + op, "input zero point cannot be statically determined"); + if (failed(maybeWZp)) + return rewriter.notifyMatchFailure( + op, "weight zero point cannot be statically determined"); + + int64_t inputZpVal = *maybeIZp; + int64_t weightZpVal = *maybeWZp; + + if (op.verifyInputZeroPoint(inputZpVal).failed()) return rewriter.notifyMatchFailure( - op, "bail out if zero points cannot statically be determined"); + op, "input zero point must be zero for non-int8 integer types"); - if (op.verifyInputZeroPoint(inputZpVal).failed() || - op.verifyWeightZeroPoint(weightZpVal).failed()) + if (op.verifyWeightZeroPoint(weightZpVal).failed()) return rewriter.notifyMatchFailure( - op, "zero point must be zero for non-int8 integer types"); + op, "weight zero point must be zero for non-int8 integer types"); bool hasZp = (inputZpVal != 0) || (weightZpVal != 0); auto weightShape = weightTy.getShape(); @@ -809,6 +825,18 @@ class AvgPool2dConverter : public OpRewritePattern { return failure(); SmallVector dynamicDims = *dynamicDimsOr; + FailureOr maybeIZp = op.getInputZeroPoint(); + FailureOr maybeOZp = op.getOutputZeroPoint(); + if (failed(maybeIZp)) + return rewriter.notifyMatchFailure( + op, "input zero point could not be statically determined"); + if (failed(maybeOZp)) + return rewriter.notifyMatchFailure( + op, "output zero point could not be statically determined"); + + int64_t inputZpVal = *maybeIZp; + int64_t outputZpVal = *maybeOZp; + // Apply padding as necessary. llvm::SmallVector pad; pad.resize(2, 0); @@ -928,9 +956,9 @@ class AvgPool2dConverter : public OpRewritePattern { // If we have quantization information we need to apply an offset // for the input zp value. - if (op.getInputZp()) { - auto inputZp = - rewriter.create(loc, op.getInputZpAttr()); + if (inputZpVal != 0) { + auto inputZp = rewriter.create( + loc, b.getIntegerAttr(accETy, inputZpVal)); Value offset = rewriter.create(loc, accETy, count, inputZp); poolVal = @@ -982,9 +1010,9 @@ class AvgPool2dConverter : public OpRewritePattern { // If we have quantization information we need to apply output // zeropoint. - if (op.getOutputZp()) { - auto outputZp = - rewriter.create(loc, op.getOutputZpAttr()); + if (outputZpVal != 0) { + auto outputZp = rewriter.create( + loc, b.getIntegerAttr(scaled.getType(), outputZpVal)); scaled = rewriter.create(loc, scaled, outputZp) .getResult(); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 1050f3f30fe98..8841d53b6e64d 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,17 +321,13 @@ static LogicalResult verifyConvOp(T op) { << weightEType << " and " << weightZpEType; } - int64_t inputZpVal; - if (op.getInputZeroPoint(inputZpVal).succeeded() && - op.verifyInputZeroPoint(inputZpVal).failed()) - return op.emitOpError( - "input zero point must be zero for non-int8 integer types"); - - int64_t weightZpVal; - if (op.getWeightZeroPoint(weightZpVal).succeeded() && - op.verifyWeightZeroPoint(weightZpVal).failed()) - return op.emitOpError( - "weight zero point must be zero for non-int8 integer types"); + FailureOr maybeIZp = op.getInputZeroPoint(); + if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed()) + return failure(); + + FailureOr maybeWZp = op.getWeightZeroPoint(); + if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed()) + return failure(); return success(); } @@ -455,18 +451,10 @@ LogicalResult tosa::ArgMaxOp::verify() { } LogicalResult tosa::AvgPool2dOp::verify() { - auto inputType = llvm::cast(getInput().getType()); - - auto inputETy = inputType.getElementType(); - auto resultETy = llvm::cast(getType()).getElementType(); - - if (auto quantType = - llvm::dyn_cast(inputETy)) - inputETy = quantType.getStorageType(); - - if (auto quantType = - llvm::dyn_cast(resultETy)) - resultETy = quantType.getStorageType(); + const Type inputETy = getStorageElementTypeOrSelf(getInput().getType()); + const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType()); + const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType()); + const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType()); auto accType = getAccType(); if (llvm::isa(inputETy) && !accType.isInteger(32)) @@ -481,6 +469,24 @@ LogicalResult tosa::AvgPool2dOp::verify() { if (inputETy.isF32() && !accType.isF32()) return emitOpError("accumulator type for f32 tensor is not f32"); + if (inputETy != inputZpETy) + return emitOpError("expect both input and its zero point are the same " + "element type, got ") + << inputETy << " and " << inputZpETy; + + if (resultETy != outputZpETy) + return emitOpError("expect both output and its zero point are the same " + "element type, got ") + << resultETy << " and " << outputZpETy; + + FailureOr maybeIZp = getInputZeroPoint(); + if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed()) + return failure(); + + FailureOr maybeOZp = getOutputZeroPoint(); + if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed()) + return failure(); + if ((inputETy.isF32() && resultETy.isF32()) || (inputETy.isF16() && resultETy.isF16()) || (inputETy.isBF16() && resultETy.isBF16()) || @@ -629,27 +635,48 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder, } /// Both the tosa.avg_pool2d and unary ops use the same -/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it +/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it /// has additional parameters not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType) { - result.addOperands(input); + const Location loc{result.location}; + int64_t inputZp{0}; + int64_t outputZp{0}; + + if (auto quantAttr = + buildUnaryOpQuantizationAttr(builder, input, outputType)) { + inputZp = quantAttr.getInputZp(); + outputZp = quantAttr.getOutputZp(); + } + const std::optional inputZpOp = + createZeroPointTensor(builder, loc, input.getType(), inputZp); + if (!inputZpOp) { + (void)emitError( + loc, + "Failed to create input zero point tensor for quantized AVG_POOL2D op"); + } + const std::optional outputZpOp = + createZeroPointTensor(builder, loc, outputType, outputZp); + if (!outputZpOp) { + (void)emitError(loc, "Failed to create output zero point tensor for " + "quantized AVG_POOL2D op"); + } + + if (inputZpOp && outputZpOp) { + result.addOperands({input, inputZpOp.value(), outputZpOp.value()}); + } else { + // failed to create one or more zero points above: just add input as + // operands this will trigger error in building the op because of missing + // zero points + result.addOperands({input}); + } result.addAttribute("kernel", kernel); result.addAttribute("stride", stride); result.addAttribute("pad", pad); result.addAttribute("acc_type", accType); - auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); - if (quantAttr) { - result.addAttribute("input_zp", - builder.getI32IntegerAttr( - static_cast(quantAttr.getInputZp()))); - result.addAttribute("output_zp", - builder.getI32IntegerAttr( - static_cast(quantAttr.getOutputZp()))); - } result.types.push_back(outputType); } @@ -1471,77 +1498,68 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { return mlir::success(); } +// return failure if val is not a constant +// set zp to -1 if val is non-zero float or val is not integer nor float +// otherwise set zp to val's constant value template -static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) { +static FailureOr getZeroPoint(T op, Value val) { ElementsAttr zpAttr; if (!matchPattern(val, m_Constant(&zpAttr))) { return failure(); } Type zpElemType = zpAttr.getElementType(); - if (auto quantType = - llvm::dyn_cast(zpElemType)) { - zp = quantType.getZeroPoint(); - return success(); - } if (llvm::isa(zpElemType)) { - if (!zpAttr.getValues()[0].isZero()) - return op.emitOpError( - "non-zero zero point is not allowed for float types"); - zp = 0; - return success(); + if (zpAttr.getValues()[0].isZero()) { + return 0; + } + // return non-zero value to trigger error check + return -1; } if (llvm::isa(zpElemType)) { - zp = zpAttr.getValues()[0].getSExtValue(); - return success(); + return zpAttr.getValues()[0].getSExtValue(); } - return op.emitOpError("zero point is not allowed for unsupported types"); + // return non-zero value to trigger error check + return -1; } template -static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) { - // TODO clean it up when the entire zero point (attribute -> input tensor - // type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D. - if constexpr (!std::is_same_v && !std::is_same_v && - !std::is_same_v && - !std::is_same_v) - return failure(); - +static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, + const std::string &operand) { Type zpElemType = getElementTypeOrSelf(val); - if (!zpElemType.isIntOrFloat()) - return op.emitOpError("zero point is not integer or float typss"); - - if (!zpElemType.isInteger(8) && zp != 0) - return op.emitOpError("zero point must be zero for non-int8 integer types"); - - if (zp < -128 || zp > 127) - return failure(); + if (!zpElemType.isInteger(8) && zp != 0) { + // convert operand to lower case for error message + std::string lower = operand; + std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + return op.emitOpError() + << lower << " zero point must be zero for non-int8 integer types"; + } return success(); } -#define ZERO_POINT_HELPER(OP) \ - LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) { \ - return getZeroPoint(*this, getInputZp(), zp); \ - } \ - LogicalResult tosa::OP::getWeightZeroPoint(int64_t &zp) { \ - return getZeroPoint(*this, getWeightZp(), zp); \ - } \ - LogicalResult tosa::OP::verifyInputZeroPoint(int64_t zp) { \ - return verifyZeroPoint(*this, getInputZp(), zp); \ +#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \ + FailureOr tosa::OP::get##OPERAND_NAME##ZeroPoint() { \ + return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \ } \ - LogicalResult tosa::OP::verifyWeightZeroPoint(int64_t zp) { \ - return verifyZeroPoint(*this, getWeightZp(), zp); \ - } - -ZERO_POINT_HELPER(Conv2DOp) -ZERO_POINT_HELPER(Conv3DOp) -ZERO_POINT_HELPER(DepthwiseConv2DOp) -ZERO_POINT_HELPER(TransposeConv2DOp) + LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \ + return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \ + } + +ZERO_POINT_HELPER(Conv2DOp, Input) +ZERO_POINT_HELPER(Conv2DOp, Weight) +ZERO_POINT_HELPER(Conv3DOp, Input) +ZERO_POINT_HELPER(Conv3DOp, Weight) +ZERO_POINT_HELPER(DepthwiseConv2DOp, Input) +ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight) +ZERO_POINT_HELPER(TransposeConv2DOp, Input) +ZERO_POINT_HELPER(TransposeConv2DOp, Weight) +ZERO_POINT_HELPER(AvgPool2dOp, Input) +ZERO_POINT_HELPER(AvgPool2dOp, Output) #undef ZERO_POINT_HELPER LogicalResult tosa::TransposeOp::inferReturnTypeComponents( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 14ee422a31541..9b4cf85c480d3 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -54,18 +54,24 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "unsupported type"); // Get and verify zero points. - int64_t iZp; - int64_t wZp; + FailureOr maybeIZp = op.getInputZeroPoint(); + if (failed(maybeIZp)) + return rewriter.notifyMatchFailure( + op, "input zero point cannot be statically determined"); - if (op.getInputZeroPoint(iZp).failed() || - op.getWeightZeroPoint(wZp).failed()) + FailureOr maybeWZp = op.getWeightZeroPoint(); + if (failed(maybeWZp)) return rewriter.notifyMatchFailure( - op, "bail out if zero points cannot statically be determined"); + op, "weight zero point cannot be statically determined"); - if (op.verifyInputZeroPoint(iZp).failed() || - op.verifyWeightZeroPoint(wZp).failed()) + int64_t iZp = *maybeIZp; + int64_t wZp = *maybeWZp; + if (op.verifyInputZeroPoint(iZp).failed()) + return rewriter.notifyMatchFailure( + op, "input zero point must be zero for non-int8 integer types"); + if (op.verifyWeightZeroPoint(wZp).failed()) return rewriter.notifyMatchFailure( - op, "zero point must be zero for non-int8 integer types"); + op, "weight zero point must be zero for non-int8 integer types"); // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. ArrayRef inputShape = inputType.getShape(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 83bdbce5d1857..fe2c85f4f9c86 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -135,18 +135,26 @@ class TransposeConvStridedConverter getTosaConstShape(rewriter, op->getLoc(), weightPadding); // Get and verify zero points. - int64_t inputZpVal; - int64_t weightZpVal; + FailureOr maybeIZp = op.getInputZeroPoint(); + if (failed(maybeIZp)) + return rewriter.notifyMatchFailure( + op, "input zero point cannot be statically determined"); + + FailureOr maybeWZp = op.getWeightZeroPoint(); + if (failed(maybeWZp)) + return rewriter.notifyMatchFailure( + op, "weight zero point cannot be statically determined"); + + int64_t inputZpVal = *maybeIZp; + int64_t weightZpVal = *maybeWZp; - if (op.getInputZeroPoint(inputZpVal).failed() || - op.getWeightZeroPoint(weightZpVal).failed()) + if (op.verifyInputZeroPoint(inputZpVal).failed()) return rewriter.notifyMatchFailure( - op, "bail out if zero points cannot statically be determined"); + op, "input zero point must be zero for non-int8 integer types"); - if (op.verifyInputZeroPoint(inputZpVal).failed() || - op.verifyWeightZeroPoint(weightZpVal).failed()) + if (op.verifyWeightZeroPoint(weightZpVal).failed()) return rewriter.notifyMatchFailure( - op, "zero point must be zero for non-int8 integer types"); + op, "weight zero point must be zero for non-int8 integer types"); if (weightZpVal != 0) { weight = CreateOpAndInferShape( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 1d8aaa65c2976..345616c9563b5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -58,6 +58,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) { template <> void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) { addValue(op.getInput()); + addValue(op.getInputZp()); + addValue(op.getOutputZp()); addType(op.getAccType()); addValue(op.getOutput()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir index afc1d5c609181..08b147ac1dc1b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics // CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type -func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { +func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform> { // expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}} - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform> return %0 : tensor<1x7x7x9x!quant.uniform> } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 02d2f16b74ef8..d4afc468eeea4 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -290,7 +290,9 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) // CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]] // CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]] // CHECK: linalg.yield %[[DIV]] - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>) -> tensor<1x5x33x62xf32> + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x33x62xf32> return %0 : tensor<1x5x33x62xf32> } @@ -375,7 +377,9 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6 // CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]] // CHECK: %[[TRUNC:.+]] = arith.truncf %[[DIV]] // CHECK: linalg.yield %[[TRUNC]] - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16> + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x33x62xf16> return %0 : tensor<1x5x33x62xf16> } @@ -416,7 +420,9 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) { // CHECK: %[[CLAMP:.+]] = arith.minsi %[[CMAX]], %[[LOW]] // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]] // CHECK: linalg.yield %[[TRUNC]] - %0 = tosa.avg_pool2d %arg0 {acc_type = i32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi8>) -> tensor<1x5x33x62xi8> + %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x33x62xi8> return %0 : tensor<1x5x33x62xi8> } @@ -439,7 +445,9 @@ func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor // CHECK: %[[GENERIC:.+]] = linalg.generic - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor) -> tensor + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor return %0 : tensor } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir index 73da2810abe04..ecd5c792e08b6 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir @@ -23,18 +23,18 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> { // ----- // check that tosa verify kick in -func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> { +func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}} - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} - : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } // ----- // check that --tosa-to-linalg kick in -func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { +func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform> { // expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}} - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform> return %0 : tensor<1x7x7x9x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 98290c7b9eedd..203a6717337a6 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -19,7 +19,9 @@ func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> { func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index dc556f7486774..e665510ff0143 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -56,7 +56,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: func.func @test_conv2d_input_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { %input_zp = "tosa.const"() <{value = dense<-1.0> : tensor<1xf16>}> : () -> tensor<1xf16> %weight_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> - // expected-error@+1 {{'tosa.conv2d' op non-zero zero point is not allowed for float types}} + // expected-error@+1 {{'tosa.conv2d' op input zero point must be zero for non-int8 integer types}} %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x27x27x16xf16> return %0 : tensor<1x27x27x16xf16> @@ -67,7 +67,7 @@ func.func @test_conv2d_input_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3 func.func @test_conv2d_weight_zp(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> %weight_zp = "tosa.const"() <{value = dense<-1.0> : tensor<1xf16>}> : () -> tensor<1xf16> - // expected-error@+1 {{'tosa.conv2d' op non-zero zero point is not allowed for float types}} + // expected-error@+1 {{'tosa.conv2d' op weight zero point must be zero for non-int8 integer types}} %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x27x27x16xf16> return %0 : tensor<1x27x27x16xf16> @@ -567,19 +567,19 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor< // ----- -func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> { +func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}} - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} - : (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x0x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } // ----- -func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> { +func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}} - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} - : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } @@ -1271,6 +1271,50 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens // ----- +// CHECK-LABEL: test_avg_pool_input_zp_same_element_type +func.func @test_avg_pool_input_zp_same_element_type(%arg0: tensor<1x16x16x8xf16>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xf16> { + // expected-error@+1 {{'tosa.avg_pool2d' op expect both input and its zero point are the same element type, got 'f16' and 'i8'}} + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x16x16x8xf16>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xf16> + return %0 : tensor<1x16x16x8xf16> +} + +// ----- + +// CHECK-LABEL: test_avg_pool_output_zp_same_element_type +func.func @test_avg_pool_output_zp_same_element_type(%arg0: tensor<1x16x16x8xi8>, %arg1: tensor<1xi8>, %arg2: tensor<1xf16>) -> tensor<1x16x16x8xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op expect both output and its zero point are the same element type, got 'i8' and 'f16'}} + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array, pad = array, stride = array} + : (tensor<1x16x16x8xi8>, tensor<1xi8>, tensor<1xf16>) -> tensor<1x16x16x8xi8> + return %0 : tensor<1x16x16x8xi8> +} + +// ----- + +// CHECK-LABEL: test_avg_pool_input_zp_non_zero +func.func @test_avg_pool_input_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> { + %input_zp = "tosa.const"() {value = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32> + %output_zp = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.avg_pool2d' op input zero point must be zero for non-int8 integer types}} + %0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32> + return %0 : tensor<1x16x16x8xf32> +} + +// ----- + +// CHECK-LABEL: test_avg_pool_output_zp_non_zero +func.func @test_avg_pool_output_zp_non_zero(%arg0: tensor<1x16x16x8xf32>) -> tensor<1x16x16x8xf32> { + %input_zp = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %output_zp = "tosa.const"() {value = dense<-1.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.avg_pool2d' op output zero point must be zero for non-int8 integer types}} + %0 = "tosa.avg_pool2d"(%arg0, %input_zp, %output_zp) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x16x16x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x16x16x8xf32> + return %0 : tensor<1x16x16x8xf32> +} + +// ----- + func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) { // expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}} %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index c136e8aac9606..e856deda2ab10 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -506,74 +506,74 @@ func.func @test_identity_rank_valid(%arg0: tensor) -> tensor { // ----- -func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } // ----- -func.func @test_avgpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } // ----- -func.func @test_avgpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } // ----- -func.func @test_avgpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } // ----- -func.func @test_avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } // ----- -func.func @test_avgpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } // ----- -func.func @test_avgpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } // ----- -func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { +func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x32x32x8xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : - (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x8xf32> return %0 : tensor<1x32x32x8xf32> } @@ -1074,9 +1074,9 @@ func.func @test_resize_tensor_size_invalid(%arg0: tensor<1x23178x23178x1xf32>) { // ----- -func.func @test_avg_pool2d_tensor_size_invalid(%arg0: tensor<1x23178x23178x9xf32>) -> tensor<1x23178x23178x9xf32> { +func.func @test_avg_pool2d_tensor_size_invalid(%arg0: tensor<1x23178x23178x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x23178x23178x9xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x23178x23178x9xf32>) -> tensor<1x23178x23178x9xf32> + %0 = tosa.avg_pool2d %arg0, %arg1, %arg2 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x23178x23178x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x23178x23178x9xf32> return %0 : tensor<1x23178x23178x9xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 45a87b97125f7..81ff9ad16c713 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -12,42 +12,54 @@ func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> { // ----- // CHECK-LABEL: avg_pool2d_f32 func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } // ----- // CHECK-LABEL: avg_pool2d_f16 func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> { - %0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f16, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x7x7x9xf16> return %0 : tensor<1x7x7x9xf16> } // ----- // CHECK-LABEL: avg_pool2d_f16_accumf32 func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> { - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x7x7x9xf16> return %0 : tensor<1x7x7x9xf16> } // ----- // CHECK-LABEL: avg_pool2d_i8 func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { - %0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } // ----- // CHECK-LABEL: avg_pool2d_i16 func.func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> { - %0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> + %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi16>}> : () -> tensor<1xi16> + %output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi16>}> : () -> tensor<1xi16> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<1x7x7x9xi16> return %0 : tensor<1x7x7x9xi16> } // ----- // CHECK-LABEL: avg_pool2d_q8 func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { - %0 = tosa.avg_pool2d %arg0 {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform> return %0 : tensor<1x7x7x9x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 4e4cb63f1d123..342c57b0dd85c 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -19,9 +19,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % } // ----- -func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { +func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + %0 = tosa.avg_pool2d %arg0, %arg1, %arg2 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index c11977854e5d7..3dd0344e3647d 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -12,9 +12,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % } // ----- -func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { +func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> { // expected-error@+1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + %0 = tosa.avg_pool2d %arg0, %arg1, %arg2 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 77d77ba957621..549ffab5db048 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -669,8 +669,11 @@ func.func @scatter_minimum_static(%arg0 : tensor, %arg1 : tensor<3x?x // CHECK-LABEL: @test_pool_static func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) { + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + // CHECK: -> tensor<3x2x4x7xf32> - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: -> tensor<3x2x4x7xf32> %1 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor @@ -699,8 +702,11 @@ func.func @conv2d_dynamic_input(%input: tensor, %weights: tensor<5x // CHECK-LABEL: @test_pool_dynamic_input func.func @test_pool_dynamic_input(%arg0: tensor) { + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + // CHECK: -> tensor - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor) -> tensor + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: -> tensor %1 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor) -> tensor @@ -711,8 +717,11 @@ func.func @test_pool_dynamic_input(%arg0: tensor) { // CHECK-LABEL: @test_pool_padded func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + // CHECK: -> tensor<3x5x11x7xf32> - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: -> tensor<3x5x11x7xf32> %1 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor @@ -741,8 +750,11 @@ func.func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3 // CHECK-LABEL: @test_pool_stride func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { + %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + // CHECK: -> tensor<3x4x4x7xf32> - %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor + %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: -> tensor<3x4x4x7xf32> %1 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor