From 0888268753289ce9e9331ae66cdea5e6c083c88e Mon Sep 17 00:00:00 2001 From: Praveen G Date: Tue, 7 Jan 2025 11:47:21 +0000 Subject: [PATCH 1/5] [torch-mlir] Support lowering of aten constraint ops 1. aten::sym_constrain_range 2. aten::sym_constrain_range_for_size 3. aten::_assert_scalar --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 71 ++++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 97 +++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 52 ++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 4 + .../build_tools/torch_ods_gen.py | 5 + .../torch_mlir_e2e_test/test_suite/basic.py | 63 ++++++++++++ .../Conversion/TorchToLinalg/constraints.mlir | 82 ++++++++++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 24 +++++ 8 files changed, 398 insertions(+) create mode 100644 test/Conversion/TorchToLinalg/constraints.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 2d71d0d8fe3d..c5a31a3d2fb2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -17771,6 +17771,77 @@ def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_ }]; } +def Torch_AtenSymConstrainRangeOp : Torch_Op<"aten.sym_constrain_range", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sym_constrain_range : (Scalar, int?, int?) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$size, + AnyTorchOptionalIntType:$min, + AnyTorchOptionalIntType:$max + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSymConstrainRangeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 0); + } + void AtenSymConstrainRangeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 0); + } + }]; +} + +def Torch_AtenSymConstrainRangeForSizeOp : Torch_Op<"aten.sym_constrain_range_for_size", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$size, + AnyTorchOptionalIntType:$min, + AnyTorchOptionalIntType:$max + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSymConstrainRangeForSizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 0); + } + void AtenSymConstrainRangeForSizeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 0); + } + }]; +} + +def Torch_Aten_AssertScalarOp : Torch_Op<"aten._assert_scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_assert_scalar : (Scalar, str) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$self, + Torch_StringType:$assert_msg + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AssertScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 0); + } + void Aten_AssertScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 0); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index c83f49d7f62d..02ad8f6f3b13 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -25,6 +25,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" #include +#include #include using namespace mlir; @@ -3564,6 +3565,98 @@ class ConvertAtenPolarOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertSymConstrainRangeOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSymConstrainRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + auto loc = op.getLoc(); + auto min = op.getMin(); + auto max = op.getMax(); + + auto minOp = min.getDefiningOp(); + auto maxOp = max.getDefiningOp(); + + if (!minOp || !maxOp) + return op.emitError("Unimplemented: Non constant min/max values"); + + int64_t minValue = std::numeric_limits::min(); + int64_t maxValue = std::numeric_limits::max(); + + Type operandType = rewriter.getI64Type(); + + if (!isa(minOp)) + if (!matchPattern(min, m_TorchConstantInt(&minValue))) + return rewriter.notifyMatchFailure( + op, "Expected min value to be constant integer"); + + if (!isa(maxOp)) + if (!matchPattern(max, m_TorchConstantInt(&maxValue))) + return rewriter.notifyMatchFailure( + op, "Expected max value to be constant integer"); + + if (maxValue < minValue) { + std::string errorMsg = + "Max must be greater than or equal to min, got min = " + + std::to_string(minValue) + ", max = " + std::to_string(maxValue); + return op.emitError(errorMsg); + } + + min = getConstant(rewriter, loc, minValue, operandType); + max = getConstant(rewriter, loc, maxValue, operandType); + + // Check min <= size <= max + + // FIXME:: Skip the below checks if constraint ops are already inserted as + // part of symbol expr evaluation + auto checkMin = createLessThanOrEqual(rewriter, loc, operandType, min, + adaptor.getSize()); + auto checkMax = createLessThanOrEqual(rewriter, loc, operandType, + adaptor.getSize(), max); + auto compareVal = rewriter.create(loc, checkMin, checkMax); + + std::string assertMessage = "Invalid value range for size between [" + + std::to_string(minValue) + ", " + + std::to_string(maxValue) + "]"; + rewriter.create(loc, compareVal, + rewriter.getStringAttr(assertMessage)); + + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +namespace { +class ConvertAssertScalarOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Aten_AssertScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + auto assertCond = convertScalarToDtype( + rewriter, op.getLoc(), adaptor.getSelf(), rewriter.getI1Type()); + + std::string assertMessage; + if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage))) + return rewriter.notifyMatchFailure( + op, "Assert message must be a constant string"); + + rewriter.replaceOpWithNewOp(op, assertCond, assertMessage); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -3626,4 +3719,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 3303ec1ecc1b..41a1186bc987 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11455,6 +11455,56 @@ class DecomposeAtenSpecialExpm1Op }; } // namespace +namespace { +class DecomposeConstrainRangeForSizeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSymConstrainRangeForSizeOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto min = op.getMin(); + auto max = op.getMax(); + auto minOp = min.getDefiningOp(); + auto maxOp = max.getDefiningOp(); + + if (!minOp || !maxOp) + return op.emitError("Unimplemented: Non constant min/max values"); + + int64_t minValue, maxValue; + + if (isa(minOp)) { + // Set min value to 0 + min = rewriter.create(loc, 0); + } else { + // Check if min value is a constant + if (!matchPattern(min, m_TorchConstantInt(&minValue))) + return rewriter.notifyMatchFailure( + op, "Expected min value to be constant integer"); + } + + if (!isa(maxOp)) { + // Verify that max value is greater than 2 + if (!matchPattern(max, m_TorchConstantInt(&maxValue))) + return rewriter.notifyMatchFailure( + op, "Expected max value to be constant integer"); + + if (maxValue <= 2) { + std::string errorMsg = "Max value to constrain_range_for_size must be " + "greater than 2, got: " + + std::to_string(maxValue); + return op.emitError(errorMsg); + } + } + + rewriter.replaceOpWithNewOp(op, op.getSize(), min, + max); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11753,6 +11803,8 @@ class DecomposeComplexOpsPass // Torchvision ops addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4df3d186f8ea..3c7b878478a5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -35,6 +35,10 @@ "Aten_TrilinearModuleZerodDimBug_basic", # missing lowering from aten.pow.Tensor_Tensor for integer result "PowIntIntModule_basic", + # Unknown builtin op: aten::_check_is_size in TorchScript + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "AtenAssertScalar", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 4d7f8d52268c..350fea711bbf 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1232,6 +1232,11 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") + # Constraint ops + emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()") + emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()") + emit("aten::_assert_scalar : (Scalar, str) -> ()") + # ========================================================================== # `prim::` namespace. # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index fe8a31186807..16c3aadb337f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6480,3 +6480,66 @@ def forward(self, x): @register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule()) def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils): module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) + + +# ============================================================================== + + +class AtenSymConstrainRange(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + torch._check_is_size(a) + torch.ops.aten.sym_constrain_range(a, max=5) + return a + + +@register_test_case(module_factory=lambda: AtenSymConstrainRange()) +def AtenSymConstrainRange_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) + + +# ============================================================================== + + +class AtenSymConstrainRangeForSize(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + torch._check_is_size(a) + # max should be >= 2 + torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10) + return a + + +@register_test_case(module_factory=lambda: AtenSymConstrainRangeForSize()) +def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) + + +# ============================================================================== +class AtenAssertScalar(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + # The below checks introduces aten._assert_scalar op + torch._check_is_size(a) + torch._check(a <= 5) + return a + + +@register_test_case(module_factory=lambda: AtenAssertScalar()) +def AtenAssertScalar_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) diff --git a/test/Conversion/TorchToLinalg/constraints.mlir b/test/Conversion/TorchToLinalg/constraints.mlir new file mode 100644 index 000000000000..8ae62822f9b3 --- /dev/null +++ b/test/Conversion/TorchToLinalg/constraints.mlir @@ -0,0 +1,82 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s +// ----- + +// CHECK-LABEL: func.func @torch.aten.sym_constrain_range( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_7:.*]] = arith.constant 9223372036854775807 : i64 +// CHECK: %[[VAL_8:.*]] = arith.cmpi ule, %[[VAL_6]], %[[VAL_5]] : i64 +// CHECK: %[[VAL_9:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1 +// CHECK: cf.assert %[[VAL_10]], "Invalid value range for size between [0, 9223372036854775807]" +// CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64 +// CHECK: %[[VAL_13:.*]] = arith.cmpi ule, %[[VAL_11]], %[[VAL_5]] : i64 +// CHECK: %[[VAL_14:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_12]] : i64 +// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 +// CHECK: cf.assert %[[VAL_15]], "Invalid value range for size between [0, 7]" +// CHECK: return %[[VAL_4]] : !torch.int + +module { + func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int { + %int7 = torch.constant.int 7 + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none + torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int + return %0 : !torch.int + } +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten._assert_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" +// CHECK: %[[VAL_2:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_3:.*]] = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]] +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_9:.*]] = arith.constant 9223372036854775807 : i64 +// CHECK: %[[VAL_10:.*]] = arith.cmpi ule, %[[VAL_8]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_11:.*]] = arith.cmpi ule, %[[VAL_7]], %[[VAL_9]] : i64 +// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_10]], %[[VAL_11]] : i1 +// CHECK: cf.assert %[[VAL_12]], "Invalid value range for size between [0, 9223372036854775807]" +// CHECK: %[[VAL_13:.*]] = torch.aten.ge.int %[[VAL_6]], %[[VAL_4]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_14:.*]] = torch.aten.Int.bool %[[VAL_13]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_15:.*]] = torch_c.to_i64 %[[VAL_14]] +// CHECK: %[[VAL_16:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_15]], %[[VAL_16]] : i64 +// CHECK: cf.assert %[[VAL_17]], "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" +// CHECK: %[[VAL_18:.*]] = torch.aten.le.int %[[VAL_6]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_19:.*]] = torch.aten.Int.bool %[[VAL_18]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_20:.*]] = torch_c.to_i64 %[[VAL_19]] +// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_22:.*]] = arith.cmpi ne, %[[VAL_20]], %[[VAL_21]] : i64 +// CHECK: cf.assert %[[VAL_22]], "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" +// CHECK: return %[[VAL_6]] : !torch.int +func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int { + %str = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" + %int7 = torch.constant.int 7 + %str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none + %1 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int + torch.aten._assert_scalar %2, %str_0 : !torch.int, !torch.str + %3 = torch.aten.le.int %0, %int7 : !torch.int, !torch.int -> !torch.bool + %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int + torch.aten._assert_scalar %4, %str : !torch.int, !torch.str + return %0 : !torch.int +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 384502ecd2af..0adb10edac80 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -228,3 +228,27 @@ func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> return %out : !torch.vtensor<[19,23],complex> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sym_constrain_range_for_size( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none +// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int +// CHECK: return %[[VAL_4]] : !torch.int +module { + func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int { + %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + %none = torch.constant.none + %none_0 = torch.constant.none + torch.aten.sym_constrain_range_for_size %0, %none, %none_0 : !torch.int, !torch.none, !torch.none + %int0_6 = torch.constant.int 0 + %int7_7 = torch.constant.int 7 + torch.aten.sym_constrain_range_for_size %0, %int0_6, %int7_7 : !torch.int, !torch.int, !torch.int + return %0 : !torch.int + } +} From 5001be48c28e9c9c966e09dbe0d92da628d046ce Mon Sep 17 00:00:00 2001 From: Praveen G Date: Tue, 7 Jan 2025 17:34:41 +0000 Subject: [PATCH 2/5] Address review comments --- lib/Conversion/TorchToLinalg/Uncategorized.cpp | 10 +++++----- .../torch_mlir_e2e_test/test_suite/basic.py | 2 +- test/Conversion/TorchToLinalg/constraints.mlir | 16 +++++++--------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 02ad8f6f3b13..360b86ccb701 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -3589,7 +3589,7 @@ class ConvertSymConstrainRangeOp int64_t minValue = std::numeric_limits::min(); int64_t maxValue = std::numeric_limits::max(); - Type operandType = rewriter.getI64Type(); + Type operandType = getTypeConverter()->convertType(op.getSize().getType()); if (!isa(minOp)) if (!matchPattern(min, m_TorchConstantInt(&minValue))) @@ -3615,10 +3615,10 @@ class ConvertSymConstrainRangeOp // FIXME:: Skip the below checks if constraint ops are already inserted as // part of symbol expr evaluation - auto checkMin = createLessThanOrEqual(rewriter, loc, operandType, min, - adaptor.getSize()); - auto checkMax = createLessThanOrEqual(rewriter, loc, operandType, - adaptor.getSize(), max); + auto checkMin = rewriter.create( + loc, arith::CmpIPredicate::sle, min, adaptor.getSize()); + auto checkMax = rewriter.create( + loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); auto compareVal = rewriter.create(loc, checkMin, checkMax); std::string assertMessage = "Invalid value range for size between [" + diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 16c3aadb337f..78856a17db3c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6515,7 +6515,7 @@ def __init__(self): def forward(self, x): a = x.item() torch._check_is_size(a) - # max should be >= 2 + # max should be > 2 torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10) return a diff --git a/test/Conversion/TorchToLinalg/constraints.mlir b/test/Conversion/TorchToLinalg/constraints.mlir index 8ae62822f9b3..bc48da402fb8 100644 --- a/test/Conversion/TorchToLinalg/constraints.mlir +++ b/test/Conversion/TorchToLinalg/constraints.mlir @@ -10,20 +10,19 @@ // CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] // CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_7:.*]] = arith.constant 9223372036854775807 : i64 -// CHECK: %[[VAL_8:.*]] = arith.cmpi ule, %[[VAL_6]], %[[VAL_5]] : i64 -// CHECK: %[[VAL_9:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64 +// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64 // CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1 // CHECK: cf.assert %[[VAL_10]], "Invalid value range for size between [0, 9223372036854775807]" // CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64 -// CHECK: %[[VAL_13:.*]] = arith.cmpi ule, %[[VAL_11]], %[[VAL_5]] : i64 -// CHECK: %[[VAL_14:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_12]] : i64 +// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64 +// CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64 // CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 // CHECK: cf.assert %[[VAL_15]], "Invalid value range for size between [0, 7]" // CHECK: return %[[VAL_4]] : !torch.int -module { - func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int { +func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int { %int7 = torch.constant.int 7 %int0 = torch.constant.int 0 %none = torch.constant.none @@ -31,7 +30,6 @@ module { torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int return %0 : !torch.int - } } // ----- @@ -47,8 +45,8 @@ module { // CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_9:.*]] = arith.constant 9223372036854775807 : i64 -// CHECK: %[[VAL_10:.*]] = arith.cmpi ule, %[[VAL_8]], %[[VAL_7]] : i64 -// CHECK: %[[VAL_11:.*]] = arith.cmpi ule, %[[VAL_7]], %[[VAL_9]] : i64 +// CHECK: %[[VAL_10:.*]] = arith.cmpi sle, %[[VAL_8]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_9]] : i64 // CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_10]], %[[VAL_11]] : i1 // CHECK: cf.assert %[[VAL_12]], "Invalid value range for size between [0, 9223372036854775807]" // CHECK: %[[VAL_13:.*]] = torch.aten.ge.int %[[VAL_6]], %[[VAL_4]] : !torch.int, !torch.int -> !torch.bool From 7b1026f107a8f92782be7cec25e8e573d8ca8bc2 Mon Sep 17 00:00:00 2001 From: Praveen G Date: Wed, 8 Jan 2025 12:48:37 +0000 Subject: [PATCH 3/5] Lower aten::_assert_scalar to torch.runtime.assert --- .../TorchToLinalg/Uncategorized.cpp | 39 ++------------ .../Torch/Transforms/DecomposeComplexOps.cpp | 44 ++++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 8 ++- .../torch_mlir_e2e_test/test_suite/basic.py | 14 ++--- .../Conversion/TorchToLinalg/constraints.mlir | 52 +------------------ test/Dialect/Torch/decompose-complex-ops.mlir | 35 +++++++++++-- 6 files changed, 85 insertions(+), 107 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 360b86ccb701..4ebdfbf94129 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -21,6 +21,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" @@ -3580,23 +3581,17 @@ class ConvertSymConstrainRangeOp auto min = op.getMin(); auto max = op.getMax(); - auto minOp = min.getDefiningOp(); - auto maxOp = max.getDefiningOp(); - - if (!minOp || !maxOp) - return op.emitError("Unimplemented: Non constant min/max values"); - int64_t minValue = std::numeric_limits::min(); int64_t maxValue = std::numeric_limits::max(); Type operandType = getTypeConverter()->convertType(op.getSize().getType()); - if (!isa(minOp)) + if (!isa(min.getType())) if (!matchPattern(min, m_TorchConstantInt(&minValue))) return rewriter.notifyMatchFailure( op, "Expected min value to be constant integer"); - if (!isa(maxOp)) + if (!isa(max.getType())) if (!matchPattern(max, m_TorchConstantInt(&maxValue))) return rewriter.notifyMatchFailure( op, "Expected max value to be constant integer"); @@ -3621,7 +3616,7 @@ class ConvertSymConstrainRangeOp loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); auto compareVal = rewriter.create(loc, checkMin, checkMax); - std::string assertMessage = "Invalid value range for size between [" + + std::string assertMessage = "Size constraint failed. Expected range: [" + std::to_string(minValue) + ", " + std::to_string(maxValue) + "]"; rewriter.create(loc, compareVal, @@ -3633,30 +3628,6 @@ class ConvertSymConstrainRangeOp }; } // namespace -namespace { -class ConvertAssertScalarOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(Aten_AssertScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - auto assertCond = convertScalarToDtype( - rewriter, op.getLoc(), adaptor.getSelf(), rewriter.getI1Type()); - - std::string assertMessage; - if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage))) - return rewriter.notifyMatchFailure( - op, "Assert message must be a constant string"); - - rewriter.replaceOpWithNewOp(op, assertCond, assertMessage); - return success(); - } -}; -} // namespace - void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -3721,6 +3692,4 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 41a1186bc987..1226ad2c03e2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11456,7 +11456,7 @@ class DecomposeAtenSpecialExpm1Op } // namespace namespace { -class DecomposeConstrainRangeForSizeOp +class DecomposeAtenConstrainRangeForSizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -11466,15 +11466,10 @@ class DecomposeConstrainRangeForSizeOp auto loc = op.getLoc(); auto min = op.getMin(); auto max = op.getMax(); - auto minOp = min.getDefiningOp(); - auto maxOp = max.getDefiningOp(); - - if (!minOp || !maxOp) - return op.emitError("Unimplemented: Non constant min/max values"); int64_t minValue, maxValue; - if (isa(minOp)) { + if (isa(min.getType())) { // Set min value to 0 min = rewriter.create(loc, 0); } else { @@ -11484,7 +11479,7 @@ class DecomposeConstrainRangeForSizeOp op, "Expected min value to be constant integer"); } - if (!isa(maxOp)) { + if (!isa(max.getType())) { // Verify that max value is greater than 2 if (!matchPattern(max, m_TorchConstantInt(&maxValue))) return rewriter.notifyMatchFailure( @@ -11505,6 +11500,35 @@ class DecomposeConstrainRangeForSizeOp }; } // namespace +namespace { +class DecomposeAten_AssertScalarOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_AssertScalarOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto assertCond = op.getSelf(); + + if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + else if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + assert(isa(assertCond.getType()) && + "Unhandled type encountered in aten._assert_scalar op"); + + std::string assertMessage; + if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage))) + return rewriter.notifyMatchFailure( + op, "Assert message must be a constant string"); + + rewriter.replaceOpWithNewOp(op, assertCond, assertMessage); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11803,7 +11827,9 @@ class DecomposeComplexOpsPass // Torchvision ops addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3c7b878478a5..08cc787bcb74 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -38,7 +38,7 @@ # Unknown builtin op: aten::_check_is_size in TorchScript "AtenSymConstrainRange_basic", "AtenSymConstrainRangeForSize_basic", - "AtenAssertScalar", + "Aten_AssertScalar_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -945,6 +945,9 @@ "UniformModule_basic", "UniformStaticShapeModule_basic", "ScaledDotProductAttentionGQAModule_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3258,6 +3261,9 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleZerodDimBug_basic", "ScaledDotProductAttentionGQAModule_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 78856a17db3c..4ba497452a76 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6493,7 +6493,6 @@ def __init__(self): @annotate_args([None, ([-1], torch.int, True)]) def forward(self, x): a = x.item() - torch._check_is_size(a) torch.ops.aten.sym_constrain_range(a, max=5) return a @@ -6514,8 +6513,6 @@ def __init__(self): @annotate_args([None, ([-1], torch.int, True)]) def forward(self, x): a = x.item() - torch._check_is_size(a) - # max should be > 2 torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10) return a @@ -6526,7 +6523,7 @@ def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils): # ============================================================================== -class AtenAssertScalar(torch.nn.Module): +class Aten_AssertScalar(torch.nn.Module): def __init__(self): super().__init__() @@ -6534,12 +6531,11 @@ def __init__(self): @annotate_args([None, ([-1], torch.int, True)]) def forward(self, x): a = x.item() - # The below checks introduces aten._assert_scalar op - torch._check_is_size(a) - torch._check(a <= 5) + assert_msg = "Assertion failed for condition x.item() > 3" + torch.ops.aten._assert_scalar(a > 3, assert_msg) return a -@register_test_case(module_factory=lambda: AtenAssertScalar()) -def AtenAssertScalar_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: Aten_AssertScalar()) +def Aten_AssertScalar_basic(module, tu: TestUtils): module.forward(torch.tensor(4)) diff --git a/test/Conversion/TorchToLinalg/constraints.mlir b/test/Conversion/TorchToLinalg/constraints.mlir index bc48da402fb8..11bafaa973d1 100644 --- a/test/Conversion/TorchToLinalg/constraints.mlir +++ b/test/Conversion/TorchToLinalg/constraints.mlir @@ -1,5 +1,4 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s -// ----- // CHECK-LABEL: func.func @torch.aten.sym_constrain_range( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { @@ -13,13 +12,13 @@ // CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64 // CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64 // CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1 -// CHECK: cf.assert %[[VAL_10]], "Invalid value range for size between [0, 9223372036854775807]" +// CHECK: cf.assert %[[VAL_10]], "Size constraint failed. Expected range: [0, 9223372036854775807]" // CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64 // CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64 // CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64 // CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 -// CHECK: cf.assert %[[VAL_15]], "Invalid value range for size between [0, 7]" +// CHECK: cf.assert %[[VAL_15]], "Size constraint failed. Expected range: [0, 7]" // CHECK: return %[[VAL_4]] : !torch.int func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int { @@ -31,50 +30,3 @@ func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !to torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int return %0 : !torch.int } - -// ----- - -// CHECK-LABEL: func.func @torch.aten._assert_scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { -// CHECK: %[[VAL_1:.*]] = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" -// CHECK: %[[VAL_2:.*]] = torch.constant.int 7 -// CHECK: %[[VAL_3:.*]] = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" -// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int -// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]] -// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_9:.*]] = arith.constant 9223372036854775807 : i64 -// CHECK: %[[VAL_10:.*]] = arith.cmpi sle, %[[VAL_8]], %[[VAL_7]] : i64 -// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_9]] : i64 -// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_10]], %[[VAL_11]] : i1 -// CHECK: cf.assert %[[VAL_12]], "Invalid value range for size between [0, 9223372036854775807]" -// CHECK: %[[VAL_13:.*]] = torch.aten.ge.int %[[VAL_6]], %[[VAL_4]] : !torch.int, !torch.int -> !torch.bool -// CHECK: %[[VAL_14:.*]] = torch.aten.Int.bool %[[VAL_13]] : !torch.bool -> !torch.int -// CHECK: %[[VAL_15:.*]] = torch_c.to_i64 %[[VAL_14]] -// CHECK: %[[VAL_16:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_15]], %[[VAL_16]] : i64 -// CHECK: cf.assert %[[VAL_17]], "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" -// CHECK: %[[VAL_18:.*]] = torch.aten.le.int %[[VAL_6]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool -// CHECK: %[[VAL_19:.*]] = torch.aten.Int.bool %[[VAL_18]] : !torch.bool -> !torch.int -// CHECK: %[[VAL_20:.*]] = torch_c.to_i64 %[[VAL_19]] -// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_22:.*]] = arith.cmpi ne, %[[VAL_20]], %[[VAL_21]] : i64 -// CHECK: cf.assert %[[VAL_22]], "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" -// CHECK: return %[[VAL_6]] : !torch.int -func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int { - %str = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" - %int7 = torch.constant.int 7 - %str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" - %int0 = torch.constant.int 0 - %none = torch.constant.none - %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int - torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none - %1 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int - torch.aten._assert_scalar %2, %str_0 : !torch.int, !torch.str - %3 = torch.aten.le.int %0, %int7 : !torch.int, !torch.int -> !torch.bool - %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int - torch.aten._assert_scalar %4, %str : !torch.int, !torch.str - return %0 : !torch.int -} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 0adb10edac80..be3f6548fc98 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -240,8 +240,7 @@ func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) // CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none // CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int // CHECK: return %[[VAL_4]] : !torch.int -module { - func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int { +func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int { %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int %none = torch.constant.none %none_0 = torch.constant.none @@ -250,5 +249,35 @@ module { %int7_7 = torch.constant.int 7 torch.aten.sym_constrain_range_for_size %0, %int0_6, %int7_7 : !torch.int, !torch.int, !torch.int return %0 : !torch.int - } +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten._assert_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_3:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[VAL_4:.*]] = torch.aten.ge.int %[[VAL_3]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_5:.*]] = torch.aten.Int.bool %[[VAL_4]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_6:.*]] = torch.aten.Bool.int %[[VAL_5]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_6]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" +// CHECK: %[[VAL_7:.*]] = torch.aten.gt.int %[[VAL_3]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_8:.*]] = torch.aten.Int.bool %[[VAL_7]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_9:.*]] = torch.aten.Bool.int %[[VAL_8]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_9]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" +// CHECK: return %[[VAL_3]] : !torch.int +func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int { + %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + %int3 = torch.constant.int 3 + %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int + %str = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" + torch.aten._assert_scalar %2, %str : !torch.int, !torch.str + %int2 = torch.constant.int 2 + %3 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool + %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int + %str_0 = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" + torch.aten._assert_scalar %4, %str_0 : !torch.int, !torch.str + return %0 : !torch.int } From 59ab6ba933585bec8e588829f219541c1075a8cb Mon Sep 17 00:00:00 2001 From: Praveen G Date: Wed, 8 Jan 2025 17:48:55 +0000 Subject: [PATCH 4/5] Move AtenNonzero1DDynamicModule test to crash set --- projects/pt1/e2e_testing/xfail_sets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 08cc787bcb74..e433fabe2712 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -627,7 +627,6 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", - "AtenNonzero1DDynamicModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenTopKModule_basic", @@ -971,6 +970,7 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", + "AtenNonzero1DDynamicModule_basic", # error: Mismatched ranks of types2 vs 1 } STABLEHLO_PASS_SET = { From f8e71328e6296abb41c0ec2cc053ad05a3789618 Mon Sep 17 00:00:00 2001 From: Praveen G Date: Thu, 23 Jan 2025 09:46:42 +0000 Subject: [PATCH 5/5] Simplify lit test cases --- .../Conversion/TorchToLinalg/constraints.mlir | 54 +++++++------- test/Dialect/Torch/decompose-complex-ops.mlir | 71 +++++++++---------- 2 files changed, 60 insertions(+), 65 deletions(-) diff --git a/test/Conversion/TorchToLinalg/constraints.mlir b/test/Conversion/TorchToLinalg/constraints.mlir index 11bafaa973d1..19075d72103a 100644 --- a/test/Conversion/TorchToLinalg/constraints.mlir +++ b/test/Conversion/TorchToLinalg/constraints.mlir @@ -1,32 +1,30 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.sym_constrain_range( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { -// CHECK: %[[VAL_1:.*]] = torch.constant.int 7 -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.none -// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int -// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] -// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_7:.*]] = arith.constant 9223372036854775807 : i64 -// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64 -// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64 -// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1 -// CHECK: cf.assert %[[VAL_10]], "Size constraint failed. Expected range: [0, 9223372036854775807]" -// CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64 -// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64 -// CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64 -// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 -// CHECK: cf.assert %[[VAL_15]], "Size constraint failed. Expected range: [0, 7]" -// CHECK: return %[[VAL_4]] : !torch.int - -func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int { - %int7 = torch.constant.int 7 - %int0 = torch.constant.int 0 - %none = torch.constant.none - %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int - torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none - torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int - return %0 : !torch.int +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[VAL_0]] +// CHECK: %[[VAL_2:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64 +// CHECK: %[[VAL_7:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_1]] : i64 +// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_6]] : i64 +// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_7]], %[[VAL_8]] : i1 +// CHECK: cf.assert %[[VAL_9]], "Size constraint failed. Expected range: [0, 9223372036854775807]" +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_11:.*]] = arith.constant 7 : i64 +// CHECK: %[[VAL_12:.*]] = arith.cmpi sle, %[[VAL_10]], %[[VAL_1]] : i64 +// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_11]] : i64 +// CHECK: %[[VAL_14:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : i1 +// CHECK: cf.assert %[[VAL_14]], "Size constraint failed. Expected range: [0, 7]" +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten.sym_constrain_range(%arg0: !torch.int) -> !torch.int { + %int7 = torch.constant.int 7 + %int0 = torch.constant.int 0 + %none = torch.constant.none + torch.aten.sym_constrain_range %arg0, %int0, %none : !torch.int, !torch.int, !torch.none + torch.aten.sym_constrain_range %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int + return %arg0 : !torch.int } diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index be3f6548fc98..4c99f4949a38 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -232,52 +232,49 @@ func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) // ----- // CHECK-LABEL: func.func @torch.aten.sym_constrain_range_for_size( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { // CHECK: %[[VAL_1:.*]] = torch.constant.int 7 // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.none -// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int -// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none -// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int -// CHECK: return %[[VAL_4]] : !torch.int -func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int { - %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int - %none = torch.constant.none - %none_0 = torch.constant.none - torch.aten.sym_constrain_range_for_size %0, %none, %none_0 : !torch.int, !torch.none, !torch.none - %int0_6 = torch.constant.int 0 - %int7_7 = torch.constant.int 7 - torch.aten.sym_constrain_range_for_size %0, %int0_6, %int7_7 : !torch.int, !torch.int, !torch.int - return %0 : !torch.int +// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none +// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.int) -> !torch.int { + %int7 = torch.constant.int 7 + %int0 = torch.constant.int 0 + %none = torch.constant.none + torch.aten.sym_constrain_range_for_size %arg0, %none, %none : !torch.int, !torch.none, !torch.none + torch.aten.sym_constrain_range_for_size %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int + return %arg0 : !torch.int } // ----- // CHECK-LABEL: func.func @torch.aten._assert_scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { // CHECK: %[[VAL_1:.*]] = torch.constant.int 2 // CHECK: %[[VAL_2:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_3:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int -// CHECK: %[[VAL_4:.*]] = torch.aten.ge.int %[[VAL_3]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool -// CHECK: %[[VAL_5:.*]] = torch.aten.Int.bool %[[VAL_4]] : !torch.bool -> !torch.int -// CHECK: %[[VAL_6:.*]] = torch.aten.Bool.int %[[VAL_5]] : !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[VAL_6]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" -// CHECK: %[[VAL_7:.*]] = torch.aten.gt.int %[[VAL_3]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool -// CHECK: %[[VAL_8:.*]] = torch.aten.Int.bool %[[VAL_7]] : !torch.bool -> !torch.int -// CHECK: %[[VAL_9:.*]] = torch.aten.Bool.int %[[VAL_8]] : !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[VAL_9]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" -// CHECK: return %[[VAL_3]] : !torch.int -func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int { - %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int - %int3 = torch.constant.int 3 - %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool - %2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int - %str = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" - torch.aten._assert_scalar %2, %str : !torch.int, !torch.str +// CHECK: %[[VAL_3:.*]] = torch.aten.ge.int %[[VAL_0]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_4:.*]] = torch.aten.Int.bool %[[VAL_3]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_5:.*]] = torch.aten.Bool.int %[[VAL_4]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_5]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" +// CHECK: %[[VAL_6:.*]] = torch.aten.gt.int %[[VAL_0]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_7:.*]] = torch.aten.Int.bool %[[VAL_6]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_8:.*]] = torch.aten.Bool.int %[[VAL_7]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_8]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int { + %str = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" %int2 = torch.constant.int 2 - %3 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int - %str_0 = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" - torch.aten._assert_scalar %4, %str_0 : !torch.int, !torch.str - return %0 : !torch.int + %str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" + %int3 = torch.constant.int 3 + %0 = torch.aten.ge.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool + %1 = torch.aten.Int.bool %0 : !torch.bool -> !torch.int + torch.aten._assert_scalar %1, %str_0 : !torch.int, !torch.str + %2 = torch.aten.gt.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool + %3 = torch.aten.Int.bool %2 : !torch.bool -> !torch.int + torch.aten._assert_scalar %3, %str : !torch.int, !torch.str + return %arg0 : !torch.int }