diff --git a/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp b/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp index 19093151fa2c..6c14ebab3507 100644 --- a/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp +++ b/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp @@ -748,4 +748,126 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( result); return success(); }); + patterns.onOp( + "QLinearGlobalAveragePool", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + int64_t channelsLast; + if (binder.tensorOperands(operands, 5) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(channelsLast, "channels_last")) + return failure(); + + Value x = operands[0]; + Value xScale, xZp, yScale, yZp; + + if (failed(extractPerTensorQuantizationArguments( + rewriter, loc, /*scale=*/operands[1], + /*zero_point=*/operands[2], xScale, xZp))) + return rewriter.notifyMatchFailure( + binder.op, "Incompatible arguments for per-tensor quantization"); + + if (failed(extractPerTensorQuantizationArguments( + rewriter, loc, /*scale=*/operands[3], + /*zero_point=*/operands[4], yScale, yZp))) + return rewriter.notifyMatchFailure( + binder.op, "Incompatible arguments for per-tensor quantization"); + + auto xTy = dyn_cast(x.getType()); + if (!xTy || !xTy.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "Expected input argument `x` to have sizes"); + ArrayRef inputShape = xTy.getSizes(); + + xTy = getQTorchTypeFromTorchIntType(xTy); + x = rewriter.create( + loc, xTy, x, xScale, xZp); + xTy = rewriter.getType(inputShape, + rewriter.getF32Type()); + // Dequantizing the input tensor `x`. + x = rewriter.create(loc, xTy, x); + + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); + } + ArrayRef resultShape = resultType.getSizes(); + + // Computing the AvgPool result. + SmallVector cstKernel, cstPadding, cstStrides; + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + unsigned inputRank = inputShape.size(); + for (unsigned i = 2; i < inputRank; i++) { + if (inputShape[i] == Torch::kUnknownSize) { + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + Value inputDimSize = + rewriter.create(loc, x, dim); + cstKernel.push_back(inputDimSize); + } else { + int64_t kernelSize = inputShape[i] - resultShape[i] + 1; + cstKernel.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(kernelSize))); + } + cstPadding.push_back(cstZero); + cstStrides.push_back(cstOne); + } + Value kernelSizeList = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value stridesList = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstFalse = rewriter.create(loc, false); + Value cstCeilMode = cstFalse; + Value cstCountIncludePad = cstFalse; + Value cstNone = rewriter.create(loc); + + auto yTy = rewriter.getType( + resultShape, rewriter.getF32Type()); + Value avgpool; + if (inputRank == 3) { + avgpool = rewriter.create( + loc, yTy, x, kernelSizeList, stridesList, paddingList, + cstCeilMode, cstCountIncludePad); + } else if (inputRank == 4) { + avgpool = rewriter.create( + loc, yTy, x, kernelSizeList, stridesList, paddingList, + cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + } else if (inputRank == 5) { + avgpool = rewriter.create( + loc, yTy, x, kernelSizeList, stridesList, paddingList, + cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + } else { + return failure(); + } + + // Quantizing the result of AvgPool op. + yTy = dyn_cast( + getQTorchTypeFromTorchIntType(resultType)); + Value dtyVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), + static_cast( + Torch::getScalarTypeForType(yTy.getDtype())))); + avgpool = rewriter.create( + loc, yTy, avgpool, yScale, yZp, dtyVal); + rewriter.replaceOpWithNewOp(binder.op, resultType, + avgpool); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index c233c84af0b6..8ba397a57446 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -3763,3 +3763,37 @@ func.func @test_qlinearconcat(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtens // CHECK: return %[[OUT]] return %0 : !torch.vtensor<[?,?,?,?],ui8> } + +// ----- + +// CHECK-LABEL: @test_qlinearglobalavgpool( +// CHECK-SAME: %[[X:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,1000,13,13],ui8>, +// CHECK-SAME: %[[X_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>, +// CHECK-SAME: %[[X_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>, +// CHECK-SAME: %[[Y_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>, +// CHECK-SAME: %[[Y_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1000,1,1],ui8> +func.func @test_qlinearglobalavgpool(%arg0: !torch.vtensor<[1,1000,13,13],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1000,1,1],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearGlobalAveragePool"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.channels_last = 0 : si64} : (!torch.vtensor<[1,1000,13,13],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1000,1,1],ui8> + // CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[XSCALE:.+]] = torch.aten.item %[[X_SCALE]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[YSCALE:.+]] = torch.aten.item %[[Y_SCALE]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[XZP:.+]] = torch.aten.item %[[X_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[YZP:.+]] = torch.aten.item %[[Y_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[X_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[X]], %[[XSCALE]], %[[XZP]] : !torch.vtensor<[1,1000,13,13],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1000,13,13],!torch.quint8> + // CHECK: %[[X_F32:.+]] = torch.aten.dequantize.self %[[X_QUANT]] : !torch.vtensor<[1,1000,13,13],!torch.quint8> -> !torch.vtensor<[1,1000,13,13],f32> + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C13:.*]] = torch.constant.int 13 + // CHECK: %[[C13_0:.*]] = torch.constant.int 13 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C13]], %[[C13_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool2d %[[X_F32]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1000,13,13],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1000,1,1],f32> + // CHECK: %[[DTY:.+]] = torch.constant.int 13 + // CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[AVGPOOL]], %[[YSCALE]], %[[YZP]], %[[DTY]] : !torch.vtensor<[1,1000,1,1],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1000,1,1],!torch.quint8> + // CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[1,1000,1,1],!torch.quint8> -> !torch.vtensor<[1,1000,1,1],ui8> + // CHECK: return %[[OUT]] + return %0 : !torch.vtensor<[1,1000,1,1],ui8> +}