Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 57 additions & 12 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
from ...configuration_utils import PreTrainedConfig
from ...generation import GenerationMixin
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...masking_utils import (
blockwise_overlay,
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
sliding_window_overlay,
)
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
Expand Down Expand Up @@ -709,6 +715,46 @@ def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch.
return block_sequence_ids


def create_masks_for_vision_model(
config: PreTrainedConfig,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | None,
past_key_values: Cache | None,
position_ids: torch.Tensor | None,
block_sequence_ids: torch.Tensor,
) -> dict:
"""Create full_attention and sliding_attention masks with correct composition.

For global (full attention) layers: OR(causal, blockwise)
For local (sliding window) layers: AND(sliding_window, OR(causal, blockwise))
"""
mask_kwargs = {
"config": config,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}

# Full attention: OR(causal, blockwise) — use block_sequence_ids directly
full_mask = create_causal_mask(**mask_kwargs, block_sequence_ids=block_sequence_ids)

# Sliding attention: AND(sliding_window, OR(causal, blockwise))
# Pass blockwise as or_mask_function (applied as step 2 in create_causal_mask)
# Pass sliding_window as and_mask_function (applied as step 3, after OR)
# Do NOT pass block_sequence_ids (to avoid the incorrect step 4 final OR)
sliding_mask = create_causal_mask(
**mask_kwargs,
or_mask_function=blockwise_overlay(block_sequence_ids),
and_mask_function=sliding_window_overlay(config.sliding_window),
)
Comment on lines +746 to +750

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing or/and functions will trigger torch.vmap path, which might add some overhead. Gemma models are the only ones using blockwise_overlay for sliding attention

I am seeing we updated all except for paligemma, does that mean paligemma-2 needs to keep full-mask in prefix for sliding layers?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking that we might change the internal create_sliding_window_mask if the pattern is used only by gemma models and in the same way


return {
"full_attention": full_mask,
"sliding_attention": sliding_mask,
}


