Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,26 +509,42 @@ void populateRewriteSubmodule(nb::module_ &m) {
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set) {
[](PyModule &module, PyFrozenRewritePatternSet &set, nb::object config) {
if (config.is_none()) {
config = nb::cast(PyGreedyRewriteDriverConfig());
}

auto status = mlirApplyPatternsAndFoldGreedily(
module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
module.get(), set.get(),
nb::cast<PyGreedyRewriteDriverConfig &>(config).get());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"module"_a, "set"_a, "config"_a = nb::none(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just do std::optional and it'll automatically turn into this same kind of thing (default will be None). Also you don't need to use nb::sig anymore after #171775 because now all the API types are the actual binded classes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh sounds good.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm it seems that this requires that PyGreedyRewriteDriverConfig can be copy-constructed (T(const T&)), but now it can only be move-constructed (T(T&&)) instead. 🤔

In file included from /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:9:
In file included from /llvm-project/mlir/lib/Bindings/Python/Rewrite.h:12:
In file included from /llvm-project/mlir/include/mlir/Bindings/Python/NanobindUtils.h:14:
In file included from /llvm-project/mlir/include/mlir/Bindings/Python/Nanobind.h:27:
In file included from /python-env/lib/python3.11/site-packages/nanobind/include/nanobind/stl/optional.h:12:
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/stl/detail/nb_optional.h:34:15: error: no matching member function for call to 'emplace'
   34 |         value.emplace(caster.operator cast_t<T>());
      |         ~~~~~~^~~~~~~
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/nb_func.h:254:41: note: in instantiation of member function 'nanobind::detail::optional_caster<std::optional<mlir::python::mlir::PyGreedyRewriteDriverConfig>>::from_python' requested here
  254 |             if ((!in.template get<Is>().from_python(args[Is], args_flags[Is],
      |                                         ^
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/nb_func.h:352:13: note: in instantiation of function template specialization 'nanobind::detail::func_create<false, true, (lambda at /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:512:8), void, mlir::python::mlir::PyModule &, mlir::python::mlir::PyFrozenRewritePatternSet &, std::optional<mlir::python::mlir::PyGreedyRewriteDriverConfig> &&, 0UL, 1UL, 2UL, nanobind::scope, nanobind::name, nanobind::arg, nanobind::arg, nanobind::arg_v, char[78]>' requested here
  352 |     detail::func_create<false, true>(
      |             ^
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/nb_func.h:409:5: note: in instantiation of function template specialization 'nanobind::cpp_function_def<void, (lambda at /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:512:8), nanobind::scope, nanobind::name, nanobind::arg, nanobind::arg, nanobind::arg_v, char[78], 0>' requested here
  409 |     cpp_function_def((detail::forward_t<Func>) f, scope(*this),
      |     ^
/llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:510:5: note: in instantiation of function template specialization 'nanobind::module_::def<(lambda at /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:512:8), nanobind::arg, nanobind::arg, nanobind::arg_v, char[78]>' requested here
  510 |   m.def(
      |     ^
/usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:914:2: note: candidate template ignored: requirement 'is_constructible_v<mlir::python::mlir::PyGreedyRewriteDriverConfig, mlir::python::mlir::PyGreedyRewriteDriverConfig &>' was not satisfied [with _Args = <Type &>]
  914 |         emplace(_Args&&... __args)
      |         ^
/usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:926:2: note: candidate template ignored: could not match 'initializer_list<_Up>' against 'Type' (aka 'mlir::python::mlir::PyGreedyRewriteDriverConfig')
  926 |         emplace(initializer_list<_Up> __il, _Args&&... __args)
      |         ^
1 error generated.
ninja: build stopped: subcommand failed.

code:

  m.def(
       "apply_patterns_and_fold_greedily",
       [](PyModule &module, PyFrozenRewritePatternSet &set,
          std::optional<PyGreedyRewriteDriverConfig> config) {
         if (!config) {
           config.emplace(PyGreedyRewriteDriverConfig());
         }

         auto status = mlirApplyPatternsAndFoldGreedily(module.get(), set.get(),
                                                        config->get());
         if (mlirLogicalResultIsFailure(status))
           throw std::runtime_error("pattern application failed to converge");
       },
       "module"_a, "set"_a, "config"_a = nb::none(),
       "Applys the given patterns to the given module greedily while folding "
       "results.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to add a copy ctor?

Copy link
Member Author

@PragmaTwice PragmaTwice Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently PyGreedyRewriteDriverConfig acts like a "unique_ptr", i.e. it holds a raw pointer from C API. It is possible but maybe we have to add a new C API for copy constructing the actual C++ object. If we just copy the pointer then double free will happen. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so then we can change to std::shared_ptr with a custom deleter? if you want i can do it but you're gonna have to push this PR up as a branch on this repo (or i can resubmit/take over the PR).

Copy link
Member Author

@PragmaTwice PragmaTwice Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm i tried std::shared_ptr but some runtime error appears:

TypeError: apply_patterns_and_fold_greedily(): incompatible function arguments. The following argument types are supported:
    1. apply_patterns_and_fold_greedily(module: mlir._mlir_libs._mlir.ir.Module, set: mlir._mlir_libs._mlir.rewrite.FrozenRewritePatternSet, config: std::shared_ptr<mlir::python::mlir::PyGreedyRewriteDriverConfig> | None = None) -> None
    2. apply_patterns_and_fold_greedily(op: mlir._mlir_libs._mlir.ir._OperationBase, set: mlir._mlir_libs._mlir.rewrite.FrozenRewritePatternSet, config: std::shared_ptr<mlir::python::mlir::PyGreedyRewriteDriverConfig> | None = None) -> None

Invoked with types: mlir._mlir_libs._mlir.ir.Module, mlir._mlir_libs._mlir.rewrite.FrozenRewritePatternSet

feel free to take over/resubmit this PR if you have time : )

// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None"),
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
[](PyOperationBase &op, PyFrozenRewritePatternSet &set,
nb::object config) {
if (config.is_none()) {
config = nb::cast(PyGreedyRewriteDriverConfig());
}

auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
mlirGreedyRewriteDriverConfigCreate());
nb::cast<PyGreedyRewriteDriverConfig &>(config).get());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"op"_a, "set"_a,
"op"_a, "set"_a, "config"_a = nb::none(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here as above

// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/python/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
level = config.region_simplification_level
print(f"region_level AGGRESSIVE: {level}")


# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
@run
def testRewriteWithGreedyRewriteDriverConfig():
def constant_1_to_2(op, rewriter):
c = op.value.value
if c != 1:
return True # failed to match
with rewriter.ip:
new_op = arith.constant(op.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])

with Context():
patterns = RewritePatternSet()
patterns.add(arith.ConstantOp, constant_1_to_2)
frozen = patterns.freeze()

module = ModuleOp.parse(
r"""
module {
func.func @const() -> (i64, i64) {
%0 = arith.constant 1 : i64
%1 = arith.constant 1 : i64
return %0, %1 : i64, i64
}
}
"""
)

config = GreedyRewriteDriverConfig()
config.enable_constant_cse = False
apply_patterns_and_fold_greedily(module, frozen, config)
# CHECK: %c2_i64 = arith.constant 2 : i64
# CHECK: %c2_i64_0 = arith.constant 2 : i64
# CHECK: return %c2_i64, %c2_i64_0 : i64, i64
print(module)

config = GreedyRewriteDriverConfig()
config.enable_constant_cse = True
apply_patterns_and_fold_greedily(module, frozen, config)
# CHECK: %c2_i64 = arith.constant 2 : i64
# CHECK: return %c2_i64, %c2_i64 : i64
print(module)