Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions src/kirin/dialects/scf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@
@dialect.register
class Lowering(lowering.FromPythonAST):

@staticmethod
def _frame_or_any_parent_has_def(frame, name) -> ir.SSAValue | None:
# NOTE: this recursively checks all parents of the current frame for the
# def. Required for nested if statements that e.g. assign to variables
# defined in outer scope
if frame is None:
return None

Check warning on line 19 in src/kirin/dialects/scf/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/kirin/dialects/scf/lowering.py#L19

Added line #L19 was not covered by tests

if name in frame.defs:
value = frame.get(name)
if value is None:
raise lowering.BuildError(f"expected value for {name}")

Check warning on line 24 in src/kirin/dialects/scf/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/kirin/dialects/scf/lowering.py#L24

Added line #L24 was not covered by tests
return value

return Lowering._frame_or_any_parent_has_def(frame.parent, name)

def lower_If(self, state: lowering.State, node: ast.If) -> lowering.Result:
cond = state.lower(node.test).expect_one()
frame = state.current_frame
Expand All @@ -29,21 +45,15 @@
yield_names: list[str] = []
body_yields: list[ir.SSAValue] = []
else_yields: list[ir.SSAValue] = []
if node.orelse:
for name in body_frame.defs.keys():
if name in else_frame.defs:
yield_names.append(name)
body_yields.append(body_frame[name])
else_yields.append(else_frame[name])
else:
for name in body_frame.defs.keys():
if name in frame.defs:
yield_names.append(name)
body_yields.append(body_frame[name])
value = frame.get(name)
if value is None:
raise lowering.BuildError(f"expected value for {name}")
else_yields.append(value)
for name in body_frame.defs.keys():
if name in else_frame.defs:
yield_names.append(name)
body_yields.append(body_frame[name])
else_yields.append(else_frame[name])
elif (value := self._frame_or_any_parent_has_def(frame, name)) is not None:
yield_names.append(name)
body_yields.append(body_frame[name])
else_yields.append(value)

if not (
body_frame.curr_block.last_stmt
Expand Down
103 changes: 85 additions & 18 deletions test/dialects/scf/test_ifelse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from kirin import ir
from kirin.passes import Fold
from kirin.prelude import python_basic
from kirin.dialects import scf, func, lowering

Expand All @@ -22,24 +23,90 @@ def run_pass(method):
return run_pass


@kernel
def main(x):
if x > 0:
y = x + 1
z = y + 1
return z
else:
y = x + 2
z = y + 2
def test_basic_if_else():
@kernel
def main(x):
if x > 0:
y = x + 1
z = y + 1
return z
else:
y = x + 2
z = y + 2

if x < 0:
y = y + 3
z = y + 3
else:
y = x + 4
z = y + 4
return y, z
if x < 0:
y = y + 3
z = y + 3
else:
y = x + 4
z = y + 4
return y, z

main.print()
print(main(1))

main.print()
# print(main(1))

def test_if_else_defs():

@kernel
def main(n: int):
x = 0

if x == n:
x = 1
else:
y = 2 # noqa: F841

return x

main.print()

# make sure fold doesn't remove the nested def
main2 = main.similar(kernel)
Fold(main2.dialects)(main2)

assert main(0) == 1 == main2(0)
assert main(10) == 0 == main2(4)

main2.print()

@kernel
def main_elif(n: int):
x = 0

if x == n:
x = 3
elif x == n + 1:
x = 4

return x

main_elif.print()

main_elif2 = main_elif.similar(kernel)
Fold(main_elif2.dialects)(main_elif2)

main_elif2.print()

assert main_elif(0) == 3 == main_elif2(0)
assert main_elif(-1) == 4 == main_elif2(-1)
assert main_elif(5) == 0 == main_elif2(7)

@kernel
def main_nested_if(n: int):
x = 0

if n > 0:
if n > 1:
if n == 3:
x = 4

return x

main_nested_if.print()

main_nested_if2 = main_nested_if.similar(kernel)
Fold(main_nested_if2.dialects)(main_nested_if2)

assert main_nested_if(3) == 4 == main_nested_if2(3)
assert main_nested_if(10) == 0 == main_nested_if2(8)
Loading