diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1a730a0475ed..e6be0304f64d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1606,6 +1606,74 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /* cudnn enabled */ boolFalse); return success(); }); + patterns.onOp( + "MeanVarianceNormalization", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + SmallVector axes; + + if (binder.tensorOperand(input) || + binder.s64IntegerArrayAttr(axes, "axes", + llvm::SmallVector({0, 2, 3})) || + binder.tensorResultType(resultType)) { + return failure(); + } + if (!resultType.hasSizes() || !resultType.hasDtype()) { + return failure(); + } + auto inputTy = cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return failure(); + } + int64_t inputRank = inputTy.getSizes().size(); + + Location loc = binder.getLoc(); + Value keepDim = rewriter.create(loc, true); + Value unBiased = rewriter.create(loc, false); + Value none = rewriter.create(loc); + + ArrayRef output_shape = resultType.getSizes(); + SmallVector reduced_shape(output_shape); + + for (int64_t i : axes) { + int64_t dim = Torch::toPositiveDim(i, inputRank); + if (!Torch::isValidDim(dim, inputRank)) { + return failure(); + } + reduced_shape[dim] = 1; + } + Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get( + resultType.getContext(), reduced_shape, resultType.getDtype()); + SmallVector cstAxes; + for (int64_t i : axes) { + cstAxes.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } + Value axes_list = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstAxes); + Value mean = rewriter.create( + loc, reducedOutTy, input, axes_list, keepDim, none); + Value variance = rewriter.create( + loc, reducedOutTy, input, axes_list, unBiased, keepDim); + Value cstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value cstEps = rewriter.create( + loc, rewriter.getF64FloatAttr(1e-9)); + variance = rewriter.create( + loc, reducedOutTy, variance, cstEps, cstOne); + Value sqrtVar = + rewriter.create(loc, reducedOutTy, variance); + Value inputMinusMean = rewriter.create( + loc, resultType, input, mean, cstOne); + Value meanVarNorm = rewriter.create( + loc, resultType, inputMinusMean, sqrtVar); + + rewriter.replaceOp(binder.op, meanVarNorm); + return success(); + }); patterns.onOp( "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index a336d78f55dd..8ff895bd71cc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1595,6 +1595,82 @@ func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @test_meanvarnorm( +func.func @test_meanvarnorm(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]] = torch.constant.bool true + // CHECK: %[[VAL_1:.*]] = torch.constant.bool false + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_7:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_6]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_8:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_6]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_9:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e-09 + // CHECK: %[[VAL_11:.*]] = torch.aten.add.Scalar %[[VAL_8]], %[[VAL_10]], %[[VAL_9]] : !torch.vtensor<[1,5,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_12:.*]] = torch.aten.sqrt %[[VAL_11]] : !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_13:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: %[[VAL_14:.*]] = torch.aten.div.Tensor %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: return %[[VAL_14]] : !torch.vtensor<[3,5,2,2],f32> + // CHECK: } + %0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> + return %0 : !torch.vtensor<[3,5,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_meanvarnorm_axes( +func.func @test_meanvarnorm_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]] = torch.constant.bool true + // CHECK: %[[VAL_1:.*]] = torch.constant.bool false + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_4:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_8:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09 + // CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32> + // CHECK: } + %0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [1 : si64, 3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> + return %0 : !torch.vtensor<[3,5,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_meanvarnorm_neg_axes( +func.func @test_meanvarnorm_neg_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]] = torch.constant.bool true + // CHECK: %[[VAL_1:.*]] = torch.constant.bool false + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int -1 + // CHECK: %[[VAL_4:.*]] = torch.constant.int -3 + // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_8:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09 + // CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32> + // CHECK: } + %0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [-1 : si64, -3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> + return %0 : !torch.vtensor<[3,5,2,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>