@@ -79,12 +79,20 @@ function automatic_sensealg_choice(
7979 prob:: Union {SciMLBase. AbstractODEProblem,
8080 SciMLBase. AbstractSDEProblem},
8181 u0, p, verbose, repack)
82+ if p === nothing || p isa SciMLBase. NullParameters
83+ tunables, repack = p, identity
84+ elseif isscimlstructure (p)
85+ tunables, repack, _ = canonicalize (Tunable (), p)
86+ else
87+ throw (SciMLStructuresCompatibilityError ())
88+ end
89+
8290 default_sensealg = if p != = DiffEqBase. NullParameters () &&
8391 ! (eltype (u0) <: ForwardDiff.Dual ) &&
8492 ! (eltype (p) <: ForwardDiff.Dual ) &&
8593 ! (eltype (u0) <: Complex ) &&
8694 ! (eltype (p) <: Complex ) &&
87- length (u0) + length (p ) <= 100
95+ length (u0) + length (tunables ) <= 100
8896 ForwardDiffSensitivity ()
8997 elseif u0 isa GPUArraysCore. AbstractGPUArray || ! DiffEqBase. isinplace (prob)
9098 # only Zygote is GPU compatible and fast
@@ -124,7 +132,7 @@ function automatic_sensealg_choice(
124132 ReverseDiff. gradient ((u) -> sum (prob. f (u, p, prob. tspan[1 ])), u0)
125133 else
126134 ReverseDiff. gradient (
127- (u, _p) -> sum (prob. f (u, repack (_p), prob. tspan[1 ])), u0, p )
135+ (u, _p) -> sum (prob. f (u, repack (_p), prob. tspan[1 ])), u0, tunables )
128136 end
129137 ReverseDiffVJP ()
130138 catch e
@@ -151,8 +159,8 @@ function automatic_sensealg_choice(
151159 end
152160 tmp1 = back (λ)
153161 else
154- _dy, back = Tracker. forward (y, p ) do u, p
155- vec (f (u, p , t))
162+ _dy, back = Tracker. forward (y, tunables ) do u, tunables
163+ vec (f (u, repack (tunables) , t))
156164 end
157165 tmp1, tmp2 = back (λ)
158166 end
0 commit comments