Skip to content

Commit 785f378

Browse files
Merge pull request #979 from DhairyaLGandhi/dg/paramgrad
feat: pass parameter gradients from `getindex`
2 parents 9248437 + 993e978 commit 785f378

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

ext/SciMLBaseZygoteExt.jl

+6-8
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,13 @@ end
9393
if is_observed(VA, sym)
9494
f = observed(VA, sym)
9595
p = parameter_values(VA)
96-
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
9796
u = state_values(VA)
9897
t = current_time(VA)
99-
y, back = Zygote.pullback(u, tunables) do u, tunables
100-
_p = repack(tunables)
101-
f.(u, Ref(_p), t)
98+
y, back = Zygote.pullback(u, p) do u, p
99+
f.(u, Ref(p), t)
102100
end
103101
gs = back(Δ)
104-
(gs[1], nothing)
102+
(u = gs[1], prob = (p = gs[2],)), nothing
105103
elseif i === nothing
106104
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
107105
else
@@ -163,14 +161,14 @@ end
163161
gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
164162
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)
165163

166-
a = Zygote.accum(gs_obs[1], gs_not_obs)
164+
a = Zygote.accum(gs_obs[1], (u = gs_not_obs,))
167165

168166
(a, nothing)
169167
end
170168
VA[sym], ODESolution_getindex_pullback
171169
end
172170

173-
@adjoint function Base.getindex(VA::SciMLBase.NonlinearSolution, sym)
171+
@adjoint function Base.getindex(VA::SciMLBase.AbstractNonlinearSolution, sym)
174172
function NonlinearSolution_getindex_pullback(Δ)
175173
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
176174
if is_observed(VA, sym)
@@ -181,7 +179,7 @@ end
181179
f.f_oop(u, p)
182180
end
183181
gs = back(Δ)
184-
((u = gs[1], prob = (p = gs[2],),), nothing)
182+
((u = gs[1], prob = (p = gs[2],)), nothing)
185183
elseif i === nothing
186184
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
187185
elseif i isa Int && VA.u isa Number

test/downstream/adjoints.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ idx_vecsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lor
5252
true_grad_vecsym = zeros(length(ModelingToolkit.unknowns(sys)))
5353
true_grad_vecsym[idx_vecsym] .= 1.0
5454

55-
@test all(map(x -> x == true_grad_vecsym, gs_vec))
55+
@test all(map(x -> x == true_grad_vecsym, gs_vec.u))
5656

5757
gs_tup, = Zygote.gradient(sol) do sol
5858
sum(sum.(collect.(sol[(lorenz1.x, lorenz2.x)])))
@@ -61,7 +61,7 @@ idx_tupsym = SymbolicIndexingInterface.variable_index.(Ref(sys), [lorenz1.x, lor
6161
true_grad_tupsym = zeros(length(ModelingToolkit.unknowns(sys)))
6262
true_grad_tupsym[idx_tupsym] .= 1.0
6363

64-
@test all(map(x -> x == true_grad_tupsym, gs_tup))
64+
@test all(map(x -> x == true_grad_tupsym, gs_tup.u))
6565

6666
gs_ts, = Zygote.gradient(sol) do sol
6767
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))

test/downstream/observables_autodiff.jl

+22-11
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ p = [σ => 28.0,
2626
β => 8 / 3]
2727

2828
tspan = (0.0, 100.0)
29-
prob = ODEProblem(sys, u0, tspan, p, jac = true)
29+
prob = ODEProblem(sys, u0, tspan, p)
3030
sol = solve(prob, Tsit5())
3131

3232
@testset "AutoDiff Observable Functions" begin
@@ -35,25 +35,36 @@ sol = solve(prob, Tsit5())
3535
end
3636
du_ = [0.0, 1.0, 1.0, 1.0]
3737
du = [du_ for _ in sol.u]
38-
@test du == gs
38+
@test du == gs.u
3939

4040
# Observable in a vector
4141
gs, = gradient(sol) do sol
4242
sum(sum.(sol[[sys.w, sys.x]]))
4343
end
4444
du_ = [0.0, 1.0, 1.0, 2.0]
4545
du = [du_ for _ in sol.u]
46-
@test du == gs
46+
@test du == gs.u
4747
end
4848

49-
# @testset "AD Observable Functions for Initialization" begin
50-
# iprob = prob.f.initialization_data.initalizeprob
51-
# isol = solve(iprob)
52-
# gs, = gradient(isol) do isol
53-
# isol[w]
54-
# end
49+
@testset "AD Observable Functions for Initialization" begin
50+
iprob = prob.f.initialization_data.initializeprob
51+
isol = solve(iprob)
52+
gs, = Zygote.gradient(isol) do isol
53+
isol[w]
54+
end
5555

56-
# end
56+
@test gs isa NamedTuple
57+
@test isempty(setdiff(fieldnames(typeof(gs)), fieldnames(typeof(isol))))
58+
59+
# Compare gradient for parameters match from observed function
60+
# to ensure parameter gradients are passed through the observed function
61+
f = SII.observed(iprob.f.sys, w)
62+
gu0, gp = gradient(SII.state_values(iprob), SII.parameter_values(iprob)) do u0, p
63+
f(u0, p)
64+
end
65+
66+
@test gs.prob.p == gp
67+
end
5768

5869
# DAE
5970

@@ -92,7 +103,7 @@ end
92103
end
93104
du_ = [0.2, 1.0]
94105
du = [du_ for _ in sol.u]
95-
@test gs == du
106+
@test gs.u == du
96107

97108
@testset "DAE Initialization Observable function AD" begin
98109
iprob = prob.f.initialization_data.initializeprob

0 commit comments

Comments
 (0)