-
Notifications
You must be signed in to change notification settings - Fork 2k
ENH Cache DoRA weight norm for inference #2661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ENH Cache DoRA weight norm for inference #2661
Conversation
Resolves huggingface#2651 This is an experimental branch that caches the weight norm from DoRA for faster inference. Since, during inference, the weights don't change, there is no need to recalculate the weight norm of a DoRA module each time. During training, recalculation is needed, thus there is no caching when the module has training=True. The cache does not prevent each and every possible duplicate calculation. For instance, the weight norm is calculated during module initialization and then again during the first forward pass when performing inference. Only starting from the second forward pass on will the weight norms be cached. The reason why this is just a draft PR is that before finishing it, I want to ensure that the addition of cache is worth it. Caches can often be a tricky business and lead to subtle bugs. Just as an example here, we detect if users put the model into training mode when calling model.train() and clear the cache. However, we would not detect it if the user directly sets module.training=True. Some preliminary testing with a small model, meta-llama/Llama-3.2-1B, and some dummy data didn't show huge improvements, with an average of 10 runs showing: - with caching: 1.3555 sec - w/o caching: 1.4071 sec Thus, before continuing, I'd like to see some more real world measurements. If the change is considered to be worth it, tests should be added to the PR before it's ready.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@phemw The PR is ready from my side, if you want to give this a try, LMK what you find. Note that the memory overhead of caching is quite significant (41% in one test), so it's turned off by default and users need to opt-in. from peft.helpers import DoraCaching
model.eval()
with DoraCaching():
output = model(inputs) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Resolves #2651
This is an experimental branch that caches the weight norm from DoRA for faster inference. Since, during inference, the weights don't change, there is no need to recalculate the weight norm of a DoRA module each time.
During training, recalculation is needed, thus there is no caching when the module has
training=True
.The cache does not prevent each and every possible duplicate calculation. For instance, the weight norm is calculated during module initialization and then again during the first forward pass when performing inference. Only starting from the second forward pass on will the weight norms be cached.
The reason why this is just a draft PR is that before finishing it, I want to ensure that the addition of cache is worth it. Caches can often be a tricky business and lead to subtle bugs. Just as an example here, we detect if users put the model into training mode when calling
model.train()
and clear the cache. However, we would not detect it if the user directly setsmodule.training=True
.Some preliminary testing with a small model, meta-llama/Llama-3.2-1B, and some dummy data didn't show huge improvements, with an average of 10 runs showing:
Thus, before continuing, I'd like to see some more real world measurements. If the change is considered to be worth it, tests should be added to the PR before it's ready.