@auto_docstring(
custom_intro="""
The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
Expand Down Expand Up @@ -841,16 +887,13 @@ def forward(
}

if token_type_ids is not None:
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
token_type_ids, device=inputs_embeds.device
block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device)
causal_mask_mapping = create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)

# Create the masks
sliding_mask_kwargs = mask_kwargs.copy()
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
else:
causal_mask_mapping = create_masks_for_generate(**mask_kwargs)

outputs = self.language_model(
attention_mask=causal_mask_mapping,
Expand Down Expand Up @@ -1064,8 +1107,10 @@ def create_masks_for_generate(
}

if token_type_ids is not None:
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
token_type_ids, device=inputs_embeds.device
block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device)
return create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)

return create_masks_for_generate(**mask_kwargs)
Expand Down
69 changes: 57 additions & 12 deletions src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from ... import initialization as init
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from ...masking_utils import (
blockwise_overlay,
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
sliding_window_overlay,
)
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import (
Expand Down Expand Up @@ -611,6 +617,46 @@ def get_block_sequence_ids_for_mask(token_type_ids: torch.Tensor, device: torch.
return block_sequence_ids


def create_masks_for_vision_model(
config: PreTrainedConfig,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | None,
past_key_values: Cache | None,
position_ids: torch.Tensor | None,
block_sequence_ids: torch.Tensor,
) -> dict:
"""Create full_attention and sliding_attention masks with correct composition.

For global (full attention) layers: OR(causal, blockwise)
For local (sliding window) layers: AND(sliding_window, OR(causal, blockwise))
"""
mask_kwargs = {
"config": config,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}

# Full attention: OR(causal, blockwise) — use block_sequence_ids directly
full_mask = create_causal_mask(**mask_kwargs, block_sequence_ids=block_sequence_ids)

# Sliding attention: AND(sliding_window, OR(causal, blockwise))
# Pass blockwise as or_mask_function (applied as step 2 in create_causal_mask)
# Pass sliding_window as and_mask_function (applied as step 3, after OR)
# Do NOT pass block_sequence_ids (to avoid the incorrect step 4 final OR)
sliding_mask = create_causal_mask(
**mask_kwargs,
or_mask_function=blockwise_overlay(block_sequence_ids),
and_mask_function=sliding_window_overlay(config.sliding_window),
)

return {
"full_attention": full_mask,
"sliding_attention": sliding_mask,
}


class Gemma3Model(PaliGemmaModel):
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
accepts_loss_kwargs = False
Expand Down Expand Up @@ -679,16 +725,13 @@ def forward(
}

if token_type_ids is not None:
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
token_type_ids, device=inputs_embeds.device
block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device)
causal_mask_mapping = create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)

# Create the masks
sliding_mask_kwargs = mask_kwargs.copy()
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
else:
causal_mask_mapping = create_masks_for_generate(**mask_kwargs)

outputs = self.language_model(
attention_mask=causal_mask_mapping,
Expand Down Expand Up @@ -885,8 +928,10 @@ def create_masks_for_generate(
}

if token_type_ids is not None:
mask_kwargs["block_sequence_ids"] = get_block_sequence_ids_for_mask(
token_type_ids, device=inputs_embeds.device
block_sequence_ids = get_block_sequence_ids_for_mask(token_type_ids, device=inputs_embeds.device)
return create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)

return create_masks_for_generate(**mask_kwargs)
Expand Down
81 changes: 62 additions & 19 deletions src/transformers/models/gemma4/modeling_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
from ...generation import GenerationMixin
from ...integrations import use_experts_implementation
from ...masking_utils import (
blockwise_overlay,
create_bidirectional_mask,
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
sliding_window_overlay,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -2103,6 +2105,46 @@ def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
return self.embedding_projection(embs_normed)


def create_masks_for_vision_model(
config: PreTrainedConfig,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | None,
past_key_values: Cache | None,
position_ids: torch.Tensor | None,
block_sequence_ids: torch.Tensor,
) -> dict:
"""Create full_attention and sliding_attention masks with correct composition.

For global (full attention) layers: OR(causal, blockwise)
For local (sliding window) layers: AND(sliding_window, OR(causal, blockwise))
"""
mask_kwargs = {
"config": config,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}

# Full attention: OR(causal, blockwise) — use block_sequence_ids directly
full_mask = create_causal_mask(**mask_kwargs, block_sequence_ids=block_sequence_ids)

# Sliding attention: AND(sliding_window, OR(causal, blockwise))
# Pass blockwise as or_mask_function (applied as step 2 in create_causal_mask)
# Pass sliding_window as and_mask_function (applied as step 3, after OR)
# Do NOT pass block_sequence_ids (to avoid the incorrect step 4 final OR)
sliding_mask = create_causal_mask(
**mask_kwargs,
or_mask_function=blockwise_overlay(block_sequence_ids),
and_mask_function=sliding_window_overlay(config.sliding_window),
)

return {
"full_attention": full_mask,
"sliding_attention": sliding_mask,
}


def get_block_sequence_ids_for_mask(mm_token_type_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
mm_token_type_ids = mm_token_type_ids.to(device)

Expand Down Expand Up @@ -2346,19 +2388,19 @@ def forward(
"position_ids": position_ids,
}

# Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
# Smaller Gemma models use a conventional casual attention mask
if self.config.get_text_config().use_bidirectional_attention == "vision":
block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device)
if mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(
mm_token_type_ids, device=inputs_embeds.device
)
text_config = self.config.get_text_config()
use_bidir = text_config.use_bidirectional_attention == "vision"

mask_kwargs["block_sequence_ids"] = block_sequence_ids

# Create the masks
causal_mask_mapping = create_masks_for_generate(**mask_kwargs)
if use_bidir and mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device)
causal_mask_mapping = create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)
else:
# Smaller Gemma models (use_bidirectional_attention=None) or
# text-only inputs use standard causal masking
causal_mask_mapping = create_masks_for_generate(**mask_kwargs)

outputs = self.language_model(
per_layer_inputs=per_layer_inputs,
Expand Down Expand Up @@ -2623,14 +2665,15 @@ def create_masks_for_generate(
"position_ids": position_ids,
}

# Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
# Smaller Gemma models use a conventional casual attention mask
if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision":
block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device)
if mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device)
text_config = config.get_text_config()
use_bidir = getattr(text_config, "use_bidirectional_attention", None) == "vision"

mask_kwargs["block_sequence_ids"] = block_sequence_ids
if use_bidir and mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device)
return create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)

return create_masks_for_generate(**mask_kwargs)

Expand Down
40 changes: 21 additions & 19 deletions src/transformers/models/gemma4/modular_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
Gemma3RotaryEmbedding,
Gemma3TextModel,
Gemma3TextScaledWordEmbedding,
create_masks_for_vision_model, # noqa: F811
)
from ..gemma3n.modeling_gemma3n import (
Gemma3nCausalLMOutputWithPast,
Expand Down Expand Up @@ -2060,19 +2061,19 @@ def forward(
"position_ids": position_ids,
}

# Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
# Smaller Gemma models use a conventional casual attention mask
if self.config.get_text_config().use_bidirectional_attention == "vision":
block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device)
if mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(
mm_token_type_ids, device=inputs_embeds.device
)

mask_kwargs["block_sequence_ids"] = block_sequence_ids
text_config = self.config.get_text_config()
use_bidir = text_config.use_bidirectional_attention == "vision"

# Create the masks
causal_mask_mapping = create_masks_for_generate(**mask_kwargs)
if use_bidir and mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device)
causal_mask_mapping = create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)
else:
# Smaller Gemma models (use_bidirectional_attention=None) or
# text-only inputs use standard causal masking
causal_mask_mapping = create_masks_for_generate(**mask_kwargs)

outputs = self.language_model(
per_layer_inputs=per_layer_inputs,
Expand Down Expand Up @@ -2249,14 +2250,15 @@ def create_masks_for_generate(
"position_ids": position_ids,
}

# Larger Gemma 4 models use Gemma 3's bidirectional attention mask for vision inputs
# Smaller Gemma models use a conventional casual attention mask
if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision":
block_sequence_ids = torch.full([*inputs_embeds.size()[:-1]], -1, device=inputs_embeds.device)
if mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device)
text_config = config.get_text_config()
use_bidir = getattr(text_config, "use_bidirectional_attention", None) == "vision"

mask_kwargs["block_sequence_ids"] = block_sequence_ids
if use_bidir and mm_token_type_ids is not None:
block_sequence_ids = get_block_sequence_ids_for_mask(mm_token_type_ids, device=inputs_embeds.device)
return create_masks_for_vision_model(
block_sequence_ids=block_sequence_ids,
**mask_kwargs,
)

return create_masks_for_generate(**mask_kwargs)

Expand Down
Loading
Loading