diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index b4d3b6fc3e01..56c35e8d1956 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -122,7 +122,7 @@ def make_flex_block_causal_mask( if attention_chunk_size is not None: # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] - document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size) + chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (attention_chunk_size) # Instead of passing a tensor mask, flex attention requires a mask_mod function # that determines which elements of QK^T should be included in the attention @@ -143,6 +143,16 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): final_mask = causal_mask & padding_mask & document_mask return final_mask + def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + """ + Combines the chunk mask with the causal mask for chunked attention. + """ + chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx] + causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx) + return chunk_mask & causal_doc_mask + + mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod + if offsets is not None: q_offset = offsets[0] kv_offset = offsets[1] @@ -150,9 +160,9 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): def mask_mod(batch_idx, head_idx, q_idx, kv_idx): offset_q = q_idx + q_offset offset_kv = kv_idx + kv_offset - return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv) + return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv) else: - mask_mod = causal_mask_mod + mask_mod = mask_mod_maybe_combined return create_block_causal_mask_flex( mask_mod=mask_mod, B=batch_size,