diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index e044c8220..dc0c859bc 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -640,7 +640,7 @@ function DiffEqBase._concrete_solve_adjoint( du0 = reshape(du0, size(u0)) dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : - dp isa AbstractArray ? reshape(dp', size(p)) : dp + dp isa AbstractArray ? reshape(dp', size(tunables)) : dp if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..482556b89 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -484,8 +484,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) ReverseDiff.reverse_pass!(tape) copyto!(vec(out), ReverseDiff.deriv(tp)) elseif sensealg.autojacvec isa ZygoteVJP - _dy, back = Zygote.pullback(p) do p - vec(f(y, p, t)) + _dy, back = Zygote.pullback(tunables) do tunables + vec(f(y, repack(tunables), t)) end tmp = back(λ) if tmp[1] === nothing