From 40b086a7359d9ad1fdc94e8129c505946e2fd614 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 26 Aug 2024 13:44:17 +0530 Subject: [PATCH] chore: use tunables in vec_pjac --- src/concrete_solve.jl | 2 +- src/gauss_adjoint.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index e044c8220..dc0c859bc 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -640,7 +640,7 @@ function DiffEqBase._concrete_solve_adjoint( du0 = reshape(du0, size(u0)) dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : - dp isa AbstractArray ? reshape(dp', size(p)) : dp + dp isa AbstractArray ? reshape(dp', size(tunables)) : dp if originator isa SciMLBase.TrackerOriginator || originator isa SciMLBase.ReverseDiffOriginator diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 656556ad3..482556b89 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -484,8 +484,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) ReverseDiff.reverse_pass!(tape) copyto!(vec(out), ReverseDiff.deriv(tp)) elseif sensealg.autojacvec isa ZygoteVJP - _dy, back = Zygote.pullback(p) do p - vec(f(y, p, t)) + _dy, back = Zygote.pullback(tunables) do tunables + vec(f(y, repack(tunables), t)) end tmp = back(λ) if tmp[1] === nothing