Skip to content
239 changes: 179 additions & 60 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,30 +326,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
patterns.onOp(
"QLinearConv", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
if ((binder.tensorOperands(operands, 8) &&
binder.tensorOperands(operands, 9)) ||
binder.tensorResultType(resultType))
return failure();
Value a = operands[0];
Value aScale = operands[1];
Value aZp = operands[2];
Value b = operands[3];
Value bScale = operands[4];
Value bZp = operands[5];
Value cScale = operands[6];
Value cZp = operands[7];
Value c = operands.size() == 9 ? operands[8] : nullptr;

auto check = [](Value v) {
auto vTy = cast<Torch::ValueTensorType>(v.getType());
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
};
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
!check(cScale) || !check(cScale))
return rewriter.notifyMatchFailure(
binder.op, "not supported for non per-tensor quantization");
Value input = operands[0];
Value inputScale = operands[1];
Value inputZp = operands[2];
Value weight = operands[3];
Value weightScale = operands[4];
Value weightZp = operands[5];
Value outputScale = operands[6];
Value outputZp = operands[7];
Value bias = operands.size() == 9 ? operands[8] : nullptr;

auto extract = [&rewriter, &binder](Value v) {
auto vTy = cast<Torch::ValueTensorType>(v.getType());
Expand All @@ -361,36 +353,153 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
v);
};

aZp = extract(aZp);
bZp = extract(bZp);
cZp = extract(cZp);
aScale = extract(aScale);
bScale = extract(bScale);
cScale = extract(cScale);
inputZp = extract(inputZp);
outputZp = extract(outputZp);
inputScale = extract(inputScale);
outputScale = extract(outputScale);

auto make = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
Value zp) -> Value {
auto ty = cast<Torch::ValueTensorType>(v.getType());
auto newTy = getQTorchTypeFromTorchIntType(ty);
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), newTy, v, scale, zp);
};

a = make(a, aScale, aZp);
b = make(b, bScale, bZp);
// The onnx's QLinearConv op allows per channel quantization only for
// the weight tensor for axis = 0.
bool isPerChannelQuantization = false;
auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType());
auto weightScaleTy =
dyn_cast<Torch::ValueTensorType>(weightScale.getType());
auto weightZpTy = dyn_cast<Torch::ValueTensorType>(weightZp.getType());
if (!weightTy || !weightScaleTy || !weightZpTy ||
!weightTy.hasSizes() || !weightScaleTy.hasSizes() ||
!weightZpTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "Expected weight, weight_scale, and weight_zero_point "
"arguments to have sizes");
ArrayRef<int64_t> weightShape(weightTy.getSizes());
SmallVector<int64_t> weightScaleShape(weightScaleTy.getSizes());
SmallVector<int64_t> weightZpShape(weightZpTy.getSizes());
if (weightScaleShape.size() == 0 ||
llvm::all_of(weightScaleShape, [](int64_t s) { return s == 1; })) {
weightZp = extract(weightZp);
weightScale = extract(weightScale);
weight = makePerTensor(weight, weightScale, weightZp);
} else if (weightScaleShape.size() == 1 &&
weightScaleShape[0] != Torch::kUnknownSize &&
weightScaleShape[0] == weightShape[0]) {
// Since the convolution operation in the downstream pipeline
// ("Linalg") does not support the per-channel quantization, hence for
// this particular case we perform the convolution over the
// dequantized input and weight instead of relying on the downstream
// pipeline to handle this. This code can be removed and made similar
// to the other paths in this lowering once the per-channel
// quantization support is added in the downstream pipeline.
isPerChannelQuantization = true;

auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "Expected input argument to have sizes");

auto cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getIntegerType(32, /*issigned=*/true));
// Dequantizing the input
// input = input.to(dtype=torch.float32)
// input_dequant = (input - input_zero_point) * input_scale

// TODO(suderman): insert convolution operator.
llvm::SmallVector<Value> newOperands = {a, b};
if (c)
newOperands.push_back(c);
// Converting the input tensor to float32 type.
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6));
Type f32InputType = rewriter.getType<Torch::ValueTensorType>(
inputTy.getSizes(), rewriter.getF32Type());
input = rewriter.create<Torch::AtenToDtypeOp>(
loc, f32InputType, input, float32Type,
/*non_blocking=*/cstFalse,
/*copy=*/cstFalse,
/*memory_format=*/none);

cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(),
rewriter.getType<Torch::QInt32Type>());
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));
input = rewriter.create<Torch::AtenSubScalarOp>(
loc, f32InputType, input, inputZp, cstOne);
input = rewriter.create<Torch::AtenMulScalarOp>(loc, f32InputType,
input, inputScale);

// Dequantizing the weight
// Shapes of the inputs are as follows:
// weight = (M x C/group x k1 x k2 x … x kn)
// weight_scale = (M)
// weight_zero_point = (M)
//
// We unsqueeze the weight_scale and weight_zero_point to match the
// rank of weight. After unsqueeze:
// weight_scale = (M, 1, 1, ..., 1)
// weight_zero_point = (M, 1, 1, ..., 1)
//
// Then, we compute the dequantized weight:
// weight = weight.to(dtype=torch.float32)
// weight_dequant = (weight - weight_zero_point) * weight_scale
int64_t diffRank = weightShape.size() - weightScaleShape.size();
for (int i = 1; i <= diffRank; i++) {
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));

weightScaleShape.push_back(1);
Type weightScaleUnsqueezeType = weightScaleTy.getWithSizesAndDtype(
weightScaleShape, weightScaleTy.getOptionalDtype());
weightScale = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, weightScaleUnsqueezeType, weightScale, cstDim);

weightZpShape.push_back(1);
Type weightZpUnsqueezeType = weightZpTy.getWithSizesAndDtype(
weightZpShape, weightZpTy.getOptionalDtype());
weightZp = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, weightZpUnsqueezeType, weightZp, cstDim);
}

// Converting the weight tensor to float32 type.
Type f32WeightType = rewriter.getType<Torch::ValueTensorType>(
weightShape, rewriter.getF32Type());
weight = rewriter.create<Torch::AtenToDtypeOp>(
loc, f32WeightType, weight, float32Type,
/*non_blocking=*/cstFalse,
/*copy=*/cstFalse,
/*memory_format=*/none);

weight = rewriter.create<Torch::AtenSubTensorOp>(
loc, f32WeightType, weight, weightZp, cstOne);
weight = rewriter.create<Torch::AtenMulTensorOp>(loc, f32WeightType,
weight, weightScale);

// Converting the bias tensor to float32 type.
if (bias) {
auto biasTy = dyn_cast<Torch::ValueTensorType>(bias.getType());
if (!biasTy || !biasTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "Expected bias argument to have sizes");
Type f32BiasType = rewriter.getType<Torch::ValueTensorType>(
biasTy.getSizes(), rewriter.getF32Type());
bias = rewriter.create<Torch::AtenToDtypeOp>(
loc, f32BiasType, bias, float32Type,
/*non_blocking=*/cstFalse,
/*copy=*/cstFalse,
/*memory_format=*/none);
}

} else {
llvm_unreachable("Unidentified case for weight quantization for "
"Onnx.QLinearConv op");
}

if (!isPerChannelQuantization)
input = makePerTensor(input, inputScale, inputZp);

llvm::SmallVector<Value> newOperands = {input, weight};
if (bias)
newOperands.push_back(bias);

llvm::SmallVector<NamedAttribute> newAttributes;
newAttributes.push_back(
Expand All @@ -402,36 +511,46 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
newAttributes.push_back(namedAttr);
}

c = rewriter
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
newAttributes,
binder.op->getRegions().size())
.getResult(0);
Type convDtype =
isPerChannelQuantization
? cast<Type>(rewriter.getF32Type())
: cast<Type>(rewriter.getType<Torch::QInt32Type>());
auto outputTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), convDtype);
Value output = rewriter
.create<Torch::OperatorOp>(
binder.getLoc(), outputTy, newOperands,
newAttributes, binder.op->getRegions().size())
.getResult(0);

if (!isPerChannelQuantization) {
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), inputScale,
weightScale);
Value outZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
output = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), outputTy, output, outScale, outZp);
outputTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), rewriter.getF32Type());

Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
bScale);
Value outZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), cTy, c, outScale, outZp);
cTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), rewriter.getF32Type());
output = rewriter.create<Torch::AtenDequantizeSelfOp>(
binder.getLoc(), outputTy, output);
}

c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
c);
cTy = getQTorchTypeFromTorchIntType(resultType);
outputTy = getQTorchTypeFromTorchIntType(resultType);
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64),
static_cast<int64_t>(
Torch::getScalarTypeForType(cTy.getDtype()))));
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
Torch::getScalarTypeForType(outputTy.getDtype()))));

output = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal);
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
c);
output);
return success();
});
patterns.onOp(
Expand Down
Loading
Loading