Skip to content

Commit aa907f2

Browse files
committed
working
1 parent 7e9b57c commit aa907f2

File tree

8 files changed

+268
-354
lines changed

8 files changed

+268
-354
lines changed

src/transformers/cache_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,9 +1652,7 @@ class HybridCache(Cache):
16521652
```
16531653
"""
16541654

1655-
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
1656-
# ALL changes from the PR that commented the line below when reactivating it.
1657-
# is_compileable = True
1655+
is_compileable = True
16581656

16591657
def __init__(
16601658
self,
@@ -1856,8 +1854,6 @@ class HybridChunkedCache(Cache):
18561854
```
18571855
"""
18581856

1859-
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
1860-
# ALL changes from the PR that commented the line below when reactivating it.
18611857
is_compileable = True
18621858

18631859
def __init__(

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 77 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
296296
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
297297
self.sliding_window = config.sliding_window
298298

299+
@deprecate_kwarg("last_cache_position", version="4.53.0")
299300
def forward(
300301
self,
301302
hidden_states: torch.Tensor,
@@ -305,7 +306,6 @@ def forward(
305306
output_attentions: Optional[bool] = False,
306307
use_cache: Optional[bool] = False,
307308
cache_position: Optional[torch.LongTensor] = None,
308-
last_cache_position: int = 0,
309309
**kwargs: Unpack[FlashAttentionKwargs],
310310
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
311311
"""
@@ -326,7 +326,6 @@ def forward(
326326
(see `past_key_values`).
327327
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
328328
Indices depicting the position of the input sequence tokens in the sequence
329-
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
330329
"""
331330

332331
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -345,11 +344,16 @@ def forward(
345344
)
346345
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
347346
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
348-
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
349-
offset = last_cache_position - effective_seq_len
347+
offset = cache_position[-1] - effective_seq_len
350348
# Should only be used when beyond the sliding window (i.e. offset > 0)
351349
offset = max(0, offset)
352-
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
350+
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
351+
# but without data-dependent slicing (i.e. torch.compile friendly)
352+
mask_indexes = torch.arange(
353+
min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
354+
)
355+
mask_indexes += offset
356+
attention_mask = attention_mask[:, :, :, mask_indexes]
353357

354358
residual = hidden_states
355359

@@ -428,73 +432,6 @@ def _init_weights(self, module):
428432
module.weight.data[module.padding_idx].zero_()
429433

430434

431-
COHERE2_INPUTS_DOCSTRING = r"""
432-
Args:
433-
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
434-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
435-
it.
436-
437-
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
438-
[`PreTrainedTokenizer.__call__`] for details.
439-
440-
[What are input IDs?](../glossary#input-ids)
441-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
442-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
443-
444-
- 1 for tokens that are **not masked**,
445-
- 0 for tokens that are **masked**.
446-
447-
[What are attention masks?](../glossary#attention-mask)
448-
449-
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
450-
[`PreTrainedTokenizer.__call__`] for details.
451-
452-
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
453-
`past_key_values`).
454-
455-
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
456-
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
457-
information on the default strategy.
458-
459-
- 1 indicates the head is **not masked**,
460-
- 0 indicates the head is **masked**.
461-
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
462-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
463-
config.n_positions - 1]`.
464-
465-
[What are position IDs?](../glossary#position-ids)
466-
past_key_values (`Cache`, *optional*):
467-
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
468-
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
469-
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
470-
471-
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
472-
473-
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
474-
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
475-
of shape `(batch_size, sequence_length)`.
476-
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
477-
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
478-
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
479-
model's internal embedding lookup matrix.
480-
use_cache (`bool`, *optional*):
481-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
482-
`past_key_values`).
483-
output_attentions (`bool`, *optional*):
484-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
485-
tensors for more detail.
486-
output_hidden_states (`bool`, *optional*):
487-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
488-
more detail.
489-
return_dict (`bool`, *optional*):
490-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
491-
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
492-
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
493-
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
494-
the complete sequence length.
495-
"""
496-
497-
498435
@add_start_docstrings(
499436
"The bare Cohere2 Model outputting raw hidden-states without any specific head on top.",
500437
COHERE2_START_DOCSTRING,
@@ -528,8 +465,7 @@ def get_input_embeddings(self):
528465
def set_input_embeddings(self, value):
529466
self.embed_tokens = value
530467

531-
@can_return_tuple
532-
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
468+
@deprecate_kwarg("last_cache_position", version="4.53.0")
533469
def forward(
534470
self,
535471
input_ids: Optional[torch.LongTensor] = None,
@@ -541,7 +477,6 @@ def forward(
541477
output_attentions: Optional[bool] = None,
542478
output_hidden_states: Optional[bool] = None,
543479
cache_position: Optional[torch.LongTensor] = None,
544-
last_cache_position: Optional[int] = None,
545480
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
546481
) -> BaseModelOutputWithPast:
547482
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -581,16 +516,6 @@ def forward(
581516
if position_ids is None:
582517
position_ids = cache_position.unsqueeze(0)
583518

584-
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
585-
# (retrieving the same value from `cache_position` later on would crash dynamo)
586-
if last_cache_position is None:
587-
last_cache_position = 0
588-
if attention_mask is not None:
589-
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
590-
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
591-
last_cache_position = (
592-
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
593-
)
594519
causal_mask = self._update_causal_mask(
595520
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
596521
)
@@ -618,7 +543,6 @@ def forward(
618543
output_attentions,
619544
use_cache,
620545
cache_position,
621-
last_cache_position,
622546
)
623547
else:
624548
layer_outputs = decoder_layer(
@@ -629,7 +553,6 @@ def forward(
629553
output_attentions=output_attentions,
630554
use_cache=use_cache,
631555
cache_position=cache_position,
632-
last_cache_position=last_cache_position,
633556
**flash_attn_kwargs,
634557
)
635558

@@ -748,6 +671,73 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
748671
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
749672

750673

674+
COHERE2_INPUTS_DOCSTRING = r"""
675+
Args:
676+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
677+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
678+
it.
679+
680+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
681+
[`PreTrainedTokenizer.__call__`] for details.
682+
683+
[What are input IDs?](../glossary#input-ids)
684+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
685+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
686+
687+
- 1 for tokens that are **not masked**,
688+
- 0 for tokens that are **masked**.
689+
690+
[What are attention masks?](../glossary#attention-mask)
691+
692+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
693+
[`PreTrainedTokenizer.__call__`] for details.
694+
695+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
696+
`past_key_values`).
697+
698+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
699+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
700+
information on the default strategy.
701+
702+
- 1 indicates the head is **not masked**,
703+
- 0 indicates the head is **masked**.
704+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
705+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
706+
config.n_positions - 1]`.
707+
708+
[What are position IDs?](../glossary#position-ids)
709+
past_key_values (`Cache`, *optional*):
710+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
711+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
712+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
713+
714+
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
715+
716+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
717+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
718+
of shape `(batch_size, sequence_length)`.
719+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
720+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
721+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
722+
model's internal embedding lookup matrix.
723+
use_cache (`bool`, *optional*):
724+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
725+
`past_key_values`).
726+
output_attentions (`bool`, *optional*):
727+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
728+
tensors for more detail.
729+
output_hidden_states (`bool`, *optional*):
730+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
731+
more detail.
732+
return_dict (`bool`, *optional*):
733+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
734+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
735+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
736+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
737+
the complete sequence length.
738+
"""
739+
740+
751741
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
752742
_tied_weights_keys = ["lm_head.weight"]
753743
_tp_plan = {"lm_head": "colwise_rep"}
@@ -916,10 +906,6 @@ def prepare_inputs_for_generation(
916906
# The clone here is for the same reason as for `position_ids`.
917907
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
918908

919-
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
920-
# (retrieving the same value from `cache_position` later on would crash dynamo)
921-
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
922-
923909
if (
924910
isinstance(past_key_values, HybridCache)
925911
and attention_mask.ndim == 2

0 commit comments

Comments
 (0)