Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 39 additions & 31 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,80 +250,83 @@ enum PyGreedySimplifyRegionLevel : std::underlying_type_t<
class PyGreedyRewriteDriverConfig {
public:
PyGreedyRewriteDriverConfig()
: config(mlirGreedyRewriteDriverConfigCreate()) {}
: config(mlirGreedyRewriteDriverConfigCreate().ptr,
PyGreedyRewriteDriverConfig::customDeleter) {}
PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
: config(other.config) {
other.config.ptr = nullptr;
}
~PyGreedyRewriteDriverConfig() {
if (config.ptr != nullptr)
mlirGreedyRewriteDriverConfigDestroy(config);
: config(std::move(other.config)) {}
PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
: config(other.config) {}

MlirGreedyRewriteDriverConfig get() {
return MlirGreedyRewriteDriverConfig{config.get()};
}
MlirGreedyRewriteDriverConfig get() { return config; }

void setMaxIterations(int64_t maxIterations) {
mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations);
}

void setMaxNumRewrites(int64_t maxNumRewrites) {
mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites);
}

void setUseTopDownTraversal(bool useTopDownTraversal) {
mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(),
useTopDownTraversal);
}

void enableFolding(bool enable) {
mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
mlirGreedyRewriteDriverConfigEnableFolding(get(), enable);
}

void setStrictness(PyGreedyRewriteStrictness strictness) {
mlirGreedyRewriteDriverConfigSetStrictness(
config, static_cast<MlirGreedyRewriteStrictness>(strictness));
get(), static_cast<MlirGreedyRewriteStrictness>(strictness));
}

void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
config, static_cast<MlirGreedySimplifyRegionLevel>(level));
get(), static_cast<MlirGreedySimplifyRegionLevel>(level));
}

void enableConstantCSE(bool enable) {
mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable);
}

int64_t getMaxIterations() {
return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
return mlirGreedyRewriteDriverConfigGetMaxIterations(get());
}

int64_t getMaxNumRewrites() {
return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get());
}

bool getUseTopDownTraversal() {
return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get());
}

bool isFoldingEnabled() {
return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get());
}

PyGreedyRewriteStrictness getStrictness() {
return static_cast<PyGreedyRewriteStrictness>(
mlirGreedyRewriteDriverConfigGetStrictness(config));
mlirGreedyRewriteDriverConfigGetStrictness(get()));
}

PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
return static_cast<PyGreedySimplifyRegionLevel>(
mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get()));
}

bool isConstantCSEEnabled() {
return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get());
}

private:
MlirGreedyRewriteDriverConfig config;
std::shared_ptr<void> config;
static void customDeleter(void *c) {
mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
}
};

/// Create the `mlir.rewrite` here.
Expand Down Expand Up @@ -509,26 +512,31 @@ void populateRewriteSubmodule(nb::module_ &m) {
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedily(
module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
[](PyModule &module, PyFrozenRewritePatternSet &set,
std::optional<PyGreedyRewriteDriverConfig> config) {
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
module.get(), set.get(),
config.has_value() ? config->get()
: mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"module"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
[](PyOperationBase &op, PyFrozenRewritePatternSet &set,
std::optional<PyGreedyRewriteDriverConfig> config) {
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
mlirGreedyRewriteDriverConfigCreate());
config.has_value() ? config->get()
: mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"op"_a, "set"_a,
"op"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/python/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,47 @@ def testGreedySimplifyRegionLevelEnum():
config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
level = config.region_simplification_level
print(f"region_level AGGRESSIVE: {level}")


# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
@run
def testRewriteWithGreedyRewriteDriverConfig():
def constant_1_to_2(op, rewriter):
c = op.value.value
if c != 1:
return True # failed to match
with rewriter.ip:
new_op = arith.constant(op.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])

with Context():
patterns = RewritePatternSet()
patterns.add(arith.ConstantOp, constant_1_to_2)
frozen = patterns.freeze()

module = ModuleOp.parse(
r"""
module {
func.func @const() -> (i64, i64) {
%0 = arith.constant 1 : i64
%1 = arith.constant 1 : i64
return %0, %1 : i64, i64
}
}
"""
)

config = GreedyRewriteDriverConfig()
config.enable_constant_cse = False
apply_patterns_and_fold_greedily(module, frozen, config)
# CHECK: %c2_i64 = arith.constant 2 : i64
# CHECK: %c2_i64_0 = arith.constant 2 : i64
# CHECK: return %c2_i64, %c2_i64_0 : i64, i64
print(module)

config = GreedyRewriteDriverConfig()
config.enable_constant_cse = True
apply_patterns_and_fold_greedily(module, frozen, config)
# CHECK: %c2_i64 = arith.constant 2 : i64
# CHECK: return %c2_i64, %c2_i64 : i64
print(module)