diff --git a/tests/torchtune/models/gemma2/test_sliding_attention_mask.py b/tests/torchtune/models/gemma2/test_sliding_attention_mask.py new file mode 100644 index 0000000000..447af59369 --- /dev/null +++ b/tests/torchtune/models/gemma2/test_sliding_attention_mask.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtune.models.gemma2._attention_mask import get_sliding_attention_mask + + +class TestGetSlidingAttentionMask: + @pytest.fixture + def basic_params(self): + return {"bsz": 2, "seq_len": 4, "sliding_window_size": 2, "device": None} + + def test_get_sliding_attention_mask(self, basic_params): + """Test that when mask is None, a causal mask is created and sliding window is applied.""" + bsz = 2 + seq_len = 4 + sliding_window_size = 2 + mask = get_sliding_attention_mask( + mask=None, + sliding_window_size=basic_params["sliding_window_size"], + bsz=basic_params["bsz"], + seq_len=basic_params["seq_len"], + device=basic_params["device"], + ) + + assert mask.shape == ( + basic_params["bsz"], + basic_params["seq_len"], + basic_params["seq_len"], + ) + assert mask.dtype == torch.bool + + # Check that the mask has the expected sliding window pattern + # True positions can be attended to, False positions are masked + expected_pattern = torch.tensor( + [ + [True, False, False, False], + [True, True, False, False], + [False, True, True, False], + [False, False, True, True], + ], + dtype=torch.bool, + ) + + # Check first batch element + torch.testing.assert_close(mask[0], expected_pattern) + # All batch elements should be identical + torch.testing.assert_close(mask[0], mask[1]) + + def test_get_sliding_attention_mask_different_window_sizes(self): + """Test sliding window with different window sizes.""" + bsz, seq_len = 1, 5 + + # Test window size 1 (only current position) + mask = get_sliding_attention_mask( + mask=None, + sliding_window_size=1, + bsz=bsz, + seq_len=seq_len, + device=None, + ) + + expected_window_1 = torch.tensor( + [ + [True, False, False, False, False], + [False, True, False, False, False], + [False, False, True, False, False], + [False, False, False, True, False], + [False, False, False, False, True], + ], + dtype=torch.bool, + ) + + torch.testing.assert_close(mask[0], expected_window_1) + + # Test window size 3 + mask = get_sliding_attention_mask( + mask=None, + sliding_window_size=3, + bsz=bsz, + seq_len=seq_len, + device=None, + ) + + expected_window_3 = torch.tensor( + [ + [True, False, False, False, False], + [True, True, False, False, False], + [True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True], + ], + dtype=torch.bool, + ) + + torch.testing.assert_close(mask[0], expected_window_3) + + def test_get_sliding_attention_mask_large_window(self): + """Test sliding window larger than sequence length.""" + bsz, seq_len = 1, 3 + sliding_window_size = 5 # Larger than seq_len + + mask = get_sliding_attention_mask( + mask=None, + sliding_window_size=sliding_window_size, + bsz=bsz, + seq_len=seq_len, + device=None, + ) + + # Should behave like a regular causal mask when window is larger than seq_len + expected_causal = torch.tensor( + [ + [True, False, False], + [True, True, False], + [True, True, True], + ], + dtype=torch.bool, + ) + + torch.testing.assert_close(mask[0], expected_causal) diff --git a/tests/torchtune/modules/test_attention_utils.py b/tests/torchtune/modules/test_attention_utils.py index feac48d2d3..a5a655c80c 100644 --- a/tests/torchtune/modules/test_attention_utils.py +++ b/tests/torchtune/modules/test_attention_utils.py @@ -122,7 +122,7 @@ def test_flex_attention(self, mock_sdpa, mock_flex): _attention_call = _sdpa_or_flex_attention() _ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal) mock_sdpa.assert_not_called() - mock_flex.assert_called_with(q, k, v, block_mask=attn_mask) + mock_flex.assert_called_with(q, k, v, block_mask=attn_mask, scale=None) # If mask is not a BlockMask, then we should call SDPA _attention_call = _sdpa_or_flex_attention() _ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal) diff --git a/torchtune/models/gemma2/_attention_mask.py b/torchtune/models/gemma2/_attention_mask.py new file mode 100644 index 0000000000..de720bbea0 --- /dev/null +++ b/torchtune/models/gemma2/_attention_mask.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torchtune.modules.attention_utils import _MaskType + + +def get_sliding_attention_mask( + mask: Optional[_MaskType], + sliding_window_size: int, + bsz: int, + seq_len: int, + device: Optional[torch.device] = None, +) -> _MaskType: + """ + Args: + mask (Optional[_MaskType]): Mask to apply to the attention scores. + sliding_window_size (int): Sliding window size to apply to the attention mask. + bsz (int): Batch size. Argument is unused, but listed for consistency. + seq_len (int): Sequence length. + device (Optional[torch.device]): Device to use for the mask. Defaults to None. + + Returns: + A tensor mask that applies sliding window masking. + + Raises: + ValueError: If the input mask is not a Tensor + """ + + if mask is None: + mask = torch.tril( + torch.ones(size=(bsz, seq_len, seq_len), dtype=torch.bool).to(device) + ) + + if not isinstance(mask, torch.Tensor): + raise ValueError( + f"For non-flex attention, mask must be a Tensor. Got: {type(mask)}" + ) + + all_ones = torch.ones_like(mask, dtype=torch.bool) + sliding_mask = torch.triu(all_ones, -1 * sliding_window_size + 1) & torch.tril( + all_ones, sliding_window_size - 1 + ) + mask = mask & sliding_mask + + return mask diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py index 623eb94f7a..01abbdce5d 100644 --- a/torchtune/models/gemma2/_component_builders.py +++ b/torchtune/models/gemma2/_component_builders.py @@ -4,25 +4,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from functools import partial from typing import Optional import torch from torch import nn -from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks -from typing import Optional +from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp +from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings +from torchtune.models.gemma.rms_norm import GemmaRMSNorm + +from torchtune.models.gemma2._attention_mask import get_sliding_attention_mask from torchtune.modules import ( FrozenNF4Linear, RotaryPositionalEmbeddings, + TiedLinear, + TransformerDecoder, TransformerSelfAttentionLayer, ) - -from torchtune.models.gemma2._attention import Gemma2Attention -from torchtune.models.gemma.rms_norm import GemmaRMSNorm -from torchtune.modules import TransformerDecoder, TiedLinear -from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings +from torchtune.modules.attention import MultiHeadAttention +from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear -from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp """ Component builders for the Gemma2 2B, 9B models and popular variants such as LoRA. @@ -36,6 +38,7 @@ the building blocks simple. """ + class TanhSoftCapping(nn.Module): def __init__( self, @@ -55,17 +58,13 @@ class Gemma2FinalNorm(nn.Module): """ Combines RMSNorm and SoftCapping """ - def __init__( - self, - capping_value: float, - embed_dim: int, - eps: float - ) -> None: + + def __init__(self, capping_value: float, embed_dim: int, eps: float) -> None: super().__init__() self.capping_value = capping_value self.rms_norm = GemmaRMSNorm(embed_dim, eps=eps) self.logit_capping = TanhSoftCapping(capping_value) - + def forward(self, x): x = self.rms_norm(x) x = self.logit_capping(x) @@ -115,14 +114,17 @@ def gemma2( Returns: TransformerDecoder: Instantiation of gemma model. """ - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) + layers = torch.nn.ModuleList() for layer_idx in range(num_layers): - + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) - - self_att = Gemma2Attention( + + # Since `nn.SPDA` doesn't support SoftCapping, soft capping is skipped + self_att = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, @@ -135,12 +137,16 @@ def gemma2( kv_cache=None, max_seq_len=max_seq_len, attn_dropout=attn_dropout, - # perform sliding window on half of the layers only - sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None, - softcapping=hidden_capping_value, - query_pre_attn_scalar=query_pre_attn_scalar + scale=(query_pre_attn_scalar or head_dim) ** -0.5, + ) + # Sliding window is applied on half of the layers only + # Currently returns a Tensor Mask so FlashAttention is not used + mask_mod = ( + partial(get_sliding_attention_mask, sliding_window_size=sliding_window_size) + if (layer_idx % 2) == 0 + else None ) - + layer = TransformerSelfAttentionLayer( attn=self_att, mlp=mlp, @@ -148,6 +154,7 @@ def gemma2( mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + mask_mod=mask_mod, ) layers.append(layer) @@ -165,7 +172,6 @@ def gemma2( return model - def lora_gemma2( lora_attn_modules: list[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, @@ -182,8 +188,8 @@ def lora_gemma2( attn_dropout: float = 0.0, norm_eps: float = 1e-6, rope_base: int = 10_000, - hidden_capping_value: float = 50., - final_capping_value: float = 30., + hidden_capping_value: float = 50.0, + final_capping_value: float = 30.0, sliding_window_size: int = 4096, query_pre_attn_scalar: Optional[int] = None, # LoRA args @@ -232,7 +238,6 @@ def lora_gemma2( tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) output_proj = TiedLinear(tok_embeddings) - layers = nn.ModuleList() for layer_idx in range(num_layers): if apply_lora_to_mlp: @@ -246,7 +251,9 @@ def lora_gemma2( quantize_base=quantize_base, ) else: - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + mlp = gemma_mlp( + dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base + ) self_att = lora_gemma2_self_attention( lora_modules=lora_attn_modules, embed_dim=embed_dim, @@ -257,7 +264,7 @@ def lora_gemma2( max_seq_len=max_seq_len, attn_dropout=attn_dropout, # perform sliding window on half of the layers only - sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None, + sliding_window_size=sliding_window_size if (layer_idx % 2) == 0 else None, softcapping=hidden_capping_value, query_pre_attn_scalar=query_pre_attn_scalar, lora_rank=lora_rank, @@ -266,7 +273,14 @@ def lora_gemma2( use_dora=use_dora, quantize_base=quantize_base, ) - + # Sliding window is applied on half of the layers only + # Currently returns a Tensor Mask so FlashAttention is not used + mask_mod = ( + partial(get_sliding_attention_mask, sliding_window_size=sliding_window_size) + if (layer_idx % 2) == 0 + else None + ) + layer = TransformerSelfAttentionLayer( attn=self_att, mlp=mlp, @@ -274,9 +288,10 @@ def lora_gemma2( mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + mask_mod=mask_mod, ) layers.append(layer) - + model = TransformerDecoder( tok_embeddings=tok_embeddings, layers=layers, @@ -284,7 +299,7 @@ def lora_gemma2( num_heads=num_heads, output=output_proj, head_dim=head_dim, - norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps) + norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps), ) if quantize_base: @@ -292,7 +307,9 @@ def lora_gemma2( # so as to not increase peak memory # TODO this is clowny, figure out a better way to get what precision the rest # of the model is in - _register_reparametrize_state_dict_hooks(model, dtype=tok_embeddings.weight.dtype) + _register_reparametrize_state_dict_hooks( + model, dtype=tok_embeddings.weight.dtype + ) return model @@ -317,8 +334,7 @@ def lora_gemma2_self_attention( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, - -) -> Gemma2Attention: +) -> MultiHeadAttention: if not lora_modules: raise ValueError( f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" @@ -392,23 +408,24 @@ def lora_gemma2_self_attention( ) ) - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) - - self_att = Gemma2Attention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=q_proj, - k_proj=k_proj, - v_proj=v_proj, - output_proj=output_proj, - pos_embeddings=rope, - kv_cache=None, - max_seq_len=max_seq_len, - attn_dropout=attn_dropout, - sliding_window_size=sliding_window_size, - softcapping=softcapping, - query_pre_attn_scalar=query_pre_attn_scalar - ) - return self_att \ No newline at end of file + rope = RotaryPositionalEmbeddings( + dim=head_dim, max_seq_len=max_seq_len, base=rope_base + ) + + # Since `nn.SPDA` doesn't support SoftCapping, soft capping is skipped + self_att = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + scale=(query_pre_attn_scalar or head_dim) ** -0.5, + ) + return self_att diff --git a/torchtune/models/llama4/_chunked_attention.py b/torchtune/models/llama4/_chunked_attention.py index 8559c514db..295bc6d11c 100644 --- a/torchtune/models/llama4/_chunked_attention.py +++ b/torchtune/models/llama4/_chunked_attention.py @@ -22,6 +22,8 @@ def get_chunked_attention_mask( chunk_size: int, bsz: int, seq_len: int, + # Unused, but listed for consistency + device: Optional[torch.device] = None, ) -> _MaskType: """ """ # TODO: check this somewhere that doesn't get called every forward diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 62e4227b57..e1cffaf3d2 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional +from typing import Callable, Optional import torch from torch import nn @@ -71,6 +71,8 @@ class MultiHeadAttention(nn.Module): is_causal (bool): sets the default mask to causal when no mask is provided attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. Default value is 0.0. + scale (Optional[float]): Optional arg, passed to attention implementations to modify the scores after + query-key multiplication before the softmax. Default is None. Raises: ValueError: @@ -98,6 +100,7 @@ def __init__( max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0, + scale: Optional[float] = None, ) -> None: super().__init__() if num_heads % num_kv_heads != 0: @@ -139,6 +142,9 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() + # Set attention arguments + self.scale = scale + # this flag indicates whether to update the kv-cache during forward # passes. when disabled, we can have the cache setup but still # perform normal forward passes @@ -294,6 +300,7 @@ def forward( k, v, mask=mask, + scale=self.scale, dropout_p=self.attn_dropout if self.training else 0.0, is_causal=self.kv_cache is None and mask is None and self.is_causal, ) diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py index 87714d8494..ebabd63fca 100644 --- a/torchtune/modules/attention_utils.py +++ b/torchtune/modules/attention_utils.py @@ -53,8 +53,9 @@ def compile_friendly_flex_attention( k: torch.Tensor, v: torch.Tensor, block_mask: BlockMask, + scale: Optional[float] = None, ) -> torch.Tensor: - return flex_attention_compiled(q, k, v, block_mask=block_mask) + return flex_attention_compiled(q, k, v, block_mask=block_mask, scale=scale) _MaskType = Union[torch.Tensor, BlockMask] else: @@ -200,6 +201,7 @@ def _sdpa_call( mask: Optional[_MaskType], dropout_p: float, is_causal: bool, + scale: Optional[float] = None, ) -> torch.Tensor: # shape: [b, 1, s, s] if mask is not None: @@ -207,7 +209,13 @@ def _sdpa_call( # Flash attention from https://pytorch.org/blog/accelerating-large-language-models/ return nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=dropout_p, is_causal=is_causal + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, ) if not _SUPPORTS_FLEX_ATTENTION: @@ -221,6 +229,7 @@ def _attention_call( mask: Optional[_MaskType], dropout_p: float, is_causal: bool, + scale: Optional[float] = None, ) -> torch.Tensor: # Flex attention uses the BlockMask # (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168) @@ -244,10 +253,11 @@ def _attention_call( k, v, block_mask=mask, + scale=scale, ) else: # If mask is a standard boolean tensor or None, then use SDPA - return _sdpa_call(q, k, v, mask, dropout_p, is_causal) + return _sdpa_call(q, k, v, mask, dropout_p, is_causal, scale) return _attention_call diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 724138b14e..614d66b939 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -128,7 +128,7 @@ def forward( if self.mask_mod is not None: # With TP we need to use a replicated tensor here bsz, seq_len, *_ = h.shape - mask = self.mask_mod(mask=mask, bsz=bsz, seq_len=seq_len) + mask = self.mask_mod(mask=mask, bsz=bsz, seq_len=seq_len, device=h.device) attn_out = self.attn(h, h, mask=mask, input_pos=input_pos) # Residual connection; shape: [batch_size, seq_length, embed_dim] h = self.sa_scale(attn_out) + x