diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9830c277ac147..9b72c4bb69b3c 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -250,80 +250,83 @@ enum PyGreedySimplifyRegionLevel : std::underlying_type_t< class PyGreedyRewriteDriverConfig { public: PyGreedyRewriteDriverConfig() - : config(mlirGreedyRewriteDriverConfigCreate()) {} + : config(mlirGreedyRewriteDriverConfigCreate().ptr, + PyGreedyRewriteDriverConfig::customDeleter) {} PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept - : config(other.config) { - other.config.ptr = nullptr; - } - ~PyGreedyRewriteDriverConfig() { - if (config.ptr != nullptr) - mlirGreedyRewriteDriverConfigDestroy(config); + : config(std::move(other.config)) {} + PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept + : config(other.config) {} + + MlirGreedyRewriteDriverConfig get() { + return MlirGreedyRewriteDriverConfig{config.get()}; } - MlirGreedyRewriteDriverConfig get() { return config; } void setMaxIterations(int64_t maxIterations) { - mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations); + mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations); } void setMaxNumRewrites(int64_t maxNumRewrites) { - mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites); + mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites); } void setUseTopDownTraversal(bool useTopDownTraversal) { - mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, + mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(), useTopDownTraversal); } void enableFolding(bool enable) { - mlirGreedyRewriteDriverConfigEnableFolding(config, enable); + mlirGreedyRewriteDriverConfigEnableFolding(get(), enable); } void setStrictness(PyGreedyRewriteStrictness strictness) { mlirGreedyRewriteDriverConfigSetStrictness( - config, static_cast(strictness)); + get(), static_cast(strictness)); } void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) { mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( - config, static_cast(level)); + get(), static_cast(level)); } void enableConstantCSE(bool enable) { - mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable); + mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable); } int64_t getMaxIterations() { - return mlirGreedyRewriteDriverConfigGetMaxIterations(config); + return mlirGreedyRewriteDriverConfigGetMaxIterations(get()); } int64_t getMaxNumRewrites() { - return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config); + return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get()); } bool getUseTopDownTraversal() { - return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config); + return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get()); } bool isFoldingEnabled() { - return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config); + return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get()); } PyGreedyRewriteStrictness getStrictness() { return static_cast( - mlirGreedyRewriteDriverConfigGetStrictness(config)); + mlirGreedyRewriteDriverConfigGetStrictness(get())); } PyGreedySimplifyRegionLevel getRegionSimplificationLevel() { return static_cast( - mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config)); + mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get())); } bool isConstantCSEEnabled() { - return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config); + return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get()); } private: - MlirGreedyRewriteDriverConfig config; + std::shared_ptr config; + static void customDeleter(void *c) { + mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c}); + } }; /// Create the `mlir.rewrite` here. @@ -509,26 +512,31 @@ void populateRewriteSubmodule(nb::module_ &m) { &PyFrozenRewritePatternSet::createFromCapsule); m.def( "apply_patterns_and_fold_greedily", - [](PyModule &module, PyFrozenRewritePatternSet &set) { - auto status = mlirApplyPatternsAndFoldGreedily( - module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate()); + [](PyModule &module, PyFrozenRewritePatternSet &set, + std::optional config) { + MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily( + module.get(), set.get(), + config.has_value() ? config->get() + : mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error("pattern application failed to converge"); }, - "module"_a, "set"_a, + "module"_a, "set"_a, "config"_a = nb::none(), "Applys the given patterns to the given module greedily while folding " "results.") .def( "apply_patterns_and_fold_greedily", - [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { - auto status = mlirApplyPatternsAndFoldGreedilyWithOp( + [](PyOperationBase &op, PyFrozenRewritePatternSet &set, + std::optional config) { + MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp( op.getOperation(), set.get(), - mlirGreedyRewriteDriverConfigCreate()); + config.has_value() ? config->get() + : mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error( "pattern application failed to converge"); }, - "op"_a, "set"_a, + "op"_a, "set"_a, "config"_a = nb::none(), "Applys the given patterns to the given op greedily while folding " "results.") .def( diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index a2fbbde38b8c0..43e9b761a0ea2 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum(): config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE level = config.region_simplification_level print(f"region_level AGGRESSIVE: {level}") + + +# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig +@run +def testRewriteWithGreedyRewriteDriverConfig(): + def constant_1_to_2(op, rewriter): + c = op.value.value + if c != 1: + return True # failed to match + with rewriter.ip: + new_op = arith.constant(op.type, 2, loc=op.location) + rewriter.replace_op(op, [new_op]) + + with Context(): + patterns = RewritePatternSet() + patterns.add(arith.ConstantOp, constant_1_to_2) + frozen = patterns.freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @const() -> (i64, i64) { + %0 = arith.constant 1 : i64 + %1 = arith.constant 1 : i64 + return %0, %1 : i64, i64 + } + } + """ + ) + + config = GreedyRewriteDriverConfig() + config.enable_constant_cse = False + apply_patterns_and_fold_greedily(module, frozen, config) + # CHECK: %c2_i64 = arith.constant 2 : i64 + # CHECK: %c2_i64_0 = arith.constant 2 : i64 + # CHECK: return %c2_i64, %c2_i64_0 : i64, i64 + print(module) + + config = GreedyRewriteDriverConfig() + config.enable_constant_cse = True + apply_patterns_and_fold_greedily(module, frozen, config) + # CHECK: %c2_i64 = arith.constant 2 : i64 + # CHECK: return %c2_i64, %c2_i64 : i64 + print(module)