Skip to content

[Gemma3] compile ✨ #37447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,9 +1654,7 @@ class HybridCache(Cache):
```
"""

# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True
is_compileable = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! Can we update the cache to also init the layers lazily like we dofor HybridChunked cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker HybridChunkedCache only works if we don't compile the first forward pass, HybricCache works regardless of we compile the first forward pass or not. torch._dynamo.mark_static_address can't be called inside torch.compile, which lazy init does.

This means that if a user creates their own custom code with HybridChunkedCache, they can't simply compile the forward pass. If anything, HybridChunkedCache should move away from lazy init :P

Copy link
Member Author

@gante gante Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted offline:

  1. lazy init is needed for TP
  2. however, lazy init is incompatible with compiling the first forward pass (prefill). lazy init + @torch.compiler.disable() doesn't solve it either
  3. solution: add a new flag lazy_init = None. If torch.distributed is initialized and the flag is unset, then it will be True.
  4. Apply this change to ALL caches -> ALL caches compatible with TP + no non-TP drawbacks


def __init__(
self,
Expand Down Expand Up @@ -1858,8 +1856,6 @@ class HybridChunkedCache(Cache):
```
"""

# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it.
is_compileable = True

def __init__(
Expand Down
33 changes: 11 additions & 22 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from .configuration_cohere2 import Cohere2Config


Expand Down Expand Up @@ -300,6 +301,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window

@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -309,7 +311,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -330,7 +331,6 @@ def forward(
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""

if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
Expand All @@ -349,11 +349,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
Comment on lines +357 to +360
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure this is cuda graph compatible?~

Copy link
Member Author

@gante gante Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, see e.g. scripts at the top of the PR header

also, see this comment explaining why :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nice

attention_mask = attention_mask[:, :, :, mask_indexes]

residual = hidden_states

Expand Down Expand Up @@ -539,6 +544,7 @@ def set_input_embeddings(self, value):

@can_return_tuple
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand All @@ -550,7 +556,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -590,16 +595,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
Expand Down Expand Up @@ -627,7 +622,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -638,7 +632,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)

Expand Down Expand Up @@ -928,10 +921,6 @@ def prepare_inputs_for_generation(
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0

if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
Expand Down
46 changes: 18 additions & 28 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,12 @@
from ...cache_utils import Cache, HybridCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
)
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
logging,
)
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from ..cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
Expand All @@ -45,6 +42,9 @@
from ..gemma2.modeling_gemma2 import Gemma2Model


COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -351,6 +351,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window

@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -360,7 +361,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -381,7 +381,6 @@ def forward(
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""

if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
Expand All @@ -400,11 +399,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]
Comment on lines +405 to +411
Copy link
Member Author

@gante gante Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Core change for the PR.

attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] requires either passing an integer in the signature to build offset (previous solution, triggers recompilation at each forward 🚫 ) or doing data-dependent slicing using offset as a tensor (crashes compile 🚫 )

The solution is to:

  1. build an arange from shapes ✅ (we can use shapes to create compile-compatible arrays on the fly, as opposed to using arbitrary tensors to create tensors)
  2. add some tensor (offset) to a tensor (fixed-shape array) ✅
  3. slice a tensor (attention mask) with another tensor (offset modified fixed-shape array) ✅

(Note: at first I tried torch.roll + fixed-shape slicing, but torch.roll doesn't support the argument shifts=offset, shifts has to be an integer 😢 )


residual = hidden_states

Expand Down Expand Up @@ -452,6 +456,9 @@ def __init__(self, config: Cohere2Config):
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.rotary_emb = Cohere2RotaryEmbedding(config=config)

@can_return_tuple
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand All @@ -463,7 +470,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -503,16 +509,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
Expand Down Expand Up @@ -540,7 +536,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -551,7 +546,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)

Expand Down Expand Up @@ -625,10 +619,6 @@ def prepare_inputs_for_generation(
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0

if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
Expand Down
31 changes: 11 additions & 20 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from .configuration_gemma2 import Gemma2Config


Expand Down Expand Up @@ -285,6 +286,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window

@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -295,7 +297,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
Expand All @@ -314,11 +315,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
)
mask_indexes += offset
attention_mask = attention_mask[:, :, :, mask_indexes]

residual = hidden_states

Expand Down Expand Up @@ -542,6 +548,7 @@ def set_input_embeddings(self, value):

@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
@deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand All @@ -553,7 +560,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -594,16 +600,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
Expand Down Expand Up @@ -639,7 +635,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -651,7 +646,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)

Expand Down Expand Up @@ -922,9 +916,6 @@ def prepare_inputs_for_generation(
**kwargs,
)

# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)

Expand Down
Loading