Skip to content

Conversation

BenjaminBossan
Copy link
Member

Resolves #2651

This is an experimental branch that caches the weight norm from DoRA for faster inference. Since, during inference, the weights don't change, there is no need to recalculate the weight norm of a DoRA module each time.

During training, recalculation is needed, thus there is no caching when the module has training=True.

The cache does not prevent each and every possible duplicate calculation. For instance, the weight norm is calculated during module initialization and then again during the first forward pass when performing inference. Only starting from the second forward pass on will the weight norms be cached.

The reason why this is just a draft PR is that before finishing it, I want to ensure that the addition of cache is worth it. Caches can often be a tricky business and lead to subtle bugs. Just as an example here, we detect if users put the model into training mode when calling model.train() and clear the cache. However, we would not detect it if the user directly sets module.training=True.

Some preliminary testing with a small model, meta-llama/Llama-3.2-1B, and some dummy data didn't show huge improvements, with an average of 10 runs showing:

  • with caching: 1.3555 sec
  • w/o caching: 1.4071 sec

Thus, before continuing, I'd like to see some more real world measurements. If the change is considered to be worth it, tests should be added to the PR before it's ready.

Resolves huggingface#2651

This is an experimental branch that caches the weight norm from DoRA for
faster inference. Since, during inference, the weights don't change,
there is no need to recalculate the weight norm of a DoRA module each
time.

During training, recalculation is needed, thus there is no caching when
the module has training=True.

The cache does not prevent each and every possible duplicate
calculation. For instance, the weight norm is calculated during module
initialization and then again during the first forward pass when
performing inference. Only starting from the second forward pass on will
the weight norms be cached.

The reason why this is just a draft PR is that before finishing it, I
want to ensure that the addition of cache is worth it. Caches can often
be a tricky business and lead to subtle bugs. Just as an example here,
we detect if users put the model into training mode when calling
model.train() and clear the cache. However, we would not detect it if
the user directly sets module.training=True.

Some preliminary testing with a small model, meta-llama/Llama-3.2-1B,
and some dummy data didn't show huge improvements, with an average of 10
runs showing:

- with caching: 1.3555 sec
- w/o caching: 1.4071 sec

Thus, before continuing, I'd like to see some more real world
measurements. If the change is considered to be worth it, tests should
be added to the PR before it's ready.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review August 18, 2025 18:06
@BenjaminBossan BenjaminBossan changed the title [WIP] ENH Cache DoRA weight norm for inference ENH Cache DoRA weight norm for inference Aug 18, 2025
@BenjaminBossan
Copy link
Member Author

@phemw The PR is ready from my side, if you want to give this a try, LMK what you find. Note that the memory overhead of caching is quite significant (41% in one test), so it's turned off by default and users need to opt-in.

from peft.helpers import DoraCaching

model.eval()
with DoraCaching():
    output = model(inputs)

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DoRA slow forward inference
2 participants