Skip to content

[Gemma3] compile ✨ #37447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

[Gemma3] compile ✨ #37447

wants to merge 9 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Apr 11, 2025

What does this PR do?

Enables compilation on Gemma3 (and re-enables it on Gemma2 / Cohere2).

Reverts #36620
Supercedes #37433 (solves the same problem, but this PR is much cleaner)

Performance

Measured on an RTX4090, excluding compile warmup time:

Tests

  • slow gemma 2 tests (9 failing tests from main -> need to be revisited)
  • slow gemma 3 tests (2 failing tests from main, tests/models/gemma3/test_modeling_gemma3.py::Gemma3Vision2TextModelTest::test_eager_matches_sdpa_generate gets fixed in this PR)

Post-mortem: How did we break compile on Gemma 2?

  1. Doing git bisect, compilation first "breaks" in the PR where the cache is initialized in the meta device (Init cache on meta device #35164). "break" here doesn't mean "crash", but rather "becomes very slow". Curiously, this change doesn't slow down StaticCache + llama (why?), so it flew under the radar when we benchmarked before merging. Nevertheless, this specific PR has been reverted ([Cache] Don't initialize the cache on meta device #36543).
  2. Along the way, we corrected how the sliding window attention works, by slicing the attention mask correctly (Fix mask slicing for models with HybridCache #35681). However, the solution here is not torch.compile friendly: forward now has an int argument that is different at each forward pass at generation time, causing recompilation (reference). The changes in this PR work around this issue.

Comment on lines +402 to +411
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
Copy link
Member Author

@gante gante Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Core change for the PR.

attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] requires either passing an integer in the signature to build offset (previous solution, triggers recompilation at each forward 🚫 ) or doing data-dependent slicing using offset as a tensor (crashes compile 🚫 )

The solution is to:

  1. build an arange from shapes ✅ (we can use shapes to create compile-compatible arrays on the fly, as opposed to using arbitrary tensors to create tensors)
  2. add some tensor (offset) to a tensor (fixed-shape array) ✅
  3. slice a tensor (attention mask) with another tensor (offset modified fixed-shape array) ✅

(Note: at first I tried torch.roll + fixed-shape slicing, but torch.roll doesn't support the argument shifts=offset, shifts has to be an integer 😢 )

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +352 to +353
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure this is cuda graph compatible?~

Copy link
Member Author

@gante gante Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, see e.g. scripts at the top of the PR header

also, see this comment explaining why :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nice

# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True
is_compileable = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! Can we update the cache to also init the layers lazily like we dofor HybridChunked cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker HybridChunkedCache only works if we don't compile the first forward pass, HybricCache works regardless of we compile the first forward pass or not. torch._dynamo.mark_static_address can't be called inside torch.compile, which lazy init does.

This means that if a user creates their own custom code with HybridChunkedCache, they can't simply compile the forward pass. If anything, HybridChunkedCache should move away from lazy init :P

Copy link
Member Author

@gante gante Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted offline:

  1. lazy init is needed for TP
  2. however, lazy init is incompatible with compiling the first forward pass (prefill). lazy init + @torch.compiler.disable() doesn't solve it either
  3. solution: add a new flag lazy_init = None. If torch.distributed is initialized and the flag is unset, then it will be True.
  4. Apply this change to ALL caches -> ALL caches compatible with TP + no non-TP drawbacks

@gante gante marked this pull request as ready for review April 11, 2025 17:49
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for detailed investigation! Do you think we can add a slow test to compare generation time with compile or include HybridCache in benchmarks board, so we don't accidentally introduce graph breaks? Given that Gemma3 is a high usage model and supports only Hybrid cache, I think it's important to not break it

Comment on lines +352 to +353
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nice

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Super nice work, thanks a lot! Just the offset that is wrong by 1, I added a detailed comment about that.

Also, as I assume you're already aware, this is only working under the assumption of a single prefill step with more than 1 tokens, then only decoding steps with 1 tokens. This was bothering me for some time I must say (what about context caching???), but never took the time to change it to make it more general and fullproof. However, with llama4, as I added prefill chunking, I had to make it work for any situation and any number of new tokens at any time. This is a nice precedent that I think we should use to now make the sliding caches more general.
However, this should probably be a separate PR, if you plan on working on it at some point let me know!! 🤗

Comment on lines 347 to 345
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len
Copy link
Member

@Cyrilvallez Cyrilvallez Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important detail: here it should actually be offset = cache_position[-1] +1 - effective_seq_len. Or equivalently, offset = cache_position[0] - sliding_window + 1 (perhaps more understandable, and more general if we want to extend behavior later). The idea being last_cache_position should in fact be the number of total processed tokens, not the final position (as it starts including 0). The comments are wrong, but the code was correct (it used to take the shape of the attention mask, which is the length, not last index).

Also, note that this only works for one prefill step, then only decoding steps. But it will fail in general with prefill chunking, or e.g. prefill caching (if the cache is already "full", i.e. we processed more than sliding_window tokens, and we want to do a forward with more than 1 new tokens, e.g. a new conversation turn. For the full general case, see the work I did in Llama4 here, as well as the cache going with it here. HybridCache and HybridChunkedCache are fully equivalent in the tokens they return and the necesary mask offsets, HybridChunked is just more general as it can always handle arbitrary number of input tokens, and based on its state return the necesary past states. The only difference is then how Llama create the mask from those states (chunked block vs sliding lower diagonal). But from a cache logic point of view, they are fully equivalent ( I first modified the HybridCache, but then we decided to create a new one for now)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important detail: here it should actually be offset = cache_position[-1] +1 - effective_seq_len. Or equivalently, offset = cache_position[0] - sliding_window + 1 (perhaps more understandable, and more general if we want to extend behavior later). The idea being last_cache_position should in fact be the number of total processed tokens, not the final position (as it starts including 0). The comments are wrong, but the code was correct (it used to take the shape of the attention mask, which is the length, not last index).

hehe good point, I replaced the slicing according to the local comments, and didn't double check what was being fed into the last_cache_position variable 👍

[working with prefill chunking / more than one input token]

Also a good point. Let's open a separate PR for it, to avoid bloating this PR. Having compile working on the base case is already very valuable for the community.

@gante
Copy link
Member Author

gante commented Apr 17, 2025

Do you think we can add a slow test to compare generation time with compile or include HybridCache in benchmarks board, so we don't accidentally introduce graph breaks? Given that Gemma3 is a high usage model and supports only Hybrid cache, I think it's important to not break it

@zucchini-nlp Benchmarks are hard to put as a test: different devices/versions -> different speeds, needs multiple runs to avoid being flaky 💔 However, we should test that compilation only happens once for the entire forward pass, i.e. that there are no graph breaks nor recompilations. I'm going to explore torch docs to see if we can add this to a test, in this PR. At the moment, we are often doubting the quality of our compiled forward passes, which is not great

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Apr 17, 2025

@gante yeah, we can do something similar to what diffusers is trying to do (in continuation to yesterday's thread discussions). Fine by me, as long as we check that there are no recompilations every step

@gante
Copy link
Member Author

gante commented Apr 17, 2025

@Cyrilvallez off by 1 comment addressed 👍 LMK if you'd like any further changes

@Cyrilvallez @zucchini-nlp There are three follow-up items to this PR, where each will be a separate PR:

  1. Mixin test update to confirm that no recompilations, dynamic shapes, ... are happening. Related comment. [I've decided not to include it in this PR, since it needs to touch other models to fix failures];
  2. TP + compile compatible caches. Related comment
  3. Gemma 3 + Prefill chunking + more than one new token at generation time. Related comment

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!! 🤗 Thanks a lot, super glad to have compile back and to simplify by removing the extra arg that was introduced as well! 💛

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants