-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patterns_and_fold_greedily #174913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Python] Add GreedyRewriteDriverConfig parameter to apply_patterns_and_fold_greedily #174913
Conversation
b137aed to
2f4367f
Compare
|
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) Changesadapted from #174785 but using Full diff: https://github.com/llvm/llvm-project/pull/174913.diff 2 Files Affected:
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 9830c277ac147..f908433268555 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -250,80 +250,82 @@ enum PyGreedySimplifyRegionLevel : std::underlying_type_t<
class PyGreedyRewriteDriverConfig {
public:
PyGreedyRewriteDriverConfig()
- : config(mlirGreedyRewriteDriverConfigCreate()) {}
+ : config(mlirGreedyRewriteDriverConfigCreate().ptr, customDeleter) {}
PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
- : config(other.config) {
- other.config.ptr = nullptr;
- }
- ~PyGreedyRewriteDriverConfig() {
- if (config.ptr != nullptr)
- mlirGreedyRewriteDriverConfigDestroy(config);
+ : config(std::move(other.config)) {}
+ PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
+ : config(other.config) {}
+
+ MlirGreedyRewriteDriverConfig get() {
+ return MlirGreedyRewriteDriverConfig{config.get()};
}
- MlirGreedyRewriteDriverConfig get() { return config; }
void setMaxIterations(int64_t maxIterations) {
- mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+ mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations);
}
void setMaxNumRewrites(int64_t maxNumRewrites) {
- mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+ mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites);
}
void setUseTopDownTraversal(bool useTopDownTraversal) {
- mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
+ mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(),
useTopDownTraversal);
}
void enableFolding(bool enable) {
- mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+ mlirGreedyRewriteDriverConfigEnableFolding(get(), enable);
}
void setStrictness(PyGreedyRewriteStrictness strictness) {
mlirGreedyRewriteDriverConfigSetStrictness(
- config, static_cast<MlirGreedyRewriteStrictness>(strictness));
+ get(), static_cast<MlirGreedyRewriteStrictness>(strictness));
}
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
- config, static_cast<MlirGreedySimplifyRegionLevel>(level));
+ get(), static_cast<MlirGreedySimplifyRegionLevel>(level));
}
void enableConstantCSE(bool enable) {
- mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+ mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable);
}
int64_t getMaxIterations() {
- return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+ return mlirGreedyRewriteDriverConfigGetMaxIterations(get());
}
int64_t getMaxNumRewrites() {
- return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+ return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get());
}
bool getUseTopDownTraversal() {
- return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+ return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get());
}
bool isFoldingEnabled() {
- return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+ return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get());
}
PyGreedyRewriteStrictness getStrictness() {
return static_cast<PyGreedyRewriteStrictness>(
- mlirGreedyRewriteDriverConfigGetStrictness(config));
+ mlirGreedyRewriteDriverConfigGetStrictness(get()));
}
PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
return static_cast<PyGreedySimplifyRegionLevel>(
- mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+ mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get()));
}
bool isConstantCSEEnabled() {
- return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+ return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get());
}
private:
- MlirGreedyRewriteDriverConfig config;
+ std::shared_ptr<void> config;
+ static void customDeleter(void *c) {
+ mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
+ }
};
/// Create the `mlir.rewrite` here.
@@ -509,26 +511,31 @@ void populateRewriteSubmodule(nb::module_ &m) {
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
- [](PyModule &module, PyFrozenRewritePatternSet &set) {
- auto status = mlirApplyPatternsAndFoldGreedily(
- module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
+ [](PyModule &module, PyFrozenRewritePatternSet &set,
+ std::optional<PyGreedyRewriteDriverConfig> config) {
+ MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
+ module.get(), set.get(),
+ config.has_value() ? config->get()
+ : mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
- "module"_a, "set"_a,
+ "module"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
- [](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
- auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
+ [](PyOperationBase &op, PyFrozenRewritePatternSet &set,
+ std::optional<PyGreedyRewriteDriverConfig> config) {
+ MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
- mlirGreedyRewriteDriverConfigCreate());
+ config.has_value() ? config->get()
+ : mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
- "op"_a, "set"_a,
+ "op"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index a2fbbde38b8c0..43e9b761a0ea2 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
level = config.region_simplification_level
print(f"region_level AGGRESSIVE: {level}")
+
+
+# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
+@run
+def testRewriteWithGreedyRewriteDriverConfig():
+ def constant_1_to_2(op, rewriter):
+ c = op.value.value
+ if c != 1:
+ return True # failed to match
+ with rewriter.ip:
+ new_op = arith.constant(op.type, 2, loc=op.location)
+ rewriter.replace_op(op, [new_op])
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.ConstantOp, constant_1_to_2)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @const() -> (i64, i64) {
+ %0 = arith.constant 1 : i64
+ %1 = arith.constant 1 : i64
+ return %0, %1 : i64, i64
+ }
+ }
+ """
+ )
+
+ config = GreedyRewriteDriverConfig()
+ config.enable_constant_cse = False
+ apply_patterns_and_fold_greedily(module, frozen, config)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: %c2_i64_0 = arith.constant 2 : i64
+ # CHECK: return %c2_i64, %c2_i64_0 : i64, i64
+ print(module)
+
+ config = GreedyRewriteDriverConfig()
+ config.enable_constant_cse = True
+ apply_patterns_and_fold_greedily(module, frozen, config)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: return %c2_i64, %c2_i64 : i64
+ print(module)
|
…rns_and_fold_greedily Co-authored-by: PragmaTwice <[email protected]>
2f4367f to
e41c852
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thank you.
…rns_and_fold_greedily (llvm#174913) We already have `GreedyRewriteDriverConfig` on the Python side, but it hasn’t yet been exposed as a parameter of `apply_patterns_and_fold_greedily`. This PR does that. Before: ```python def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet) -> None def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None ``` After: ```python def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet, config: GreedyRewriteDriverConfig | None = None) -> None ``` Note this PR is adapted from llvm#174785 but using `std::optional` instead of `nb::object`. Note, this required refactoring `PyGreedyRewriteDriverConfig` to have a `std::shared_ptr` so that it could support a copy-ctor. Co-authored-by: PragmaTwice <[email protected]>
We already have
GreedyRewriteDriverConfigon the Python side, but it hasn’t yet been exposed as a parameter ofapply_patterns_and_fold_greedily. This PR does that.Before:
After:
Note this PR is adapted from #174785 but using
std::optionalinstead ofnb::object. Note, this required refactoringPyGreedyRewriteDriverConfigto have astd::shared_ptrso that it could support a copy-ctor.