From c0e47ff71ffcae30a58b309cf4698c1c09eb5f96 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Tue, 28 Jan 2025 17:59:31 +0000 Subject: [PATCH 01/24] refactor(ONNX): chains mutually-exclusive guard and operand usage in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 944c258a8d12..b3cca7c5d656 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2832,17 +2832,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( createScalarSublist(binder.getLoc(), scaleOperand, assumedForemostSpatialDim, rewriter); sizesValueList = noneVal; - } else { + } else if (operands.size() == 4) { Value sizeOperand = operands[3]; scalesValueList = noneVal; sizesValueList = createScalarSublist(binder.getLoc(), sizeOperand, assumedForemostSpatialDim, rewriter); - } - if (isa(scalesValueList.getType()) && - isa(sizesValueList.getType())) { + } else return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); - } + rewriter .replaceOpWithNewOp( binder.op, resultType, operands[0], sizesValueList, From bf60cfa8487b9b5abe0d15d8380b1062fbe64df0 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Tue, 28 Jan 2025 18:02:18 +0000 Subject: [PATCH 02/24] refactor(ONNX): narrows operand count comparison in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b3cca7c5d656..f8d59a470386 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2826,7 +2826,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( int64_t assumedForemostSpatialDim = 2; - if (operands.size() < 4) { + if (operands.size() == 3) { Value scaleOperand = operands[2]; scalesValueList = createScalarSublist(binder.getLoc(), scaleOperand, From 82f6445f604e4aa5e40a769adadf21e09b456cd0 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Tue, 28 Jan 2025 18:05:14 +0000 Subject: [PATCH 03/24] refactor(ONNX): extracts `numberOfOperands` within onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index f8d59a470386..310318294b19 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2824,15 +2824,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.create(binder.getLoc(), modeStr); } + auto numberOfOperands = operands.size(); + int64_t assumedForemostSpatialDim = 2; - if (operands.size() == 3) { + if (numberOfOperands == 3) { Value scaleOperand = operands[2]; scalesValueList = createScalarSublist(binder.getLoc(), scaleOperand, assumedForemostSpatialDim, rewriter); sizesValueList = noneVal; - } else if (operands.size() == 4) { + } else if (numberOfOperands == 4) { Value sizeOperand = operands[3]; scalesValueList = noneVal; sizesValueList = From 50c252992afda0deec33a0831e06cf10dae5d506 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Tue, 14 Jan 2025 17:44:54 +0000 Subject: [PATCH 04/24] refactor(ONNX): enforces min assignment-usage distance for value lists in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 310318294b19..0361733b3efe 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2774,8 +2774,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.create(binder.getLoc(), true); Value modeStrValue; - Value scalesValueList = noneVal; - Value sizesValueList = noneVal; Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; if (mode == "cubic") { @@ -2828,6 +2826,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( int64_t assumedForemostSpatialDim = 2; + Value scalesValueList = noneVal; + Value sizesValueList = noneVal; + if (numberOfOperands == 3) { Value scaleOperand = operands[2]; scalesValueList = From e4f4425fa95171b84cba1c5b8c3c2aa10319110b Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Tue, 14 Jan 2025 17:46:48 +0000 Subject: [PATCH 05/24] refactor(ONNX): removes redundant nulling assignment in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 0361733b3efe..668f12184c9d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2826,8 +2826,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( int64_t assumedForemostSpatialDim = 2; - Value scalesValueList = noneVal; - Value sizesValueList = noneVal; + Value scalesValueList; + Value sizesValueList; if (numberOfOperands == 3) { Value scaleOperand = operands[2]; From bd23b93436e9aae16345e7c3b649b0c772a39810 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 8 Jan 2025 15:19:30 +0000 Subject: [PATCH 06/24] refactor(ONNX): enforces min assignment-usage distance for `noneVal` in onnx.resize - avoids SSA before match failures --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 668f12184c9d..1e6047790c33 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2705,7 +2705,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( std::string mode, nearest_mode, coordTfMode; int64_t antialias, exclude_outside; float extrapolation_value, cubic_coeff_a; - Value noneVal = rewriter.create(binder.getLoc()); if (auto attr = binder.op->getAttr("torch.onnx.axes")) { return rewriter.notifyMatchFailure( @@ -2829,6 +2828,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scalesValueList; Value sizesValueList; + Value noneVal = rewriter.create(binder.getLoc()); + if (numberOfOperands == 3) { Value scaleOperand = operands[2]; scalesValueList = From 87f9f54b164b026534c2ce351a4f86dbeadfb3b4 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Thu, 9 Jan 2025 22:21:09 +0000 Subject: [PATCH 07/24] refactor(ONNX): extracts `loc` within onnx.resize --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 1e6047790c33..955bb502b681 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2767,10 +2767,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( .getSizes() .size(); - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value cstTrue = - rewriter.create(binder.getLoc(), true); + auto loc = binder.getLoc(); + + Value cstFalse = rewriter.create(loc, false); + Value cstTrue = rewriter.create(loc, true); Value modeStrValue; Value alignCorners = @@ -2779,8 +2779,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( std::string modeStr = "cubic"; if (coordTfMode != "half_pixel") modeStr = modeStr + "_" + coordTfMode; - modeStrValue = - rewriter.create(binder.getLoc(), modeStr); + modeStrValue = rewriter.create(loc, modeStr); } // supported modes: // bilinear (half_pixel), bilinear with align_corners, @@ -2806,8 +2805,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // mode is apparently half_pixel, NOT pytorch_half_pixel if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") modeStr = (modeStr + "_") + coordTfMode; - modeStrValue = - rewriter.create(binder.getLoc(), modeStr); + modeStrValue = rewriter.create(loc, modeStr); } if (mode == "nearest") { std::string modeStr = "nearest"; @@ -2817,8 +2815,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStr = (modeStr + "_") + coordTfMode; if (nearest_mode != "floor" && nearest_mode != "") modeStr = modeStr + "," + nearest_mode; - modeStrValue = - rewriter.create(binder.getLoc(), modeStr); + modeStrValue = rewriter.create(loc, modeStr); } auto numberOfOperands = operands.size(); @@ -2828,20 +2825,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scalesValueList; Value sizesValueList; - Value noneVal = rewriter.create(binder.getLoc()); + Value noneVal = rewriter.create(loc); if (numberOfOperands == 3) { Value scaleOperand = operands[2]; - scalesValueList = - createScalarSublist(binder.getLoc(), scaleOperand, - assumedForemostSpatialDim, rewriter); + scalesValueList = createScalarSublist( + loc, scaleOperand, assumedForemostSpatialDim, rewriter); sizesValueList = noneVal; } else if (numberOfOperands == 4) { Value sizeOperand = operands[3]; scalesValueList = noneVal; - sizesValueList = - createScalarSublist(binder.getLoc(), sizeOperand, - assumedForemostSpatialDim, rewriter); + sizesValueList = createScalarSublist( + loc, sizeOperand, assumedForemostSpatialDim, rewriter); } else return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); From 874827d9cb59b92b8f82b55dbe0f41ec77f5a273 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Fri, 10 Jan 2025 20:04:19 +0000 Subject: [PATCH 08/24] refactor(ONNX): moves `rank` closer to first usage in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 955bb502b681..c04813fd86b0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2763,10 +2763,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented: cubic coeff must be -0.75"); } - unsigned rank = dyn_cast(operands[0].getType()) - .getSizes() - .size(); - auto loc = binder.getLoc(); Value cstFalse = rewriter.create(loc, false); @@ -2781,6 +2777,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStr = modeStr + "_" + coordTfMode; modeStrValue = rewriter.create(loc, modeStr); } + + unsigned rank = dyn_cast(operands[0].getType()) + .getSizes() + .size(); + // supported modes: // bilinear (half_pixel), bilinear with align_corners, // bilinear_pytorch_half_pixel, bilinear_asymmetric nearest From 1d77eb95a4d60eb03e40bb71583c4e7a6a945034 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 8 Jan 2025 17:08:00 +0000 Subject: [PATCH 09/24] refactor(ONNX): forces cast of operand in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index c04813fd86b0..51fd0492f0a6 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2778,7 +2778,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStrValue = rewriter.create(loc, modeStr); } - unsigned rank = dyn_cast(operands[0].getType()) + unsigned rank = cast(operands[0].getType()) .getSizes() .size(); From e41fa6254fa459b34826165a436fa7ef15161323 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 15 Jan 2025 21:34:21 +0000 Subject: [PATCH 10/24] refactor(ONNX): loosens downcast in onnx.resize - cast to `ValueTensorType` was overly specific for the methods used --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 51fd0492f0a6..77beda583f97 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2778,7 +2778,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStrValue = rewriter.create(loc, modeStr); } - unsigned rank = cast(operands[0].getType()) + unsigned rank = cast(operands[0].getType()) .getSizes() .size(); From 140b62820a2bafea7134c0c81f5ee63997239cfe Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 8 Jan 2025 17:32:40 +0000 Subject: [PATCH 11/24] refactor(ONNX): extracts `inputTensor` within onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 77beda583f97..834930719686 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2778,7 +2778,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStrValue = rewriter.create(loc, modeStr); } - unsigned rank = cast(operands[0].getType()) + Value inputTensor = operands[0]; + unsigned rank = cast(inputTensor.getType()) .getSizes() .size(); @@ -2843,7 +2844,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter .replaceOpWithNewOp( - binder.op, resultType, operands[0], sizesValueList, + binder.op, resultType, inputTensor, sizesValueList, scalesValueList, modeStrValue, /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, From fff08f2bc7d72f09f434e7af76a2407c284ce4d5 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 8 Jan 2025 17:38:15 +0000 Subject: [PATCH 12/24] refactor(ONNX): extracts `inputTensorType` from rank derivation in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 834930719686..e78ee16fdd6a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2779,9 +2779,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } Value inputTensor = operands[0]; - unsigned rank = cast(inputTensor.getType()) - .getSizes() - .size(); + auto inputTensorType = + cast(inputTensor.getType()); + unsigned rank = inputTensorType.getSizes().size(); // supported modes: // bilinear (half_pixel), bilinear with align_corners, From 1f7cdf0b69b8342cd56f24d4ac93daaaaa2b5b83 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 8 Jan 2025 17:40:02 +0000 Subject: [PATCH 13/24] refactor(ONNX): extracts `sizesOfInputTensor` from rank derivation in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index e78ee16fdd6a..97d635a4dd3c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2781,7 +2781,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value inputTensor = operands[0]; auto inputTensorType = cast(inputTensor.getType()); - unsigned rank = inputTensorType.getSizes().size(); + auto sizesOfInputTensor = inputTensorType.getSizes(); + unsigned rank = sizesOfInputTensor.size(); // supported modes: // bilinear (half_pixel), bilinear with align_corners, From e835f1ac197ff07751c832967c2570031d176af7 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Mon, 27 Jan 2025 18:04:19 +0000 Subject: [PATCH 14/24] refactor(ONNX): uses `auto` annotation for `rank` in onnx.resize - intellisense is able to infer `unsigned` aspect from `.size()` --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 97d635a4dd3c..184940ebc1b3 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2782,7 +2782,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto inputTensorType = cast(inputTensor.getType()); auto sizesOfInputTensor = inputTensorType.getSizes(); - unsigned rank = sizesOfInputTensor.size(); + auto rank = sizesOfInputTensor.size(); // supported modes: // bilinear (half_pixel), bilinear with align_corners, From a858e45d01d020635044f07b75785ab3abf6e4de Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Fri, 10 Jan 2025 20:40:28 +0000 Subject: [PATCH 15/24] refactor(ONNX): renames `rank` to `rankOfInputTensor` in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 184940ebc1b3..de3504acd5c0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2782,7 +2782,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto inputTensorType = cast(inputTensor.getType()); auto sizesOfInputTensor = inputTensorType.getSizes(); - auto rank = sizesOfInputTensor.size(); + auto rankOfInputTensor = sizesOfInputTensor.size(); // supported modes: // bilinear (half_pixel), bilinear with align_corners, @@ -2791,7 +2791,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // nearest_pytorch_half_pixel if (mode == "linear") { std::string modeStr; - switch (rank) { + switch (rankOfInputTensor) { case 3: modeStr = "linear"; break; From 3f41467340a34148d16dd62f010d6094ae3d93aa Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 8 Jan 2025 16:31:41 +0000 Subject: [PATCH 16/24] refactor(ONNX): renames `resultType` to `outputTensorType` in onnx.resize - emphasizes parallel to `inputTensorType` --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index de3504acd5c0..65ae090445bf 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2700,7 +2700,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( }); patterns.onOp( "Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; + Torch::ValueTensorType outputTensorType; llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; int64_t antialias, exclude_outside; @@ -2719,7 +2719,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || + binder.tensorResultType(outputTensorType) || binder.customOpNameStringAttr(mode, "mode", "nearest") || binder.customOpNameStringAttr( coordTfMode, "coordinate_transformation_mode", "half_pixel") || @@ -2845,7 +2845,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter .replaceOpWithNewOp( - binder.op, resultType, inputTensor, sizesValueList, + binder.op, outputTensorType, inputTensor, sizesValueList, scalesValueList, modeStrValue, /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, From 948a53ef2a87675058c7253d3344833b80c54788 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 15 Jan 2025 15:32:40 +0000 Subject: [PATCH 17/24] refactor(ONNX): renames `sizesValueList` to `supportedSizes` in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 65ae090445bf..4a8abd18c859 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2826,7 +2826,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( int64_t assumedForemostSpatialDim = 2; Value scalesValueList; - Value sizesValueList; + Value supportedSizes; Value noneVal = rewriter.create(loc); @@ -2834,18 +2834,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scaleOperand = operands[2]; scalesValueList = createScalarSublist( loc, scaleOperand, assumedForemostSpatialDim, rewriter); - sizesValueList = noneVal; + supportedSizes = noneVal; } else if (numberOfOperands == 4) { Value sizeOperand = operands[3]; scalesValueList = noneVal; - sizesValueList = createScalarSublist( + supportedSizes = createScalarSublist( loc, sizeOperand, assumedForemostSpatialDim, rewriter); } else return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); rewriter .replaceOpWithNewOp( - binder.op, outputTensorType, inputTensor, sizesValueList, + binder.op, outputTensorType, inputTensor, supportedSizes, scalesValueList, modeStrValue, /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, From fd20a79182b2bcfb2804c878363c0177c85a4539 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 15 Jan 2025 15:36:13 +0000 Subject: [PATCH 18/24] refactor(ONNX): renames `scalesValueList` to `supportedScaleFactors` in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 4a8abd18c859..c73b3c496fc0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2825,19 +2825,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( int64_t assumedForemostSpatialDim = 2; - Value scalesValueList; + Value supportedScaleFactors; Value supportedSizes; Value noneVal = rewriter.create(loc); if (numberOfOperands == 3) { Value scaleOperand = operands[2]; - scalesValueList = createScalarSublist( + supportedScaleFactors = createScalarSublist( loc, scaleOperand, assumedForemostSpatialDim, rewriter); supportedSizes = noneVal; } else if (numberOfOperands == 4) { Value sizeOperand = operands[3]; - scalesValueList = noneVal; + supportedScaleFactors = noneVal; supportedSizes = createScalarSublist( loc, sizeOperand, assumedForemostSpatialDim, rewriter); } else @@ -2846,7 +2846,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter .replaceOpWithNewOp( binder.op, outputTensorType, inputTensor, supportedSizes, - scalesValueList, modeStrValue, + supportedScaleFactors, modeStrValue, /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, /*Torch_BoolType:$antialias*/ cstFalse); From b897a3400ef8f20cfa3016f4659c6fbd39ee94d5 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Tue, 14 Jan 2025 15:28:09 +0000 Subject: [PATCH 19/24] refactor(ONNX): renames `scaleOperand` to `proposedScaleFactors` in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index c73b3c496fc0..696cddf398e1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2831,9 +2831,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value noneVal = rewriter.create(loc); if (numberOfOperands == 3) { - Value scaleOperand = operands[2]; + Value proposedScaleFactors = operands[2]; supportedScaleFactors = createScalarSublist( - loc, scaleOperand, assumedForemostSpatialDim, rewriter); + loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter); supportedSizes = noneVal; } else if (numberOfOperands == 4) { Value sizeOperand = operands[3]; From 01e2274c5b9e02d04c5177b6a4c7d74a500b00b9 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Tue, 14 Jan 2025 15:34:40 +0000 Subject: [PATCH 20/24] refactor(ONNX): renames `sizeOperand` to `proposedSizes` in onnx.resize --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 696cddf398e1..1cda901dfc51 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2836,10 +2836,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter); supportedSizes = noneVal; } else if (numberOfOperands == 4) { - Value sizeOperand = operands[3]; + Value proposedSizes = operands[3]; supportedScaleFactors = noneVal; supportedSizes = createScalarSublist( - loc, sizeOperand, assumedForemostSpatialDim, rewriter); + loc, proposedSizes, assumedForemostSpatialDim, rewriter); } else return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); From 266a82076c974b05134409702dccea7ffe7ae377 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 22 Jan 2025 15:36:26 +0000 Subject: [PATCH 21/24] refactor(ONNX): prefers multiline attributes in onnx.resize tests - easier to read - allows for cleaner diffs if they ever change --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 5dd6ee037b75..24088bfa0727 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2257,7 +2257,12 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "asymmetric", + torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, + torch.onnx.mode = "nearest", + torch.onnx.nearest_mode = "floor" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -2270,7 +2275,8 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", - torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + torch.onnx.mode = "nearest" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -2281,7 +2287,9 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.mode = "linear" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } From c9f219708b1d97c94c4c3c822de82c26fa33eab0 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 22 Jan 2025 15:52:00 +0000 Subject: [PATCH 22/24] refactor(ONNX): distills checks in lit tests for onnx.resize --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 24088bfa0727..8eaf92218034 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2256,7 +2256,9 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.*]] = torch.constant.str "nearest" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, @@ -2271,8 +2273,9 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "nearest" @@ -2286,7 +2289,9 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.*]] = torch.constant.str "bilinear" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.mode = "linear" } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> From 6dc3fdfc4e847a488cdda3b1f8576136f19a942c Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 22 Jan 2025 15:47:25 +0000 Subject: [PATCH 23/24] fix(ONNX): differentiates names of lit tests for onnx.resize --- test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 8eaf92218034..22f5cbbe8752 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2270,8 +2270,8 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // ----- -// CHECK-LABEL: func.func @test_resize_sizes_nearest -func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_resize_sizes_nearest_half_pixel +func.func @test_resize_sizes_nearest_half_pixel(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: %[[MODE_STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" // CHECK: torch.aten.__interpolate.size_list_scale_list From a23008444f4f310613b4a7554a0d0191aa020b6c Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 15 Jan 2025 18:02:14 +0000 Subject: [PATCH 24/24] fix(ONNX): avoids resizing unsupported dimensions --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 91 ++++++++++++++++++- 1 file changed, 86 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 1cda901dfc51..142caf583f24 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2731,6 +2731,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "round_prefer_floor") || binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); + + int64_t const /* */ batchDim = 0; + int64_t const /**/ channelDim = 1; + + SmallVector nonResizableDims{ + batchDim, + channelDim, + }; + + Value inputTensor = operands[0]; + auto inputTensorType = + cast(inputTensor.getType()); + auto sizesOfInputTensor = inputTensorType.getSizes(); + auto sizesOfOutputTensor = outputTensorType.getSizes(); + + auto unknownSize = Torch::kUnknownSize; + + // Compile-time check for dimensions of static size + for (auto &eachDim : nonResizableDims) { + auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim]; + auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim]; + + if (eachSizeOfInputTensor == unknownSize || + eachSizeOfOutputTensor == unknownSize) + continue; + if (eachSizeOfInputTensor == eachSizeOfOutputTensor) + continue; + + auto resizingIntentErrorMessage = + "unsupported: non-trivial intent to resize dimension: " + + std::to_string(eachDim); + + return rewriter.notifyMatchFailure(binder.op, + resizingIntentErrorMessage); + }; + if (antialias != 0) { return rewriter.notifyMatchFailure( binder.op, @@ -2778,10 +2814,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStrValue = rewriter.create(loc, modeStr); } - Value inputTensor = operands[0]; - auto inputTensorType = - cast(inputTensor.getType()); - auto sizesOfInputTensor = inputTensorType.getSizes(); auto rankOfInputTensor = sizesOfInputTensor.size(); // supported modes: @@ -2823,7 +2855,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto numberOfOperands = operands.size(); - int64_t assumedForemostSpatialDim = 2; + Type boolType = rewriter.getType(); + + int64_t assumedForemostSpatialDim = 1 + nonResizableDims.back(); Value supportedScaleFactors; Value supportedSizes; @@ -2832,11 +2866,58 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (numberOfOperands == 3) { Value proposedScaleFactors = operands[2]; + + Value scaleIdentity = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + + // run-time scale factor check for dynamic sizes + for (auto &eachDim : nonResizableDims) { + Value eachProposedScaleFactor = extractTorchScalar( + loc, eachDim, proposedScaleFactors, rewriter); + + Value eachScaleFactorIsIdentity = + rewriter.create( + loc, boolType, eachProposedScaleFactor, scaleIdentity); + + auto errorMessageForEachDim = + "Unsupported: non-trivial scale factor for dimension " + + std::to_string(eachDim); + + rewriter.create( + loc, eachScaleFactorIsIdentity, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + supportedScaleFactors = createScalarSublist( loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter); supportedSizes = noneVal; } else if (numberOfOperands == 4) { Value proposedSizes = operands[3]; + + // run-time target size check for dynamic sizes + for (auto &eachDimAsInt : nonResizableDims) { + Value eachDimAsValue = + rewriter.create(loc, eachDimAsInt); + + Value eachSizeOfInputTensor = rewriter.create( + loc, inputTensor, eachDimAsValue); + + Value eachProposedSize = + extractTorchScalar(loc, eachDimAsInt, proposedSizes, rewriter); + + Value eachProposedSizeIsTrivial = + rewriter.create( + loc, boolType, eachProposedSize, eachSizeOfInputTensor); + + auto errorMessageForEachDim = + "Unsupported: non-trivial resizing of dimension " + + std::to_string(eachDimAsInt); + + rewriter.create( + loc, eachProposedSizeIsTrivial, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + supportedScaleFactors = noneVal; supportedSizes = createScalarSublist( loc, proposedSizes, assumedForemostSpatialDim, rewriter);