diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 944c258a8d12..142caf583f24 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2700,12 +2700,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( }); patterns.onOp( "Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; + Torch::ValueTensorType outputTensorType; llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; int64_t antialias, exclude_outside; float extrapolation_value, cubic_coeff_a; - Value noneVal = rewriter.create(binder.getLoc()); if (auto attr = binder.op->getAttr("torch.onnx.axes")) { return rewriter.notifyMatchFailure( @@ -2720,7 +2719,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || + binder.tensorResultType(outputTensorType) || binder.customOpNameStringAttr(mode, "mode", "nearest") || binder.customOpNameStringAttr( coordTfMode, "coordinate_transformation_mode", "half_pixel") || @@ -2732,6 +2731,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "round_prefer_floor") || binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); + + int64_t const /* */ batchDim = 0; + int64_t const /**/ channelDim = 1; + + SmallVector nonResizableDims{ + batchDim, + channelDim, + }; + + Value inputTensor = operands[0]; + auto inputTensorType = + cast(inputTensor.getType()); + auto sizesOfInputTensor = inputTensorType.getSizes(); + auto sizesOfOutputTensor = outputTensorType.getSizes(); + + auto unknownSize = Torch::kUnknownSize; + + // Compile-time check for dimensions of static size + for (auto &eachDim : nonResizableDims) { + auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim]; + auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim]; + + if (eachSizeOfInputTensor == unknownSize || + eachSizeOfOutputTensor == unknownSize) + continue; + if (eachSizeOfInputTensor == eachSizeOfOutputTensor) + continue; + + auto resizingIntentErrorMessage = + "unsupported: non-trivial intent to resize dimension: " + + std::to_string(eachDim); + + return rewriter.notifyMatchFailure(binder.op, + resizingIntentErrorMessage); + }; + if (antialias != 0) { return rewriter.notifyMatchFailure( binder.op, @@ -2764,27 +2799,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented: cubic coeff must be -0.75"); } - unsigned rank = dyn_cast(operands[0].getType()) - .getSizes() - .size(); + auto loc = binder.getLoc(); - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value cstTrue = - rewriter.create(binder.getLoc(), true); + Value cstFalse = rewriter.create(loc, false); + Value cstTrue = rewriter.create(loc, true); Value modeStrValue; - Value scalesValueList = noneVal; - Value sizesValueList = noneVal; Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; if (mode == "cubic") { std::string modeStr = "cubic"; if (coordTfMode != "half_pixel") modeStr = modeStr + "_" + coordTfMode; - modeStrValue = - rewriter.create(binder.getLoc(), modeStr); + modeStrValue = rewriter.create(loc, modeStr); } + + auto rankOfInputTensor = sizesOfInputTensor.size(); + // supported modes: // bilinear (half_pixel), bilinear with align_corners, // bilinear_pytorch_half_pixel, bilinear_asymmetric nearest @@ -2792,7 +2823,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // nearest_pytorch_half_pixel if (mode == "linear") { std::string modeStr; - switch (rank) { + switch (rankOfInputTensor) { case 3: modeStr = "linear"; break; @@ -2809,8 +2840,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // mode is apparently half_pixel, NOT pytorch_half_pixel if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") modeStr = (modeStr + "_") + coordTfMode; - modeStrValue = - rewriter.create(binder.getLoc(), modeStr); + modeStrValue = rewriter.create(loc, modeStr); } if (mode == "nearest") { std::string modeStr = "nearest"; @@ -2820,33 +2850,84 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStr = (modeStr + "_") + coordTfMode; if (nearest_mode != "floor" && nearest_mode != "") modeStr = modeStr + "," + nearest_mode; - modeStrValue = - rewriter.create(binder.getLoc(), modeStr); + modeStrValue = rewriter.create(loc, modeStr); } - int64_t assumedForemostSpatialDim = 2; + auto numberOfOperands = operands.size(); - if (operands.size() < 4) { - Value scaleOperand = operands[2]; - scalesValueList = - createScalarSublist(binder.getLoc(), scaleOperand, - assumedForemostSpatialDim, rewriter); - sizesValueList = noneVal; - } else { - Value sizeOperand = operands[3]; - scalesValueList = noneVal; - sizesValueList = - createScalarSublist(binder.getLoc(), sizeOperand, - assumedForemostSpatialDim, rewriter); - } - if (isa(scalesValueList.getType()) && - isa(sizesValueList.getType())) { + Type boolType = rewriter.getType(); + + int64_t assumedForemostSpatialDim = 1 + nonResizableDims.back(); + + Value supportedScaleFactors; + Value supportedSizes; + + Value noneVal = rewriter.create(loc); + + if (numberOfOperands == 3) { + Value proposedScaleFactors = operands[2]; + + Value scaleIdentity = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + + // run-time scale factor check for dynamic sizes + for (auto &eachDim : nonResizableDims) { + Value eachProposedScaleFactor = extractTorchScalar( + loc, eachDim, proposedScaleFactors, rewriter); + + Value eachScaleFactorIsIdentity = + rewriter.create( + loc, boolType, eachProposedScaleFactor, scaleIdentity); + + auto errorMessageForEachDim = + "Unsupported: non-trivial scale factor for dimension " + + std::to_string(eachDim); + + rewriter.create( + loc, eachScaleFactorIsIdentity, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + + supportedScaleFactors = createScalarSublist( + loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter); + supportedSizes = noneVal; + } else if (numberOfOperands == 4) { + Value proposedSizes = operands[3]; + + // run-time target size check for dynamic sizes + for (auto &eachDimAsInt : nonResizableDims) { + Value eachDimAsValue = + rewriter.create(loc, eachDimAsInt); + + Value eachSizeOfInputTensor = rewriter.create( + loc, inputTensor, eachDimAsValue); + + Value eachProposedSize = + extractTorchScalar(loc, eachDimAsInt, proposedSizes, rewriter); + + Value eachProposedSizeIsTrivial = + rewriter.create( + loc, boolType, eachProposedSize, eachSizeOfInputTensor); + + auto errorMessageForEachDim = + "Unsupported: non-trivial resizing of dimension " + + std::to_string(eachDimAsInt); + + rewriter.create( + loc, eachProposedSizeIsTrivial, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + + supportedScaleFactors = noneVal; + supportedSizes = createScalarSublist( + loc, proposedSizes, assumedForemostSpatialDim, rewriter); + } else return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); - } + rewriter .replaceOpWithNewOp( - binder.op, resultType, operands[0], sizesValueList, - scalesValueList, modeStrValue, + binder.op, outputTensorType, inputTensor, supportedSizes, + supportedScaleFactors, modeStrValue, /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, /*Torch_BoolType:$antialias*/ cstFalse); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 5dd6ee037b75..22f5cbbe8752 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2256,21 +2256,30 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.*]] = torch.constant.str "nearest" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "asymmetric", + torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, + torch.onnx.mode = "nearest", + torch.onnx.nearest_mode = "floor" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } // ----- -// CHECK-LABEL: func.func @test_resize_sizes_nearest -func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_resize_sizes_nearest_half_pixel +func.func @test_resize_sizes_nearest_half_pixel(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", - torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + torch.onnx.mode = "nearest" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -2280,8 +2289,12 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.*]] = torch.constant.str "bilinear" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.mode = "linear" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> }