From 114b3a3d08dd3c8c6f117a4a3b41909b605470f0 Mon Sep 17 00:00:00 2001 From: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Date: Mon, 22 Jun 2026 21:23:22 +0000 Subject: [PATCH 1/4] Fix bidirectional attention masking crossing sliding window boundaries in Gemma 3/4 --- .../models/gemma3/modeling_gemma3.py | 69 +++++++++++++--- .../models/gemma3/modular_gemma3.py | 69 +++++++++++++--- .../models/gemma4/modeling_gemma4.py | 81 ++++++++++++++----- .../models/gemma4/modular_gemma4.py | 46 ++++++----- tests/models/gemma3/test_modeling_gemma3.py | 48 +++++++++++ tests/models/gemma4/test_modeling_gemma4.py | 48 +++++++++++ 6 files changed, 301 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 227c6a2b51d8..0e32b4d8ccc0 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -31,7 +31,13 @@ from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin from ...integrations import use_kernel_func_from_hub, use_kernelized_func -from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask +from ...masking_utils import ( + blockwise_overlay, + create_causal_mask, + create_masks_for_generate, + create_sliding_window_causal_mask, + sliding_window_overlay, +) from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -709,6 +715,46 @@ def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch. return block_sequence_ids +def create_masks_for_vision_model( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + block_sequence_ids: torch.Tensor, +) -> dict: + """Create full_attention and sliding_attention masks with correct composition. + + For global (full attention) layers: OR(causal, blockwise) + For local (sliding window) layers: AND(sliding_window, OR(causal, blockwise)) + """ + mask_kwargs = { + "config": config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + + # Full attention: OR(causal, blockwise) — use block_sequence_ids directly + full_mask = create_causal_mask(**mask_kwargs, block_sequence_ids=block_sequence_ids) + + # Sliding attention: AND(sliding_window, OR(causal, blockwise)) + # Pass blockwise as or_mask_function (applied as step 2 in create_causal_mask) + # Pass sliding_window as and_mask_function (applied as step 3, after OR) + # Do NOT pass block_sequence_ids (to avoid the incorrect step 4 final OR) + sliding_mask = create_causal_mask( + **mask_kwargs, + or_mask_function=blockwise_overlay(block_sequence_ids), + and_mask_function=sliding_window_overlay(config.sliding_window), + ) + + return { + "full_attention": full_mask, + "sliding_attention": sliding_mask, + } + + @auto_docstring( custom_intro=""" The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head., @@ -841,16 +887,13 @@ def forward( } if token_type_ids is not None: - mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask( - token_type_ids, device=inputs_embeds.device + block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device) + causal_mask_mapping = create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, ) - - # Create the masks - sliding_mask_kwargs = mask_kwargs.copy() - causal_mask_mapping = { - "full_attention": create_causal_mask(**mask_kwargs), - "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), - } + else: + causal_mask_mapping = create_masks_for_generate(**mask_kwargs) outputs = self.language_model( attention_mask=causal_mask_mapping, @@ -1064,8 +1107,10 @@ def create_masks_for_generate( } if token_type_ids is not None: - mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask( - token_type_ids, device=inputs_embeds.device + block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device) + return create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, ) return create_masks_for_generate(**mask_kwargs) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 82afbc884c65..08aa81ea57eb 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -22,7 +22,13 @@ from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig -from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask +from ...masking_utils import ( + blockwise_overlay, + create_causal_mask, + create_masks_for_generate, + create_sliding_window_causal_mask, + sliding_window_overlay, +) from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ( @@ -611,6 +617,46 @@ def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch. return block_sequence_ids +def create_masks_for_vision_model( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + block_sequence_ids: torch.Tensor, +) -> dict: + """Create full_attention and sliding_attention masks with correct composition. + + For global (full attention) layers: OR(causal, blockwise) + For local (sliding window) layers: AND(sliding_window, OR(causal, blockwise)) + """ + mask_kwargs = { + "config": config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + + # Full attention: OR(causal, blockwise) — use block_sequence_ids directly + full_mask = create_causal_mask(**mask_kwargs, block_sequence_ids=block_sequence_ids) + + # Sliding attention: AND(sliding_window, OR(causal, blockwise)) + # Pass blockwise as or_mask_function (applied as step 2 in create_causal_mask) + # Pass sliding_window as and_mask_function (applied as step 3, after OR) + # Do NOT pass block_sequence_ids (to avoid the incorrect step 4 final OR) + sliding_mask = create_causal_mask( + **mask_kwargs, + or_mask_function=blockwise_overlay(block_sequence_ids), + and_mask_function=sliding_window_overlay(config.sliding_window), + ) + + return { + "full_attention": full_mask, + "sliding_attention": sliding_mask, + } + + class Gemma3Model(PaliGemmaModel): # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch accepts_loss_kwargs = False @@ -679,16 +725,15 @@ def forward( } if token_type_ids is not None: - mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask( + block_sequence_ids = get_block_sequence_ids_for_mask( token_type_ids, device=inputs_embeds.device ) - - # Create the masks - sliding_mask_kwargs = mask_kwargs.copy() - causal_mask_mapping = { - "full_attention": create_causal_mask(**mask_kwargs), - "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), - } + causal_mask_mapping = create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, + ) + else: + causal_mask_mapping = create_masks_for_generate(**mask_kwargs) outputs = self.language_model( attention_mask=causal_mask_mapping, @@ -885,9 +930,13 @@ def create_masks_for_generate( } if token_type_ids is not None: - mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask( + block_sequence_ids = get_block_sequence_ids_for_mask( token_type_ids, device=inputs_embeds.device ) + return create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, + ) return create_masks_for_generate(**mask_kwargs) diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index 73887286dc22..6224ac00b870 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -36,10 +36,12 @@ from ...generation import GenerationMixin from ...integrations import use_experts_implementation from ...masking_utils import ( + blockwise_overlay, create_bidirectional_mask, create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask, + sliding_window_overlay, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -2103,6 +2105,46 @@ def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: return self.embedding_projection(embs_normed) +def create_masks_for_vision_model( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + block_sequence_ids: torch.Tensor, +) -> dict: + """Create full_attention and sliding_attention masks with correct composition. + + For global (full attention) layers: OR(causal, blockwise) + For local (sliding window) layers: AND(sliding_window, OR(causal, blockwise)) + """ + mask_kwargs = { + "config": config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + + # Full attention: OR(causal, blockwise) — use block_sequence_ids directly + full_mask = create_causal_mask(**mask_kwargs, block_sequence_ids=block_sequence_ids) + + # Sliding attention: AND(sliding_window, OR(causal, blockwise)) + # Pass blockwise as or_mask_function (applied as step 2 in create_causal_mask) + # Pass sliding_window as and_mask_function (applied as step 3, after OR) + # Do NOT pass block_sequence_ids (to avoid the incorrect step 4 final OR) + sliding_mask = create_causal_mask( + **mask_kwargs, + or_mask_function=blockwise_overlay(block_sequence_ids), + and_mask_function=sliding_window_overlay(config.sliding_window), + ) + + return { + "full_attention": full_mask, + "sliding_attention": sliding_mask, + } + + def get_block_sequence_ids_for_mask(mm_token_type_ids: torch.Tensor, device: torch.device) -> torch.Tensor: mm_token_type_ids = mm_token_type_ids.to(device) @@ -2346,19 +2388,19 @@ def forward( "position_ids": position_ids, } - # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs - # Smaller Gemma models use a conventional casual attention mask - if self.config.get_text_config().use_bidirectional_attention == "vision": - block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device) - if mm_token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask( - mm_token_type_ids, device=inputs_embeds.device - ) + text_config = self.config.get_text_config() + use_bidir = text_config.use_bidirectional_attention == "vision" - mask_kwargs["block_sequence_ids"] = block_sequence_ids - - # Create the masks - causal_mask_mapping = create_masks_for_generate(**mask_kwargs) + if use_bidir and mm_token_type_ids is not None: + block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) + causal_mask_mapping = create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, + ) + else: + # Smaller Gemma models (use_bidirectional_attention=None) or + # text-only inputs use standard causal masking + causal_mask_mapping = create_masks_for_generate(**mask_kwargs) outputs = self.language_model( per_layer_inputs=per_layer_inputs, @@ -2623,14 +2665,15 @@ def create_masks_for_generate( "position_ids": position_ids, } - # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs - # Smaller Gemma models use a conventional casual attention mask - if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision": - block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device) - if mm_token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) + text_config = config.get_text_config() + use_bidir = getattr(text_config, "use_bidirectional_attention", None) == "vision" - mask_kwargs["block_sequence_ids"] = block_sequence_ids + if use_bidir and mm_token_type_ids is not None: + block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) + return create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, + ) return create_masks_for_generate(**mask_kwargs) diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index 34f0b4f082ab..f39d8292e517 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -27,10 +27,12 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...masking_utils import ( + blockwise_overlay, create_bidirectional_mask, create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask, + sliding_window_overlay, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling @@ -57,6 +59,7 @@ Gemma3TextModel, Gemma3TextScaledWordEmbedding, ) +from ..gemma3.modeling_gemma3 import create_masks_for_vision_model # noqa: F811 from ..gemma3n.modeling_gemma3n import ( Gemma3nCausalLMOutputWithPast, Gemma3nForConditionalGeneration, @@ -2060,19 +2063,21 @@ def forward( "position_ids": position_ids, } - # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs - # Smaller Gemma models use a conventional casual attention mask - if self.config.get_text_config().use_bidirectional_attention == "vision": - block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device) - if mm_token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask( - mm_token_type_ids, device=inputs_embeds.device - ) - - mask_kwargs["block_sequence_ids"] = block_sequence_ids + text_config = self.config.get_text_config() + use_bidir = text_config.use_bidirectional_attention == "vision" - # Create the masks - causal_mask_mapping = create_masks_for_generate(**mask_kwargs) + if use_bidir and mm_token_type_ids is not None: + block_sequence_ids = get_block_sequence_ids_for_mask( + mm_token_type_ids, device=inputs_embeds.device + ) + causal_mask_mapping = create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, + ) + else: + # Smaller Gemma models (use_bidirectional_attention=None) or + # text-only inputs use standard causal masking + causal_mask_mapping = create_masks_for_generate(**mask_kwargs) outputs = self.language_model( per_layer_inputs=per_layer_inputs, @@ -2249,14 +2254,17 @@ def create_masks_for_generate( "position_ids": position_ids, } - # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs - # Smaller Gemma models use a conventional casual attention mask - if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision": - block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device) - if mm_token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) + text_config = config.get_text_config() + use_bidir = getattr(text_config, "use_bidirectional_attention", None) == "vision" - mask_kwargs["block_sequence_ids"] = block_sequence_ids + if use_bidir and mm_token_type_ids is not None: + block_sequence_ids = get_block_sequence_ids_for_mask( + mm_token_type_ids, device=inputs_embeds.device + ) + return create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, + ) return create_masks_for_generate(**mask_kwargs) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index bb34be7e92e5..b7b0f4df3750 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -427,6 +427,54 @@ def test_flash_attn_3_from_config(self): def test_flash_attn_4_from_config(self): self.flash_attn_from_config(attn_implementation="flash_attention_4", test_fwd_in_train=False) + def test_attention_mask_composition(self): + from transformers.models.gemma3.modeling_gemma3 import create_masks_for_vision_model + + config = self.model_tester.get_config() + config.text_config._attn_implementation = "eager" + + # Override sliding window to a known small value to test truncation + sliding_window = 4 + config.text_config.sliding_window = sliding_window + + # Create a sequence of 13 tokens: 0..4 text, 5..11 image (7 tokens), 12 text + # block_sequence_ids maps image tokens to group 0, and text tokens to -1 + block_sequence_ids = torch.tensor([[-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, -1]], dtype=torch.long) + attention_mask = torch.ones((1, 13), dtype=torch.bool) + position_ids = torch.arange(13).unsqueeze(0) + inputs_embeds = torch.randn(1, 13, config.text_config.hidden_size) + + mask_dict = create_masks_for_vision_model( + config=config.text_config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=None, + position_ids=position_ids, + block_sequence_ids=block_sequence_ids, + ) + + full_mask = mask_dict["full_attention"] + sliding_mask = mask_dict["sliding_attention"] + + min_val = torch.finfo(full_mask.dtype).min + + # In full_attention, bidirectional image block (tokens 5-11) is fully visible + # (Both look-back and look-ahead are 0.0) + self.assertEqual(full_mask[0, 0, 5, 11].item(), 0.0) + self.assertEqual(full_mask[0, 0, 11, 5].item(), 0.0) + + # In sliding_attention, look-back within the sliding window is visible bidirectionally + # Token 8 looking back at 5 (dist 3 < 4) -> VISIBLE + self.assertEqual(sliding_mask[0, 0, 8, 5].item(), 0.0) + + # In sliding_attention, look-back outside the sliding window is strictly masked + # Token 11 looking back at 5 (dist 6 > 4) -> MASKED + self.assertLess(sliding_mask[0, 0, 11, 5].item(), -1000) + + # Verify that causal masking still applies correctly to text + # Token 12 (text) looking ahead to Token 11 (image) is masked + self.assertLess(full_mask[0, 0, 11, 12].item(), -1000) + @slow @require_torch_accelerator diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py index c9e89c52f42c..2cf16dcf544f 100644 --- a/tests/models/gemma4/test_modeling_gemma4.py +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -601,6 +601,54 @@ def count_calls(*args, **kwargs): _ = model(inputs_embeds=inputs_embeds) self.assertEqual(counter["call_count"], 1) + def test_attention_mask_composition(self): + from transformers.models.gemma4.modeling_gemma4 import create_masks_for_vision_model + + config = self.model_tester.get_config() + config.text_config._attn_implementation = "eager" + + # Override sliding window to a known small value to test truncation + sliding_window = 4 + config.text_config.sliding_window = sliding_window + + # Create a sequence of 13 tokens: 0..4 text, 5..11 image (7 tokens), 12 text + # block_sequence_ids maps image tokens to group 0, and text tokens to -1 + block_sequence_ids = torch.tensor([[-1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, -1]], dtype=torch.long) + attention_mask = torch.ones((1, 13), dtype=torch.bool) + position_ids = torch.arange(13).unsqueeze(0) + inputs_embeds = torch.randn(1, 13, config.text_config.hidden_size) + + mask_dict = create_masks_for_vision_model( + config=config.text_config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=None, + position_ids=position_ids, + block_sequence_ids=block_sequence_ids, + ) + + full_mask = mask_dict["full_attention"] + sliding_mask = mask_dict["sliding_attention"] + + min_val = torch.finfo(full_mask.dtype).min + + # In full_attention, bidirectional image block (tokens 5-11) is fully visible + # (Both look-back and look-ahead are 0.0) + self.assertEqual(full_mask[0, 0, 5, 11].item(), 0.0) + self.assertEqual(full_mask[0, 0, 11, 5].item(), 0.0) + + # In sliding_attention, look-back within the sliding window is visible bidirectionally + # Token 8 looking back at 5 (dist 3 < 4) -> VISIBLE + self.assertEqual(sliding_mask[0, 0, 8, 5].item(), 0.0) + + # In sliding_attention, look-back outside the sliding window is strictly masked + # Token 11 looking back at 5 (dist 6 > 4) -> MASKED + self.assertLess(sliding_mask[0, 0, 11, 5].item(), -1000) + + # Verify that causal masking still applies correctly to text + # Token 12 (text) looking ahead to Token 11 (image) is masked + self.assertLess(full_mask[0, 0, 11, 12].item(), -1000) + @slow @require_torch_accelerator From 77876736b8c9243220a8f1b7a51ff22fa465a4c8 Mon Sep 17 00:00:00 2001 From: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Date: Tue, 23 Jun 2026 22:34:26 +0000 Subject: [PATCH 2/4] run make fix-repo --- src/transformers/models/gemma3/modular_gemma3.py | 8 ++------ src/transformers/models/gemma4/modular_gemma4.py | 12 +++--------- tests/models/gemma3/test_modeling_gemma3.py | 2 +- tests/models/gemma4/test_modeling_gemma4.py | 2 +- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 08aa81ea57eb..abd4bc67a13c 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -725,9 +725,7 @@ def forward( } if token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask( - token_type_ids, device=inputs_embeds.device - ) + block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device) causal_mask_mapping = create_masks_for_vision_model( block_sequence_ids=block_sequence_ids, **mask_kwargs, @@ -930,9 +928,7 @@ def create_masks_for_generate( } if token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask( - token_type_ids, device=inputs_embeds.device - ) + block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device) return create_masks_for_vision_model( block_sequence_ids=block_sequence_ids, **mask_kwargs, diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index f39d8292e517..754a7a52a9ff 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -27,12 +27,10 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...masking_utils import ( - blockwise_overlay, create_bidirectional_mask, create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask, - sliding_window_overlay, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling @@ -58,8 +56,8 @@ Gemma3RotaryEmbedding, Gemma3TextModel, Gemma3TextScaledWordEmbedding, + create_masks_for_vision_model, # noqa: F811 ) -from ..gemma3.modeling_gemma3 import create_masks_for_vision_model # noqa: F811 from ..gemma3n.modeling_gemma3n import ( Gemma3nCausalLMOutputWithPast, Gemma3nForConditionalGeneration, @@ -2067,9 +2065,7 @@ def forward( use_bidir = text_config.use_bidirectional_attention == "vision" if use_bidir and mm_token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask( - mm_token_type_ids, device=inputs_embeds.device - ) + block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) causal_mask_mapping = create_masks_for_vision_model( block_sequence_ids=block_sequence_ids, **mask_kwargs, @@ -2258,9 +2254,7 @@ def create_masks_for_generate( use_bidir = getattr(text_config, "use_bidirectional_attention", None) == "vision" if use_bidir and mm_token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask( - mm_token_type_ids, device=inputs_embeds.device - ) + block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) return create_masks_for_vision_model( block_sequence_ids=block_sequence_ids, **mask_kwargs, diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index b7b0f4df3750..18727f2d25b4 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -432,7 +432,7 @@ def test_attention_mask_composition(self): config = self.model_tester.get_config() config.text_config._attn_implementation = "eager" - + # Override sliding window to a known small value to test truncation sliding_window = 4 config.text_config.sliding_window = sliding_window diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py index 2cf16dcf544f..3a944bf9e41e 100644 --- a/tests/models/gemma4/test_modeling_gemma4.py +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -606,7 +606,7 @@ def test_attention_mask_composition(self): config = self.model_tester.get_config() config.text_config._attn_implementation = "eager" - + # Override sliding window to a known small value to test truncation sliding_window = 4 config.text_config.sliding_window = sliding_window From 77908536a1cd179ddfc12124c2c61baf6ca066a3 Mon Sep 17 00:00:00 2001 From: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Date: Tue, 23 Jun 2026 22:50:02 +0000 Subject: [PATCH 3/4] remove unused vars in test --- tests/models/gemma3/test_modeling_gemma3.py | 2 -- tests/models/gemma4/test_modeling_gemma4.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 18727f2d25b4..3f7066847427 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -456,8 +456,6 @@ def test_attention_mask_composition(self): full_mask = mask_dict["full_attention"] sliding_mask = mask_dict["sliding_attention"] - min_val = torch.finfo(full_mask.dtype).min - # In full_attention, bidirectional image block (tokens 5-11) is fully visible # (Both look-back and look-ahead are 0.0) self.assertEqual(full_mask[0, 0, 5, 11].item(), 0.0) diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py index 3a944bf9e41e..e1b290800474 100644 --- a/tests/models/gemma4/test_modeling_gemma4.py +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -630,8 +630,6 @@ def test_attention_mask_composition(self): full_mask = mask_dict["full_attention"] sliding_mask = mask_dict["sliding_attention"] - min_val = torch.finfo(full_mask.dtype).min - # In full_attention, bidirectional image block (tokens 5-11) is fully visible # (Both look-back and look-ahead are 0.0) self.assertEqual(full_mask[0, 0, 5, 11].item(), 0.0) From 3e6f07da6db0191a132773d4de1e831ef3fa1472 Mon Sep 17 00:00:00 2001 From: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Date: Tue, 23 Jun 2026 23:02:39 +0000 Subject: [PATCH 4/4] update unified model as well --- .../gemma4_unified/modeling_gemma4_unified.py | 63 ++++++++++++++++--- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma4_unified/modeling_gemma4_unified.py b/src/transformers/models/gemma4_unified/modeling_gemma4_unified.py index c520b997cb06..7d4e05416508 100644 --- a/src/transformers/models/gemma4_unified/modeling_gemma4_unified.py +++ b/src/transformers/models/gemma4_unified/modeling_gemma4_unified.py @@ -30,7 +30,13 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin -from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask +from ...masking_utils import ( + blockwise_overlay, + create_causal_mask, + create_masks_for_generate, + create_sliding_window_causal_mask, + sliding_window_overlay, +) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput @@ -1188,6 +1194,46 @@ def get_video_features( return self.get_image_features(pixel_values_videos, video_position_ids, **kwargs) +def create_masks_for_vision_model( + config: PreTrainedConfig, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None, + position_ids: torch.Tensor | None, + block_sequence_ids: torch.Tensor, +) -> dict: + """Create full_attention and sliding_attention masks with correct composition. + + For global (full attention) layers: OR(causal, blockwise) + For local (sliding window) layers: AND(sliding_window, OR(causal, blockwise)) + """ + mask_kwargs = { + "config": config, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + + # Full attention: OR(causal, blockwise) — use block_sequence_ids directly + full_mask = create_causal_mask(**mask_kwargs, block_sequence_ids=block_sequence_ids) + + # Sliding attention: AND(sliding_window, OR(causal, blockwise)) + # Pass blockwise as or_mask_function (applied as step 2 in create_causal_mask) + # Pass sliding_window as and_mask_function (applied as step 3, after OR) + # Do NOT pass block_sequence_ids (to avoid the incorrect step 4 final OR) + sliding_mask = create_causal_mask( + **mask_kwargs, + or_mask_function=blockwise_overlay(block_sequence_ids), + and_mask_function=sliding_window_overlay(config.sliding_window), + ) + + return { + "full_attention": full_mask, + "sliding_attention": sliding_mask, + } + + @auto_docstring( custom_intro=""" The base Gemma 4 model comprising a vision backbone, an audio backbone, a language model, and a language modeling @@ -1356,14 +1402,15 @@ def create_masks_for_generate( "position_ids": position_ids, } - # Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs - # Smaller Gemma models use a conventional casual attention mask - if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision": - block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device) - if mm_token_type_ids is not None: - block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) + text_config = config.get_text_config() + use_bidir = getattr(text_config, "use_bidirectional_attention", None) == "vision" - mask_kwargs["block_sequence_ids"] = block_sequence_ids + if use_bidir and mm_token_type_ids is not None: + block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device) + return create_masks_for_vision_model( + block_sequence_ids=block_sequence_ids, + **mask_kwargs, + ) return create_masks_for_generate(**mask_kwargs)