Skip to content

Commit

Permalink
[mlir][tosa] Switch zero point of avgpool2d to input variable type (l…
Browse files Browse the repository at this point in the history
…lvm#128983)

This commit changes the TOSA operator AvgPool2d's zero point attributes
to inputs to align with TOSA 1.0 spec.

Signed-off-by: Luke Hutton <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
  • Loading branch information
Tai78641 and lhutton1 authored Mar 4, 2025
1 parent 17bfc00 commit 25a29ce
Show file tree
Hide file tree
Showing 18 changed files with 355 additions and 204 deletions.
14 changes: 8 additions & 6 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ profileComplianceMap = {
{{{Profile::pro_int}, {{i8T, i32T}}},
{{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Profile::pro_int}, {{i8T, i32T, i8T}}},
{{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
{{Profile::pro_fp},
{{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
{{fp16T, fp16T, fp16T, fp16T, fp16T},
{fp16T, fp16T, fp16T, fp32T, fp16T},
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.conv2d",
{{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
{{Profile::pro_fp},
Expand Down Expand Up @@ -243,10 +245,10 @@ extensionComplianceMap = {
{{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
{{Extension::bf16}, {{bf16T, i32T}}}}},
{"tosa.avg_pool2d",
{{{Extension::int16}, {{i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
{{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
{{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
{"tosa.conv2d",
{{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
{{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
Expand Down
46 changes: 27 additions & 19 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$output_zp,
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<I32Attr>:$input_zp,
OptionalAttr<I32Attr>:$output_zp
TypeAttrOf<Tosa_AccType>:$acc_type
);

let results = (outs
Expand All @@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
];

let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];

let extraClassDeclaration = [{
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getOutputZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];

let hasVerifier = 1;
}

Expand All @@ -116,8 +124,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Expand All @@ -136,8 +144,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand All @@ -161,8 +169,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Expand All @@ -181,8 +189,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand All @@ -207,8 +215,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Expand All @@ -227,8 +235,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand Down Expand Up @@ -412,8 +420,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_ScalarTensor:$input_zp,
Tosa_ScalarTensor:$weight_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$weight_zp,

Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Expand All @@ -431,8 +439,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
];

let extraClassDeclaration = [{
LogicalResult getInputZeroPoint(int64_t &zp);
LogicalResult getWeightZeroPoint(int64_t &zp);
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getWeightZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyWeightZeroPoint(int64_t zp);
}];
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;

def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>;

// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
Expand Down
72 changes: 50 additions & 22 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,26 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();

// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");

FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");

if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;

if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point must be zero for non-int8 integer types");

if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "weight zero point must be zero for non-int8 integer types");

bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);

Expand Down Expand Up @@ -448,18 +456,26 @@ class DepthwiseConvConverter
/*kernelSizeDims=*/{0, 1}, rewriter);

// Get and verify zero points.
int64_t inputZpVal;
int64_t weightZpVal;

if (op.getInputZeroPoint(inputZpVal).failed() ||
op.getWeightZeroPoint(weightZpVal).failed())
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");
if (failed(maybeWZp))
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");

int64_t inputZpVal = *maybeIZp;
int64_t weightZpVal = *maybeWZp;

if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
op, "bail out if zero points cannot statically be determined");
op, "input zero point must be zero for non-int8 integer types");

if (op.verifyInputZeroPoint(inputZpVal).failed() ||
op.verifyWeightZeroPoint(weightZpVal).failed())
if (op.verifyWeightZeroPoint(weightZpVal).failed())
return rewriter.notifyMatchFailure(
op, "zero point must be zero for non-int8 integer types");
op, "weight zero point must be zero for non-int8 integer types");

bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
auto weightShape = weightTy.getShape();
Expand Down Expand Up @@ -809,6 +825,18 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
return failure();
SmallVector<Value> dynamicDims = *dynamicDimsOr;

FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeIZp))
return rewriter.notifyMatchFailure(
op, "input zero point could not be statically determined");
if (failed(maybeOZp))
return rewriter.notifyMatchFailure(
op, "output zero point could not be statically determined");

int64_t inputZpVal = *maybeIZp;
int64_t outputZpVal = *maybeOZp;

// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
Expand Down Expand Up @@ -928,9 +956,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {

// If we have quantization information we need to apply an offset
// for the input zp value.
if (op.getInputZp()) {
auto inputZp =
rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
if (inputZpVal != 0) {
auto inputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(accETy, inputZpVal));
Value offset =
rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
poolVal =
Expand Down Expand Up @@ -982,9 +1010,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {

// If we have quantization information we need to apply output
// zeropoint.
if (op.getOutputZp()) {
auto outputZp =
rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
if (outputZpVal != 0) {
auto outputZp = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
.getResult();
}
Expand Down
Loading

0 comments on commit 25a29ce

Please sign in to comment.