@@ -296,6 +296,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
296
296
self .is_sliding = (layer_idx + 1 ) % self .config .sliding_window_pattern != 0
297
297
self .sliding_window = config .sliding_window
298
298
299
+ @deprecate_kwarg ("last_cache_position" , version = "4.53.0" )
299
300
def forward (
300
301
self ,
301
302
hidden_states : torch .Tensor ,
@@ -305,7 +306,6 @@ def forward(
305
306
output_attentions : Optional [bool ] = False ,
306
307
use_cache : Optional [bool ] = False ,
307
308
cache_position : Optional [torch .LongTensor ] = None ,
308
- last_cache_position : int = 0 ,
309
309
** kwargs : Unpack [FlashAttentionKwargs ],
310
310
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
311
311
"""
@@ -326,7 +326,6 @@ def forward(
326
326
(see `past_key_values`).
327
327
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
328
328
Indices depicting the position of the input sequence tokens in the sequence
329
- last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
330
329
"""
331
330
332
331
if self .is_sliding and attention_mask is not None : # efficient SDPA and no padding
@@ -345,11 +344,16 @@ def forward(
345
344
)
346
345
attention_mask = torch .where (sliding_window_mask , min_dtype , attention_mask )
347
346
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
348
- # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
349
- offset = last_cache_position - effective_seq_len
347
+ offset = cache_position [- 1 ] - effective_seq_len
350
348
# Should only be used when beyond the sliding window (i.e. offset > 0)
351
349
offset = max (0 , offset )
352
- attention_mask = attention_mask [:, :, :, offset : offset + effective_seq_len ]
350
+ # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
351
+ # but without data-dependent slicing (i.e. torch.compile friendly)
352
+ mask_indexes = torch .arange (
353
+ min (effective_seq_len , attention_mask .shape [- 1 ]), device = attention_mask .device
354
+ )
355
+ mask_indexes += offset
356
+ attention_mask = attention_mask [:, :, :, mask_indexes ]
353
357
354
358
residual = hidden_states
355
359
@@ -428,73 +432,6 @@ def _init_weights(self, module):
428
432
module .weight .data [module .padding_idx ].zero_ ()
429
433
430
434
431
- COHERE2_INPUTS_DOCSTRING = r"""
432
- Args:
433
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
434
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
435
- it.
436
-
437
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
438
- [`PreTrainedTokenizer.__call__`] for details.
439
-
440
- [What are input IDs?](../glossary#input-ids)
441
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
442
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
443
-
444
- - 1 for tokens that are **not masked**,
445
- - 0 for tokens that are **masked**.
446
-
447
- [What are attention masks?](../glossary#attention-mask)
448
-
449
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
450
- [`PreTrainedTokenizer.__call__`] for details.
451
-
452
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
453
- `past_key_values`).
454
-
455
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
456
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
457
- information on the default strategy.
458
-
459
- - 1 indicates the head is **not masked**,
460
- - 0 indicates the head is **masked**.
461
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
462
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
463
- config.n_positions - 1]`.
464
-
465
- [What are position IDs?](../glossary#position-ids)
466
- past_key_values (`Cache`, *optional*):
467
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
468
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
469
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
470
-
471
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
472
-
473
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
474
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
475
- of shape `(batch_size, sequence_length)`.
476
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
477
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
478
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
479
- model's internal embedding lookup matrix.
480
- use_cache (`bool`, *optional*):
481
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
482
- `past_key_values`).
483
- output_attentions (`bool`, *optional*):
484
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
485
- tensors for more detail.
486
- output_hidden_states (`bool`, *optional*):
487
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
488
- more detail.
489
- return_dict (`bool`, *optional*):
490
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
491
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
492
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
493
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
494
- the complete sequence length.
495
- """
496
-
497
-
498
435
@add_start_docstrings (
499
436
"The bare Cohere2 Model outputting raw hidden-states without any specific head on top." ,
500
437
COHERE2_START_DOCSTRING ,
@@ -528,8 +465,7 @@ def get_input_embeddings(self):
528
465
def set_input_embeddings (self , value ):
529
466
self .embed_tokens = value
530
467
531
- @can_return_tuple
532
- @add_start_docstrings_to_model_forward (COHERE2_INPUTS_DOCSTRING )
468
+ @deprecate_kwarg ("last_cache_position" , version = "4.53.0" )
533
469
def forward (
534
470
self ,
535
471
input_ids : Optional [torch .LongTensor ] = None ,
@@ -541,7 +477,6 @@ def forward(
541
477
output_attentions : Optional [bool ] = None ,
542
478
output_hidden_states : Optional [bool ] = None ,
543
479
cache_position : Optional [torch .LongTensor ] = None ,
544
- last_cache_position : Optional [int ] = None ,
545
480
** flash_attn_kwargs : Unpack [FlashAttentionKwargs ],
546
481
) -> BaseModelOutputWithPast :
547
482
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
@@ -581,16 +516,6 @@ def forward(
581
516
if position_ids is None :
582
517
position_ids = cache_position .unsqueeze (0 )
583
518
584
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
585
- # (retrieving the same value from `cache_position` later on would crash dynamo)
586
- if last_cache_position is None :
587
- last_cache_position = 0
588
- if attention_mask is not None :
589
- # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
590
- # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
591
- last_cache_position = (
592
- attention_mask .shape [- 1 ] if attention_mask .dim () == 2 else cache_position [- 1 ].item ()
593
- )
594
519
causal_mask = self ._update_causal_mask (
595
520
attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
596
521
)
@@ -618,7 +543,6 @@ def forward(
618
543
output_attentions ,
619
544
use_cache ,
620
545
cache_position ,
621
- last_cache_position ,
622
546
)
623
547
else :
624
548
layer_outputs = decoder_layer (
@@ -629,7 +553,6 @@ def forward(
629
553
output_attentions = output_attentions ,
630
554
use_cache = use_cache ,
631
555
cache_position = cache_position ,
632
- last_cache_position = last_cache_position ,
633
556
** flash_attn_kwargs ,
634
557
)
635
558
@@ -748,6 +671,73 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
748
671
class KwargsForCausalLM (FlashAttentionKwargs , LossKwargs ): ...
749
672
750
673
674
+ COHERE2_INPUTS_DOCSTRING = r"""
675
+ Args:
676
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
677
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
678
+ it.
679
+
680
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
681
+ [`PreTrainedTokenizer.__call__`] for details.
682
+
683
+ [What are input IDs?](../glossary#input-ids)
684
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
685
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
686
+
687
+ - 1 for tokens that are **not masked**,
688
+ - 0 for tokens that are **masked**.
689
+
690
+ [What are attention masks?](../glossary#attention-mask)
691
+
692
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
693
+ [`PreTrainedTokenizer.__call__`] for details.
694
+
695
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
696
+ `past_key_values`).
697
+
698
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
699
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
700
+ information on the default strategy.
701
+
702
+ - 1 indicates the head is **not masked**,
703
+ - 0 indicates the head is **masked**.
704
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
705
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
706
+ config.n_positions - 1]`.
707
+
708
+ [What are position IDs?](../glossary#position-ids)
709
+ past_key_values (`Cache`, *optional*):
710
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
711
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
712
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
713
+
714
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
715
+
716
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
717
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
718
+ of shape `(batch_size, sequence_length)`.
719
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
720
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
721
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
722
+ model's internal embedding lookup matrix.
723
+ use_cache (`bool`, *optional*):
724
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
725
+ `past_key_values`).
726
+ output_attentions (`bool`, *optional*):
727
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
728
+ tensors for more detail.
729
+ output_hidden_states (`bool`, *optional*):
730
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
731
+ more detail.
732
+ return_dict (`bool`, *optional*):
733
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
734
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
735
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
736
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
737
+ the complete sequence length.
738
+ """
739
+
740
+
751
741
class Cohere2ForCausalLM (Cohere2PreTrainedModel , GenerationMixin ):
752
742
_tied_weights_keys = ["lm_head.weight" ]
753
743
_tp_plan = {"lm_head" : "colwise_rep" }
@@ -916,10 +906,6 @@ def prepare_inputs_for_generation(
916
906
# The clone here is for the same reason as for `position_ids`.
917
907
model_inputs = {"input_ids" : input_ids .clone (memory_format = torch .contiguous_format ), "inputs_embeds" : None }
918
908
919
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
920
- # (retrieving the same value from `cache_position` later on would crash dynamo)
921
- model_inputs ["last_cache_position" ] = attention_mask .shape [- 1 ] if attention_mask is not None else 0
922
-
923
909
if (
924
910
isinstance (past_key_values , HybridCache )
925
911
and attention_mask .ndim == 2
0 commit comments