diff --git a/src/kirin/dialects/ilist/rewrite/flatten_add.py b/src/kirin/dialects/ilist/rewrite/flatten_add.py index 0792e5136..f39b13390 100644 --- a/src/kirin/dialects/ilist/rewrite/flatten_add.py +++ b/src/kirin/dialects/ilist/rewrite/flatten_add.py @@ -18,7 +18,18 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: new_data = () # lhs: - if isinstance((lhs := node.lhs).owner, ilist.New): + lhs = node.lhs + rhs = node.rhs + + if ( + (lhs_parent := lhs.owner.parent) is None + or (rhs_parent := rhs.owner.parent) is None + or lhs_parent is not rhs_parent + ): + # do not flatten across different blocks/regions + return RewriteResult() + + if isinstance(lhs.owner, ilist.New): new_data += lhs.owner.values elif ( not isinstance(const_lhs := lhs.hints.get("const"), const.Value) @@ -27,7 +38,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return RewriteResult() # rhs: - if isinstance((rhs := node.rhs).owner, ilist.New): + if isinstance(rhs.owner, ilist.New): new_data += rhs.owner.values elif ( not isinstance(const_rhs := rhs.hints.get("const"), const.Value) diff --git a/test/dialects/test_ilist.py b/test/dialects/test_ilist.py index 7a860e4e3..68b34ba4e 100644 --- a/test/dialects/test_ilist.py +++ b/test/dialects/test_ilist.py @@ -2,7 +2,7 @@ from kirin import ir, types, rewrite from kirin.passes import aggressive -from kirin.prelude import basic_no_opt, python_basic +from kirin.prelude import structural, basic_no_opt, python_basic from kirin.analysis import const from kirin.dialects import py, func, ilist, lowering from kirin.passes.typeinfer import TypeInfer @@ -328,6 +328,46 @@ def test_ilist_flatten_add_both_new(): assert test_block.is_structurally_equal(expected_block) +def test_region_boundary_structural(): + + # Do not optimize across region boundary like if-else or basic blocks + @structural + def test_impl(n: int): + a = ilist.IList([]) + + if n > 0: + a = a + [n] + + return a + + expected_impl = test_impl.similar() + test_impl.print(hint="const") + rule = rewrite.Walk(ilist.rewrite.FlattenAdd()) + rule.rewrite(test_impl.code) + + assert test_impl.code.is_structurally_equal(expected_impl.code) + + +def test_region_boundary(): + + # Do not optimize across region boundary like if-else or basic blocks + @basic_no_opt + def test_impl(n: int): + a = ilist.IList([]) + + if n > 0: + a = a + [n] + + return a + + expected_impl = test_impl.similar() + test_impl.print(hint="const") + rule = rewrite.Walk(ilist.rewrite.FlattenAdd()) + rule.rewrite(test_impl.code) + + assert test_impl.code.is_structurally_equal(expected_impl.code) + + def test_ilist_constprop(): from kirin.analysis import const