diff --git a/src/kirin/dialects/scf/scf2cf.py b/src/kirin/dialects/scf/scf2cf.py new file mode 100644 index 000000000..a1e6ac5c5 --- /dev/null +++ b/src/kirin/dialects/scf/scf2cf.py @@ -0,0 +1,190 @@ +from ... import ir +from .stmts import For, Yield, IfElse +from ...rewrite.abc import RewriteRule, RewriteResult + + +class ScfToCfRule(RewriteRule): + + def rewrite_ifelse( + self, node: ir.Region, block_idx: int, curr_block: ir.Block, stmt: IfElse + ): + from kirin.dialects import cf + + # create a new block for entering the if statement + entry_block = ir.Block() + for arg in curr_block.args: + arg.replace_by(entry_block.args.append_from(arg.type, arg.name)) + + # delete the args of the old block and replace with the result of the # if statement + for arg in curr_block.args: + curr_block.args.delete(arg) + + for arg in stmt.results: + arg.replace_by(curr_block.args.append_from(arg.type, arg.name)) + + (then_block := stmt.then_body.blocks[0]).detach() + (else_block := stmt.else_body.blocks[0]).detach() + + entry_block.stmts.append( + cf.ConditionalBranch( + cond=stmt.cond, + then_arguments=tuple(stmt.args), + then_successor=then_block, + else_arguments=tuple(stmt.args), + else_successor=else_block, + ) + ) + + # insert the then/else blocks and add branch to the current block + # if the last statement of the then block is a yield + if isinstance(last_stmt := else_block.last_stmt, Yield): + last_stmt.replace_by( + cf.Branch( + arguments=tuple(last_stmt.args), + successor=curr_block, + ) + ) + + if isinstance(last_stmt := then_block.last_stmt, Yield): + last_stmt.replace_by( + cf.Branch( + arguments=tuple(last_stmt.args), + successor=curr_block, + ) + ) + + node.blocks.insert(block_idx, curr_block) + node.blocks.insert(block_idx, else_block) + node.blocks.insert(block_idx, then_block) + + curr_stmt = stmt + next_stmt = stmt.prev_stmt + curr_stmt.delete() + + return next_stmt, entry_block + + def rewrite_for( + self, node: ir.Region, block_idx: int, curr_block: ir.Block, stmt: For + ): + from kirin.dialects import cf, py, func + + (body_block := stmt.body.blocks[0]).detach() + + entry_block = ir.Block() + for arg in curr_block.args: + arg.replace_by(entry_block.args.append_from(arg.type, arg.name)) + + # Get iterator from iterable object + entry_block.stmts.append(iterable_stmt := py.iterable.Iter(stmt.iterable)) + entry_block.stmts.append(const_none := func.ConstantNone()) + last_stmt = entry_block.last_stmt + entry_block.stmts.append( + next_stmt := py.iterable.Next(iterable_stmt.expect_one_result()) + ) + entry_block.stmts.append( + loop_cmp := py.cmp.Is(next_stmt.expect_one_result(), const_none.result) + ) + entry_block.stmts.append( + cf.ConditionalBranch( + cond=loop_cmp.result, + then_arguments=tuple(stmt.initializers), + then_successor=curr_block, + else_arguments=(next_stmt.expect_one_result(),) + + tuple(stmt.initializers), + else_successor=body_block, + ) + ) + + for arg in curr_block.args: + curr_block.args.delete(arg) + + for arg in stmt.results: + arg.replace_by(curr_block.args.append_from(arg.type, arg.name)) + + if isinstance(last_stmt := body_block.last_stmt, Yield): + ( + next_stmt := py.iterable.Next(iterable_stmt.expect_one_result()) + ).insert_before(last_stmt) + ( + loop_cmp := py.cmp.Is(next_stmt.expect_one_result(), const_none.result) + ).insert_before(last_stmt) + last_stmt.replace_by( + cf.ConditionalBranch( + cond=loop_cmp.result, + else_arguments=(next_stmt.expect_one_result(),) + + tuple(last_stmt.args), + else_successor=body_block, + then_arguments=tuple(last_stmt.args), + then_successor=curr_block, + ) + ) + + # insert the body block and add branch to the current block + node.blocks.insert(block_idx, curr_block) + node.blocks.insert(block_idx, body_block) + + curr_stmt = stmt + next_stmt = stmt.prev_stmt + curr_stmt.delete() + + return next_stmt, entry_block + + def rewrite_ssacfg(self, node: ir.Region): + + has_done_something = False + + for block_idx in range(len(node.blocks)): + + block = node.blocks.pop(block_idx) + + stmt = block.last_stmt + if stmt is None: + continue + + curr_block = ir.Block() + + for arg in block.args: + arg.replace_by(curr_block.args.append_from(arg.type, arg.name)) + + while stmt is not None: + if isinstance(stmt, For): + has_done_something = True + stmt, curr_block = self.rewrite_for( + node, block_idx, curr_block, stmt + ) + + elif isinstance(stmt, IfElse): + has_done_something = True + stmt, curr_block = self.rewrite_ifelse( + node, block_idx, curr_block, stmt + ) + else: + curr_stmt = stmt + stmt = stmt.prev_stmt + curr_stmt.detach() + + if curr_block.first_stmt is None: + curr_block.stmts.append(curr_stmt) + else: + curr_stmt.insert_before(curr_block.first_stmt) + + # if the last block is empty, remove it + if curr_block.parent is None and curr_block.first_stmt is not None: + node.blocks.insert(block_idx, curr_block) + + return RewriteResult(has_done_something=has_done_something) + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if ( + isinstance(node, (For, IfElse)) + or not node.has_trait(ir.HasCFG) + and not node.has_trait(ir.SSACFG) + ): + # do not do rewrite in scf regions + return RewriteResult() + + result = RewriteResult() + for region in node.regions: + result = result.join(self.rewrite_ssacfg(region)) + + return result diff --git a/test/dialects/scf/test_scf2cf.py b/test/dialects/scf/test_scf2cf.py new file mode 100644 index 000000000..3040ec242 --- /dev/null +++ b/test/dialects/scf/test_scf2cf.py @@ -0,0 +1,199 @@ +from kirin import ir, types +from kirin.prelude import basic, structural +from kirin.rewrite import Walk +from kirin.dialects import cf, py, func, ilist +from kirin.dialects.scf import scf2cf + + +def test_scf2cf_if_1(): + + @structural(typeinfer=True) + def test(b: bool): + if b: + b = False + else: + b = not b + + return b + + rule = Walk(scf2cf.ScfToCfRule()) + rule.rewrite(test.code) + test = test.similar(basic) + + excpected_callable_region = ir.Region( + [ + branch_block := ir.Block(), + then_block := ir.Block(), + else_block := ir.Block(), + join_block := ir.Block(), + ] + ) + + branch_block.args.append_from(types.MethodType, "self") + b = branch_block.args.append_from(types.Bool, "b") + branch_block.stmts.append( + cf.ConditionalBranch( + cond=b, + then_arguments=(b,), + then_successor=then_block, + else_arguments=(b,), + else_successor=else_block, + ) + ) + + then_block.args.append_from(types.Bool, "b") + then_block.stmts.append(stmt := py.Constant(value=False)) + then_block.stmts.append( + cf.Branch( + arguments=(stmt.result,), + successor=join_block, + ) + ) + + b = else_block.args.append_from(types.Bool) + else_block.stmts.append(stmt := py.unary.Not(b)) + else_block.stmts.append( + cf.Branch( + arguments=(stmt.result,), + successor=join_block, + ) + ) + ret = join_block.args.append_from(types.Bool) + join_block.stmts.append(func.Return(ret)) + + expected_code = func.Function( + sym_name="test", + slots=("b",), + signature=func.Signature( + output=types.Bool, + inputs=(types.Bool,), + ), + body=excpected_callable_region, + ) + + expected_test = ir.Method( + dialects=basic, + code=expected_code, + ) + + if basic.run_pass is not None: + basic.run_pass(expected_test, typeinfer=True) + basic.run_pass(test, typeinfer=True) + + assert expected_test.callable_region.is_structurally_equal(test.callable_region) + + +def test_scf2cf_for_1(): + + @structural(typeinfer=True, fold=False) + def test(): + j = 0 + for i in range(10): + j = j + 1 + + return j + + rule = Walk(scf2cf.ScfToCfRule()) + rule.rewrite(test.code) + test = test.similar(basic) + + expected_callable_region = ir.Region( + [ + entry_block := ir.Block(), + body_block := ir.Block(), + exit_block := ir.Block(), + ] + ) + + entry_block.args.append_from(types.MethodType, "self") + entry_block.stmts.append(j_start := py.Constant(value=0)) + j_start.result.name = "j" + entry_block.stmts.append(iter_start := py.Constant(value=0)) + entry_block.stmts.append(iter_end := py.Constant(value=10)) + entry_block.stmts.append(iter_step := py.Constant(value=1)) + entry_block.stmts.append( + range_stmt := ilist.stmts.Range( + start=iter_start.result, + stop=iter_end.result, + step=iter_step.result, + ) + ) + range_stmt.result.type = ilist.IListType[types.Int, types.Literal(10)] + entry_block.stmts.append(iterable_stmt := py.iterable.Iter(range_stmt.result)) + entry_block.stmts.append(none_stmt := func.ConstantNone()) + entry_block.stmts.append( + first_iter := py.iterable.Next(iterable_stmt.expect_one_result()) + ) + entry_block.stmts.append( + loop_cmp := py.cmp.Is(first_iter.expect_one_result(), none_stmt.result) + ) + entry_block.stmts.append( + cf.ConditionalBranch( + cond=loop_cmp.result, + then_arguments=(j_start.result, j_start.result), + then_successor=exit_block, + else_arguments=( + first_iter.expect_one_result(), + j_start.result, + j_start.result, + ), + else_successor=body_block, + ) + ) + + body_block.args.append_from(types.Int, "i") + body_block.args.append_from(types.Int, "j") + body_block.args.append_from(types.Int, "j") + + body_block.stmts.append(one_stmt := py.Constant(value=1)) + body_block.stmts.append( + j_add := py.binop.Add( + lhs=body_block.args[1], + rhs=one_stmt.result, + ) + ) + j_add.result.name = "j" + j_add.result.type = types.Int + body_block.stmts.append( + next_iter := py.iterable.Next(iterable_stmt.expect_one_result()) + ) + body_block.stmts.append( + loop_cmp := py.cmp.Is(next_iter.expect_one_result(), none_stmt.result) + ) + body_block.stmts.append( + cf.ConditionalBranch( + cond=loop_cmp.result, + then_arguments=(j_add.result, j_add.result), + then_successor=exit_block, + else_arguments=(next_iter.expect_one_result(), j_add.result, j_add.result), + else_successor=body_block, + ) + ) + + exit_block.args.append_from(types.Int, "j") + exit_block.args.append_from(types.Int, "j") + exit_block.stmts.append(func.Return(exit_block.args[1])) + + expected_code = func.Function( + sym_name="test", + slots=(), + signature=func.Signature( + output=types.Literal(10), + inputs=(), + ), + body=expected_callable_region, + ) + + expected_test = ir.Method( + dialects=basic, + code=expected_code, + ) + + test.print() + expected_test.print() + + if basic.run_pass is not None: + basic.run_pass(test, typeinfer=True, fold=False) + basic.run_pass(expected_test, typeinfer=True, fold=False) + + assert expected_test.callable_region.is_structurally_equal(test.callable_region)