Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/kirin/dialects/ilist/rewrite/flatten_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
new_data = ()

# lhs:
if isinstance((lhs := node.lhs).owner, ilist.New):
lhs = node.lhs
rhs = node.rhs

if (
(lhs_parent := lhs.owner.parent) is None
or (rhs_parent := rhs.owner.parent) is None
or lhs_parent is not rhs_parent
):
# do not flatten across different blocks/regions
return RewriteResult()

if isinstance(lhs.owner, ilist.New):
new_data += lhs.owner.values
elif (
not isinstance(const_lhs := lhs.hints.get("const"), const.Value)
Expand All @@ -27,7 +38,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
return RewriteResult()

# rhs:
if isinstance((rhs := node.rhs).owner, ilist.New):
if isinstance(rhs.owner, ilist.New):
new_data += rhs.owner.values
elif (
not isinstance(const_rhs := rhs.hints.get("const"), const.Value)
Expand Down
42 changes: 41 additions & 1 deletion test/dialects/test_ilist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from kirin import ir, types, rewrite
from kirin.passes import aggressive
from kirin.prelude import basic_no_opt, python_basic
from kirin.prelude import structural, basic_no_opt, python_basic
from kirin.analysis import const
from kirin.dialects import py, func, ilist, lowering
from kirin.passes.typeinfer import TypeInfer
Expand Down Expand Up @@ -328,6 +328,46 @@ def test_ilist_flatten_add_both_new():
assert test_block.is_structurally_equal(expected_block)


def test_region_boundary_structural():

# Do not optimize across region boundary like if-else or basic blocks
@structural
def test_impl(n: int):
a = ilist.IList([])

if n > 0:
a = a + [n]

return a

expected_impl = test_impl.similar()
test_impl.print(hint="const")
rule = rewrite.Walk(ilist.rewrite.FlattenAdd())
rule.rewrite(test_impl.code)

assert test_impl.code.is_structurally_equal(expected_impl.code)


def test_region_boundary():

# Do not optimize across region boundary like if-else or basic blocks
@basic_no_opt
def test_impl(n: int):
a = ilist.IList([])

if n > 0:
a = a + [n]

return a

expected_impl = test_impl.similar()
test_impl.print(hint="const")
rule = rewrite.Walk(ilist.rewrite.FlattenAdd())
rule.rewrite(test_impl.code)

assert test_impl.code.is_structurally_equal(expected_impl.code)


def test_ilist_constprop():
from kirin.analysis import const

Expand Down