Skip to content

Commit 9143ea7

Browse files
Merge pull request #1135 from SciML/dg/nnrev2
Avoid needing adjoints of SciMLStructures' constructor
2 parents 1458111 + 20d64fd commit 9143ea7

File tree

4 files changed

+114
-10
lines changed

4 files changed

+114
-10
lines changed

src/concrete_solve.jl

+31-8
Original file line numberDiff line numberDiff line change
@@ -639,15 +639,26 @@ function DiffEqBase._concrete_solve_adjoint(
639639

640640
du0 = reshape(du0, size(u0))
641641

642-
dp = p === nothing || p === SciMLBase.NullParameters() ? nothing :
643-
dp isa AbstractArray ? reshape(dp', size(p)) : dp
642+
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
643+
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp
644+
645+
_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
646+
!isscimlstructure(p)
647+
nothing, x -> (x,)
648+
else
649+
Zygote.pullback(p) do p
650+
t, _, _ = canonicalize(Tunable(), p)
651+
t
652+
end
653+
end
644654

645655
if originator isa SciMLBase.TrackerOriginator ||
646656
originator isa SciMLBase.ReverseDiffOriginator
647-
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
657+
(NoTangent(), NoTangent(), du0, repack_adjoint(dp)[1], NoTangent(),
648658
ntuple(_ -> NoTangent(), length(args))...)
649659
else
650-
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
660+
(NoTangent(), NoTangent(), NoTangent(),
661+
du0, repack_adjoint(dp)[1], NoTangent(),
651662
ntuple(_ -> NoTangent(), length(args))...)
652663
end
653664
end
@@ -835,7 +846,7 @@ function DiffEqBase._concrete_solve_adjoint(
835846
pparts = typeof(tunables[1:1])[]
836847
for j in 0:(num_chunks - 1)
837848
local chunk
838-
if ((j + 1) * chunk_size) <= length(p)
849+
if ((j + 1) * chunk_size) <= length(tunables)
839850
chunk = ((j * chunk_size + 1):((j + 1) * chunk_size))
840851
pchunk = vec(tunables)[chunk]
841852
pdualpart = seed_duals(pchunk, prob.f,
@@ -957,7 +968,7 @@ function DiffEqBase._concrete_solve_adjoint(
957968
end
958969
push!(pparts, vec(_dp))
959970
end
960-
SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts))
971+
reduce(vcat, pparts)
961972
end
962973
else
963974
dp = nothing
@@ -1134,12 +1145,24 @@ function DiffEqBase._concrete_solve_adjoint(
11341145
end
11351146
end
11361147

1148+
_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
1149+
!isscimlstructure(p)
1150+
nothing, x -> (x,)
1151+
else
1152+
Zygote.pullback(p) do p
1153+
t, _, _ = canonicalize(Tunable(), p)
1154+
t
1155+
end
1156+
end
1157+
11371158
if originator isa SciMLBase.TrackerOriginator ||
11381159
originator isa SciMLBase.ReverseDiffOriginator
1139-
(NoTangent(), NoTangent(), unthunk(du0), unthunk(dp), NoTangent(),
1160+
(NoTangent(), NoTangent(), unthunk(du0),
1161+
repack_adjoint(unthunk(dp))[1], NoTangent(),
11401162
ntuple(_ -> NoTangent(), length(args))...)
11411163
else
1142-
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
1164+
(NoTangent(), NoTangent(), NoTangent(),
1165+
du0, repack_adjoint(unthunk(dp))[1], NoTangent(),
11431166
ntuple(_ -> NoTangent(), length(args))...)
11441167
end
11451168
end

src/gauss_adjoint.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,8 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
482482
ReverseDiff.reverse_pass!(tape)
483483
copyto!(vec(out), ReverseDiff.deriv(tp))
484484
elseif sensealg.autojacvec isa ZygoteVJP
485-
_dy, back = Zygote.pullback(p) do p
486-
vec(f(y, p, t))
485+
_dy, back = Zygote.pullback(tunables) do tunables
486+
vec(f(y, tunables, t))
487487
end
488488
tmp = back(λ)
489489
if tmp[1] === nothing

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ end
2828
@time @safetestset "Prob Kwargs" include("prob_kwargs.jl")
2929
@time @safetestset "DiscreteProblem Adjoints" include("discrete.jl")
3030
@time @safetestset "Time Type Mixing Adjoints" include("time_type_mixing.jl")
31+
@time @safetestset "SciMLStructures Interface" include("scimlstructures_interface.jl")
3132
end
3233
end
3334

test/scimlstructures_interface.jl

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# taken from https://github.com/SciML/SciMLStructures.jl/pull/28
2+
using OrdinaryDiffEq, SciMLSensitivity, Zygote
3+
using LinearAlgebra
4+
import SciMLStructures as SS
5+
6+
mutable struct SubproblemParameters{P, Q, R}
7+
p::P # tunable
8+
q::Q
9+
r::R
10+
end
11+
mutable struct Parameters{P, C}
12+
subparams::P
13+
coeffs::C # tunable matrix
14+
end
15+
# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
16+
# and `du[length(subparams)+1:end] .= coeffs * u`
17+
function rhs!(du, u, p::Parameters, t)
18+
for (i, subpars) in enumerate(p.subparams)
19+
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
20+
end
21+
N = length(p.subparams)
22+
mul!(view(du, (N + 1):(length(du))), p.coeffs, u)
23+
return nothing
24+
end
25+
u = sin.(0.1:0.1:1.0)
26+
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
27+
p = Parameters(subparams, cos.([0.1i + 0.33j for i in 1:5, j in 1:10]))
28+
tspan = (0.0, 1.0)
29+
prob = ODEProblem(rhs!, u, tspan, p)
30+
solve(prob, Tsit5())
31+
32+
# Mark the struct as a SciMLStructure
33+
SS.isscimlstructure(::Parameters) = true
34+
# It is mutable
35+
SS.ismutablescimlstructure(::Parameters) = true
36+
# Only contains `Tunable` portion
37+
# We could also add a `Constants` portion to contain the values that are
38+
# not tunable. The implementation would be similar to this one.
39+
SS.hasportion(::SS.Tunable, ::Parameters) = true
40+
function SS.canonicalize(::SS.Tunable, p::Parameters)
41+
# concatenate all tunable values into a single vector
42+
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
43+
# repack takes a new vector of the same length as `buffer`, and constructs
44+
# a new `Parameters` object using the values from the new vector for tunables
45+
# and retaining old values for other parameters. This is exactly what replace does,
46+
# so we can use that instead.
47+
repack = let p = p
48+
function repack(newbuffer)
49+
SS.replace(SS.Tunable(), p, newbuffer)
50+
end
51+
end
52+
# the canonicalized vector, the repack function, and a boolean indicating
53+
# whether the buffer aliases values in the parameter object (here, it doesn't)
54+
return buffer, repack, false
55+
end
56+
function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
57+
N = length(p.subparams) + length(p.coeffs)
58+
@assert length(newbuffer) == N
59+
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r)
60+
for (i, subpar) in enumerate(p.subparams)]
61+
coeffs = reshape(
62+
view(newbuffer, (length(p.subparams) + 1):length(newbuffer)), size(p.coeffs))
63+
return Parameters(subparams, coeffs)
64+
end
65+
function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
66+
N = length(p.subparams) + length(p.coeffs)
67+
@assert length(newbuffer) == N
68+
for (subpar, val) in zip(p.subparams, newbuffer)
69+
subpar.p = val
70+
end
71+
copyto!(coeffs, view(newbuffer, (length(p.subparams) + 1):length(newbuffer)))
72+
return p
73+
end
74+
75+
Zygote.gradient(0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1]))) do tunables
76+
newp = SS.replace(SS.Tunable(), p, tunables)
77+
newprob = remake(prob; p = newp)
78+
sol = solve(newprob, Tsit5())
79+
return sum(sol.u[end])
80+
end

0 commit comments

Comments
 (0)