diff --git a/src/kirin/dialects/scf/lowering.py b/src/kirin/dialects/scf/lowering.py index 17519fbff..0d24b768b 100644 --- a/src/kirin/dialects/scf/lowering.py +++ b/src/kirin/dialects/scf/lowering.py @@ -10,6 +10,22 @@ @dialect.register class Lowering(lowering.FromPythonAST): + @staticmethod + def _frame_or_any_parent_has_def(frame, name) -> ir.SSAValue | None: + # NOTE: this recursively checks all parents of the current frame for the + # def. Required for nested if statements that e.g. assign to variables + # defined in outer scope + if frame is None: + return None + + if name in frame.defs: + value = frame.get(name) + if value is None: + raise lowering.BuildError(f"expected value for {name}") + return value + + return Lowering._frame_or_any_parent_has_def(frame.parent, name) + def lower_If(self, state: lowering.State, node: ast.If) -> lowering.Result: cond = state.lower(node.test).expect_one() frame = state.current_frame @@ -29,21 +45,15 @@ def lower_If(self, state: lowering.State, node: ast.If) -> lowering.Result: yield_names: list[str] = [] body_yields: list[ir.SSAValue] = [] else_yields: list[ir.SSAValue] = [] - if node.orelse: - for name in body_frame.defs.keys(): - if name in else_frame.defs: - yield_names.append(name) - body_yields.append(body_frame[name]) - else_yields.append(else_frame[name]) - else: - for name in body_frame.defs.keys(): - if name in frame.defs: - yield_names.append(name) - body_yields.append(body_frame[name]) - value = frame.get(name) - if value is None: - raise lowering.BuildError(f"expected value for {name}") - else_yields.append(value) + for name in body_frame.defs.keys(): + if name in else_frame.defs: + yield_names.append(name) + body_yields.append(body_frame[name]) + else_yields.append(else_frame[name]) + elif (value := self._frame_or_any_parent_has_def(frame, name)) is not None: + yield_names.append(name) + body_yields.append(body_frame[name]) + else_yields.append(value) if not ( body_frame.curr_block.last_stmt diff --git a/test/dialects/scf/test_ifelse.py b/test/dialects/scf/test_ifelse.py index 01c225696..e2a67d0fc 100644 --- a/test/dialects/scf/test_ifelse.py +++ b/test/dialects/scf/test_ifelse.py @@ -1,4 +1,5 @@ from kirin import ir +from kirin.passes import Fold from kirin.prelude import python_basic from kirin.dialects import scf, func, lowering @@ -22,24 +23,90 @@ def run_pass(method): return run_pass -@kernel -def main(x): - if x > 0: - y = x + 1 - z = y + 1 - return z - else: - y = x + 2 - z = y + 2 +def test_basic_if_else(): + @kernel + def main(x): + if x > 0: + y = x + 1 + z = y + 1 + return z + else: + y = x + 2 + z = y + 2 - if x < 0: - y = y + 3 - z = y + 3 - else: - y = x + 4 - z = y + 4 - return y, z + if x < 0: + y = y + 3 + z = y + 3 + else: + y = x + 4 + z = y + 4 + return y, z + main.print() + print(main(1)) -main.print() -# print(main(1)) + +def test_if_else_defs(): + + @kernel + def main(n: int): + x = 0 + + if x == n: + x = 1 + else: + y = 2 # noqa: F841 + + return x + + main.print() + + # make sure fold doesn't remove the nested def + main2 = main.similar(kernel) + Fold(main2.dialects)(main2) + + assert main(0) == 1 == main2(0) + assert main(10) == 0 == main2(4) + + main2.print() + + @kernel + def main_elif(n: int): + x = 0 + + if x == n: + x = 3 + elif x == n + 1: + x = 4 + + return x + + main_elif.print() + + main_elif2 = main_elif.similar(kernel) + Fold(main_elif2.dialects)(main_elif2) + + main_elif2.print() + + assert main_elif(0) == 3 == main_elif2(0) + assert main_elif(-1) == 4 == main_elif2(-1) + assert main_elif(5) == 0 == main_elif2(7) + + @kernel + def main_nested_if(n: int): + x = 0 + + if n > 0: + if n > 1: + if n == 3: + x = 4 + + return x + + main_nested_if.print() + + main_nested_if2 = main_nested_if.similar(kernel) + Fold(main_nested_if2.dialects)(main_nested_if2) + + assert main_nested_if(3) == 4 == main_nested_if2(3) + assert main_nested_if(10) == 0 == main_nested_if2(8)