Skip to content

Conversation

@makslevental
Copy link
Contributor

@makslevental makslevental commented Jan 8, 2026

We already have GreedyRewriteDriverConfig on the Python side, but it hasn’t yet been exposed as a parameter of apply_patterns_and_fold_greedily. This PR does that.

Before:

def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None

After:

def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet,
                                     config: GreedyRewriteDriverConfig | None = None) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet,
                                     config: GreedyRewriteDriverConfig | None = None) -> None

Note this PR is adapted from #174785 but using std::optional instead of nb::object. Note, this required refactoring PyGreedyRewriteDriverConfig to have a std::shared_ptr so that it could support a copy-ctor.

@makslevental makslevental changed the title [MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patte… [MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patterns_and_fold_greedily Jan 8, 2026
@makslevental makslevental force-pushed the users/makslevental/mlir-python-rewrite-config- branch 3 times, most recently from b137aed to 2f4367f Compare January 8, 2026 06:16
@makslevental makslevental marked this pull request as ready for review January 8, 2026 06:17
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Jan 8, 2026
@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2026

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

adapted from #174785 but using std::optional instead of nb::object. Note, this required refactoring PyGreedyRewriteDriverConfig so that it could support a copy-ctor.


Full diff: https://github.com/llvm/llvm-project/pull/174913.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+38-31)
  • (modified) mlir/test/python/rewrite.py (+44)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9830c277ac147..f908433268555 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -250,80 +250,82 @@ enum PyGreedySimplifyRegionLevel : std::underlying_type_t<
 class PyGreedyRewriteDriverConfig {
 public:
   PyGreedyRewriteDriverConfig()
-      : config(mlirGreedyRewriteDriverConfigCreate()) {}
+      : config(mlirGreedyRewriteDriverConfigCreate().ptr, customDeleter) {}
   PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
-      : config(other.config) {
-    other.config.ptr = nullptr;
-  }
-  ~PyGreedyRewriteDriverConfig() {
-    if (config.ptr != nullptr)
-      mlirGreedyRewriteDriverConfigDestroy(config);
+      : config(std::move(other.config)) {}
+  PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
+      : config(other.config) {}
+
+  MlirGreedyRewriteDriverConfig get() {
+    return MlirGreedyRewriteDriverConfig{config.get()};
   }
-  MlirGreedyRewriteDriverConfig get() { return config; }
 
   void setMaxIterations(int64_t maxIterations) {
-    mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+    mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations);
   }
 
   void setMaxNumRewrites(int64_t maxNumRewrites) {
-    mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+    mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites);
   }
 
   void setUseTopDownTraversal(bool useTopDownTraversal) {
-    mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
+    mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(),
                                                         useTopDownTraversal);
   }
 
   void enableFolding(bool enable) {
-    mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+    mlirGreedyRewriteDriverConfigEnableFolding(get(), enable);
   }
 
   void setStrictness(PyGreedyRewriteStrictness strictness) {
     mlirGreedyRewriteDriverConfigSetStrictness(
-        config, static_cast<MlirGreedyRewriteStrictness>(strictness));
+        get(), static_cast<MlirGreedyRewriteStrictness>(strictness));
   }
 
   void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
     mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
-        config, static_cast<MlirGreedySimplifyRegionLevel>(level));
+        get(), static_cast<MlirGreedySimplifyRegionLevel>(level));
   }
 
   void enableConstantCSE(bool enable) {
-    mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+    mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable);
   }
 
   int64_t getMaxIterations() {
-    return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+    return mlirGreedyRewriteDriverConfigGetMaxIterations(get());
   }
 
   int64_t getMaxNumRewrites() {
-    return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+    return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get());
   }
 
   bool getUseTopDownTraversal() {
-    return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+    return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get());
   }
 
   bool isFoldingEnabled() {
-    return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+    return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get());
   }
 
   PyGreedyRewriteStrictness getStrictness() {
     return static_cast<PyGreedyRewriteStrictness>(
-        mlirGreedyRewriteDriverConfigGetStrictness(config));
+        mlirGreedyRewriteDriverConfigGetStrictness(get()));
   }
 
   PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
     return static_cast<PyGreedySimplifyRegionLevel>(
-        mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+        mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get()));
   }
 
   bool isConstantCSEEnabled() {
-    return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+    return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get());
   }
 
 private:
-  MlirGreedyRewriteDriverConfig config;
+  std::shared_ptr<void> config;
+  static void customDeleter(void *c) {
+    mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
+  }
 };
 
 /// Create the `mlir.rewrite` here.
