-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[Gemma3] compile ✨ #37447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Gemma3] compile ✨ #37447
Changes from all commits
a5eb7ce
e397840
ae15edf
4453f58
690a168
7085bc3
32ac69b
9638ad7
1930251
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
logging, | ||
replace_return_docstrings, | ||
) | ||
from ...utils.deprecation import deprecate_kwarg | ||
from .configuration_cohere2 import Cohere2Config | ||
|
||
|
||
|
@@ -300,6 +301,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int): | |
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 | ||
self.sliding_window = config.sliding_window | ||
|
||
@deprecate_kwarg("last_cache_position", version="4.53.0") | ||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
|
@@ -309,7 +311,6 @@ def forward( | |
output_attentions: Optional[bool] = False, | ||
use_cache: Optional[bool] = False, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
last_cache_position: int = 0, | ||
**kwargs: Unpack[FlashAttentionKwargs], | ||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||
""" | ||
|
@@ -330,7 +331,6 @@ def forward( | |
(see `past_key_values`). | ||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): | ||
Indices depicting the position of the input sequence tokens in the sequence | ||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing | ||
""" | ||
|
||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding | ||
|
@@ -349,11 +349,16 @@ def forward( | |
) | ||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) | ||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing | ||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo | ||
offset = last_cache_position - effective_seq_len | ||
offset = cache_position[-1] - effective_seq_len + 1 | ||
# Should only be used when beyond the sliding window (i.e. offset > 0) | ||
offset = max(0, offset) | ||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] | ||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, | ||
# but without data-dependent slicing (i.e. torch.compile friendly) | ||
mask_indexes = torch.arange( | ||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device | ||
) | ||
mask_indexes += offset | ||
Comment on lines
+357
to
+360
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you sure this is cuda graph compatible?~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, see e.g. scripts at the top of the PR header also, see this comment explaining why :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super nice |
||
attention_mask = attention_mask[:, :, :, mask_indexes] | ||
|
||
residual = hidden_states | ||
|
||
|
@@ -539,6 +544,7 @@ def set_input_embeddings(self, value): | |
|
||
@can_return_tuple | ||
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) | ||
@deprecate_kwarg("last_cache_position", version="4.53.0") | ||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
|
@@ -550,7 +556,6 @@ def forward( | |
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
last_cache_position: Optional[int] = None, | ||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], | ||
) -> BaseModelOutputWithPast: | ||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
|
@@ -590,16 +595,6 @@ def forward( | |
if position_ids is None: | ||
position_ids = cache_position.unsqueeze(0) | ||
|
||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing | ||
# (retrieving the same value from `cache_position` later on would crash dynamo) | ||
if last_cache_position is None: | ||
last_cache_position = 0 | ||
if attention_mask is not None: | ||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position | ||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice) | ||
last_cache_position = ( | ||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() | ||
) | ||
causal_mask = self._update_causal_mask( | ||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | ||
) | ||
|
@@ -627,7 +622,6 @@ def forward( | |
output_attentions, | ||
use_cache, | ||
cache_position, | ||
last_cache_position, | ||
) | ||
else: | ||
layer_outputs = decoder_layer( | ||
|
@@ -638,7 +632,6 @@ def forward( | |
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
cache_position=cache_position, | ||
last_cache_position=last_cache_position, | ||
**flash_attn_kwargs, | ||
) | ||
|
||
|
@@ -928,10 +921,6 @@ def prepare_inputs_for_generation( | |
# The clone here is for the same reason as for `position_ids`. | ||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} | ||
|
||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing | ||
# (retrieving the same value from `cache_position` later on would crash dynamo) | ||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 | ||
|
||
if ( | ||
isinstance(past_key_values, HybridCache) | ||
and attention_mask.ndim == 2 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,15 +23,12 @@ | |
from ...cache_utils import Cache, HybridCache | ||
from ...configuration_utils import PretrainedConfig | ||
from ...modeling_flash_attention_utils import FlashAttentionKwargs | ||
from ...modeling_outputs import ( | ||
BaseModelOutputWithPast, | ||
) | ||
from ...modeling_outputs import BaseModelOutputWithPast | ||
from ...modeling_rope_utils import rope_config_validation | ||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS | ||
from ...processing_utils import Unpack | ||
from ...utils import ( | ||
logging, | ||
) | ||
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging | ||
from ...utils.deprecation import deprecate_kwarg | ||
from ..cohere.modeling_cohere import ( | ||
CohereAttention, | ||
CohereDecoderLayer, | ||
|
@@ -45,6 +42,9 @@ | |
from ..gemma2.modeling_gemma2 import Gemma2Model | ||
|
||
|
||
COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
|
@@ -351,6 +351,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int): | |
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 | ||
self.sliding_window = config.sliding_window | ||
|
||
@deprecate_kwarg("last_cache_position", version="4.53.0") | ||
gante marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
|
@@ -360,7 +361,6 @@ def forward( | |
output_attentions: Optional[bool] = False, | ||
use_cache: Optional[bool] = False, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
last_cache_position: int = 0, | ||
**kwargs: Unpack[FlashAttentionKwargs], | ||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||
""" | ||
|
@@ -381,7 +381,6 @@ def forward( | |
(see `past_key_values`). | ||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): | ||
Indices depicting the position of the input sequence tokens in the sequence | ||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing | ||
""" | ||
|
||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding | ||
|
@@ -400,11 +399,16 @@ def forward( | |
) | ||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) | ||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing | ||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo | ||
offset = last_cache_position - effective_seq_len | ||
offset = cache_position[-1] - effective_seq_len + 1 | ||
# Should only be used when beyond the sliding window (i.e. offset > 0) | ||
offset = max(0, offset) | ||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] | ||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, | ||
# but without data-dependent slicing (i.e. torch.compile friendly) | ||
mask_indexes = torch.arange( | ||
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device | ||
) | ||
mask_indexes += offset | ||
attention_mask = attention_mask[:, :, :, mask_indexes] | ||
Comment on lines
+405
to
+411
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Core change for the PR.
The solution is to:
(Note: at first I tried |
||
|
||
residual = hidden_states | ||
|
||
|
@@ -452,6 +456,9 @@ def __init__(self, config: Cohere2Config): | |
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) | ||
self.rotary_emb = Cohere2RotaryEmbedding(config=config) | ||
|
||
@can_return_tuple | ||
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) | ||
@deprecate_kwarg("last_cache_position", version="4.53.0") | ||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
|
@@ -463,7 +470,6 @@ def forward( | |
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
cache_position: Optional[torch.LongTensor] = None, | ||
last_cache_position: Optional[int] = None, | ||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], | ||
) -> BaseModelOutputWithPast: | ||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
|
@@ -503,16 +509,6 @@ def forward( | |
if position_ids is None: | ||
position_ids = cache_position.unsqueeze(0) | ||
|
||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing | ||
# (retrieving the same value from `cache_position` later on would crash dynamo) | ||
if last_cache_position is None: | ||
last_cache_position = 0 | ||
if attention_mask is not None: | ||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position | ||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice) | ||
last_cache_position = ( | ||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() | ||
) | ||
causal_mask = self._update_causal_mask( | ||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | ||
) | ||
|
@@ -540,7 +536,6 @@ def forward( | |
output_attentions, | ||
use_cache, | ||
cache_position, | ||
last_cache_position, | ||
) | ||
else: | ||
layer_outputs = decoder_layer( | ||
|
@@ -551,7 +546,6 @@ def forward( | |
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
cache_position=cache_position, | ||
last_cache_position=last_cache_position, | ||
**flash_attn_kwargs, | ||
) | ||
|
||
|
@@ -625,10 +619,6 @@ def prepare_inputs_for_generation( | |
# The clone here is for the same reason as for `position_ids`. | ||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} | ||
|
||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing | ||
# (retrieving the same value from `cache_position` later on would crash dynamo) | ||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 | ||
|
||
if ( | ||
isinstance(past_key_values, HybridCache) | ||
and attention_mask.ndim == 2 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! Can we update the cache to also init the layers lazily like we dofor HybridChunked cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker
HybridChunkedCache
only works if we don't compile the first forward pass,HybricCache
works regardless of we compile the first forward pass or not.torch._dynamo.mark_static_address
can't be called insidetorch.compile
, which lazy init does.This means that if a user creates their own custom code with
HybridChunkedCache
, they can't simply compile the forward pass. If anything,HybridChunkedCache
should move away from lazy init :PUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chatted offline:
@torch.compiler.disable()
doesn't solve it eitherlazy_init = None
. Iftorch.distributed
is initialized and the flag is unset, then it will beTrue
.