Skip to content

Commit

Permalink
add recompile kwarg to ReverseDiffVJP
Browse files Browse the repository at this point in the history
  • Loading branch information
frankschae committed Jan 5, 2023
1 parent 0d31321 commit 7f97fc8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
5 changes: 3 additions & 2 deletions src/callback_tracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback
S = integrator.f.f # get the sensitivity function

# Create a fake sensitivity function to do the vjps
@assert sensealg.autojacvec isa ReverseDiffVJP
fakeS = CallbackSensitivityFunction(w, sensealg, S.diffcache, integrator.sol.prob)

du = first(get_tmp_cache(integrator))
Expand Down Expand Up @@ -297,13 +298,13 @@ function _setup_reverse_callbacks(cb::Union{ContinuousCallback, DiscreteCallback
integrator.sol.prob)
#vjp with Jacobin given by dw/dp before event and vector given by grad
vecjacobian!(dgrad, integrator.p, grad, y, integrator.t, fakeSp;
dgrad = nothing, dy = nothing)
dgrad = nothing, dy = nothing, recompile_tape = true)
grad .= dgrad
end
end

vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS;
dgrad = dgrad, dy = dy)
dgrad = dgrad, dy = dy, recompile_tape = true)

dgrad !== nothing && (dgrad .*= -1)
if cb isa Union{ContinuousCallback, VectorContinuousCallback}
Expand Down
12 changes: 6 additions & 6 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ function jacobianmat!(JM::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number
end
function vecjacobian!(dλ, y, λ, p, t, S::TS;
dgrad = nothing, dy = nothing,
W = nothing) where {TS <: SensitivityFunction}
_vecjacobian!(dλ, y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W)
W = nothing, kwargs...) where {TS <: SensitivityFunction}
_vecjacobian!(dλ, y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W; kwargs...)
return
end

function vecjacobian(y, λ, p, t, S::TS;
dgrad = nothing, dy = nothing,
W = nothing) where {TS <: SensitivityFunction}
return _vecjacobian(y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W)
W = nothing, kwargs...) where {TS <: SensitivityFunction}
return _vecjacobian(y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W; kwargs...)
end

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy,
Expand Down Expand Up @@ -418,7 +418,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::TrackerVJP, dgrad,
end

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dgrad, dy,
W) where {TS <: SensitivityFunction}
W; recompile_tape = false) where {TS <: SensitivityFunction}
@unpack sensealg = S
prob = getprob(S)
f = unwrapped_f(S.f)
Expand All @@ -431,7 +431,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ReverseDiffVJP, dg

if prob isa Union{SteadyStateProblem, NonlinearProblem} ||
(eltype(λ) <: eltype(prob.u0) && typeof(t) <: eltype(prob.u0) &&
compile_tape(sensealg.autojacvec))
compile_tape(sensealg.autojacvec) && !recompile_tape)
tape = S.diffcache.paramjac_config

## These other cases happen due to autodiff in stiff ODE solvers
Expand Down
33 changes: 33 additions & 0 deletions test/callbacks/continuous_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,37 @@ println("Continuous Callbacks")
@test du01 dstuff[1:2]
@test dp1 dstuff[3:4]
end
@testset "Re-compile tape in ReverseDiffVJP" begin
N0 = [0.0] # initial population
p = [100.0, 50.0] # steady-state pop., M
tspan = (0.0, 10.0) # integration time

# system
f(D, u, p, t) = (D[1] = p[1] - u[1])

# when N = 3α/4 we inject M cells.
condition(u, t, integrator) = u[1] - 3 // 4 * integrator.p[1]
affect!(integrator) = integrator.u[1] += integrator.p[2]
cb = ContinuousCallback(condition, affect!, save_positions = (false, false))
prob = ODEProblem(f, N0, tspan, p)

function loss(p, cb, sensealg)
_prob = remake(prob, p = p)
_sol = solve(_prob, Tsit5(); callback = cb,
abstol = 1e-14, reltol = 1e-14,
sensealg = sensealg)
_sol[end][1]
end

gND = FiniteDiff.finite_difference_gradient(p -> loss(p, cb, nothing), p)
gFD = ForwardDiff.gradient(p -> loss(p, cb, nothing), p)
# @show gND # [0.9999546000702386, 0.00018159971904994378]
@test gNDgFD rtol=1e-10

for compile_tape in (true, false)
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(compile_tape))
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
@test gFDgZy rtol=1e-10
end
end
end

0 comments on commit 7f97fc8

Please sign in to comment.