diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index e143f118a1f01..2b649f79c5982 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -242,14 +242,14 @@ enum class PyGreedySimplifyRegionLevel : std::underlying_type_t< }; /// Owning Wrapper around a GreedyRewriteDriverConfig. -class PyGreedyRewriteDriverConfig { +class PyGreedyRewriteConfig { public: - PyGreedyRewriteDriverConfig() + PyGreedyRewriteConfig() : config(mlirGreedyRewriteDriverConfigCreate().ptr, - PyGreedyRewriteDriverConfig::customDeleter) {} - PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept + PyGreedyRewriteConfig::customDeleter) {} + PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept : config(std::move(other.config)) {} - PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept + PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept : config(other.config) {} MlirGreedyRewriteDriverConfig get() { @@ -470,34 +470,32 @@ void populateRewriteSubmodule(nb::module_ &m) { nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH - nb::class_(m, "GreedyRewriteDriverConfig") + nb::class_(m, "GreedyRewriteConfig") .def(nb::init<>(), "Create a greedy rewrite driver config with defaults") - .def_prop_rw("max_iterations", - &PyGreedyRewriteDriverConfig::getMaxIterations, - &PyGreedyRewriteDriverConfig::setMaxIterations, + .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations, + &PyGreedyRewriteConfig::setMaxIterations, "Maximum number of iterations") .def_prop_rw("max_num_rewrites", - &PyGreedyRewriteDriverConfig::getMaxNumRewrites, - &PyGreedyRewriteDriverConfig::setMaxNumRewrites, + &PyGreedyRewriteConfig::getMaxNumRewrites, + &PyGreedyRewriteConfig::setMaxNumRewrites, "Maximum number of rewrites per iteration") .def_prop_rw("use_top_down_traversal", - &PyGreedyRewriteDriverConfig::getUseTopDownTraversal, - &PyGreedyRewriteDriverConfig::setUseTopDownTraversal, + &PyGreedyRewriteConfig::getUseTopDownTraversal, + &PyGreedyRewriteConfig::setUseTopDownTraversal, "Whether to use top-down traversal") - .def_prop_rw("enable_folding", - &PyGreedyRewriteDriverConfig::isFoldingEnabled, - &PyGreedyRewriteDriverConfig::enableFolding, + .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled, + &PyGreedyRewriteConfig::enableFolding, "Enable or disable folding") - .def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness, - &PyGreedyRewriteDriverConfig::setStrictness, + .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness, + &PyGreedyRewriteConfig::setStrictness, "Rewrite strictness level") .def_prop_rw("region_simplification_level", - &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel, - &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel, + &PyGreedyRewriteConfig::getRegionSimplificationLevel, + &PyGreedyRewriteConfig::setRegionSimplificationLevel, "Region simplification level") .def_prop_rw("enable_constant_cse", - &PyGreedyRewriteDriverConfig::isConstantCSEEnabled, - &PyGreedyRewriteDriverConfig::enableConstantCSE, + &PyGreedyRewriteConfig::isConstantCSEEnabled, + &PyGreedyRewriteConfig::enableConstantCSE, "Enable or disable constant CSE"); nb::class_(m, "FrozenRewritePatternSet") @@ -508,7 +506,7 @@ void populateRewriteSubmodule(nb::module_ &m) { m.def( "apply_patterns_and_fold_greedily", [](PyModule &module, PyFrozenRewritePatternSet &set, - std::optional config) { + std::optional config) { MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily( module.get(), set.get(), config.has_value() ? config->get() @@ -522,7 +520,7 @@ void populateRewriteSubmodule(nb::module_ &m) { .def( "apply_patterns_and_fold_greedily", [](PyOperationBase &op, PyFrozenRewritePatternSet &set, - std::optional config) { + std::optional config) { MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp( op.getOperation(), set.get(), config.has_value() ? config->get() diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index 43e9b761a0ea2..8ef49981a8b3c 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -89,19 +89,19 @@ def constant_1_to_2(op, rewriter): print(module) -# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigCreation +# CHECK-LABEL: TEST: testGreedyRewriteConfigCreation @run -def testGreedyRewriteDriverConfigCreation(): +def testGreedyRewriteConfigCreation(): # Test basic config creation and destruction - config = GreedyRewriteDriverConfig() + config = GreedyRewriteConfig() # CHECK: Config created successfully print("Config created successfully") -# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigGetters +# CHECK-LABEL: TEST: testGreedyRewriteConfigGetters @run -def testGreedyRewriteDriverConfigGetters(): - config = GreedyRewriteDriverConfig() +def testGreedyRewriteConfigGetters(): + config = GreedyRewriteConfig() # Set some values config.max_iterations = 5 @@ -139,7 +139,7 @@ def testGreedyRewriteDriverConfigGetters(): # CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum @run def testGreedyRewriteStrictnessEnum(): - config = GreedyRewriteDriverConfig() + config = GreedyRewriteConfig() # Test ANY_OP # CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP @@ -163,7 +163,7 @@ def testGreedyRewriteStrictnessEnum(): # CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum @run def testGreedySimplifyRegionLevelEnum(): - config = GreedyRewriteDriverConfig() + config = GreedyRewriteConfig() # Test DISABLED # CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED @@ -184,9 +184,9 @@ def testGreedySimplifyRegionLevelEnum(): print(f"region_level AGGRESSIVE: {level}") -# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig +# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteConfig @run -def testRewriteWithGreedyRewriteDriverConfig(): +def testRewriteWithGreedyRewriteConfig(): def constant_1_to_2(op, rewriter): c = op.value.value if c != 1: @@ -212,7 +212,7 @@ def constant_1_to_2(op, rewriter): """ ) - config = GreedyRewriteDriverConfig() + config = GreedyRewriteConfig() config.enable_constant_cse = False apply_patterns_and_fold_greedily(module, frozen, config) # CHECK: %c2_i64 = arith.constant 2 : i64 @@ -220,7 +220,7 @@ def constant_1_to_2(op, rewriter): # CHECK: return %c2_i64, %c2_i64_0 : i64, i64 print(module) - config = GreedyRewriteDriverConfig() + config = GreedyRewriteConfig() config.enable_constant_cse = True apply_patterns_and_fold_greedily(module, frozen, config) # CHECK: %c2_i64 = arith.constant 2 : i64