Skip to content

Commit ef7a3b8

Browse files
Merge pull request #1079 from SciML/dg/length
Check `p` while choosing default sensealg
2 parents 4f90aca + 94e212a commit ef7a3b8

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/concrete_solve.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,20 @@ function automatic_sensealg_choice(
7979
prob::Union{SciMLBase.AbstractODEProblem,
8080
SciMLBase.AbstractSDEProblem},
8181
u0, p, verbose, repack)
82+
if p === nothing || p isa SciMLBase.NullParameters
83+
tunables, repack = p, identity
84+
elseif isscimlstructure(p)
85+
tunables, repack, _ = canonicalize(Tunable(), p)
86+
else
87+
throw(SciMLStructuresCompatibilityError())
88+
end
89+
8290
default_sensealg = if p !== DiffEqBase.NullParameters() &&
8391
!(eltype(u0) <: ForwardDiff.Dual) &&
8492
!(eltype(p) <: ForwardDiff.Dual) &&
8593
!(eltype(u0) <: Complex) &&
8694
!(eltype(p) <: Complex) &&
87-
length(u0) + length(p) <= 100
95+
length(u0) + length(tunables) <= 100
8896
ForwardDiffSensitivity()
8997
elseif u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob)
9098
# only Zygote is GPU compatible and fast
@@ -124,7 +132,7 @@ function automatic_sensealg_choice(
124132
ReverseDiff.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
125133
else
126134
ReverseDiff.gradient(
127-
(u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p)
135+
(u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, tunables)
128136
end
129137
ReverseDiffVJP()
130138
catch e
@@ -151,8 +159,8 @@ function automatic_sensealg_choice(
151159
end
152160
tmp1 = back(λ)
153161
else
154-
_dy, back = Tracker.forward(y, p) do u, p
155-
vec(f(u, p, t))
162+
_dy, back = Tracker.forward(y, tunables) do u, tunables
163+
vec(f(u, repack(tunables), t))
156164
end
157165
tmp1, tmp2 = back(λ)
158166
end

0 commit comments

Comments
 (0)