Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ enum class PyGreedySimplifyRegionLevel : std::underlying_type_t<
};

/// Owning Wrapper around a GreedyRewriteDriverConfig.
class PyGreedyRewriteDriverConfig {
class PyGreedyRewriteConfig {
public:
PyGreedyRewriteDriverConfig()
PyGreedyRewriteConfig()
: config(mlirGreedyRewriteDriverConfigCreate().ptr,
PyGreedyRewriteDriverConfig::customDeleter) {}
PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
PyGreedyRewriteConfig::customDeleter) {}
PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
: config(std::move(other.config)) {}
PyGreedyRewriteDriverConfig(const PyGreedyRewriteDriverConfig &other) noexcept
PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
: config(other.config) {}

MlirGreedyRewriteDriverConfig get() {
Expand Down Expand Up @@ -470,34 +470,32 @@ void populateRewriteSubmodule(nb::module_ &m) {
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
.def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
.def_prop_rw("max_iterations",
&PyGreedyRewriteDriverConfig::getMaxIterations,
&PyGreedyRewriteDriverConfig::setMaxIterations,
.def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
&PyGreedyRewriteConfig::setMaxIterations,
"Maximum number of iterations")
.def_prop_rw("max_num_rewrites",
&PyGreedyRewriteDriverConfig::getMaxNumRewrites,
&PyGreedyRewriteDriverConfig::setMaxNumRewrites,
&PyGreedyRewriteConfig::getMaxNumRewrites,
&PyGreedyRewriteConfig::setMaxNumRewrites,
"Maximum number of rewrites per iteration")
.def_prop_rw("use_top_down_traversal",
&PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
&PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
&PyGreedyRewriteConfig::getUseTopDownTraversal,
&PyGreedyRewriteConfig::setUseTopDownTraversal,
"Whether to use top-down traversal")
.def_prop_rw("enable_folding",
&PyGreedyRewriteDriverConfig::isFoldingEnabled,
&PyGreedyRewriteDriverConfig::enableFolding,
.def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
&PyGreedyRewriteConfig::enableFolding,
"Enable or disable folding")
.def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness,
&PyGreedyRewriteDriverConfig::setStrictness,
.def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
&PyGreedyRewriteConfig::setStrictness,
"Rewrite strictness level")
.def_prop_rw("region_simplification_level",
&PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
&PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
&PyGreedyRewriteConfig::getRegionSimplificationLevel,
&PyGreedyRewriteConfig::setRegionSimplificationLevel,
"Region simplification level")
.def_prop_rw("enable_constant_cse",
&PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
&PyGreedyRewriteDriverConfig::enableConstantCSE,
&PyGreedyRewriteConfig::isConstantCSEEnabled,
&PyGreedyRewriteConfig::enableConstantCSE,
"Enable or disable constant CSE");

nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
Expand All @@ -508,7 +506,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set,
std::optional<PyGreedyRewriteDriverConfig> config) {
std::optional<PyGreedyRewriteConfig> config) {
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
module.get(), set.get(),
config.has_value() ? config->get()
Expand All @@ -522,7 +520,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set,
std::optional<PyGreedyRewriteDriverConfig> config) {
std::optional<PyGreedyRewriteConfig> config) {
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
config.has_value() ? config->get()
Expand Down
24 changes: 12 additions & 12 deletions mlir/test/python/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,19 @@ def constant_1_to_2(op, rewriter):
print(module)


# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigCreation
# CHECK-LABEL: TEST: testGreedyRewriteConfigCreation
@run
def testGreedyRewriteDriverConfigCreation():
def testGreedyRewriteConfigCreation():
# Test basic config creation and destruction
config = GreedyRewriteDriverConfig()
config = GreedyRewriteConfig()
# CHECK: Config created successfully
print("Config created successfully")


# CHECK-LABEL: TEST: testGreedyRewriteDriverConfigGetters
# CHECK-LABEL: TEST: testGreedyRewriteConfigGetters
@run
def testGreedyRewriteDriverConfigGetters():
config = GreedyRewriteDriverConfig()
def testGreedyRewriteConfigGetters():
config = GreedyRewriteConfig()

# Set some values
config.max_iterations = 5
Expand Down Expand Up @@ -139,7 +139,7 @@ def testGreedyRewriteDriverConfigGetters():
# CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum
@run
def testGreedyRewriteStrictnessEnum():
config = GreedyRewriteDriverConfig()
config = GreedyRewriteConfig()

# Test ANY_OP
# CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP
Expand All @@ -163,7 +163,7 @@ def testGreedyRewriteStrictnessEnum():
# CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum
@run
def testGreedySimplifyRegionLevelEnum():
config = GreedyRewriteDriverConfig()
config = GreedyRewriteConfig()

# Test DISABLED
# CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED
Expand All @@ -184,9 +184,9 @@ def testGreedySimplifyRegionLevelEnum():
print(f"region_level AGGRESSIVE: {level}")


# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteDriverConfig
# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteConfig
@run
def testRewriteWithGreedyRewriteDriverConfig():
def testRewriteWithGreedyRewriteConfig():
def constant_1_to_2(op, rewriter):
c = op.value.value
if c != 1:
Expand All @@ -212,15 +212,15 @@ def constant_1_to_2(op, rewriter):
"""
)

config = GreedyRewriteDriverConfig()
config = GreedyRewriteConfig()
config.enable_constant_cse = False
apply_patterns_and_fold_greedily(module, frozen, config)
# CHECK: %c2_i64 = arith.constant 2 : i64
# CHECK: %c2_i64_0 = arith.constant 2 : i64
# CHECK: return %c2_i64, %c2_i64_0 : i64, i64
print(module)

config = GreedyRewriteDriverConfig()
config = GreedyRewriteConfig()
config.enable_constant_cse = True
apply_patterns_and_fold_greedily(module, frozen, config)
# CHECK: %c2_i64 = arith.constant 2 : i64
Expand Down