-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[Gemma3] compile ✨ #37447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Gemma3] compile ✨ #37447
Conversation
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, | ||
# but without data-dependent slicing (i.e. torch.compile friendly) | ||
mask_indexes = torch.arange( | ||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device | ||
) | ||
mask_indexes += offset | ||
attention_mask = attention_mask[:, :, :, mask_indexes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Core change for the PR.
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
requires either passing an integer in the signature to build offset
(previous solution, triggers recompilation at each forward
🚫 ) or doing data-dependent slicing using offset
as a tensor
(crashes compile 🚫 )
The solution is to:
- build an
arange
fromshapes
✅ (we can useshapes
to create compile-compatible arrays on the fly, as opposed to using arbitrary tensors to create tensors) - add some tensor (offset) to a tensor (fixed-shape array) ✅
- slice a tensor (attention mask) with another tensor (offset modified fixed-shape array) ✅
(Note: at first I tried torch.roll
+ fixed-shape slicing, but torch.roll
doesn't support the argument shifts=offset
, shifts
has to be an integer 😢 )
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
mask_indexes = torch.arange( | ||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device | ||
) | ||
mask_indexes += offset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure this is cuda graph compatible?~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, see e.g. scripts at the top of the PR header
also, see this comment explaining why :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nice
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert | ||
# ALL changes from the PR that commented the line below when reactivating it. | ||
# is_compileable = True | ||
is_compileable = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! Can we update the cache to also init the layers lazily like we dofor HybridChunked cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker HybridChunkedCache
only works if we don't compile the first forward pass, HybricCache
works regardless of we compile the first forward pass or not. torch._dynamo.mark_static_address
can't be called inside torch.compile
, which lazy init does.
This means that if a user creates their own custom code with HybridChunkedCache
, they can't simply compile the forward pass. If anything, HybridChunkedCache
should move away from lazy init :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chatted offline:
- lazy init is needed for TP
- however, lazy init is incompatible with compiling the first forward pass (prefill). lazy init +
@torch.compiler.disable()
doesn't solve it either - solution: add a new flag
lazy_init = None
. Iftorch.distributed
is initialized and the flag is unset, then it will beTrue
. - Apply this change to ALL caches -> ALL caches compatible with TP + no non-TP drawbacks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for detailed investigation! Do you think we can add a slow test to compare generation time with compile or include HybridCache
in benchmarks board, so we don't accidentally introduce graph breaks? Given that Gemma3 is a high usage model and supports only Hybrid cache, I think it's important to not break it
mask_indexes = torch.arange( | ||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device | ||
) | ||
mask_indexes += offset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Super nice work, thanks a lot! Just the offset that is wrong by 1, I added a detailed comment about that.
Also, as I assume you're already aware, this is only working under the assumption of a single prefill step with more than 1 tokens, then only decoding steps with 1 tokens. This was bothering me for some time I must say (what about context caching???), but never took the time to change it to make it more general and fullproof. However, with llama4, as I added prefill chunking, I had to make it work for any situation and any number of new tokens at any time. This is a nice precedent that I think we should use to now make the sliding caches more general.
However, this should probably be a separate PR, if you plan on working on it at some point let me know!! 🤗
offset = last_cache_position - effective_seq_len | ||
offset = cache_position[-1] - effective_seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very important detail: here it should actually be offset = cache_position[-1] +1 - effective_seq_len
. Or equivalently, offset = cache_position[0] - sliding_window + 1
(perhaps more understandable, and more general if we want to extend behavior later). The idea being last_cache_position
should in fact be the number of total processed tokens, not the final position (as it starts including 0). The comments are wrong, but the code was correct (it used to take the shape of the attention mask, which is the length, not last index).
Also, note that this only works for one prefill step, then only decoding steps. But it will fail in general with prefill chunking, or e.g. prefill caching (if the cache is already "full", i.e. we processed more than sliding_window tokens, and we want to do a forward with more than 1 new tokens, e.g. a new conversation turn. For the full general case, see the work I did in Llama4 here, as well as the cache going with it here. HybridCache and HybridChunkedCache are fully equivalent in the tokens they return and the necesary mask offsets, HybridChunked is just more general as it can always handle arbitrary number of input tokens, and based on its state return the necesary past states. The only difference is then how Llama create the mask from those states (chunked block vs sliding lower diagonal). But from a cache logic point of view, they are fully equivalent ( I first modified the HybridCache, but then we decided to create a new one for now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very important detail: here it should actually be offset = cache_position[-1] +1 - effective_seq_len. Or equivalently, offset = cache_position[0] - sliding_window + 1 (perhaps more understandable, and more general if we want to extend behavior later). The idea being last_cache_position should in fact be the number of total processed tokens, not the final position (as it starts including 0). The comments are wrong, but the code was correct (it used to take the shape of the attention mask, which is the length, not last index).
hehe good point, I replaced the slicing according to the local comments, and didn't double check what was being fed into the last_cache_position
variable 👍
[working with prefill chunking / more than one input token]
Also a good point. Let's open a separate PR for it, to avoid bloating this PR. Having compile working on the base case is already very valuable for the community.
@zucchini-nlp Benchmarks are hard to put as a test: different devices/versions -> different speeds, needs multiple runs to avoid being flaky 💔 However, we should test that compilation only happens once for the entire forward pass, i.e. that there are no graph breaks nor recompilations. I'm going to explore torch docs to see if we can add this to a test, in this PR. At the moment, we are often doubting the quality of our compiled forward passes, which is not great |
@gante yeah, we can do something similar to what diffusers is trying to do (in continuation to yesterday's thread discussions). Fine by me, as long as we check that there are no recompilations every step |
@Cyrilvallez off by 1 comment addressed 👍 LMK if you'd like any further changes @Cyrilvallez @zucchini-nlp There are three follow-up items to this PR, where each will be a separate PR:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!! 🤗 Thanks a lot, super glad to have compile back and to simplify by removing the extra arg that was introduced as well! 💛
What does this PR do?
Enables compilation on Gemma3 (and re-enables it on Gemma2 / Cohere2).
Reverts #36620
Supercedes #37433 (solves the same problem, but this PR is much cleaner)
Performance
Measured on an RTX4090, excluding compile warmup time:
main
(not compiled) -> 2.39s this PRmain
(not compiled) -> 2.18s this PRTests
main
-> need to be revisited)main
,tests/models/gemma3/test_modeling_gemma3.py::Gemma3Vision2TextModelTest::test_eager_matches_sdpa_generate
gets fixed in this PR)Post-mortem: How did we break compile on Gemma 2?
git bisect
, compilation first "breaks" in the PR where the cache is initialized in themeta
device (Init cache on meta device #35164). "break" here doesn't mean "crash", but rather "becomes very slow". Curiously, this change doesn't slow downStaticCache
+llama
(why?), so it flew under the radar when we benchmarked before merging. Nevertheless, this specific PR has been reverted ([Cache] Don't initialize the cache onmeta
device #36543).torch.compile
friendly:forward
now has anint
argument that is different at each forward pass at generation time, causing recompilation (reference). The changes in this PR work around this issue.