Skip to content

Conversation

@PragmaTwice
Copy link
Member

This is mainly for two purposes:

  1. to keep it consistent with the C++ class name mlir::GreedyRewriteConfig,
  2. to make it shorter.

Since this type was only added a few days ago (654b3e8), it shouldn’t cause any obvious compatibility issues.

@llvmbot
Copy link
Member

llvmbot commented Jan 11, 2026

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

This is mainly for two purposes:

  1. to keep it consistent with the C++ class name mlir::GreedyRewriteConfig,
  2. to make it shorter.

Since this type was only added a few days ago (654b3e8), it shouldn’t cause any obvious compatibility issues.


Full diff: https://github.com/llvm/llvm-project/pull/175409.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+22-24)
  • (modified) mlir/test/python/rewrite.py (+12-12)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index e143f118a1f01..2b649f79c5982 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -242,14 +242,14 @@ enum class PyGreedySimplifyRegionLevel : std::underlying_type_t<
 };
 
 /// Owning Wrapper around a GreedyRewriteDriverConfig.
-class PyGreedyRewriteDriverConfig {
+class PyGreedyRewriteConfig {
 public:
-  PyGreedyRewriteDriverConfig()
+  PyGreedyRewriteConfig()
       : config(mlirGreedyRewriteDriverConfigCreate().ptr,
-               PyGreedyRewriteDriverConfig::customDeleter) {}
-  PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
+               PyGreedyRewriteConfig::customDeleter) {}
+  PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
       : config(std::move(other.config)) {}
-  PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
+  PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
       : config(other.config) {}
 
   MlirGreedyRewriteDriverConfig get() {
@@ -470,34 +470,32 @@ void populateRewriteSubmodule(nb::module_ &m) {
           nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
 
-  nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
+  nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
       .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
-      .def_prop_rw("max_iterations",
-                   &PyGreedyRewriteDriverConfig::getMaxIterations,
-                   &PyGreedyRewriteDriverConfig::setMaxIterations,
+      .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
+                   &PyGreedyRewriteConfig::setMaxIterations,
                    "Maximum number of iterations")
       .def_prop_rw("max_num_rewrites",
-                   &PyGreedyRewriteDriverConfig::getMaxNumRewrites,
-                   &PyGreedyRewriteDriverConfig::setMaxNumRewrites,
+                   &PyGreedyRewriteConfig::getMaxNumRewrites,
+                   &PyGreedyRewriteConfig::setMaxNumRewrites,
                    "Maximum number of rewrites per iteration")
       .def_prop_rw("use_top_down_traversal",
-                   &PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
-                   &PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
+                   &PyGreedyRewriteConfig::getUseTopDownTraversal,
+                   &PyGreedyRewriteConfig::setUseTopDownTraversal,
                    "Whether to use top-down traversal")
-      .def_prop_rw("enable_folding",
-                   &PyGreedyRewriteDriverConfig::isFoldingEnabled,
-                   &PyGreedyRewriteDriverConfig::enableFolding,
+      .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
+                   &PyGreedyRewriteConfig::enableFolding,
                    "Enable or disable folding")
-      .def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness,
-                   &PyGreedyRewriteDriverConfig::setStrictness,
+      .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
+                   &PyGreedyRewriteConfig::setStrictness,
                    "Rewrite strictness level")
       .def_prop_rw("region_simplification_level",
-                   &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
-                   &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
+                   &PyGreedyRewriteConfig::getRegionSimplificationLevel,
+                   &PyGreedyRewriteConfig::setRegionSimplificationLevel,
                    "Region simplification level")
       .def_prop_rw("enable_constant_cse",
-                   &PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
-                   &PyGreedyRewriteDriverConfig::enableConstantCSE,
+                   &PyGreedyRewriteConfig::isConstantCSEEnabled,
+                   &PyGreedyRewriteConfig::enableConstantCSE,
                    "Enable or disable constant CSE");
 
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
@@ -508,7 +506,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
   m.def(
        "apply_patterns_and_fold_greedily",
        [](PyModule &module, PyFrozenRewritePatternSet &set,
-          std::optional<PyGreedyRewriteDriverConfig> config) {
+          std::optional<PyGreedyRewriteConfig> config) {
          MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
              module.get(), set.get(),
              config.has_value() ? config->get()
@@ -522,7 +520,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "apply_patterns_and_fold_greedily",
           [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
-             std::optional<PyGreedyRewriteDriverConfig> config) {
+             std::optional<PyGreedyRewriteConfig> config) {
             MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
                 op.getOperation(), set.get(),
                 config.has_value() ? config->get()
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index 43e9b761a0ea2..8ef49981a8b3c 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -89,19 +89,19 @@ def constant_1_to_2(op, rewriter):
         print(module)
 
 
-# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigCreation
+# CHECK-LABEL: TEST: testGreedyRewriteConfigCreation
 @run
-def testGreedyRewriteDriverConfigCreation():
+def testGreedyRewriteConfigCreation():
     # Test basic config creation and destruction
-    config = GreedyRewriteDriverConfig()
+    config = GreedyRewriteConfig()
     # CHECK: Config created successfully
     print("Config created successfully")
 
 
-# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigGetters
+# CHECK-LABEL: TEST: testGreedyRewriteConfigGetters
 @run
-def testGreedyRewriteDriverConfigGetters():
-    config = GreedyRewriteDriverConfig()
+def testGreedyRewriteConfigGetters():
+    config = GreedyRewriteConfig()
 
     # Set some values
     config.max_iterations = 5
@@ -139,7 +139,7 @@ def testGreedyRewriteDriverConfigGetters():
 # CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum
 @run
 def testGreedyRewriteStrictnessEnum():
-    config = GreedyRewriteDriverConfig()
+    config = GreedyRewriteConfig()
 
     # Test ANY_OP
     # CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP
@@ -163,7 +163,7 @@ def testGreedyRewriteStrictnessEnum():
 # CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum
 @run
 def testGreedySimplifyRegionLevelEnum():
-    config = GreedyRewriteDriverConfig()
+    config = GreedyRewriteConfig()
 
     # Test DISABLED
     # CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED
@@ -184,9 +184,9 @@ def testGreedySimplifyRegionLevelEnum():
     print(f"region_level AGGRESSIVE: {level}")
 
 
-# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteConfig
 @run
-def testRewriteWithGreedyRewriteDriverConfig():
+def testRewriteWithGreedyRewriteConfig():
     def constant_1_to_2(op, rewriter):
         c = op.value.value
         if c != 1:
@@ -212,7 +212,7 @@ def constant_1_to_2(op, rewriter):
             """
         )
 
-        config = GreedyRewriteDriverConfig()
+        config = GreedyRewriteConfig()
         config.enable_constant_cse = False
         apply_patterns_and_fold_greedily(module, frozen, config)
         # CHECK: %c2_i64 = arith.constant 2 : i64
@@ -220,7 +220,7 @@ def constant_1_to_2(op, rewriter):
         # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
         print(module)
 
-        config = GreedyRewriteDriverConfig()
+        config = GreedyRewriteConfig()
         config.enable_constant_cse = True
         apply_patterns_and_fold_greedily(module, frozen, config)
         # CHECK: %c2_i64 = arith.constant 2 : i64

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine since this is a fairly recent addition but in general we shouldn't rename public things even if they could be more accurately named or something (it's unnecessary "churn" that usually pisses people off).

@PragmaTwice
Copy link
Member Author

This is fine since this is a fairly recent addition but in general we shouldn't rename public things even if they could be more accurately named or something (it's unnecessary "churn" that usually pisses people off).

yah I know that. this type was introduced on Jan 2nd but it became useful (via #174913) on Jan 8th, so I think it should generally be safe to rename. indeed we should be careful about this : )

@PragmaTwice
Copy link
Member Author

merging..

@PragmaTwice PragmaTwice merged commit 9bd910d into llvm:main Jan 11, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants