Skip to content

Commit

Permalink
Merge pull request #811 from SciML/nullparameters
Browse files Browse the repository at this point in the history
Handle nullparameters cases in auto vjp
  • Loading branch information
ChrisRackauckas authored Mar 30, 2023
2 parents c0c1a78 + e4c6df3 commit 185925b
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ function inplace_vjp(prob, u0, p, verbose)

vjp = try
f = unwrapped_f(prob.f)
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
return vec(du1)
if p === nothing || p isa SciMLBase.NullParameters
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
return vec(du1)
end
else
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t
du1 = similar(u, size(u))
f(du1, u, p, first(t))
return vec(du1)
end
end
ReverseDiffVJP(compile)
catch e
Expand Down Expand Up @@ -73,7 +81,11 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem,
# so if out-of-place, try Zygote

vjp = try
Zygote.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
if p === nothing || p isa SciMLBase.NullParameters
Zygote.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
else
Zygote.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
end
ZygoteVJP()
catch e
if verbose
Expand All @@ -86,7 +98,11 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem,

if vjp == false
vjp = try
ReverseDiff.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
if p === nothing || p isa SciMLBase.NullParameters
ReverseDiff.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
else
ReverseDiff.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
end
ReverseDiffVJP()
catch e
if verbose
Expand All @@ -100,7 +116,11 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem,

if vjp == false
vjp = try
Tracker.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
if p === nothing || p isa SciMLBase.NullParameters
Tracker.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
else
Tracker.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
end
TrackerVJP()
catch e
if verbose
Expand Down

0 comments on commit 185925b

Please sign in to comment.