From 4cbad40df75c7910f565a746bef7b05c0147933a Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Mon, 24 Feb 2025 14:01:38 +0100 Subject: [PATCH] cf: extract condition to a name (#801) * cf: extract condition to a name Closes #768. * handle if with return --- lib/ReactantCore/src/ReactantCore.jl | 25 +++++++++++++++++++++---- test/control_flow.jl | 3 +-- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 4b02f42bd..acf8862a8 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -225,7 +225,11 @@ function trace_if_with_returns(mod, expr) new_expr, _, all_check_vars = trace_if( mod, expr.args[2]; store_last_line=expr.args[1], depth=1 ) + cond_name = first(all_check_vars) + original_cond = expr.args[2].args[1] + expr.args[2].args[1] = cond_name return quote + $(cond_name) = $(original_cond) if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),)) $(new_expr) else @@ -353,29 +357,42 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) ) false_branch_fn = :($(false_branch_fn_name) = $(false_branch_fn)) + cond_name = gensym(:cond) + reactant_code_block = quote $(true_branch_fn) $(false_branch_fn) ($(all_output_vars...),) = $(traced_if)( - $(cond_expr), + $(cond_name), $(true_branch_fn_name), $(false_branch_fn_name), ($(all_input_vars...),), ) end - all_check_vars = [all_input_vars..., condition_vars...] + non_reactant_code_block = Expr(:if, cond_name, original_expr.args[2]) + if length(original_expr.args) > 2 # has else block + append!(non_reactant_code_block.args, original_expr.args[3:end]) + end + + all_check_vars = [cond_name, all_input_vars..., condition_vars...] unique!(all_check_vars) depth > 0 && return ( - reactant_code_block, (true_branch_fn_name, false_branch_fn_name), all_check_vars + quote + $(cond_name) = $(cond_expr) + $(reactant_code_block) + end, + (true_branch_fn_name, false_branch_fn_name), + all_check_vars, ) return quote + $(cond_name) = $(cond_expr) if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),)) $(reactant_code_block) else - $(original_expr) + $(non_reactant_code_block) end end end diff --git a/test/control_flow.jl b/test/control_flow.jl index 8c54c7124..63f51f126 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -749,8 +749,7 @@ mutable struct TestSimulation{C,I,B} end function step!(sim) - cond = sim.clock.iteration >= sim.stop_iteration - @trace if cond + @trace if sim.clock.iteration >= sim.stop_iteration sim.running = false else sim.clock.iteration += 1 # time step