-
-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Handle Adjoints through Initialization #1168
base: master
Are you sure you want to change the base?
Conversation
I wanted to ask whether it is preferred to retain |
LinearSolve.jl should be faster across the board? It depends a bit on the CPU architecture since it depends on whether it guesses the right LU correctly, |
Note that with the latest MTK update, there is now an |
@@ -425,6 +425,21 @@ function DiffEqBase._concrete_solve_adjoint( | |||
save_end = true, kwargs_fwd...) | |||
end | |||
|
|||
# Get gradients for the initialization problem if it exists | |||
igs = if _prob.f.initialization_data.initializeprob != nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be before the solve, since you can use the initialization solution from here in the remake
s of 397-405 in order to set new u0
and p
and thus skip running the initialization a second time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can I indicate to solve
to avoid running initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initializealg = NoInit()
. Should probably just do CheckInit()
for safety but either is fine.
@@ -103,15 +102,18 @@ end | |||
else | |||
if linsolve === nothing && isempty(sensealg.linsolve_kwargs) | |||
# For the default case use `\` to avoid any form of unnecessary cache allocation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I don't know about that comment. I think it's just old. (a) \
always allocates because it uses lu
instead of lu!
, so it's re-allocating the while matrix which is larger than any LinearSolve allocation, and (b) we have since 2023 setup tests on StaticArrays, so the immutable path is non-allocating. I don't think (b) was true when this was written.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So glad we can remove this branch altogether.
iprob = _prob.f.initialization_data.initializeprob | ||
ip = parameter_values(iprob) | ||
itunables, irepack, ialiases = canonicalize(Tunable(), ip) | ||
igs, = Zygote.gradient(ip) do ip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gradient isn't used? I think this would go into the backpass and if I'm thinking clearly, the resulting return is dp .* igs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. These gradients are currently against the parameters of the initialization problem, not the system exactly. And the mapping between the two is ill defined, so we cannot simply accum
I spoke with @AayushSabharwal about a way to map, it seems initialization_data.intializeprobmap
might have some support to return the correctly shaped vector, but there are cases where we cannot know the ordering of dp
either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, you need to use the initializeprobmap https://github.com/SciML/SciMLBase.jl/blob/master/src/initialization.jl#L268 to map it back to the shape of the initial parameters.
but there are cases where we cannot know the ordering of dp either.
p and dp just need the same ordering, so initializeprobmap should do the trick.
There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually
This is the only change to (u0,p)
before solving, so this would account for it, given initializeprobmap
is just an index map so an identity function.
Trying to use the initialization end to end caused gradients against parameters to get dropped. https://github.com/DhairyaLGandhi/SciMLBase.jl/tree/dg/nonlinear is a WIP branch which adds adjoints to the |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
MTK and SciML construct an initialization problem before starting the time stepping to ensure the starting values of the unknowns and parameters adhere to any constraints needed for the system. This PR adds handling for adjoint sensitivities of the NonlinearProblem, NonlinearSquaresProblem, SCCNonlinearProblem etc.
I am opening this to get some feedback regarding how we can accumulate gradients correctly. I have also included a test case for a DAE which I will update to use the values out of SciMLSensitivity.
Add any other context about the problem here.
Currently the gradients get calculated but don't get accumulated, we need to be able to update the gradients for the parameters. Since this is a manual dispatch, the usual graph building in AD is bypassed, and we need to handle this manually. Ideally, we should make it so the cfg itself includes the initialization so we would not have gotten incorrect gradients in the first place 😅 We are also forced to use a LinearProblem instead of
\
because it cannot handle singular jacobians.cc @ChrisRackauckas