diff --git a/src/callback_tracking.jl b/src/callback_tracking.jl index 85a83c880..ba5a98d89 100644 --- a/src/callback_tracking.jl +++ b/src/callback_tracking.jl @@ -206,28 +206,40 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback cb.save_positions == [1, 0] && error("save_positions=[1,0] is currently not supported.") + # event times + times = if typeof(cb) <: DiscreteCallback + cb.affect!.event_times + else + [cb.affect!.event_times; cb.affect_neg!.event_times] + end + + # precompute w and wp to generate a tape + if cb isa DiscreteCallback + w, wp = setup_w_wp(cb, true, nothing, nothing) + elseif cb isa ContinuousCallback + # compile a tape for pos and neg + w, wp = [setup_w_wp(cb, pos_neg, nothing, nothing) for pos_neg in (true, false)] + else + # cb isa VectorContinuousCallback + # compile a tape for pos/neg and all event indices + w, wp = [setup_w_wp(cb, pos_neg, event_idx, nothing) for pos_neg in (true, false) + for event_idx in 1:(cb.len)] + end + function affect!(integrator) indx, pos_neg = get_indx(cb, integrator.t) tprev = get_tprev(cb, indx, pos_neg) event_idx = cb isa VectorContinuousCallback ? get_event_idx(cb, indx, pos_neg) : nothing - w = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx - function (du, u, p, t) - _affect! = get_affect!(cb, pos_neg) - fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) - if cb isa VectorContinuousCallback - _affect!(fakeinteg, event_idx) - else - _affect!(fakeinteg) - end - du .= fakeinteg.u - end - end + # how can we make sure tprev didn't change, so precompiled tape is correct? + w, wp = setup_w_wp(cb, pos_neg, event_idx, tprev) + S = integrator.f.f # get the sensitivity function # Create a fake sensitivity function to do the vjps @assert sensealg.autojacvec isa ReverseDiffVJP + # update diffcache here to use the correct precompiled callback tape? fakeS = CallbackSensitivityFunction(w, sensealg, S.diffcache, integrator.sol.prob) du = first(get_tmp_cache(integrator)) @@ -282,18 +294,6 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback if update_p # changes in parameters if !(sensealg isa QuadratureAdjoint) - wp = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx - function (dp, p, u, t) - _affect! = get_affect!(cb, pos_neg) - fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) - if cb isa VectorContinuousCallback - _affect!(fakeinteg, event_idx) - else - _affect!(fakeinteg) - end - dp .= fakeinteg.p - end - end fakeSp = CallbackSensitivityFunction(wp, sensealg, S.diffcache, integrator.sol.prob) #vjp with Jacobin given by dw/dp before event and vector given by grad @@ -330,12 +330,6 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback end end - times = if typeof(cb) <: DiscreteCallback - cb.affect!.event_times - else - [cb.affect!.event_times; cb.affect_neg!.event_times] - end - PresetTimeCallback(times, affect!, save_positions = (false, false)) @@ -350,6 +344,35 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback cb end +function setup_w_wp(cb, pos_neg, event_idx, tprev) + w = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx + function (du, u, p, t) + _affect! = get_affect!(cb, pos_neg) + fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) + if cb isa VectorContinuousCallback + _affect!(fakeinteg, event_idx) + else + _affect!(fakeinteg) + end + du .= fakeinteg.u + end + end + + wp = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx + function (dp, p, u, t) + _affect! = get_affect!(cb, pos_neg) + fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev) + if cb isa VectorContinuousCallback + _affect!(fakeinteg, event_idx) + else + _affect!(fakeinteg) + end + dp .= fakeinteg.p + end + end + return w, wp +end + get_indx(cb::DiscreteCallback, t) = (searchsortedfirst(cb.affect!.event_times, t), true) function get_indx(cb::Union{ContinuousCallback, VectorContinuousCallback}, t) if !isempty(cb.affect!.event_times) || !isempty(cb.affect_neg!.event_times)