Skip to content

Commit 1e62b87

Browse files
fix: fix ODESolution-related adjoints
1 parent 8f07d49 commit 1e62b87

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

Diff for: ext/SciMLBaseZygoteExt.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using RecursiveArrayTools
3131
N = length((size(dprob.u0)..., length(du)))
3232
end
3333
Δ′ = ODESolution{T, N}(du, nothing, nothing,
34-
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
34+
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
3535
VA.alg_choice, VA.retcode)
3636
(Δ′, nothing, nothing)
3737
end
@@ -66,7 +66,7 @@ end
6666
N = length((size(dprob.u0)..., length(du)))
6767
end
6868
Δ′ = ODESolution{T, N}(du, nothing, nothing,
69-
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
69+
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
7070
VA.alg_choice, VA.retcode)
7171
(Δ′, nothing, nothing)
7272
end
@@ -144,15 +144,15 @@ end
144144
VA[sym], ODESolution_getindex_pullback
145145
end
146146

147-
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
147+
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13
148148
}(u,
149149
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
150-
T9, T10, T11, T12}
150+
T9, T10, T11, T12, T13}
151151
function ODESolutionAdjoint(ȳ)
152152
(ȳ, ntuple(_ -> nothing, length(args))...)
153153
end
154154

155-
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
155+
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13}(u, args...),
156156
ODESolutionAdjoint
157157
end
158158

0 commit comments

Comments
 (0)