-
Notifications
You must be signed in to change notification settings - Fork 92
Apply optimizer to model weights without data copy #222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Another issue I just found is that I think that arrays of the optimizer, e.g., |
@jvdp1 Indeed, that's the extra bookkeeping needed. However, I'm thinking now if we should instantiate an optimizer per layer rather than the whole network. This would take care of the bookkeeping altogether, because now the "memory" arrays like |
… switched to per-layer optimizer instances
…ed 2 calls per batch; this is now generalized to allow any number of calls until size(params) is exhausted
All layers, with the exception of embedding and MHA now implement getting parameters and gradients as pointers. This removes the need for having MHA and embedding are left alone because they are a bit more complex and I don't yet feel comfortable refactoring those, but they should be switched to the new approach as well. Note that these layers were not previously integrated with the |
Alternative approach to #184
Currently implemented only for dense layer so a lot of other stuff is not working. We now send the pointers to each layer's weights and biases to the optimizer where they are updated in-place.
Since this approach runs the optimizer on a layer-by-layer basis, optimizer memory such as velocity, rms gradients, etc. are not implemented yet, as they originally assumed whole-network update. Additional bookkeeping is needed there. A possible approach to this bookkeeping is to require the caller of
optimizer % minimize()
to pass explicit start and end indices over which the optimizer will run. This may seem tedious but helper functions can make it easy.On MNIST training (examples/dense_mnist), the difference is only ~1% in run time. However, running the profiler on both the main branch
dense_mnist
and this branch'sdense_mnist
shows that the data copies (mostly inget_params
; two redundant copies of all model parameters happening there) are now gone:main
branchdense_mnist
:This PR
dense_mnist
:The above runs were compiled using gfortran-14.2.0 and
-pg -Ofast
flags.