Skip to content

Commit

Permalink
pull out w and wp construction from affect function
Browse files Browse the repository at this point in the history
  • Loading branch information
frankschae committed Jan 5, 2023
1 parent b169a12 commit acfdcd1
Showing 1 changed file with 53 additions and 30 deletions.
83 changes: 53 additions & 30 deletions src/callback_tracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,28 +206,40 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback

cb.save_positions == [1, 0] && error("save_positions=[1,0] is currently not supported.")

# event times
times = if typeof(cb) <: DiscreteCallback
cb.affect!.event_times
else
[cb.affect!.event_times; cb.affect_neg!.event_times]
end

# precompute w and wp to generate a tape
if cb isa DiscreteCallback
w, wp = setup_w_wp(cb, true, nothing, nothing)
elseif cb isa ContinuousCallback
# compile a tape for pos and neg
w, wp = [setup_w_wp(cb, pos_neg, nothing, nothing) for pos_neg in (true, false)]
else
# cb isa VectorContinuousCallback
# compile a tape for pos/neg and all event indices
w, wp = [setup_w_wp(cb, pos_neg, event_idx, nothing) for pos_neg in (true, false)
for event_idx in 1:(cb.len)]
end

function affect!(integrator)
indx, pos_neg = get_indx(cb, integrator.t)
tprev = get_tprev(cb, indx, pos_neg)
event_idx = cb isa VectorContinuousCallback ? get_event_idx(cb, indx, pos_neg) :
nothing

w = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx
function (du, u, p, t)
_affect! = get_affect!(cb, pos_neg)
fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev)
if cb isa VectorContinuousCallback
_affect!(fakeinteg, event_idx)
else
_affect!(fakeinteg)
end
du .= fakeinteg.u
end
end
# how can we make sure tprev didn't change, so precompiled tape is correct?
w, wp = setup_w_wp(cb, pos_neg, event_idx, tprev)

S = integrator.f.f # get the sensitivity function

# Create a fake sensitivity function to do the vjps
@assert sensealg.autojacvec isa ReverseDiffVJP
# update diffcache here to use the correct precompiled callback tape?
fakeS = CallbackSensitivityFunction(w, sensealg, S.diffcache, integrator.sol.prob)

du = first(get_tmp_cache(integrator))
Expand Down Expand Up @@ -282,18 +294,6 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback
if update_p
# changes in parameters
if !(sensealg isa QuadratureAdjoint)
wp = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx
function (dp, p, u, t)
_affect! = get_affect!(cb, pos_neg)
fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev)
if cb isa VectorContinuousCallback
_affect!(fakeinteg, event_idx)
else
_affect!(fakeinteg)
end
dp .= fakeinteg.p
end
end
fakeSp = CallbackSensitivityFunction(wp, sensealg, S.diffcache,
integrator.sol.prob)
#vjp with Jacobin given by dw/dp before event and vector given by grad
Expand Down Expand Up @@ -330,12 +330,6 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback
end
end

times = if typeof(cb) <: DiscreteCallback
cb.affect!.event_times
else
[cb.affect!.event_times; cb.affect_neg!.event_times]
end

PresetTimeCallback(times,
affect!,
save_positions = (false, false))
Expand All @@ -350,6 +344,35 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback
cb
end

function setup_w_wp(cb, pos_neg, event_idx, tprev)
w = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx
function (du, u, p, t)
_affect! = get_affect!(cb, pos_neg)
fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev)
if cb isa VectorContinuousCallback
_affect!(fakeinteg, event_idx)
else
_affect!(fakeinteg)
end
du .= fakeinteg.u
end
end

wp = let tprev = tprev, pos_neg = pos_neg, event_idx = event_idx
function (dp, p, u, t)
_affect! = get_affect!(cb, pos_neg)
fakeinteg = FakeIntegrator([x for x in u], [x for x in p], t, tprev)
if cb isa VectorContinuousCallback
_affect!(fakeinteg, event_idx)
else
_affect!(fakeinteg)
end
dp .= fakeinteg.p
end
end
return w, wp
end

get_indx(cb::DiscreteCallback, t) = (searchsortedfirst(cb.affect!.event_times, t), true)
function get_indx(cb::Union{ContinuousCallback, VectorContinuousCallback}, t)
if !isempty(cb.affect!.event_times) || !isempty(cb.affect_neg!.event_times)
Expand Down

0 comments on commit acfdcd1

Please sign in to comment.