Skip to content

[Gemma2] Use nn.SDPA via MultiHeadAttention #2844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions tests/torchtune/models/gemma2/test_sliding_attention_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.gemma2._attention_mask import get_sliding_attention_mask


class TestGetSlidingAttentionMask:
@pytest.fixture
def basic_params(self):
return {"bsz": 2, "seq_len": 4, "sliding_window_size": 2, "device": None}

def test_get_sliding_attention_mask(self, basic_params):
"""Test that when mask is None, a causal mask is created and sliding window is applied."""
bsz = 2
seq_len = 4
sliding_window_size = 2
mask = get_sliding_attention_mask(
mask=None,
sliding_window_size=basic_params["sliding_window_size"],
bsz=basic_params["bsz"],
seq_len=basic_params["seq_len"],
device=basic_params["device"],
)

assert mask.shape == (
basic_params["bsz"],
basic_params["seq_len"],
basic_params["seq_len"],
)
assert mask.dtype == torch.bool

# Check that the mask has the expected sliding window pattern
# True positions can be attended to, False positions are masked
expected_pattern = torch.tensor(
[
[True, False, False, False],
[True, True, False, False],
[False, True, True, False],
[False, False, True, True],
],
dtype=torch.bool,
)

# Check first batch element
torch.testing.assert_close(mask[0], expected_pattern)
# All batch elements should be identical
torch.testing.assert_close(mask[0], mask[1])

def test_get_sliding_attention_mask_different_window_sizes(self):
"""Test sliding window with different window sizes."""
bsz, seq_len = 1, 5

# Test window size 1 (only current position)
mask = get_sliding_attention_mask(
mask=None,
sliding_window_size=1,
bsz=bsz,
seq_len=seq_len,
device=None,
)

expected_window_1 = torch.tensor(
[
[True, False, False, False, False],
[False, True, False, False, False],
[False, False, True, False, False],
[False, False, False, True, False],
[False, False, False, False, True],
],
dtype=torch.bool,
)

torch.testing.assert_close(mask[0], expected_window_1)

# Test window size 3
mask = get_sliding_attention_mask(
mask=None,
sliding_window_size=3,
bsz=bsz,
seq_len=seq_len,
device=None,
)

expected_window_3 = torch.tensor(
[
[True, False, False, False, False],
[True, True, False, False, False],
[True, True, True, False, False],
[False, True, True, True, False],
[False, False, True, True, True],
],
dtype=torch.bool,
)

torch.testing.assert_close(mask[0], expected_window_3)

def test_get_sliding_attention_mask_large_window(self):
"""Test sliding window larger than sequence length."""
bsz, seq_len = 1, 3
sliding_window_size = 5 # Larger than seq_len

mask = get_sliding_attention_mask(
mask=None,
sliding_window_size=sliding_window_size,
bsz=bsz,
seq_len=seq_len,
device=None,
)

# Should behave like a regular causal mask when window is larger than seq_len
expected_causal = torch.tensor(
[
[True, False, False],
[True, True, False],
[True, True, True],
],
dtype=torch.bool,
)

torch.testing.assert_close(mask[0], expected_causal)
2 changes: 1 addition & 1 deletion tests/torchtune/modules/test_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_flex_attention(self, mock_sdpa, mock_flex):
_attention_call = _sdpa_or_flex_attention()
_ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal)
mock_sdpa.assert_not_called()
mock_flex.assert_called_with(q, k, v, block_mask=attn_mask)
mock_flex.assert_called_with(q, k, v, block_mask=attn_mask, scale=None)
# If mask is not a BlockMask, then we should call SDPA
_attention_call = _sdpa_or_flex_attention()
_ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal)
Expand Down
52 changes: 52 additions & 0 deletions torchtune/models/gemma2/_attention_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

from torchtune.modules.attention_utils import _MaskType


def get_sliding_attention_mask(
mask: Optional[_MaskType],
sliding_window_size: int,
bsz: int,
seq_len: int,
device: Optional[torch.device] = None,
) -> _MaskType:
"""
Args:
mask (Optional[_MaskType]): Mask to apply to the attention scores.
sliding_window_size (int): Sliding window size to apply to the attention mask.
bsz (int): Batch size. Argument is unused, but listed for consistency.
seq_len (int): Sequence length.
device (Optional[torch.device]): Device to use for the mask. Defaults to None.

Returns:
A tensor mask that applies sliding window masking.

Raises:
ValueError: If the input mask is not a Tensor
"""

if mask is None:
mask = torch.tril(
torch.ones(size=(bsz, seq_len, seq_len), dtype=torch.bool).to(device)
)

if not isinstance(mask, torch.Tensor):
raise ValueError(
f"For non-flex attention, mask must be a Tensor. Got: {type(mask)}"
)

all_ones = torch.ones_like(mask, dtype=torch.bool)
sliding_mask = torch.triu(all_ones, -1 * sliding_window_size + 1) & torch.tril(
all_ones, sliding_window_size - 1
)
mask = mask & sliding_mask

return mask
Loading
Loading