Skip to content

Commit c06e551

Browse files
authored
Merge branch 'main' into patch-1
2 parents aa430a7 + bf41e54 commit c06e551

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/modeling_flash_attention_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _flash_attention_forward(
280280
query_states: torch.Tensor,
281281
key_states: torch.Tensor,
282282
value_states: torch.Tensor,
283-
attention_mask: torch.Tensor,
283+
attention_mask: Optional[torch.Tensor],
284284
query_length: int,
285285
is_causal: bool,
286286
dropout: float = 0.0,
@@ -308,7 +308,7 @@ def _flash_attention_forward(
308308
Input key states to be passed to Flash Attention API
309309
value_states (`torch.Tensor`):
310310
Input value states to be passed to Flash Attention API
311-
attention_mask (`torch.Tensor`):
311+
attention_mask (`torch.Tensor`, *optional*):
312312
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
313313
position of padding tokens and 1 for the position of non-padding tokens.
314314
dropout (`float`):

0 commit comments

Comments
 (0)