Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add per channel quantization support for Onnx.QLinearConv op #3917

Merged
merged 11 commits into from
Mar 10, 2025
239 changes: 179 additions & 60 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,30 +326,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
patterns.onOp(
"QLinearConv", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
if ((binder.tensorOperands(operands, 8) &&
binder.tensorOperands(operands, 9)) ||
binder.tensorResultType(resultType))
return failure();
Value a = operands[0];
Value aScale = operands[1];
Value aZp = operands[2];
Value b = operands[3];
Value bScale = operands[4];
Value bZp = operands[5];
Value cScale = operands[6];
Value cZp = operands[7];
Value c = operands.size() == 9 ? operands[8] : nullptr;

auto check = [](Value v) {
auto vTy = cast<Torch::ValueTensorType>(v.getType());
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
};
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
!check(cScale) || !check(cScale))
return rewriter.notifyMatchFailure(
binder.op, "not supported for non per-tensor quantization");
Value input = operands[0];
Value inputScale = operands[1];
Value inputZp = operands[2];
Value weight = operands[3];
Value weightScale = operands[4];
Value weightZp = operands[5];
Value outputScale = operands[6];
Value outputZp = operands[7];
Value bias = operands.size() == 9 ? operands[8] : nullptr;

auto extract = [&rewriter, &binder](Value v) {
auto vTy = cast<Torch::ValueTensorType>(v.getType());
Expand All @@ -361,36 +353,153 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
v);
};

aZp = extract(aZp);
bZp = extract(bZp);
cZp = extract(cZp);
aScale = extract(aScale);
bScale = extract(bScale);
cScale = extract(cScale);
inputZp = extract(inputZp);
outputZp = extract(outputZp);
inputScale = extract(inputScale);
outputScale = extract(outputScale);

auto make = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
auto ty = cast<Torch::ValueTensorType>(v.getType());
auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp);
};

a = make(a, aScale, aZp);
b = make(b, bScale, bZp);
// The onnx's QLinearConv op allows per channel quantization only for
// the weight tensor for axis = 0.
bool isPerChannelQuantization = false;
auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType());
auto weightScaleTy =
dyn_cast<Torch::ValueTensorType>(weightScale.getType());
auto weightZpTy = dyn_cast<Torch::ValueTensorType>(weightZp.getType());
if (!weightTy || !weightScaleTy || !weightZpTy ||
!weightTy.hasSizes() || !weightScaleTy.hasSizes() ||
!weightZpTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "Expected weight, weight_scale, and weight_zero_point "
"arguments to have sizes");
ArrayRef<int64_t> weightShape(weightTy.getSizes());
SmallVector<int64_t> weightScaleShape(weightScaleTy.getSizes());
SmallVector<int64_t> weightZpShape(weightZpTy.getSizes());
if (weightScaleShape.size() == 0 ||
llvm::all_of(weightScaleShape, [](int64_t s) { return s == 1; })) {
weightZp = extract(weightZp);
weightScale = extract(weightScale);
weight = makePerTensor(weight, weightScale, weightZp);
} else if (weightScaleShape.size() == 1 &&
weightScaleShape[0] != Torch::kUnknownSize &&
weightScaleShape[0] == weightShape[0]) {
// Since the convolution operation in the downstream pipeline
// ("Linalg") does not support the per-channel quantization, hence for
// this particular case we perform the convolution over the
// dequantized input and weight instead of relying on the downstream
// pipeline to handle this. This code can be removed and made similar
// to the other paths in this lowering once the per-channel
// quantization support is added in the downstream pipeline.
isPerChannelQuantization = true;

auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "Expected input argument to have sizes");

auto cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getIntegerType(32, /*issigned=*/true));
// Dequantizing the input
// input = input.to(dtype=torch.float32)
// input_dequant = (input - input_zero_point) * input_scale

// TODO(suderman): insert convolution operator.
llvm::SmallVector<Value> newOperands = {a, b};
if (c)
newOperands.push_back(c);
// Converting the input tensor to float32 type.
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6));
Type f32InputType = rewriter.getType<Torch::ValueTensorType>(
inputTy.getSizes(), rewriter.getF32Type());
input = rewriter.create<Torch::AtenToDtypeOp>(
loc, f32InputType, input, float32Type,
/*non_blocking=*/cstFalse,
/*copy=*/cstFalse,
/*memory_format=*/none);

cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getType<Torch::QInt32Type>());
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));
input = rewriter.create<Torch::AtenSubScalarOp>(
loc, f32InputType, input, inputZp, cstOne);
input = rewriter.create<Torch::AtenMulScalarOp>(loc, f32InputType,
input, inputScale);

// Dequantizing the weight
// Shapes of the inputs are as follows:
// weight = (M x C/group x k1 x k2 x … x kn)
// weight_scale = (M)
// weight_zero_point = (M)
//
// We unsqueeze the weight_scale and weight_zero_point to match the
// rank of weight. After unsqueeze:
// weight_scale = (M, 1, 1, ..., 1)
// weight_zero_point = (M, 1, 1, ..., 1)
//
// Then, we compute the dequantized weight:
// weight = weight.to(dtype=torch.float32)
// weight_dequant = (weight - weight_zero_point) * weight_scale
int64_t diffRank = weightShape.size() - weightScaleShape.size();
for (int i = 1; i <= diffRank; i++) {
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));

weightScaleShape.push_back(1);
Type weightScaleUnsqueezeType = weightScaleTy.getWithSizesAndDtype(
weightScaleShape, weightScaleTy.getOptionalDtype());
weightScale = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, weightScaleUnsqueezeType, weightScale, cstDim);

weightZpShape.push_back(1);
Type weightZpUnsqueezeType = weightZpTy.getWithSizesAndDtype(
weightZpShape, weightZpTy.getOptionalDtype());
weightZp = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, weightZpUnsqueezeType, weightZp, cstDim);
}

// Converting the weight tensor to float32 type.
Type f32WeightType = rewriter.getType<Torch::ValueTensorType>(
weightShape, rewriter.getF32Type());
weight = rewriter.create<Torch::AtenToDtypeOp>(
loc, f32WeightType, weight, float32Type,
/*non_blocking=*/cstFalse,
/*copy=*/cstFalse,
/*memory_format=*/none);

weight = rewriter.create<Torch::AtenSubTensorOp>(
loc, f32WeightType, weight, weightZp, cstOne);
weight = rewriter.create<Torch::AtenMulTensorOp>(loc, f32WeightType,
weight, weightScale);

// Converting the bias tensor to float32 type.
if (bias) {
auto biasTy = dyn_cast<Torch::ValueTensorType>(bias.getType());
if (!biasTy || !biasTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "Expected bias argument to have sizes");
Type f32BiasType = rewriter.getType<Torch::ValueTensorType>(
biasTy.getSizes(), rewriter.getF32Type());
bias = rewriter.create<Torch::AtenToDtypeOp>(
loc, f32BiasType, bias, float32Type,
/*non_blocking=*/cstFalse,
/*copy=*/cstFalse,
/*memory_format=*/none);
}

} else {
llvm_unreachable("Unidentified case for weight quantization for "
"Onnx.QLinearConv op");
}

if (!isPerChannelQuantization)
input = makePerTensor(input, inputScale, inputZp);

llvm::SmallVector<Value> newOperands = {input, weight};
if (bias)
newOperands.push_back(bias);

llvm::SmallVector<NamedAttribute> newAttributes;
newAttributes.push_back(
Expand All @@ -402,36 +511,46 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
newAttributes.push_back(namedAttr);
}

c = rewriter
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
newAttributes,
binder.op->getRegions().size())
.getResult(0);
Type convDtype =
isPerChannelQuantization
? cast<Type>(rewriter.getF32Type())
: cast<Type>(rewriter.getType<Torch::QInt32Type>());
auto outputTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), convDtype);
Value output = rewriter
.create<Torch::OperatorOp>(
binder.getLoc(), outputTy, newOperands,
newAttributes, binder.op->getRegions().size())
.getResult(0);

if (!isPerChannelQuantization) {
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), inputScale,
weightScale);
Value outZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
output = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), outputTy, output, outScale, outZp);
outputTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), rewriter.getF32Type());

Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
bScale);
Value outZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), cTy, c, outScale, outZp);
cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), rewriter.getF32Type());
output = rewriter.create<Torch::AtenDequantizeSelfOp>(
binder.getLoc(), outputTy, output);
}

c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
c);
cTy = getQTorchTypeFromTorchIntType(resultType);
outputTy = getQTorchTypeFromTorchIntType(resultType);
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64),
static_cast<int64_t>(
Torch::getScalarTypeForType(cTy.getDtype()))));
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
Torch::getScalarTypeForType(outputTy.getDtype()))));

output = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal);
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
c);
output);
return success();
});
patterns.onOp(
Expand Down
Loading
Loading