-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patterns_and_fold_greedily
#174785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…rns_and_fold_greedily
|
@llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesWe already have Before: def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> NoneAfter: def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> NoneFull diff: https://github.com/llvm/llvm-project/pull/174785.diff 2 Files Affected:
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9830c277ac147..faab66d5ce4e5 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -509,26 +509,42 @@ void populateRewriteSubmodule(nb::module_ &m) {
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
- [](PyModule &module, PyFrozenRewritePatternSet &set) {
+ [](PyModule &module, PyFrozenRewritePatternSet &set, nb::object config) {
+ if (config.is_none()) {
+ config = nb::cast(PyGreedyRewriteDriverConfig());
+ }
+
auto status = mlirApplyPatternsAndFoldGreedily(
- module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
+ module.get(), set.get(),
+ nb::cast<PyGreedyRewriteDriverConfig &>(config).get());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
- "module"_a, "set"_a,
+ "module"_a, "set"_a, "config"_a = nb::none(),
+ // clang-format off
+ nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None"),
+ // clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
+ [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
+ nb::object config) {
+ if (config.is_none()) {
+ config = nb::cast(PyGreedyRewriteDriverConfig());
+ }
+
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
- mlirGreedyRewriteDriverConfigCreate());
+ nb::cast<PyGreedyRewriteDriverConfig &>(config).get());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
- "op"_a, "set"_a,
+ "op"_a, "set"_a, "config"_a = nb::none(),
+ // clang-format off
+ nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None"),
+ // clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index a2fbbde38b8c0..43e9b761a0ea2 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
level = config.region_simplification_level
print(f"region_level AGGRESSIVE: {level}")
+
+
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+@run
+def testRewriteWithGreedyRewriteDriverConfig():
+ def constant_1_to_2(op, rewriter):
+ c = op.value.value
+ if c != 1:
+ return True # failed to match
+ with rewriter.ip:
+ new_op = arith.constant(op.type, 2, loc=op.location)
+ rewriter.replace_op(op, [new_op])
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.ConstantOp, constant_1_to_2)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @const() -> (i64, i64) {
+ %0 = arith.constant 1 : i64
+ %1 = arith.constant 1 : i64
+ return %0, %1 : i64, i64
+ }
+ }
+ """
+ )
+
+ config = GreedyRewriteDriverConfig()
+ config.enable_constant_cse = False
+ apply_patterns_and_fold_greedily(module, frozen, config)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: %c2_i64_0 = arith.constant 2 : i64
+ # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
+ print(module)
+
+ config = GreedyRewriteDriverConfig()
+ config.enable_constant_cse = True
+ apply_patterns_and_fold_greedily(module, frozen, config)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: return %c2_i64, %c2_i64 : i64
+ print(module)
|
| throw std::runtime_error("pattern application failed to converge"); | ||
| }, | ||
| "module"_a, "set"_a, | ||
| "module"_a, "set"_a, "config"_a = nb::none(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just do std::optional and it'll automatically turn into this same kind of thing (default will be None). Also you don't need to use nb::sig anymore after #171775 because now all the API types are the actual binded classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh sounds good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm it seems that this requires that PyGreedyRewriteDriverConfig can be copy-constructed (T(const T&)), but now it can only be move-constructed (T(T&&)) instead. 🤔
In file included from /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:9:
In file included from /llvm-project/mlir/lib/Bindings/Python/Rewrite.h:12:
In file included from /llvm-project/mlir/include/mlir/Bindings/Python/NanobindUtils.h:14:
In file included from /llvm-project/mlir/include/mlir/Bindings/Python/Nanobind.h:27:
In file included from /python-env/lib/python3.11/site-packages/nanobind/include/nanobind/stl/optional.h:12:
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/stl/detail/nb_optional.h:34:15: error: no matching member function for call to 'emplace'
34 | value.emplace(caster.operator cast_t<T>());
| ~~~~~~^~~~~~~
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/nb_func.h:254:41: note: in instantiation of member function 'nanobind::detail::optional_caster<std::optional<mlir::python::mlir::PyGreedyRewriteDriverConfig>>::from_python' requested here
254 | if ((!in.template get<Is>().from_python(args[Is], args_flags[Is],
| ^
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/nb_func.h:352:13: note: in instantiation of function template specialization 'nanobind::detail::func_create<false, true, (lambda at /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:512:8), void, mlir::python::mlir::PyModule &, mlir::python::mlir::PyFrozenRewritePatternSet &, std::optional<mlir::python::mlir::PyGreedyRewriteDriverConfig> &&, 0UL, 1UL, 2UL, nanobind::scope, nanobind::name, nanobind::arg, nanobind::arg, nanobind::arg_v, char[78]>' requested here
352 | detail::func_create<false, true>(
| ^
/python-env/lib/python3.11/site-packages/nanobind/include/nanobind/nb_func.h:409:5: note: in instantiation of function template specialization 'nanobind::cpp_function_def<void, (lambda at /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:512:8), nanobind::scope, nanobind::name, nanobind::arg, nanobind::arg, nanobind::arg_v, char[78], 0>' requested here
409 | cpp_function_def((detail::forward_t<Func>) f, scope(*this),
| ^
/llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:510:5: note: in instantiation of function template specialization 'nanobind::module_::def<(lambda at /llvm-project/mlir/lib/Bindings/Python/Rewrite.cpp:512:8), nanobind::arg, nanobind::arg, nanobind::arg_v, char[78]>' requested here
510 | m.def(
| ^
/usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:914:2: note: candidate template ignored: requirement 'is_constructible_v<mlir::python::mlir::PyGreedyRewriteDriverConfig, mlir::python::mlir::PyGreedyRewriteDriverConfig &>' was not satisfied [with _Args = <Type &>]
914 | emplace(_Args&&... __args)
| ^
/usr/lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/optional:926:2: note: candidate template ignored: could not match 'initializer_list<_Up>' against 'Type' (aka 'mlir::python::mlir::PyGreedyRewriteDriverConfig')
926 | emplace(initializer_list<_Up> __il, _Args&&... __args)
| ^
1 error generated.
ninja: build stopped: subcommand failed.
code:
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set,
std::optional<PyGreedyRewriteDriverConfig> config) {
if (!config) {
config.emplace(PyGreedyRewriteDriverConfig());
}
auto status = mlirApplyPatternsAndFoldGreedily(module.get(), set.get(),
config->get());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given module greedily while folding "
"results.")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not possible to add a copy ctor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently PyGreedyRewriteDriverConfig acts like a "unique_ptr", i.e. it holds a raw pointer from C API. It is possible but maybe we have to add a new C API for copy constructing the actual C++ object. If we just copy the pointer then double free will happen. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so then we can change to std::shared_ptr with a custom deleter? if you want i can do it but you're gonna have to push this PR up as a branch on this repo (or i can resubmit/take over the PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm i tried std::shared_ptr but some runtime error appears:
TypeError: apply_patterns_and_fold_greedily(): incompatible function arguments. The following argument types are supported:
1. apply_patterns_and_fold_greedily(module: mlir._mlir_libs._mlir.ir.Module, set: mlir._mlir_libs._mlir.rewrite.FrozenRewritePatternSet, config: std::shared_ptr<mlir::python::mlir::PyGreedyRewriteDriverConfig> | None = None) -> None
2. apply_patterns_and_fold_greedily(op: mlir._mlir_libs._mlir.ir._OperationBase, set: mlir._mlir_libs._mlir.rewrite.FrozenRewritePatternSet, config: std::shared_ptr<mlir::python::mlir::PyGreedyRewriteDriverConfig> | None = None) -> None
Invoked with types: mlir._mlir_libs._mlir.ir.Module, mlir._mlir_libs._mlir.rewrite.FrozenRewritePatternSet
feel free to take over/resubmit this PR if you have time : )
| "pattern application failed to converge"); | ||
| }, | ||
| "op"_a, "set"_a, | ||
| "op"_a, "set"_a, "config"_a = nb::none(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here as above
|
Superseded by #174913. |
…rns_and_fold_greedily (#174913) We already have `GreedyRewriteDriverConfig` on the Python side, but it hasn’t yet been exposed as a parameter of `apply_patterns_and_fold_greedily`. This PR does that. Before: ```python def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None ``` After: ```python def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None ``` Note this PR is adapted from #174785 but using `std::optional` instead of `nb::object`. Note, this required refactoring `PyGreedyRewriteDriverConfig` to have a `std::shared_ptr` so that it could support a copy-ctor. Co-authored-by: PragmaTwice <[email protected]>
…apply_patterns_and_fold_greedily (#174913)
We already have `GreedyRewriteDriverConfig` on the Python side, but it
hasn’t yet been exposed as a parameter of
`apply_patterns_and_fold_greedily`. This PR does that.
Before:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None
```
After:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> None
```
Note this PR is adapted from
llvm/llvm-project#174785 but using
`std::optional` instead of `nb::object`. Note, this required refactoring
`PyGreedyRewriteDriverConfig` to have a `std::shared_ptr` so that it
could support a copy-ctor.
Co-authored-by: PragmaTwice <[email protected]>
…rns_and_fold_greedily (#174913)
We already have `GreedyRewriteDriverConfig` on the Python side, but it
hasn’t yet been exposed as a parameter of
`apply_patterns_and_fold_greedily`. This PR does that.
Before:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None
```
After:
```python
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> None
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet,
config: GreedyRewriteDriverConfig | None = None) -> None
```
Note this PR is adapted from
llvm/llvm-project#174785 but using
`std::optional` instead of `nb::object`. Note, this required refactoring
`PyGreedyRewriteDriverConfig` to have a `std::shared_ptr` so that it
could support a copy-ctor.
Co-authored-by: PragmaTwice <[email protected]>
…rns_and_fold_greedily (llvm#174913) We already have `GreedyRewriteDriverConfig` on the Python side, but it hasn’t yet been exposed as a parameter of `apply_patterns_and_fold_greedily`. This PR does that. Before: ```python def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None ``` After: ```python def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None ``` Note this PR is adapted from llvm#174785 but using `std::optional` instead of `nb::object`. Note, this required refactoring `PyGreedyRewriteDriverConfig` to have a `std::shared_ptr` so that it could support a copy-ctor. Co-authored-by: PragmaTwice <[email protected]>
We already have
GreedyRewriteDriverConfigon the Python side, but it hasn’t yet been exposed as a parameter ofapply_patterns_and_fold_greedily. This PR does that.Before:
After: