File tree 1 file changed +2
-2
lines changed
1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -280,7 +280,7 @@ def _flash_attention_forward(
280
280
query_states : torch .Tensor ,
281
281
key_states : torch .Tensor ,
282
282
value_states : torch .Tensor ,
283
- attention_mask : torch .Tensor ,
283
+ attention_mask : Optional [ torch .Tensor ] ,
284
284
query_length : int ,
285
285
is_causal : bool ,
286
286
dropout : float = 0.0 ,
@@ -308,7 +308,7 @@ def _flash_attention_forward(
308
308
Input key states to be passed to Flash Attention API
309
309
value_states (`torch.Tensor`):
310
310
Input value states to be passed to Flash Attention API
311
- attention_mask (`torch.Tensor`):
311
+ attention_mask (`torch.Tensor`, *optional* ):
312
312
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
313
313
position of padding tokens and 1 for the position of non-padding tokens.
314
314
dropout (`float`):
You can’t perform that action at this time.
0 commit comments