diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 22a45664c..bfd6f09eb 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -40,10 +40,18 @@ function inplace_vjp(prob, u0, p, verbose) vjp = try f = unwrapped_f(prob.f) - ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t - du1 = similar(u, size(u)) - f(du1, u, p, first(t)) - return vec(du1) + if p === nothing || p isa SciMLBase.NullParameters + ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t + du1 = similar(u, size(u)) + f(du1, u, p, first(t)) + return vec(du1) + end + else + ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t + du1 = similar(u, size(u)) + f(du1, u, p, first(t)) + return vec(du1) + end end ReverseDiffVJP(compile) catch e @@ -73,7 +81,11 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem, # so if out-of-place, try Zygote vjp = try - Zygote.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + if p === nothing || p isa SciMLBase.NullParameters + Zygote.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + else + Zygote.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + end ZygoteVJP() catch e if verbose @@ -86,7 +98,11 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem, if vjp == false vjp = try - ReverseDiff.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + if p === nothing || p isa SciMLBase.NullParameters + ReverseDiff.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + else + ReverseDiff.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + end ReverseDiffVJP() catch e if verbose @@ -100,7 +116,11 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem, if vjp == false vjp = try - Tracker.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + if p === nothing || p isa SciMLBase.NullParameters + Tracker.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + else + Tracker.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + end TrackerVJP() catch e if verbose