Skip to content

Commit

Permalink
cf: extract condition to a name (#801)
Browse files Browse the repository at this point in the history
* cf: extract condition to a name

Closes #768.

* handle if with return
  • Loading branch information
Pangoraw authored Feb 24, 2025
1 parent c9f4586 commit 4cbad40
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
25 changes: 21 additions & 4 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,11 @@ function trace_if_with_returns(mod, expr)
new_expr, _, all_check_vars = trace_if(
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
)
cond_name = first(all_check_vars)
original_cond = expr.args[2].args[1]
expr.args[2].args[1] = cond_name
return quote
$(cond_name) = $(original_cond)
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
$(new_expr)
else
Expand Down Expand Up @@ -353,29 +357,42 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
)
false_branch_fn = :($(false_branch_fn_name) = $(false_branch_fn))

cond_name = gensym(:cond)

reactant_code_block = quote
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(cond_expr),
$(cond_name),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
)
end

all_check_vars = [all_input_vars..., condition_vars...]
non_reactant_code_block = Expr(:if, cond_name, original_expr.args[2])
if length(original_expr.args) > 2 # has else block
append!(non_reactant_code_block.args, original_expr.args[3:end])
end

all_check_vars = [cond_name, all_input_vars..., condition_vars...]
unique!(all_check_vars)

depth > 0 && return (
reactant_code_block, (true_branch_fn_name, false_branch_fn_name), all_check_vars
quote
$(cond_name) = $(cond_expr)
$(reactant_code_block)
end,
(true_branch_fn_name, false_branch_fn_name),
all_check_vars,
)

return quote
$(cond_name) = $(cond_expr)
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
else
$(original_expr)
$(non_reactant_code_block)
end
end
end
Expand Down
3 changes: 1 addition & 2 deletions test/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,7 @@ mutable struct TestSimulation{C,I,B}
end

function step!(sim)
cond = sim.clock.iteration >= sim.stop_iteration
@trace if cond
@trace if sim.clock.iteration >= sim.stop_iteration
sim.running = false
else
sim.clock.iteration += 1 # time step
Expand Down

0 comments on commit 4cbad40

Please sign in to comment.