diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d3251c589ac8..39613830f72c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -326,30 +326,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( patterns.onOp( "QLinearConv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; llvm::SmallVector operands; if ((binder.tensorOperands(operands, 8) && binder.tensorOperands(operands, 9)) || binder.tensorResultType(resultType)) return failure(); - Value a = operands[0]; - Value aScale = operands[1]; - Value aZp = operands[2]; - Value b = operands[3]; - Value bScale = operands[4]; - Value bZp = operands[5]; - Value cScale = operands[6]; - Value cZp = operands[7]; - Value c = operands.size() == 9 ? operands[8] : nullptr; - - auto check = [](Value v) { - auto vTy = cast(v.getType()); - return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); - }; - if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || - !check(cScale) || !check(cScale)) - return rewriter.notifyMatchFailure( - binder.op, "not supported for non per-tensor quantization"); + Value input = operands[0]; + Value inputScale = operands[1]; + Value inputZp = operands[2]; + Value weight = operands[3]; + Value weightScale = operands[4]; + Value weightZp = operands[5]; + Value outputScale = operands[6]; + Value outputZp = operands[7]; + Value bias = operands.size() == 9 ? operands[8] : nullptr; auto extract = [&rewriter, &binder](Value v) { auto vTy = cast(v.getType()); @@ -361,36 +353,153 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( v); }; - aZp = extract(aZp); - bZp = extract(bZp); - cZp = extract(cZp); - aScale = extract(aScale); - bScale = extract(bScale); - cScale = extract(cScale); + inputZp = extract(inputZp); + outputZp = extract(outputZp); + inputScale = extract(inputScale); + outputScale = extract(outputScale); - auto make = [&rewriter, &binder](Value v, Value scale, - Value zp) -> Value { + auto makePerTensor = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { auto ty = cast(v.getType()); auto newTy = getQTorchTypeFromTorchIntType(ty); return rewriter.create( binder.getLoc(), newTy, v, scale, zp); }; - a = make(a, aScale, aZp); - b = make(b, bScale, bZp); + // The onnx's QLinearConv op allows per channel quantization only for + // the weight tensor for axis = 0. + bool isPerChannelQuantization = false; + auto weightTy = dyn_cast(weight.getType()); + auto weightScaleTy = + dyn_cast(weightScale.getType()); + auto weightZpTy = dyn_cast(weightZp.getType()); + if (!weightTy || !weightScaleTy || !weightZpTy || + !weightTy.hasSizes() || !weightScaleTy.hasSizes() || + !weightZpTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected weight, weight_scale, and weight_zero_point " + "arguments to have sizes"); + ArrayRef weightShape(weightTy.getSizes()); + SmallVector weightScaleShape(weightScaleTy.getSizes()); + SmallVector weightZpShape(weightZpTy.getSizes()); + if (weightScaleShape.size() == 0 || + llvm::all_of(weightScaleShape, [](int64_t s) { return s == 1; })) { + weightZp = extract(weightZp); + weightScale = extract(weightScale); + weight = makePerTensor(weight, weightScale, weightZp); + } else if (weightScaleShape.size() == 1 && + weightScaleShape[0] != Torch::kUnknownSize && + weightScaleShape[0] == weightShape[0]) { + // Since the convolution operation in the downstream pipeline + // ("Linalg") does not support the per-channel quantization, hence for + // this particular case we perform the convolution over the + // dequantized input and weight instead of relying on the downstream + // pipeline to handle this. This code can be removed and made similar + // to the other paths in this lowering once the per-channel + // quantization support is added in the downstream pipeline. + isPerChannelQuantization = true; + + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected input argument to have sizes"); - auto cTy = rewriter.getType( - resultType.getOptionalSizes(), - rewriter.getIntegerType(32, /*issigned=*/true)); + // Dequantizing the input + // input = input.to(dtype=torch.float32) + // input_dequant = (input - input_zero_point) * input_scale - // TODO(suderman): insert convolution operator. - llvm::SmallVector newOperands = {a, b}; - if (c) - newOperands.push_back(c); + // Converting the input tensor to float32 type. + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + Value float32Type = rewriter.create( + loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + Type f32InputType = rewriter.getType( + inputTy.getSizes(), rewriter.getF32Type()); + input = rewriter.create( + loc, f32InputType, input, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); - cTy = rewriter.getType( - resultType.getOptionalSizes(), - rewriter.getType()); + Value cstOne = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + input = rewriter.create( + loc, f32InputType, input, inputZp, cstOne); + input = rewriter.create(loc, f32InputType, + input, inputScale); + + // Dequantizing the weight + // Shapes of the inputs are as follows: + // weight = (M x C/group x k1 x k2 x … x kn) + // weight_scale = (M) + // weight_zero_point = (M) + // + // We unsqueeze the weight_scale and weight_zero_point to match the + // rank of weight. After unsqueeze: + // weight_scale = (M, 1, 1, ..., 1) + // weight_zero_point = (M, 1, 1, ..., 1) + // + // Then, we compute the dequantized weight: + // weight = weight.to(dtype=torch.float32) + // weight_dequant = (weight - weight_zero_point) * weight_scale + int64_t diffRank = weightShape.size() - weightScaleShape.size(); + for (int i = 1; i <= diffRank; i++) { + Value cstDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + + weightScaleShape.push_back(1); + Type weightScaleUnsqueezeType = weightScaleTy.getWithSizesAndDtype( + weightScaleShape, weightScaleTy.getOptionalDtype()); + weightScale = rewriter.create( + loc, weightScaleUnsqueezeType, weightScale, cstDim); + + weightZpShape.push_back(1); + Type weightZpUnsqueezeType = weightZpTy.getWithSizesAndDtype( + weightZpShape, weightZpTy.getOptionalDtype()); + weightZp = rewriter.create( + loc, weightZpUnsqueezeType, weightZp, cstDim); + } + + // Converting the weight tensor to float32 type. + Type f32WeightType = rewriter.getType( + weightShape, rewriter.getF32Type()); + weight = rewriter.create( + loc, f32WeightType, weight, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + + weight = rewriter.create( + loc, f32WeightType, weight, weightZp, cstOne); + weight = rewriter.create(loc, f32WeightType, + weight, weightScale); + + // Converting the bias tensor to float32 type. + if (bias) { + auto biasTy = dyn_cast(bias.getType()); + if (!biasTy || !biasTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected bias argument to have sizes"); + Type f32BiasType = rewriter.getType( + biasTy.getSizes(), rewriter.getF32Type()); + bias = rewriter.create( + loc, f32BiasType, bias, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + } + + } else { + llvm_unreachable("Unidentified case for weight quantization for " + "Onnx.QLinearConv op"); + } + + if (!isPerChannelQuantization) + input = makePerTensor(input, inputScale, inputZp); + + llvm::SmallVector newOperands = {input, weight}; + if (bias) + newOperands.push_back(bias); llvm::SmallVector newAttributes; newAttributes.push_back( @@ -402,36 +511,46 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( newAttributes.push_back(namedAttr); } - c = rewriter - .create(binder.getLoc(), cTy, newOperands, - newAttributes, - binder.op->getRegions().size()) - .getResult(0); + Type convDtype = + isPerChannelQuantization + ? cast(rewriter.getF32Type()) + : cast(rewriter.getType()); + auto outputTy = rewriter.getType( + resultType.getOptionalSizes(), convDtype); + Value output = rewriter + .create( + binder.getLoc(), outputTy, newOperands, + newAttributes, binder.op->getRegions().size()) + .getResult(0); + + if (!isPerChannelQuantization) { + Value outScale = rewriter.create( + binder.getLoc(), rewriter.getType(), inputScale, + weightScale); + Value outZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + output = rewriter.create( + binder.getLoc(), outputTy, output, outScale, outZp); + outputTy = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF32Type()); - Value outScale = rewriter.create( - binder.getLoc(), rewriter.getType(), aScale, - bScale); - Value outZp = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - c = rewriter.create( - binder.getLoc(), cTy, c, outScale, outZp); - cTy = rewriter.getType( - resultType.getOptionalSizes(), rewriter.getF32Type()); + output = rewriter.create( + binder.getLoc(), outputTy, output); + } - c = rewriter.create(binder.getLoc(), cTy, - c); - cTy = getQTorchTypeFromTorchIntType(resultType); + outputTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( - Torch::getScalarTypeForType(cTy.getDtype())))); - c = rewriter.create( - binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + Torch::getScalarTypeForType(outputTy.getDtype())))); + + output = rewriter.create( + binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, - c); + output); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 2caddff9bc3b..b98c10792ecc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -68,13 +68,13 @@ func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch. func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] @@ -107,13 +107,13 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int - // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] @@ -141,6 +141,52 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // ----- +// CHECK-LABEL: func.func @test_qlinearconv_weight_per_channel_quantization +func.func @test_qlinearconv_weight_per_channel_quantization(%arg0: !torch.vtensor<[?,3,224,224],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[64,3,7,7],si8>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[64],si8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.mlfeaturizers = 1 : si64, com.microsoft.nchwc = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) {torch.onnx.auto_pad = "NOTSET", torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,3,224,224],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64,3,7,7],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8> + // CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[OUTPUT_ZP:.*]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[INPUT_SCALE:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[OUTPUT_SCALE:.*]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[F32DTYPE:.*]] = torch.constant.int 6 + // CHECK: %[[F32_INPUT:.*]] = torch.aten.to.dtype %arg0, %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,3,224,224],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[VAL_18:.*]] = torch.aten.sub.Scalar %[[F32_INPUT]], %[[INPUT_ZP]], %float1.000000e00 : !torch.vtensor<[?,3,224,224],f32>, !torch.int, !torch.float -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[DEQUANT_INPUT:.*]] = torch.aten.mul.Scalar %[[VAL_18]], %[[INPUT_SCALE]] : !torch.vtensor<[?,3,224,224],f32>, !torch.float -> !torch.vtensor<[?,3,224,224],f32> + // CHECK: %[[VAL_21:.*]] = torch.aten.unsqueeze %arg4, %int1 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64,1],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %arg5, %int1 : !torch.vtensor<[64],si8>, !torch.int -> !torch.vtensor<[64,1],si8> + // CHECK: %[[VAL_24:.*]] = torch.aten.unsqueeze %[[VAL_21]], %int2 : !torch.vtensor<[64,1],f32>, !torch.int -> !torch.vtensor<[64,1,1],f32> + // CHECK: %[[VAL_25:.*]] = torch.aten.unsqueeze %[[VAL_22]], %int2 : !torch.vtensor<[64,1],si8>, !torch.int -> !torch.vtensor<[64,1,1],si8> + // CHECK: %[[WEIGHT_SCALE:.*]] = torch.aten.unsqueeze %[[VAL_24]], %int3 : !torch.vtensor<[64,1,1],f32>, !torch.int -> !torch.vtensor<[64,1,1,1],f32> + // CHECK: %[[WEIGHT_ZP:.*]] = torch.aten.unsqueeze %[[VAL_25]], %int3 : !torch.vtensor<[64,1,1],si8>, !torch.int -> !torch.vtensor<[64,1,1,1],si8> + // CHECK: %[[F32_WEIGHT:.*]] = torch.aten.to.dtype %arg3, %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64,3,7,7],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[VAL_30:.*]] = torch.aten.sub.Tensor %[[F32_WEIGHT]], %[[WEIGHT_ZP]], %float1.000000e00 : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],si8>, !torch.float -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[DEQUANT_WEIGHT:.*]] = torch.aten.mul.Tensor %[[VAL_30]], %[[WEIGHT_SCALE]] : !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64,1,1,1],f32> -> !torch.vtensor<[64,3,7,7],f32> + // CHECK: %[[F32_BIAS:.*]] = torch.aten.to.dtype %arg8, %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[64],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[64],f32> + // CHECK: %[[VAL_33:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_34:.*]] = torch.constant.int 3 + // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[VAL_33]], %[[VAL_34]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_36:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_37:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_38:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_39:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_40:.*]] = torch.constant.int 0 + // CHECK: %[[KERNEL:.*]] = torch.prim.ListConstruct %[[VAL_36]], %[[VAL_37]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_39]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[VAL_40]], %[[VAL_40]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[CONV:.*]] = torch.aten.convolution %[[DEQUANT_INPUT]], %[[DEQUANT_WEIGHT]], %[[F32_BIAS]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[TRANSPOSED]], %[[STRIDE]], %[[GROUPS]] : !torch.vtensor<[?,3,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,64,112,112],f32> + // CHECK: %[[DTYPE:.*]] = torch.constant.int 13 + // CHECK: %[[QUANT:.*]] = torch.aten.quantize_per_tensor %[[CONV]], %[[OUTPUT_SCALE]], %[[OUTPUT_ZP]], %[[DTYPE]] : !torch.vtensor<[?,64,112,112],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,64,112,112],!torch.quint8> + // CHECK: %[[OUTPUT:.*]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[?,64,112,112],!torch.quint8> -> !torch.vtensor<[?,64,112,112],ui8> + // CHECK: return %[[OUTPUT]] : !torch.vtensor<[?,64,112,112],ui8> + return %0 : !torch.vtensor<[?,64,112,112],ui8> +} + +// ----- + // CHECK-LABEL: @test_qlinearmatmul_2D func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8>