Skip to content

Commit aed1c4e

Browse files
weinbe58Roger-luo
authored andcommitted
Fix issue in Flatten Add. (#522)
Currently, this optimization will accidentally merge lists across basic blocks or regions, which should not happen. I have added code to fix this issue and included some tests as well.
1 parent 170968c commit aed1c4e

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

src/kirin/dialects/ilist/rewrite/flatten_add.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,18 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
1818
new_data = ()
1919

2020
# lhs:
21-
if isinstance((lhs := node.lhs).owner, ilist.New):
21+
lhs = node.lhs
22+
rhs = node.rhs
23+
24+
if (
25+
(lhs_parent := lhs.owner.parent) is None
26+
or (rhs_parent := rhs.owner.parent) is None
27+
or lhs_parent is not rhs_parent
28+
):
29+
# do not flatten across different blocks/regions
30+
return RewriteResult()
31+
32+
if isinstance(lhs.owner, ilist.New):
2233
new_data += lhs.owner.values
2334
elif (
2435
not isinstance(const_lhs := lhs.hints.get("const"), const.Value)
@@ -27,7 +38,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2738
return RewriteResult()
2839

2940
# rhs:
30-
if isinstance((rhs := node.rhs).owner, ilist.New):
41+
if isinstance(rhs.owner, ilist.New):
3142
new_data += rhs.owner.values
3243
elif (
3344
not isinstance(const_rhs := rhs.hints.get("const"), const.Value)

test/dialects/test_ilist.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from kirin import ir, types, rewrite
44
from kirin.passes import aggressive
5-
from kirin.prelude import basic_no_opt, python_basic
5+
from kirin.prelude import structural, basic_no_opt, python_basic
66
from kirin.analysis import const
77
from kirin.dialects import py, func, ilist, lowering
88
from kirin.passes.typeinfer import TypeInfer
@@ -328,6 +328,46 @@ def test_ilist_flatten_add_both_new():
328328
assert test_block.is_equal(expected_block)
329329

330330

331+
def test_region_boundary_structural():
332+
333+
# Do not optimize across region boundary like if-else or basic blocks
334+
@structural
335+
def test_impl(n: int):
336+
a = ilist.IList([])
337+
338+
if n > 0:
339+
a = a + [n]
340+
341+
return a
342+
343+
expected_impl = test_impl.similar()
344+
test_impl.print(hint="const")
345+
rule = rewrite.Walk(ilist.rewrite.FlattenAdd())
346+
rule.rewrite(test_impl.code)
347+
348+
assert test_impl.code.is_structurally_equal(expected_impl.code)
349+
350+
351+
def test_region_boundary():
352+
353+
# Do not optimize across region boundary like if-else or basic blocks
354+
@basic_no_opt
355+
def test_impl(n: int):
356+
a = ilist.IList([])
357+
358+
if n > 0:
359+
a = a + [n]
360+
361+
return a
362+
363+
expected_impl = test_impl.similar()
364+
test_impl.print(hint="const")
365+
rule = rewrite.Walk(ilist.rewrite.FlattenAdd())
366+
rule.rewrite(test_impl.code)
367+
368+
assert test_impl.code.is_structurally_equal(expected_impl.code)
369+
370+
331371
def test_ilist_constprop():
332372
from kirin.analysis import const
333373

0 commit comments

Comments
 (0)