@@ -509,26 +511,31 @@ void populateRewriteSubmodule(nb::module_ &m) {
            &PyFrozenRewritePatternSet::createFromCapsule);
   m.def(
        "apply_patterns_and_fold_greedily",
-       [](PyModule &module, PyFrozenRewritePatternSet &set) {
-         auto status = mlirApplyPatternsAndFoldGreedily(
-             module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
+       [](PyModule &module, PyFrozenRewritePatternSet &set,
+          std::optional<PyGreedyRewriteDriverConfig> config) {
+         MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
+             module.get(), set.get(),
+             config.has_value() ? config->get()
+                                : mlirGreedyRewriteDriverConfigCreate());
          if (mlirLogicalResultIsFailure(status))
            throw std::runtime_error("pattern application failed to converge");
        },
-       "module"_a, "set"_a,
+       "module"_a, "set"_a, "config"_a = nb::none(),
        "Applys the given patterns to the given module greedily while folding "
        "results.")
       .def(
           "apply_patterns_and_fold_greedily",
-          [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
-            auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
+          [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
+             std::optional<PyGreedyRewriteDriverConfig> config) {
+            MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
                 op.getOperation(), set.get(),
-                mlirGreedyRewriteDriverConfigCreate());
+                config.has_value() ? config->get()
+                                   : mlirGreedyRewriteDriverConfigCreate());
             if (mlirLogicalResultIsFailure(status))
               throw std::runtime_error(
                   "pattern application failed to converge");
           },
-          "op"_a, "set"_a,
+          "op"_a, "set"_a, "config"_a = nb::none(),
           "Applys the given patterns to the given op greedily while folding "
           "results.")
       .def(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index a2fbbde38b8c0..43e9b761a0ea2 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
     config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
     level = config.region_simplification_level
     print(f"region_level AGGRESSIVE: {level}")
+
+
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+@run
+def testRewriteWithGreedyRewriteDriverConfig():
+    def constant_1_to_2(op, rewriter):
+        c = op.value.value
+        if c != 1:
+            return True  # failed to match
+        with rewriter.ip:
+            new_op = arith.constant(op.type, 2, loc=op.location)
+        rewriter.replace_op(op, [new_op])
+
+    with Context():
+        patterns = RewritePatternSet()
+        patterns.add(arith.ConstantOp, constant_1_to_2)
+        frozen = patterns.freeze()
+
+        module = ModuleOp.parse(
+            r"""
+            module {
+              func.func @const() -> (i64, i64) {
+                %0 = arith.constant 1 : i64
+                %1 = arith.constant 1 : i64
+                return %0, %1 : i64, i64
+              }
+            }
+            """
+        )
+
+        config = GreedyRewriteDriverConfig()
+        config.enable_constant_cse = False
+        apply_patterns_and_fold_greedily(module, frozen, config)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: %c2_i64_0 = arith.constant 2 : i64
+        # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
+        print(module)
+
+        config = GreedyRewriteDriverConfig()
+        config.enable_constant_cse = True
+        apply_patterns_and_fold_greedily(module, frozen, config)
+        # CHECK: %c2_i64 = arith.constant 2 : i64
+        # CHECK: return %c2_i64, %c2_i64 : i64
+        print(module)

@makslevental makslevental force-pushed the users/makslevental/mlir-python-rewrite-config- branch from 2f4367f to e41c852 Compare January 8, 2026 06:18
@makslevental makslevental requested a review from jpienaar January 8, 2026 06:20
Copy link
Member

@PragmaTwice PragmaTwice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thank you.

@makslevental makslevental merged commit 94a9565 into main Jan 8, 2026
10 checks passed
@makslevental makslevental deleted the users/makslevental/mlir-python-rewrite-config- branch January 8, 2026 11:48
kshitijvp pushed a commit to kshitijvp/llvm-project that referenced this pull request Jan 9, 2026
…rns_and_fold_greedily (llvm#174913)

We already have `GreedyRewriteDriverConfig` on the Python side, but it
hasn’t yet been exposed as a parameter of
`apply_patterns_and_fold_greedily`. This PR does that.

Before:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None
```

After:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet,
                                     config: GreedyRewriteDriverConfig | None = None) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet,
                                     config: GreedyRewriteDriverConfig | None = None) -> None
```

Note this PR is adapted from
llvm#174785 but using
`std::optional` instead of `nb::object`. Note, this required refactoring
`PyGreedyRewriteDriverConfig` to have a `std::shared_ptr` so that it
could support a copy-ctor.

Co-authored-by: PragmaTwice <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants