diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3921898051d9..ea3f05c14d5d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1235,6 +1235,8 @@ title: GlmOcr - local: model_doc/got_ocr2 title: GOT-OCR2 + - local: model_doc/granite4_vision + title: Granite4Vision - local: model_doc/granitevision title: GraniteVision - local: model_doc/grounding-dino diff --git a/docs/source/en/model_doc/granite4_vision.md b/docs/source/en/model_doc/granite4_vision.md new file mode 100644 index 000000000000..da5090c251ba --- /dev/null +++ b/docs/source/en/model_doc/granite4_vision.md @@ -0,0 +1,185 @@ + +*This model was released on 2026-03-27 and added to Hugging Face Transformers on 2026-05-03.* + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +# Granite4Vision + +[Granite Vision 4.1](https://huggingface.co/ibm-granite/granite-vision-4.1-4b) is a vision-language model from IBM Research designed for enterprise-grade document data extraction. It specializes in chart extraction (Chart2CSV, Chart2Summary, Chart2Code), table extraction (JSON, HTML, OTSL), and semantic key-value pair extraction. + +The model builds on [LLaVA-NeXT](llava_next) with several architectural innovations: + +1. **SigLIP2 Vision Encoder** ([`google/siglip2-so400m-patch16-384`](https://huggingface.co/google/siglip2-so400m-patch16-384)): images are tiled into 384x384 patches. +2. **Window Q-Former Projectors**: compress visual features 4x using windowed cross-attention over 4x4 patch windows into 2x2 tokens. +3. **DeepStack Feature Injection** with 8 vision-to-LLM injection points: + - *LayerDeepstack*: features from 4 vision encoder depths are projected into different early LLM layers. + - *SpatialDeepstack*: deepest vision features are split into 4 spatial groups and injected at later LLM layers. +4. **Language Model**: [Granite 4.1](https://huggingface.co/ibm-granite/granite-4.1-4b-base) (4B params) with LoRA adapters (rank 256) across all self-attention and MLP layers. + +The model is delivered as a LoRA adapter on top of the base LLM, enabling single deployments to support both multimodal and text-only workloads. Total parameter count is ~4B. + +> [!TIP] +> This model was contributed by the [IBM Granite Vision Team](https://github.com/ibm-granite). + +## Usage Tips + +- Set `padding_side="left"` during batched generation for more accurate results. + +```py +processor.tokenizer.padding_side = "left" +``` + +- The model supports specialized task tags for document extraction: ``, ``, ``, ``, ``, ``. Pass these as the text prompt along with a document image. + +- For key-value pair extraction, provide a JSON schema describing the fields to extract. The model returns structured JSON matching the schema. + +The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class. + + + + + +```python +from transformers import pipeline + +pipe = pipeline( + task="image-text-to-text", + model="ibm-granite/granite-vision-4.1-4b", +) +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + {"type": "text", "text": "Describe this image."}, + ], + } +] +pipe(text=messages, max_new_tokens=100, return_full_text=False) +``` + + + + + +```python +import torch +from transformers import AutoProcessor, AutoModelForImageTextToText + +model_id = "ibm-granite/granite-vision-4.1-4b" + +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForImageTextToText.from_pretrained(model_id) + +conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + {"type": "text", "text": "Describe this image."}, + ], + }, +] +inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +output = model.generate(**inputs, max_new_tokens=100) +print(processor.decode(output[0], skip_special_tokens=True)) +``` + + + + + +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. + +The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4. + +```python +import torch +from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig + +quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_type="nf4", +) + +model_id = "ibm-granite/granite-vision-4.1-4b" +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForImageTextToText.from_pretrained( + model_id, quantization_config=quant_config, device_map="auto" +) + +conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + {"type": "text", "text": "Describe this image."}, + ], + }, +] +inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +output = model.generate(**inputs, max_new_tokens=100) +print(processor.decode(output[0], skip_special_tokens=True)) +``` + +## Granite4VisionConfig + +[[autodoc]] Granite4VisionConfig + +## Granite4VisionTextConfig + +[[autodoc]] Granite4VisionTextConfig + +## Granite4VisionProcessor + +[[autodoc]] Granite4VisionProcessor + - __call__ + +## Granite4VisionModel + +[[autodoc]] Granite4VisionModel + +## Granite4VisionTextModel + +[[autodoc]] Granite4VisionTextModel + +## Granite4VisionForConditionalGeneration + +[[autodoc]] Granite4VisionForConditionalGeneration + - forward + - get_image_features diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7fd1bbacf90b..887ad30183c3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -178,6 +178,7 @@ from .gpt_sw3 import * from .gptj import * from .granite import * + from .granite4_vision import * from .granite_speech import * from .granite_speech_plus import * from .granitemoe import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 0b2d7dd79167..53755949645b 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -235,6 +235,7 @@ ("gpt_oss", "GptOssConfig"), ("gptj", "GPTJConfig"), ("granite", "GraniteConfig"), + ("granite4_vision", "Granite4VisionConfig"), ("granite_speech", "GraniteSpeechConfig"), ("granite_speech_encoder", "GraniteSpeechEncoderConfig"), ("granite_speech_plus", "GraniteSpeechPlusConfig"), @@ -892,6 +893,7 @@ ("glm_image", {"pil": "GlmImageImageProcessorPil", "torchvision": "GlmImageImageProcessor"}), ("glpn", {"pil": "GLPNImageProcessorPil", "torchvision": "GLPNImageProcessor"}), ("got_ocr2", {"pil": "GotOcr2ImageProcessorPil", "torchvision": "GotOcr2ImageProcessor"}), + ("granite4_vision", {"pil": "LlavaNextImageProcessorPil", "torchvision": "LlavaNextImageProcessor"}), ("grounding-dino", {"pil": "GroundingDinoImageProcessorPil", "torchvision": "GroundingDinoImageProcessor"}), ("idefics", {"pil": "IdeficsImageProcessorPil", "torchvision": "IdeficsImageProcessor"}), ("idefics2", {"pil": "Idefics2ImageProcessorPil", "torchvision": "Idefics2ImageProcessor"}), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 98447b6d1724..c1a46661aecf 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -88,6 +88,7 @@ ("focalnet", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), ("gemma3n", {"torchvision": "SiglipImageProcessor", "pil": "SiglipImageProcessorPil"}), ("git", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), + ("granite4_vision", {"torchvision": "LlavaNextImageProcessor", "pil": "LlavaNextImageProcessorPil"}), ("groupvit", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}), ("hiera", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), ("ijepa", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 67dc6cd64cce..5fb61fcf53b6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -211,6 +211,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gpt_oss", "GptOssModel"), ("gptj", "GPTJModel"), ("granite", "GraniteModel"), + ("granite4_vision", "Granite4VisionModel"), ("granite_speech", "GraniteSpeechForConditionalGeneration"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), @@ -995,6 +996,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("glm4v_moe", "Glm4vMoeForConditionalGeneration"), ("glm_ocr", "GlmOcrForConditionalGeneration"), ("got_ocr2", "GotOcr2ForConditionalGeneration"), + ("granite4_vision", "Granite4VisionForConditionalGeneration"), ("idefics", "IdeficsForVisionText2Text"), ("idefics2", "Idefics2ForConditionalGeneration"), ("idefics3", "Idefics3ForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 24008bd4263a..c6a7b3fd1296 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -89,6 +89,7 @@ ("glm_image", "Glm4vProcessor"), ("glmasr", "GlmAsrProcessor"), ("got_ocr2", "GotOcr2Processor"), + ("granite4_vision", "Granite4VisionProcessor"), ("granite_speech", "GraniteSpeechProcessor"), ("granite_speech_plus", "GraniteSpeechProcessor"), ("grounding-dino", "GroundingDinoProcessor"), diff --git a/src/transformers/models/granite4_vision/__init__.py b/src/transformers/models/granite4_vision/__init__.py new file mode 100644 index 000000000000..113694a1a26c --- /dev/null +++ b/src/transformers/models/granite4_vision/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 IBM. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_granite4_vision import * + from .modeling_granite4_vision import * + from .processing_granite4_vision import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/granite4_vision/configuration_granite4_vision.py b/src/transformers/models/granite4_vision/configuration_granite4_vision.py new file mode 100644 index 000000000000..9b98bd64e152 --- /dev/null +++ b/src/transformers/models/granite4_vision/configuration_granite4_vision.py @@ -0,0 +1,209 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granite4_vision/modular_granite4_vision.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granite4_vision.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring +from ..auto import CONFIG_MAPPING, AutoConfig + + +@auto_docstring(checkpoint="ibm-granite4_vision_text/granite4_vision_text-3.0-8b-base") +@strict +class Granite4VisionTextConfig(PreTrainedConfig): + r""" + ```python + >>> from transformers import Granite4VisionTextModel, Granite4VisionTextConfig + + >>> # Initializing a Granite4VisionText granite4_vision_text-3b style configuration + >>> configuration = Granite4VisionTextConfig() + + >>> # Initializing a model from the granite4_vision_text-7b style configuration + >>> model = Granite4VisionTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "granite4_vision_text" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Granite4VisionTextModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = None + hidden_act: str = "silu" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int | None = None + bos_token_id: int | None = 1 + eos_token_id: int | list[int] | None = 2 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + attention_bias: bool = False + attention_dropout: float | int = 0.0 + mlp_bias: bool = False + embedding_multiplier: float | int = 1.0 + logits_scaling: float | int = 1.0 + residual_multiplier: float | int = 1.0 + attention_multiplier: float | int = 1.0 + base_config_key = "text_config" + + def __post_init__(self, **kwargs): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + super().__post_init__(**kwargs) + + +@auto_docstring(checkpoint="llava-hf/llava-v1.6-mistral-7b-hf") +@strict +class Granite4VisionConfig(PreTrainedConfig): + r""" + downsample_rate (`str`, *optional*): + Fractional downsample rate for the Window Q-Former projector, e.g. `"1/4"` or `"3/8"`. + The numerator is the query window side, the denominator is the key window side. + deepstack_layer_map (`list`, *optional*): + List of `[vision_layer_idx, llm_layer_idx]` pairs. Features from each vision encoder layer + are projected and injected at the corresponding LLM decoder layer during forward pass. + use_spatial_sampling (`bool`, *optional*, defaults to `False`): + Whether to enable spatial offset sampling, which creates 4 groups (TL, TR, BL, BR) from + a single vision layer, each injected at a different LLM layer. + spatial_vision_layer (`int`, *optional*, defaults to `-1`): + Index of the vision encoder layer used for spatial sampling. + spatial_target_layers (`list`, *optional*, defaults to `[12, 15, 18, 21]`): + Target LLM layers for the 4 spatial offset groups. + projector_dropout (`float`, *optional*, defaults to `0.1`): + Dropout probability in the Window Q-Former projector. + qformer_config (`dict` or `Blip2QFormerConfig`, *optional*): + Configuration for the Window Q-Former projector. If `None`, defaults are derived from + `vision_config.hidden_size`. + image_grid_pinpoints (`list`, *optional*): + A list of possible resolutions to use for processing high resolution images. Each item in the list should be a + tuple or list of the form `(height, width)`. + """ + + model_type = "granite4_vision" + attribute_map = {"image_token_id": "image_token_index"} + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "qformer_config": AutoConfig} + + vision_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + image_token_index: int = 32000 + vision_feature_select_strategy: Literal["default", "full"] = "default" + vision_feature_layer: int | list[int] = -2 + tie_word_embeddings: bool = False + image_grid_pinpoints: list | None = None + image_seq_length: int = 576 + + downsample_rate: str | None = None + deepstack_layer_map: list | None = None + use_spatial_sampling: bool = False + spatial_vision_layer: int = -1 + spatial_target_layers: list | None = None + projector_dropout: float = 0.1 + qformer_config: dict | PreTrainedConfig | None = None + + def __post_init__(self, **kwargs): + if self.deepstack_layer_map is not None: + self.deepstack_layer_map = [(int(v), int(l)) for v, l in self.deepstack_layer_map] + + if self.spatial_target_layers is None: + self.spatial_target_layers = [12, 15, 18, 21] + + # Peek at vision hidden_size before super() to build a fully-specified qformer_config, + # avoiding any runtime field patching after super(). + if isinstance(self.vision_config, dict): + vision_hidden_size = self.vision_config.get("hidden_size", 1152) + elif self.vision_config is not None: + vision_hidden_size = self.vision_config.hidden_size + else: + vision_hidden_size = 1152 + + # Convert qformer_config dict → object before super() so _attn_implementation.setter + # (called inside super().__post_init__) sees a config object, not a raw dict. + if isinstance(self.qformer_config, dict): + model_type = self.qformer_config.get("model_type", "blip_2_qformer") + self.qformer_config = CONFIG_MAPPING[model_type](**self.qformer_config) + elif self.qformer_config is None: + self.qformer_config = CONFIG_MAPPING["blip_2_qformer"]( + num_hidden_layers=1, + intermediate_size=3072, + cross_attention_frequency=1, + max_position_embeddings=2048, + use_qformer_text_input=False, + hidden_size=vision_hidden_size, + num_attention_heads=vision_hidden_size // 64, + encoder_hidden_size=vision_hidden_size, + ) + if isinstance(self.vision_config, dict): + self.vision_config["model_type"] = self.vision_config.get("model_type", "clip_vision_model") + self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config) + elif self.vision_config is None: + self.vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + if isinstance(self.text_config, dict): + self.text_config["model_type"] = self.text_config.get("model_type", "llama") + self.text_config = CONFIG_MAPPING[self.text_config["model_type"]](**self.text_config) + elif self.text_config is None: + self.text_config = CONFIG_MAPPING["llama"]() + + self.image_grid_pinpoints = ( + self.image_grid_pinpoints + if self.image_grid_pinpoints is not None + else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + ) + + super().__post_init__(**kwargs) + + +__all__ = ["Granite4VisionConfig", "Granite4VisionTextConfig"] diff --git a/src/transformers/models/granite4_vision/modeling_granite4_vision.py b/src/transformers/models/granite4_vision/modeling_granite4_vision.py new file mode 100644 index 000000000000..2e5919df5588 --- /dev/null +++ b/src/transformers/models/granite4_vision/modeling_granite4_vision.py @@ -0,0 +1,1290 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granite4_vision/modular_granite4_vision.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granite4_vision.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from fractions import Fraction +from typing import Optional + +import numpy as np +import torch +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...image_processing_utils import select_best_resolution +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from ..auto import AutoModel +from .configuration_granite4_vision import Granite4VisionConfig, Granite4VisionTextConfig + + +@dataclass +class Granite4VisionModelOutputWithPast(BaseModelOutputWithPast): + """ + Args: + deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*): + List of `(llm_layer_idx, packed_features)` pairs produced by the deepstack + and spatial projectors. Each entry targets one LLM decoder layer; `packed_features` + is a per-image list of tensors of shape `(num_image_tokens, hidden_size)`. + """ + + image_hidden_states: torch.FloatTensor | None = None + + deepstack_features: list | None = None + + +@dataclass +class Granite4VisionCausalLMOutputWithPast(ModelOutput): + """ + Args: + deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*): + List of `(llm_layer_idx, packed_features)` pairs. See `Granite4VisionModelOutputWithPast`. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None + + deepstack_features: list | None = None + + +@dataclass +class Granite4VisionImageFeaturesOutput(BaseModelOutputWithPooling): + """ + Output of `Granite4VisionModel.get_image_features`. + + Args: + deepstack_features (`list[tuple[int, list[torch.Tensor]]]`): + List of `(llm_layer_idx, packed_features)` pairs. Each entry targets one LLM + decoder layer; `packed_features` is a per-image list of tensors of shape + `(num_image_tokens, hidden_size)`. + """ + + deepstack_features: list | None = None + + +# ── Downsampling helpers ───────────────────────────────────────────────────── + + +def interpolate_downsample(image_features: torch.Tensor, orig_side: int, new_side: int) -> torch.Tensor: + """Spatial downsampling via area interpolation.""" + batch, _, channels = image_features.size() + spatial = image_features.view(batch, orig_side, orig_side, channels).permute(0, 3, 1, 2) + spatial = torch.nn.functional.interpolate(spatial, size=(new_side, new_side), mode="area") + return spatial.permute(0, 2, 3, 1).flatten(1, 2) + + +def spatial_offset_downsample(image_features: torch.Tensor, orig_side: int, offset: int = 0) -> torch.Tensor: + """Sample one position from each 2x2 block; offset selects which corner (0=TL,1=TR,2=BL,3=BR).""" + offset_h, offset_w = [(0, 0), (0, 1), (1, 0), (1, 1)][offset] + new_side = orig_side // 2 + batch, _, channels = image_features.shape + grid = image_features.reshape(batch, orig_side, orig_side, channels) + grid = grid.reshape(batch, new_side, 2, new_side, 2, channels) + return grid[:, :, offset_h, :, offset_w, :].reshape(batch, -1, channels) + + +class Granite4VisionWindowQFormerDownsampler(nn.Module): + """Window-based QFormer downsampler that processes image patches in windows.""" + + def __init__(self, config, spatial_offset=None): + super().__init__() + llm_hidden_size = config.text_config.hidden_size + vision_hidden_size = config.vision_config.hidden_size + + self.dropout = nn.Dropout(config.projector_dropout) + self._spatial_offset = spatial_offset + self._downsample_rate = config.downsample_rate + + self.qformer = AutoModel.from_config(config.qformer_config) + + self.image_side = config.vision_config.image_size // config.vision_config.patch_size + query_side_str, window_side_str = config.downsample_rate.split("/") + self.query_side, self.window_side = int(query_side_str), int(window_side_str) + self.query_length = self.query_side**2 + self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6) + self.query = nn.Parameter(torch.empty(1, self.query_length, vision_hidden_size)) + self.image_positions = nn.Parameter(torch.empty(1, self.window_side**2, vision_hidden_size)) + self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True) + + def _windowed_raster(self, features, side, window_size): + """(B, side*side, C) raster -> (B*num_win*num_win, window_size*window_size, C)""" + batch, _, channels = features.shape + num_win = side // window_size + return ( + features.view(batch, side, side, channels) + .view(batch, num_win, window_size, num_win, window_size, channels) + .transpose(2, 3) + .flatten(0, 2) + .flatten(1, 2) + ) + + def _unwindowed_raster(self, windowed_features, num_win, window_size): + """(B*num_win*num_win, window_size*window_size, C) -> (B, (num_win*window_size)^2, C)""" + batch_win, _, channels = windowed_features.shape + if batch_win % (num_win * num_win) != 0: + raise ValueError(f"Expected batch_win ({batch_win}) to be divisible by num_win^2 ({num_win**2}).") + batch = batch_win // (num_win * num_win) + side = num_win * window_size + return ( + windowed_features.view(batch, num_win, num_win, window_size, window_size, channels) + .transpose(2, 3) + .contiguous() + .view(batch, side, side, channels) + .flatten(1, 2) + ) + + def forward(self, image_features): + batch, hw, channels = image_features.shape + if self.image_side * self.image_side != hw: + raise ValueError( + f"Expected image_features with {self.image_side**2} spatial tokens, got {hw}. " + "Check that the vision encoder image_size and patch_size match the config." + ) + num_windows = self.image_side // self.window_side + interp_side = int(self.image_side * Fraction(self._downsample_rate)) + image_features = self.norm(image_features) + enc = self._windowed_raster(image_features, self.image_side, self.window_side) + + if self._spatial_offset is not None: + downsampled = spatial_offset_downsample(image_features, self.image_side, self._spatial_offset) + else: + downsampled = interpolate_downsample(image_features, self.image_side, interp_side) + + downsampled_side = num_windows * self.query_side + downsampled_w = self._windowed_raster(downsampled, downsampled_side, self.query_side) + + query_embeds = self.query + downsampled_w + encoder_embeds = self.dropout(enc + self.image_positions) + out_w = self.qformer( + query_embeds=query_embeds, + encoder_hidden_states=encoder_embeds, + return_dict=True, + ).last_hidden_state + + out = self._unwindowed_raster(out_w, num_win=num_windows, window_size=self.query_side) + out = self.dropout(out) + return self.out_linear(out) + + +class Granite4VisionTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Granite4VisionTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: Granite4VisionTextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +@use_kernelized_func(apply_rotary_pos_emb) +class Granite4VisionTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Granite4VisionTextConfig, layer_idx: int | None = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.attention_multiplier + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class Granite4VisionTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Granite4VisionTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Granite4VisionTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Granite4VisionTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Granite4VisionTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Granite4VisionTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = Granite4VisionTextMLP(config) + self.input_layernorm = Granite4VisionTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Granite4VisionTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +@auto_docstring +class Granite4VisionPreTrainedModel(PreTrainedModel): + config: Granite4VisionConfig + base_model_prefix = "model" + input_modalities = ("image", "text") + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Granite4VisionTextDecoderLayer, + "attentions": Granite4VisionTextAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, Granite4VisionModel): + embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) + init.normal_(module.image_newline, mean=0.0, std=embed_std) + if isinstance(module, nn.Embedding): + init.normal_( + module.weight, + mean=0.0, + std=getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range), + ) + if module.padding_idx is not None: + init.zeros_(module.weight[module.padding_idx]) + elif isinstance(module, nn.LayerNorm) or ( + hasattr(module, "weight") and hasattr(module, "variance_epsilon") and not isinstance(module, nn.Linear) + ): + init.ones_(module.weight) + if hasattr(module, "bias") and module.bias is not None: + init.zeros_(module.bias) + if isinstance(module, Granite4VisionTextRotaryEmbedding): + # Non-persistent buffers (inv_freq, original_inv_freq) are replaced with + # torch.empty_like() garbage by _move_missing_keys_from_meta_to_device. + # Recompute them here so _initialize_missing_keys restores correct values. + rope_type = module.config.rope_parameters.get("rope_type", "default") + if rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + else: + rope_init_fn = module.compute_default_rope_parameters + inv_freq, attention_scaling = rope_init_fn(module.config, module.inv_freq.device) + init.copy_(module.inv_freq, inv_freq) + init.copy_(module.original_inv_freq, inv_freq) + module.attention_scaling = attention_scaling + if isinstance(module, Granite4VisionWindowQFormerDownsampler): + embed_std = 1 / math.sqrt(module.query.shape[-1]) + init.normal_(module.query, mean=0.0, std=embed_std) + init.normal_(module.image_positions, mean=0.0, std=embed_std) + + def _deepstack_inject( + self, + hidden_states: torch.Tensor, + vision_mask: torch.Tensor, + features: torch.Tensor, + ) -> torch.Tensor: + """Add projected vision features into the image-token positions of hidden_states.""" + vision_mask = vision_mask.to(hidden_states.device) + features = features.to(hidden_states.device, hidden_states.dtype) + return hidden_states.masked_scatter( + vision_mask, + (hidden_states[vision_mask] + features.flatten()).view(-1), + ) + + +@auto_docstring +class Granite4VisionTextModel(Granite4VisionPreTrainedModel): + """Granite LLM backbone with deepstack feature injection support.""" + + base_model_prefix = "model" + _no_split_modules = ["Granite4VisionTextDecoderLayer"] + + def __init__(self, config: Granite4VisionTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Granite4VisionTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Granite4VisionTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Granite4VisionTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.embedding_multiplier = config.embedding_multiplier + + # Initialize weights and apply final processing + self.post_init() + + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + vision_mask: torch.BoolTensor | None = None, + deepstack_features: dict | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + r""" + vision_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Boolean mask marking image token positions. Required when `deepstack_features` is provided. + deepstack_features (`dict[int, torch.Tensor]`, *optional*): + Mapping from LLM layer index to projected vision features of shape `(num_image_tokens, hidden_size)`. + Features are added into image-token positions of hidden states before the corresponding decoder layer. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + ).unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if deepstack_features is not None and layer_idx in deepstack_features: + hidden_states = self._deepstack_inject(hidden_states, vision_mask, deepstack_features[layer_idx]) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """ + Calculate the number of patches after the preprocessing for images of any resolution. + + Args: + image_size (`torch.LongTensor` or `np.ndarray` or `tuple[int, int]`): + The size of the input image in the format (height, width). ? + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + int: the number of patches + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + if not isinstance(original_size, (list, tuple)): + if not isinstance(original_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +@auto_docstring( + custom_intro=""" + The Llava-Next model which consists of a vision backbone and a language model without language modeling head. + """ +) +class Granite4VisionModel(Granite4VisionPreTrainedModel): + base_model_prefix = "model" + config_class = Granite4VisionConfig + + def __init__(self, config: Granite4VisionConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + embed_std = 1 / math.sqrt(config.text_config.hidden_size) + self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) + + self.vocab_size = config.text_config.vocab_size + + # Replace the inherited LLM backbone with our deepstack-aware subclass + self.language_model = Granite4VisionTextModel(config.text_config) + + self.spatial_projectors = None + self.downsample_rate = config.downsample_rate + self.projector_dropout = config.projector_dropout + + # Deepstack projectors: one per (vision_layer, llm_layer) pair + self.layerwise_projectors = nn.ModuleList( + [Granite4VisionWindowQFormerDownsampler(config) for _ in range(len(config.deepstack_layer_map))] + ) + + # Spatial sampling projectors: 4 offset groups (TL, TR, BL, BR) + if config.use_spatial_sampling: + self.spatial_projectors = nn.ModuleList( + [Granite4VisionWindowQFormerDownsampler(config, spatial_offset=i) for i in range(4)] + ) + + self.pad_token_id = ( + self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1 + ) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + + Overrides the parent to apply downsample_rate to height/width calculations. + """ + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + if self.layerwise_projectors is not None: + ds_rate = Fraction(self.downsample_rate) + height = int(height * ds_rate) + width = int(width * ds_rate) + + if ( + np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 + and vision_feature_select_strategy == "default" + ): + raise ValueError( + "Image feature shape does not line up with the provided patch size. " + "You may be using the `default` vision_feature_select_strategy with a " + "visual encoder that does not have CLS token." + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring( + custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection." + ) + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int | list[int] | None = None, + vision_feature_select_strategy: str | None = None, + **kwargs, + ) -> Granite4VisionImageFeaturesOutput: + """ + Extract image features via deepstack (multi-layer) and spatial sampling projections. + + Runs the vision tower once, then: + 1. Deepstack: for each (vision_layer, llm_layer) in deepstack_layer_map, + extracts features from that vision layer, downsamples via interpolation + QFormer, + and pairs them with the target LLM layer. + 2. Spatial: if enabled, extracts the spatial_vision_layer and creates 4 spatial + offset groups (TL, TR, BL, BR), each targeting a different LLM layer. + """ + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + for imsize in image_sizes + ] + + if pixel_values.dim() == 5: + _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + output_hidden_states = kwargs.pop("output_hidden_states", None) + if output_hidden_states is None: + output_hidden_states = getattr(self.config, "output_hidden_states", False) + vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + + # Deepstack features: extract from multiple vision layers, downsample via interpolation + all_features = [] + for projection_idx, (vision_layer, llm_layer) in enumerate(self.config.deepstack_layer_map): + selected_feature = vision_outputs.hidden_states[vision_layer] + + if vision_feature_select_strategy == "default": + selected_feature = selected_feature[:, 1:] + + projected_features = self.layerwise_projectors[projection_idx](selected_feature) + projected_features = torch.split(projected_features, image_num_patches, dim=0) + + packed_features, _ = self.pack_image_features( + projected_features, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + all_features.append((llm_layer, packed_features)) + + # Spatial features: extract 4 offset groups from a single vision layer + if self.config.use_spatial_sampling: + spatial_feature = vision_outputs.hidden_states[self.config.spatial_vision_layer] + + if vision_feature_select_strategy == "default": + spatial_feature = spatial_feature[:, 1:] + + for group_idx, llm_layer in enumerate(self.config.spatial_target_layers): + projected_group = self.spatial_projectors[group_idx](spatial_feature) + projected_group_split = torch.split(projected_group, image_num_patches, dim=0) + + packed_group, _ = self.pack_image_features( + projected_group_split, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + all_features.append((llm_layer, packed_group)) + + return Granite4VisionImageFeaturesOutput( + deepstack_features=all_features, + hidden_states=vision_outputs.hidden_states if output_hidden_states else None, + ) + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_image_mask].numel() == image_features.numel(), + f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", + ) + return special_image_mask + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_sizes: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + vision_feature_layer: int | list[int] | None = None, + vision_feature_select_strategy: str | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Granite4VisionModelOutputWithPast: + r""" + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # Build deepstack injection map and scatter initial image embeddings + deepstack_features = None + vision_mask = None + image_features = None + if pixel_values is not None and pixel_values.size(0) > 0: + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + deepstack_features = {} + for idx, (llm_layer_idx, packed_features) in enumerate(image_features.deepstack_features): + concat_features = torch.cat(packed_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + if idx == 0: + vision_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=concat_features + ) + inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0) + deepstack_features[llm_layer_idx] = concat_features + + # Bypass nn.Module.__call__ overhead by calling the unbound forward directly. + # nn.Module.__call__ has non-trivial per-call overhead that accumulates across 40 layers × N steps. + outputs = Granite4VisionTextModel.forward( + self.language_model, + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + vision_mask=vision_mask, + deepstack_features=deepstack_features, + **kwargs, + ) + + return Granite4VisionModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + deepstack_features=image_features.deepstack_features if pixel_values is not None else None, + ) + + +@auto_docstring( + custom_intro=""" + The LLAVA-NeXT model which consists of a vision backbone and a language model. + """ +) +class Granite4VisionForConditionalGeneration(Granite4VisionPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + config_class = Granite4VisionConfig + + def __init__(self, config: Granite4VisionConfig): + super().__init__(config) + self.model = Granite4VisionModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + return self.model.pack_image_features( + image_features=image_features, + image_sizes=image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=image_newline, + ) + + @merge_with_config_defaults + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int | list[int] | list[int] | None = None, + vision_feature_select_strategy: str | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + """ + return self.model.get_image_features( + pixel_values=pixel_values, + image_sizes=image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + **kwargs, + ) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_sizes: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + vision_feature_layer: int | list[int] | None = None, + vision_feature_select_strategy: str | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Granite4VisionCausalLMOutputWithPast: + r""" + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import httpx + >>> from io import BytesIO + >>> from transformers import AutoProcessor, Granite4VisionForConditionalGeneration + + >>> model = Granite4VisionForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + + >>> prompt = "[INST] \nWhat is shown in this image? [/INST]" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)" + ```""" + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + outputs = self.model( + input_ids, + pixel_values=pixel_values, + image_sizes=image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=True, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + loss = None + logits = self.lm_head(hidden_states) + logits = logits / self.config.text_config.logits_scaling + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.text_config.vocab_size, + **kwargs, + ) + + if isinstance(logits_to_keep, int) and logits_to_keep > 0: + logits = logits[:, -logits_to_keep:, :] + + return Granite4VisionCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + deepstack_features=outputs.deepstack_features, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + logits_to_keep=None, + is_first_iteration=False, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + logits_to_keep=logits_to_keep, + is_first_iteration=is_first_iteration, + **kwargs, + ) + + # Pixel values are used only in the first iteration if available + # In subsequent iterations, they are already merged with text and cached + # NOTE: first iteration doesn't have to be prefill, it can be the first + # iteration with a question and cached system prompt (continue generate from cache) + if is_first_iteration or not kwargs.get("use_cache", True): + model_inputs["pixel_values"] = pixel_values + model_inputs["image_sizes"] = image_sizes + + return model_inputs + + +__all__ = [ + "Granite4VisionPreTrainedModel", + "Granite4VisionTextModel", + "Granite4VisionModel", + "Granite4VisionForConditionalGeneration", +] diff --git a/src/transformers/models/granite4_vision/modular_granite4_vision.py b/src/transformers/models/granite4_vision/modular_granite4_vision.py new file mode 100644 index 000000000000..3559a3298ce3 --- /dev/null +++ b/src/transformers/models/granite4_vision/modular_granite4_vision.py @@ -0,0 +1,830 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from fractions import Fraction + +import numpy as np +import torch +from torch import nn + +from ... import initialization as init +from ...cache_utils import Cache +from ...configuration_utils import PreTrainedConfig +from ...image_processing_utils import select_best_resolution +from ...masking_utils import create_causal_mask +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.output_capturing import capture_outputs +from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel +from ..granite.configuration_granite import GraniteConfig +from ..granite.modeling_granite import GraniteAttention, GraniteDecoderLayer, GraniteModel, GraniteRotaryEmbedding +from ..llava_next.configuration_llava_next import LlavaNextConfig +from ..llava_next.modeling_llava_next import ( + LlavaNextCausalLMOutputWithPast, + LlavaNextForConditionalGeneration, + LlavaNextModel, + LlavaNextModelOutputWithPast, + LlavaNextPreTrainedModel, + get_anyres_image_grid_shape, + image_size_to_num_patches, + unpad_image, +) +from ..llava_next.processing_llava_next import LlavaNextProcessor + + +# ── Output classes ────────────────────────────────────────────────────────── + + +@dataclass +class Granite4VisionModelOutputWithPast(LlavaNextModelOutputWithPast): + """ + Args: + deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*): + List of `(llm_layer_idx, packed_features)` pairs produced by the deepstack + and spatial projectors. Each entry targets one LLM decoder layer; `packed_features` + is a per-image list of tensors of shape `(num_image_tokens, hidden_size)`. + """ + + deepstack_features: list | None = None + + +@dataclass +class Granite4VisionCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast): + """ + Args: + deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*): + List of `(llm_layer_idx, packed_features)` pairs. See `Granite4VisionModelOutputWithPast`. + """ + + deepstack_features: list | None = None + + +@dataclass +class Granite4VisionImageFeaturesOutput(BaseModelOutputWithPooling): + """ + Output of `Granite4VisionModel.get_image_features`. + + Args: + deepstack_features (`list[tuple[int, list[torch.Tensor]]]`): + List of `(llm_layer_idx, packed_features)` pairs. Each entry targets one LLM + decoder layer; `packed_features` is a per-image list of tensors of shape + `(num_image_tokens, hidden_size)`. + """ + + deepstack_features: list | None = None + + +# ── Config ────────────────────────────────────────────────────────────────── + + +class Granite4VisionTextConfig(GraniteConfig): + model_type = "granite4_vision_text" + base_config_key = "text_config" + + +class Granite4VisionConfig(LlavaNextConfig): + r""" + downsample_rate (`str`, *optional*): + Fractional downsample rate for the Window Q-Former projector, e.g. `"1/4"` or `"3/8"`. + The numerator is the query window side, the denominator is the key window side. + deepstack_layer_map (`list`, *optional*): + List of `[vision_layer_idx, llm_layer_idx]` pairs. Features from each vision encoder layer + are projected and injected at the corresponding LLM decoder layer during forward pass. + use_spatial_sampling (`bool`, *optional*, defaults to `False`): + Whether to enable spatial offset sampling, which creates 4 groups (TL, TR, BL, BR) from + a single vision layer, each injected at a different LLM layer. + spatial_vision_layer (`int`, *optional*, defaults to `-1`): + Index of the vision encoder layer used for spatial sampling. + spatial_target_layers (`list`, *optional*, defaults to `[12, 15, 18, 21]`): + Target LLM layers for the 4 spatial offset groups. + projector_dropout (`float`, *optional*, defaults to `0.1`): + Dropout probability in the Window Q-Former projector. + qformer_config (`dict` or `Blip2QFormerConfig`, *optional*): + Configuration for the Window Q-Former projector. If `None`, defaults are derived from + `vision_config.hidden_size`. + image_grid_pinpoints (`list`, *optional*): + A list of possible resolutions to use for processing high resolution images. Each item in the list should be a + tuple or list of the form `(height, width)`. + """ + + model_type = "granite4_vision" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "qformer_config": AutoConfig} + + multimodal_projector_bias = AttributeError() + projector_hidden_act = AttributeError() + + downsample_rate: str | None = None + deepstack_layer_map: list | None = None + use_spatial_sampling: bool = False + spatial_vision_layer: int = -1 + spatial_target_layers: list | None = None + projector_dropout: float = 0.1 + qformer_config: dict | PreTrainedConfig | None = None + + def __post_init__(self, **kwargs): + if self.deepstack_layer_map is not None: + self.deepstack_layer_map = [(int(v), int(l)) for v, l in self.deepstack_layer_map] + + if self.spatial_target_layers is None: + self.spatial_target_layers = [12, 15, 18, 21] + + # Peek at vision hidden_size before super() to build a fully-specified qformer_config, + # avoiding any runtime field patching after super(). + if isinstance(self.vision_config, dict): + vision_hidden_size = self.vision_config.get("hidden_size", 1152) + elif self.vision_config is not None: + vision_hidden_size = self.vision_config.hidden_size + else: + vision_hidden_size = 1152 + + # Convert qformer_config dict → object before super() so _attn_implementation.setter + # (called inside super().__post_init__) sees a config object, not a raw dict. + if isinstance(self.qformer_config, dict): + model_type = self.qformer_config.get("model_type", "blip_2_qformer") + self.qformer_config = CONFIG_MAPPING[model_type](**self.qformer_config) + elif self.qformer_config is None: + self.qformer_config = CONFIG_MAPPING["blip_2_qformer"]( + num_hidden_layers=1, + intermediate_size=3072, + cross_attention_frequency=1, + max_position_embeddings=2048, + use_qformer_text_input=False, + hidden_size=vision_hidden_size, + num_attention_heads=vision_hidden_size // 64, + encoder_hidden_size=vision_hidden_size, + ) + + super().__post_init__(**kwargs) + + +# ── Processor ─────────────────────────────────────────────────────────────── + + +class Granite4VisionProcessor(LlavaNextProcessor): + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size=None, + vision_feature_select_strategy=None, + chat_template=None, + image_token="", + num_additional_image_tokens=0, + downsample_rate=None, + **kwargs, + ): + r""" + patch_size (`int`, *optional*): + Patch size from the vision tower. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Should be same as in model's config. + image_token (`str`, *optional*, defaults to `""`): + Special token used to denote image location. + num_additional_image_tokens (`int`, *optional*, defaults to `0`): + Number of additional tokens added to the image embeddings, such as CLS (+1). + downsample_rate (`str`, *optional*): + Fractional downsample rate (e.g. `"1/4"`), used to adjust the number of image tokens + when computing token counts for padding/truncation. + """ + super().__init__( + image_processor=image_processor, + tokenizer=tokenizer, + patch_size=patch_size, + vision_feature_select_strategy=vision_feature_select_strategy, + chat_template=chat_template, + image_token=image_token, + num_additional_image_tokens=num_additional_image_tokens, + ) + self.downsample_rate = downsample_rate + + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + image_grid_pinpoints = self.image_processor.image_grid_pinpoints + + height_best_resolution, width_best_resolution = select_best_resolution( + [orig_height, orig_width], image_grid_pinpoints + ) + scale_height, scale_width = height_best_resolution // height, width_best_resolution // width + + patches_height = height // self.patch_size + patches_width = width // self.patch_size + if self.downsample_rate is not None: + ds_rate = Fraction(self.downsample_rate) + patches_height = int(patches_height * ds_rate) + patches_width = int(patches_width * ds_rate) + + unpadded_features, newline_features = self._get_unpadded_features( + orig_height, orig_width, patches_height, patches_width, scale_height, scale_width + ) + base_features = patches_height * patches_width + self.num_additional_image_tokens + num_image_tokens = unpadded_features + newline_features + base_features + return num_image_tokens + + +# ── Downsampling helpers ───────────────────────────────────────────────────── + + +def interpolate_downsample(image_features: torch.Tensor, orig_side: int, new_side: int) -> torch.Tensor: + """Spatial downsampling via area interpolation.""" + batch, _, channels = image_features.size() + spatial = image_features.view(batch, orig_side, orig_side, channels).permute(0, 3, 1, 2) + spatial = torch.nn.functional.interpolate(spatial, size=(new_side, new_side), mode="area") + return spatial.permute(0, 2, 3, 1).flatten(1, 2) + + +def spatial_offset_downsample(image_features: torch.Tensor, orig_side: int, offset: int = 0) -> torch.Tensor: + """Sample one position from each 2x2 block; offset selects which corner (0=TL,1=TR,2=BL,3=BR).""" + offset_h, offset_w = [(0, 0), (0, 1), (1, 0), (1, 1)][offset] + new_side = orig_side // 2 + batch, _, channels = image_features.shape + grid = image_features.reshape(batch, orig_side, orig_side, channels) + grid = grid.reshape(batch, new_side, 2, new_side, 2, channels) + return grid[:, :, offset_h, :, offset_w, :].reshape(batch, -1, channels) + + +class Granite4VisionWindowQFormerDownsampler(nn.Module): + """Window-based QFormer downsampler that processes image patches in windows.""" + + def __init__(self, config, spatial_offset=None): + super().__init__() + llm_hidden_size = config.text_config.hidden_size + vision_hidden_size = config.vision_config.hidden_size + + self.dropout = nn.Dropout(config.projector_dropout) + self._spatial_offset = spatial_offset + self._downsample_rate = config.downsample_rate + + self.qformer = AutoModel.from_config(config.qformer_config) + + self.image_side = config.vision_config.image_size // config.vision_config.patch_size + query_side_str, window_side_str = config.downsample_rate.split("/") + self.query_side, self.window_side = int(query_side_str), int(window_side_str) + self.query_length = self.query_side**2 + self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6) + self.query = nn.Parameter(torch.empty(1, self.query_length, vision_hidden_size)) + self.image_positions = nn.Parameter(torch.empty(1, self.window_side**2, vision_hidden_size)) + self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True) + + def _windowed_raster(self, features, side, window_size): + """(B, side*side, C) raster -> (B*num_win*num_win, window_size*window_size, C)""" + batch, _, channels = features.shape + num_win = side // window_size + return ( + features.view(batch, side, side, channels) + .view(batch, num_win, window_size, num_win, window_size, channels) + .transpose(2, 3) + .flatten(0, 2) + .flatten(1, 2) + ) + + def _unwindowed_raster(self, windowed_features, num_win, window_size): + """(B*num_win*num_win, window_size*window_size, C) -> (B, (num_win*window_size)^2, C)""" + batch_win, _, channels = windowed_features.shape + if batch_win % (num_win * num_win) != 0: + raise ValueError(f"Expected batch_win ({batch_win}) to be divisible by num_win^2 ({num_win**2}).") + batch = batch_win // (num_win * num_win) + side = num_win * window_size + return ( + windowed_features.view(batch, num_win, num_win, window_size, window_size, channels) + .transpose(2, 3) + .contiguous() + .view(batch, side, side, channels) + .flatten(1, 2) + ) + + def forward(self, image_features): + batch, hw, channels = image_features.shape + if self.image_side * self.image_side != hw: + raise ValueError( + f"Expected image_features with {self.image_side**2} spatial tokens, got {hw}. " + "Check that the vision encoder image_size and patch_size match the config." + ) + num_windows = self.image_side // self.window_side + interp_side = int(self.image_side * Fraction(self._downsample_rate)) + image_features = self.norm(image_features) + enc = self._windowed_raster(image_features, self.image_side, self.window_side) + + if self._spatial_offset is not None: + downsampled = spatial_offset_downsample(image_features, self.image_side, self._spatial_offset) + else: + downsampled = interpolate_downsample(image_features, self.image_side, interp_side) + + downsampled_side = num_windows * self.query_side + downsampled_w = self._windowed_raster(downsampled, downsampled_side, self.query_side) + + query_embeds = self.query + downsampled_w + encoder_embeds = self.dropout(enc + self.image_positions) + out_w = self.qformer( + query_embeds=query_embeds, + encoder_hidden_states=encoder_embeds, + return_dict=True, + ).last_hidden_state + + out = self._unwindowed_raster(out_w, num_win=num_windows, window_size=self.query_side) + out = self.dropout(out) + return self.out_linear(out) + + +# ── Model ─────────────────────────────────────────────────────────────────── + + +class Granite4VisionTextRotaryEmbedding(GraniteRotaryEmbedding): + pass + + +class Granite4VisionTextAttention(GraniteAttention): + pass + + +class Granite4VisionTextDecoderLayer(GraniteDecoderLayer): + pass + + +class Granite4VisionPreTrainedModel(LlavaNextPreTrainedModel): + _can_record_outputs = { + "hidden_states": Granite4VisionTextDecoderLayer, + "attentions": Granite4VisionTextAttention, + } + + def _deepstack_inject( + self, + hidden_states: torch.Tensor, + vision_mask: torch.Tensor, + features: torch.Tensor, + ) -> torch.Tensor: + """Add projected vision features into the image-token positions of hidden_states.""" + vision_mask = vision_mask.to(hidden_states.device) + features = features.to(hidden_states.device, hidden_states.dtype) + return hidden_states.masked_scatter( + vision_mask, + (hidden_states[vision_mask] + features.flatten()).view(-1), + ) + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, nn.Embedding): + init.normal_( + module.weight, + mean=0.0, + std=getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range), + ) + if module.padding_idx is not None: + init.zeros_(module.weight[module.padding_idx]) + elif isinstance(module, nn.LayerNorm) or ( + hasattr(module, "weight") and hasattr(module, "variance_epsilon") and not isinstance(module, nn.Linear) + ): + init.ones_(module.weight) + if hasattr(module, "bias") and module.bias is not None: + init.zeros_(module.bias) + if isinstance(module, Granite4VisionTextRotaryEmbedding): + # Non-persistent buffers (inv_freq, original_inv_freq) are replaced with + # torch.empty_like() garbage by _move_missing_keys_from_meta_to_device. + # Recompute them here so _initialize_missing_keys restores correct values. + rope_type = module.config.rope_parameters.get("rope_type", "default") + if rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + else: + rope_init_fn = module.compute_default_rope_parameters + inv_freq, attention_scaling = rope_init_fn(module.config, module.inv_freq.device) + init.copy_(module.inv_freq, inv_freq) + init.copy_(module.original_inv_freq, inv_freq) + module.attention_scaling = attention_scaling + if isinstance(module, Granite4VisionWindowQFormerDownsampler): + embed_std = 1 / math.sqrt(module.query.shape[-1]) + init.normal_(module.query, mean=0.0, std=embed_std) + init.normal_(module.image_positions, mean=0.0, std=embed_std) + + +class Granite4VisionTextModel(Granite4VisionPreTrainedModel, GraniteModel): + """Granite LLM backbone with deepstack feature injection support.""" + + base_model_prefix = "model" + _no_split_modules = ["Granite4VisionTextDecoderLayer"] + + def __init__(self, config: Granite4VisionTextConfig): + super().__init__(config) + + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + vision_mask: torch.BoolTensor | None = None, + deepstack_features: dict | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + r""" + vision_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Boolean mask marking image token positions. Required when `deepstack_features` is provided. + deepstack_features (`dict[int, torch.Tensor]`, *optional*): + Mapping from LLM layer index to projected vision features of shape `(num_image_tokens, hidden_size)`. + Features are added into image-token positions of hidden states before the corresponding decoder layer. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + ).unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if deepstack_features is not None and layer_idx in deepstack_features: + hidden_states = self._deepstack_inject(hidden_states, vision_mask, deepstack_features[layer_idx]) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class Granite4VisionModel(LlavaNextModel): + config_class = Granite4VisionConfig + + def __init__(self, config: Granite4VisionConfig): + super().__init__(config) + + # Replace parent's single multi_modal_projector with layerwise_projectors + del self.multi_modal_projector + + self.spatial_projectors = None + self.downsample_rate = config.downsample_rate + self.projector_dropout = config.projector_dropout + + # Deepstack projectors: one per (vision_layer, llm_layer) pair + self.layerwise_projectors = nn.ModuleList( + [Granite4VisionWindowQFormerDownsampler(config) for _ in range(len(config.deepstack_layer_map))] + ) + + # Spatial sampling projectors: 4 offset groups (TL, TR, BL, BR) + if config.use_spatial_sampling: + self.spatial_projectors = nn.ModuleList( + [Granite4VisionWindowQFormerDownsampler(config, spatial_offset=i) for i in range(4)] + ) + + self.pad_token_id = ( + self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1 + ) + + # Replace the inherited LLM backbone with our deepstack-aware subclass + self.language_model = Granite4VisionTextModel(config.text_config) + + def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + + Overrides the parent to apply downsample_rate to height/width calculations. + """ + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + if self.layerwise_projectors is not None: + ds_rate = Fraction(self.downsample_rate) + height = int(height * ds_rate) + width = int(width * ds_rate) + + if ( + np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 + and vision_feature_select_strategy == "default" + ): + raise ValueError( + "Image feature shape does not line up with the provided patch size. " + "You may be using the `default` vision_feature_select_strategy with a " + "visual encoder that does not have CLS token." + ) + + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device, image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int | list[int] | None = None, + vision_feature_select_strategy: str | None = None, + **kwargs, + ) -> Granite4VisionImageFeaturesOutput: + """ + Extract image features via deepstack (multi-layer) and spatial sampling projections. + + Runs the vision tower once, then: + 1. Deepstack: for each (vision_layer, llm_layer) in deepstack_layer_map, + extracts features from that vision layer, downsamples via interpolation + QFormer, + and pairs them with the target LLM layer. + 2. Spatial: if enabled, extracts the spatial_vision_layer and creates 4 spatial + offset groups (TL, TR, BL, BR), each targeting a different LLM layer. + """ + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + for imsize in image_sizes + ] + + if pixel_values.dim() == 5: + _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + output_hidden_states = kwargs.pop("output_hidden_states", None) + if output_hidden_states is None: + output_hidden_states = getattr(self.config, "output_hidden_states", False) + vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + + # Deepstack features: extract from multiple vision layers, downsample via interpolation + all_features = [] + for projection_idx, (vision_layer, llm_layer) in enumerate(self.config.deepstack_layer_map): + selected_feature = vision_outputs.hidden_states[vision_layer] + + if vision_feature_select_strategy == "default": + selected_feature = selected_feature[:, 1:] + + projected_features = self.layerwise_projectors[projection_idx](selected_feature) + projected_features = torch.split(projected_features, image_num_patches, dim=0) + + packed_features, _ = self.pack_image_features( + projected_features, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + all_features.append((llm_layer, packed_features)) + + # Spatial features: extract 4 offset groups from a single vision layer + if self.config.use_spatial_sampling: + spatial_feature = vision_outputs.hidden_states[self.config.spatial_vision_layer] + + if vision_feature_select_strategy == "default": + spatial_feature = spatial_feature[:, 1:] + + for group_idx, llm_layer in enumerate(self.config.spatial_target_layers): + projected_group = self.spatial_projectors[group_idx](spatial_feature) + projected_group_split = torch.split(projected_group, image_num_patches, dim=0) + + packed_group, _ = self.pack_image_features( + projected_group_split, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) + + all_features.append((llm_layer, packed_group)) + + return Granite4VisionImageFeaturesOutput( + deepstack_features=all_features, + hidden_states=vision_outputs.hidden_states if output_hidden_states else None, + ) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_sizes: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + vision_feature_layer: int | list[int] | None = None, + vision_feature_select_strategy: str | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Granite4VisionModelOutputWithPast: + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # Build deepstack injection map and scatter initial image embeddings + deepstack_features = None + vision_mask = None + image_features = None + if pixel_values is not None and pixel_values.size(0) > 0: + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + deepstack_features = {} + for idx, (llm_layer_idx, packed_features) in enumerate(image_features.deepstack_features): + concat_features = torch.cat(packed_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + if idx == 0: + vision_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=concat_features + ) + inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0) + deepstack_features[llm_layer_idx] = concat_features + + # Bypass nn.Module.__call__ overhead by calling the unbound forward directly. + # nn.Module.__call__ has non-trivial per-call overhead that accumulates across 40 layers × N steps. + outputs = Granite4VisionTextModel.forward( + self.language_model, + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + vision_mask=vision_mask, + deepstack_features=deepstack_features, + **kwargs, + ) + + return Granite4VisionModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + deepstack_features=image_features.deepstack_features if pixel_values is not None else None, + ) + + +# ── ForConditionalGeneration ──────────────────────────────────────────────── + + +class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration): + config_class = Granite4VisionConfig + + def __init__(self, config: Granite4VisionConfig): + super().__init__(config) + self.model = Granite4VisionModel(config) + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + image_sizes: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + vision_feature_layer: int | list[int] | None = None, + vision_feature_select_strategy: str | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Granite4VisionCausalLMOutputWithPast: + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + outputs = self.model( + input_ids, + pixel_values=pixel_values, + image_sizes=image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + return_dict=True, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + loss = None + logits = self.lm_head(hidden_states) + logits = logits / self.config.text_config.logits_scaling + if labels is not None: + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.text_config.vocab_size, + **kwargs, + ) + + if isinstance(logits_to_keep, int) and logits_to_keep > 0: + logits = logits[:, -logits_to_keep:, :] + + return Granite4VisionCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + deepstack_features=outputs.deepstack_features, + ) + + +__all__ = [ + "Granite4VisionConfig", + "Granite4VisionTextConfig", + "Granite4VisionProcessor", + "Granite4VisionPreTrainedModel", + "Granite4VisionTextModel", + "Granite4VisionModel", + "Granite4VisionForConditionalGeneration", +] diff --git a/src/transformers/models/granite4_vision/processing_granite4_vision.py b/src/transformers/models/granite4_vision/processing_granite4_vision.py new file mode 100644 index 000000000000..f8fa2c2b09ce --- /dev/null +++ b/src/transformers/models/granite4_vision/processing_granite4_vision.py @@ -0,0 +1,237 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granite4_vision/modular_granite4_vision.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granite4_vision.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fractions import Fraction + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import select_best_resolution +from ...image_utils import ImageInput, SizeDict, get_image_size, to_numpy_array +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import auto_docstring + + +class Granite4VisionProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + "images_kwargs": { + "do_pad": True, + }, + } + + +@auto_docstring +class Granite4VisionProcessor(ProcessorMixin): + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size=None, + vision_feature_select_strategy=None, + chat_template=None, + image_token="", + num_additional_image_tokens=0, + downsample_rate=None, + **kwargs, + ): + r""" + patch_size (`int`, *optional*): + Patch size from the vision tower. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Should be same as in model's config. + image_token (`str`, *optional*, defaults to `""`): + Special token used to denote image location. + num_additional_image_tokens (`int`, *optional*, defaults to `0`): + Number of additional tokens added to the image embeddings, such as CLS (+1). + downsample_rate (`str`, *optional*): + Fractional downsample rate (e.g. `"1/4"`), used to adjust the number of image tokens + when computing token counts for padding/truncation. + """ + self.patch_size = patch_size + self.num_additional_image_tokens = num_additional_image_tokens + self.vision_feature_select_strategy = vision_feature_select_strategy + self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + self.downsample_rate = downsample_rate + + @auto_docstring + def __call__( + self, + images: ImageInput | None = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + **kwargs: Unpack[Granite4VisionProcessorKwargs], + ) -> BatchFeature: + r""" + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is None and text is None: + raise ValueError("You have to specify at least images or text.") + + output_kwargs = self._merge_kwargs( + Granite4VisionProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + prompt_strings = text + if image_inputs: + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + if not isinstance(image_size, (list, tuple)): + # cast to list to avoid numerical precision errors when calculating unpadding + image_size = image_size.tolist() + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"]) + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + image_grid_pinpoints = self.image_processor.image_grid_pinpoints + + height_best_resolution, width_best_resolution = select_best_resolution( + [orig_height, orig_width], image_grid_pinpoints + ) + scale_height, scale_width = height_best_resolution // height, width_best_resolution // width + + patches_height = height // self.patch_size + patches_width = width // self.patch_size + if self.downsample_rate is not None: + ds_rate = Fraction(self.downsample_rate) + patches_height = int(patches_height * ds_rate) + patches_width = int(patches_width * ds_rate) + + unpadded_features, newline_features = self._get_unpadded_features( + orig_height, orig_width, patches_height, patches_width, scale_height, scale_width + ) + base_features = patches_height * patches_width + self.num_additional_image_tokens + num_image_tokens = unpadded_features + newline_features + base_features + return num_image_tokens + + def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width): + """ + Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA + because it divided each image into patches depending on its resolution. Therefore we need to calculate how many + patches an image is divided into and get the number of features from that. + """ + current_height = patches_height * scale_height + current_width = patches_width * scale_width + + original_aspect_ratio = width / height + current_aspect_ratio = current_width / current_height + if original_aspect_ratio > current_aspect_ratio: + new_height = int(round(height * (current_width / width), 7)) + padding = (current_height - new_height) // 2 + current_height -= padding * 2 + else: + new_width = int(round(width * (current_height / height), 7)) + padding = (current_width - new_width) // 2 + current_width -= padding * 2 + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (list[list[str]], *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + vision_data = {} + if image_sizes is not None: + images_kwargs = Granite4VisionProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + size = images_kwargs.get("size", None) or self.image_processor.size + if isinstance(size, SizeDict): + size = ( + (size.shortest_edge, size.shortest_edge) + if size.shortest_edge is not None + else (min(size.height, size.width), min(size.height, size.width)) + ) + else: + size = ( + (size["shortest_edge"], size["shortest_edge"]) + if "shortest_edge" in size + else (min(size["height"], size["width"]), min(size["height"], size["width"])) + ) + processed_height, processed_width = size + + batch_num_image_tokens = [] + num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch` + for image_size in image_sizes: + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features( + orig_height, orig_width, processed_height, processed_width + ) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + batch_num_image_tokens.append(num_image_tokens) + vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + + +__all__ = ["Granite4VisionProcessor"] diff --git a/tests/models/granite4_vision/__init__.py b/tests/models/granite4_vision/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/granite4_vision/test_modeling_granite4_vision.py b/tests/models/granite4_vision/test_modeling_granite4_vision.py new file mode 100644 index 000000000000..10dafd8b3abe --- /dev/null +++ b/tests/models/granite4_vision/test_modeling_granite4_vision.py @@ -0,0 +1,241 @@ +# Copyright 2026 IBM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Granite4Vision model.""" + +import unittest + +import pytest +import requests + +from transformers import ( + AutoProcessor, + CLIPVisionConfig, + Granite4VisionConfig, + Granite4VisionForConditionalGeneration, + Granite4VisionModel, + GraniteConfig, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + cleanup, + require_torch, + slow, + torch_device, +) + +from ...test_modeling_common import floats_tensor +from ...vlm_tester import VLMModelTest, VLMModelTester + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from PIL import Image + + +class Granite4VisionModelTester(VLMModelTester): + base_model_class = Granite4VisionModel + config_class = Granite4VisionConfig + conditional_generation_class = Granite4VisionForConditionalGeneration + text_config_class = GraniteConfig + vision_config_class = CLIPVisionConfig + + def __init__(self, parent, **kwargs): + # Vision hidden_size must be divisible by 64 (QFormer num_attention_heads = hidden_size // 64) + kwargs.setdefault("hidden_size", 64) + kwargs.setdefault("intermediate_size", 64) + kwargs.setdefault("num_attention_heads", 2) + kwargs.setdefault("num_key_value_heads", 2) + kwargs.setdefault("num_hidden_layers", 2) + # Image/patch sizes: image_side = image_size // patch_size must be divisible by window_side + kwargs.setdefault("image_size", 8) + kwargs.setdefault("patch_size", 2) + kwargs.setdefault("projection_dim", 64) + kwargs.setdefault("num_patches_per_image", 2) + # Granite4Vision-specific + kwargs.setdefault("downsample_rate", "1/2") + kwargs.setdefault("deepstack_layer_map", [[1, 0]]) + kwargs.setdefault("use_spatial_sampling", False) + kwargs.setdefault("projector_dropout", 0.0) + kwargs.setdefault("image_token_index", kwargs.get("image_token_id", 3)) + + # Compute num_image_tokens after downsampling: + # image_side = image_size/patch_size = 4, ds 1/2 -> patches_h = patches_w = 2 + # pinpoints [[8,8]] -> scale 1x1 -> current_h = current_w = 2 + # unpadded = 2*2 = 4, newline = 2, base = 2*2 = 4 -> total = 10 + kwargs.setdefault("num_image_tokens", 10) + + super().__init__(parent, **kwargs) + + def create_pixel_values(self): + """Granite4Vision expects 5D pixel_values: (batch_size, num_patches, channels, height, width)""" + return floats_tensor( + [ + self.batch_size, + self.num_patches_per_image, + self.num_channels, + self.image_size, + self.image_size, + ] + ) + + def get_additional_inputs(self, config, input_ids, pixel_values): + """Granite4Vision requires image_sizes tensor""" + return { + "image_sizes": torch.tensor([[self.image_size, self.image_size]] * self.batch_size), + } + + def get_config(self): + config = super().get_config() + config.image_grid_pinpoints = [[self.image_size, self.image_size]] + config.downsample_rate = self.downsample_rate + config.deepstack_layer_map = self.deepstack_layer_map + config.use_spatial_sampling = self.use_spatial_sampling + config.projector_dropout = self.projector_dropout + return config + + +@require_torch +class Granite4VisionModelTest(VLMModelTest, unittest.TestCase): + """ + Model tester for `Granite4VisionForConditionalGeneration`. + """ + + model_tester_class = Granite4VisionModelTester + skip_test_image_features_output_shape = True + test_torch_exportable = False + # Custom layer-by-layer forward doesn't support output_attentions + # (GraniteDecoderLayer discards attention weights internally) + test_attention_outputs = False + has_attentions = False + + @pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.") + def test_training_gradient_checkpointing(self): + super().test_training_gradient_checkpointing() + + @pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.") + def test_training_gradient_checkpointing_use_reentrant_false(self): + super().test_training_gradient_checkpointing_use_reentrant_false() + + @pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.") + def test_training_gradient_checkpointing_use_reentrant_true(self): + super().test_training_gradient_checkpointing_use_reentrant_true() + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_eager_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Custom layer-by-layer forward has graph breaks incompatible with fullgraph compile") + def test_generate_compile_model_forward_fullgraph(self): + pass + + @unittest.skip("Blip2QFormerModel in WindowQFormerDownsampler does not support SDPA dispatch") + def test_can_set_attention_dynamically_composite_model(self): + pass + + +@require_torch +class Granite4VisionIntegrationTest(unittest.TestCase): + model_id = "ibm-granite/granite-vision-4.1-4b" + + def setUp(self): + self.processor = AutoProcessor.from_pretrained(self.model_id) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + self.image = Image.open(requests.get(url, stream=True).raw) + + def make_prompt(self, question): + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}] + return self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_small_model_integration_test(self): + model = Granite4VisionForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to( + torch_device + ) + + prompt = self.make_prompt("Describe this image briefly.") + inputs = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device) + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + new_tokens = output[:, inputs["input_ids"].shape[1] :] + + EXPECTED_RESPONSE = "The image depicts two cats resting on a pink couch. They are lying in a relaxed, sprawled position, with one cat appearing to be in a" # fmt: skip + self.assertEqual(self.processor.decode(new_tokens[0], skip_special_tokens=True), EXPECTED_RESPONSE) + + @slow + def test_small_model_integration_test_batch(self): + model = Granite4VisionForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to( + torch_device + ) + + url2 = "http://images.cocodataset.org/val2017/000000001000.jpg" + image2 = Image.open(requests.get(url2, stream=True).raw) + + prompt = self.make_prompt("What do you see in this image?") + inputs = self.processor( + text=[prompt, prompt], + images=[self.image, image2], + return_tensors="pt", + padding=True, + ).to(model.device) + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + new_tokens = output[:, inputs["input_ids"].shape[1] :] + responses = self.processor.batch_decode(new_tokens, skip_special_tokens=True) + + self.assertIn("cat", responses[0].lower()) + self.assertIn("tennis", responses[1].lower()) + + @slow + def test_small_model_integration_test_batch_matches_single(self): + model = Granite4VisionForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to( + torch_device + ) + + prompt = self.make_prompt("What do you see in this image?") + + # Single inference + inputs_single = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device) + output_single = model.generate(**inputs_single, max_new_tokens=30, do_sample=False) + decoded_single = self.processor.decode( + output_single[0, inputs_single["input_ids"].shape[1] :], skip_special_tokens=True + ) + + # Batch inference (same image as first in batch) + url2 = "http://images.cocodataset.org/val2017/000000001000.jpg" + image2 = Image.open(requests.get(url2, stream=True).raw) + inputs_batch = self.processor( + text=[prompt, prompt], + images=[self.image, image2], + return_tensors="pt", + padding=True, + ).to(model.device) + output_batch = model.generate(**inputs_batch, max_new_tokens=30, do_sample=False) + decoded_batch = self.processor.decode( + output_batch[0, inputs_batch["input_ids"].shape[1] :], skip_special_tokens=True + ) + + self.assertEqual(decoded_single, decoded_batch) diff --git a/tests/models/granite4_vision/test_processing_granite4_vision.py b/tests/models/granite4_vision/test_processing_granite4_vision.py new file mode 100644 index 000000000000..8a56aa69b020 --- /dev/null +++ b/tests/models/granite4_vision/test_processing_granite4_vision.py @@ -0,0 +1,122 @@ +# Copyright 2025 IBM. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import unittest + +import torch + +from transformers import Granite4VisionProcessor +from transformers.testing_utils import ( + require_vision, +) + +from ...test_processing_common import ProcessorTesterMixin + + +@require_vision +class Granite4VisionProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Granite4VisionProcessor + # Image token expansion with downsample_rate="1/2" produces more tokens than the defaults + image_text_kwargs_max_length = 300 + image_text_kwargs_override_max_length = 280 + image_unstructured_max_length = 260 + + @classmethod + def _setup_tokenizer(cls): + tokenizer_class = cls._get_component_class_from_processor("tokenizer") + tokenizer = tokenizer_class.from_pretrained("huggyllama/llama-7b") + tokenizer.add_special_tokens({"additional_special_tokens": [""]}) + if not tokenizer.pad_token: + tokenizer.pad_token = "[PAD]" + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = 0 + return tokenizer + + @classmethod + def _setup_test_attributes(cls, processor): + cls.image_token = processor.image_token + + @staticmethod + def prepare_processor_dict(): + return { + "chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", + "patch_size": 14, + "vision_feature_select_strategy": "default", + "downsample_rate": "1/2", + } # fmt: skip + + def test_get_num_vision_tokens(self): + """Tests general functionality of the helper used internally in vLLM""" + processor = self.get_processor() + + output = processor._get_num_multimodal_tokens(image_sizes=[(100, 100), (300, 100), (500, 30)]) + self.assertTrue("num_image_tokens" in output) + self.assertEqual(len(output["num_image_tokens"]), 3) + + self.assertTrue("num_image_patches" in output) + self.assertEqual(len(output["num_image_patches"]), 3) + + def test_chat_template_is_saved(self): + processor_loaded = self.processor_class.from_pretrained(self.tmpdirname) + processor_dict_loaded = json.loads(processor_loaded.to_json_string()) + # chat templates aren't serialized to json in processors + self.assertFalse("chat_template" in processor_dict_loaded) + + # they have to be saved as separate file and loaded back from that file + # so we check if the same template is loaded + processor_dict = self.prepare_processor_dict() + self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) + + def test_image_token_filling(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + processor.patch_size = 14 + processor.vision_feature_select_strategy = "default" + processor.downsample_rate = "1/2" + processor.image_processor.crop_size = {"height": 336, "width": 336} + processor.image_processor.size = {"shortest_edge": 336} + processor.image_processor.image_grid_pinpoints = [[672, 336]] + # Important to check with non square image + image = torch.randint(0, 2, (3, 503, 316)) + image_token_index = processor.image_token_id + + # With downsample_rate="1/2" and patch_size=14: + # patches = 336/14 = 24, after ds: 24*1/2 = 12 + # best resolution for (503, 316): [672, 336] + # scale_height=2, scale_width=1 + # current = 12*2=24 h, 12*1=12 w + # aspect: 316/503 = 0.628, 12/24 = 0.5 -> orig > current -> new_height = round(503*(12/316)) = 19 + # padding = (24-19)//2 = 2, current_height = 24 - 4 = 20 + # unpadded = 20*12 = 240, newline = 20 + # base = 12*12 + 0 = 144 + # total = 240 + 20 + 144 = 404 + # with "default" strategy: 404 - 1 = 403 + expected_image_tokens = 403 + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens) diff --git a/utils/check_repo.py b/utils/check_repo.py index 0a89b4aa63a9..af97b084a7f6 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -283,6 +283,7 @@ "Gemma4VisionModel", # Building part of a bigger model, tested implicitly "Gemma4AudioModel", # Building part of a bigger model, tested implicitly "Sam3LiteTextTextModel", # Building part of a bigger model, tested implicitly through Sam3LiteTextModel + "Granite4VisionTextModel", # Building part of bigger (tested) model. Tested implicitly through Granite4VisionModel. ] ) @@ -509,6 +510,7 @@ "Ernie4_5_VL_MoeModel", # BC Alias "Ernie4_5_VL_MoeTextModel", # BC Alias "UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel + "Granite4VisionTextModel", # Building part of bigger (tested) model. ]