diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 3921898051d9..ea3f05c14d5d 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1235,6 +1235,8 @@
title: GlmOcr
- local: model_doc/got_ocr2
title: GOT-OCR2
+ - local: model_doc/granite4_vision
+ title: Granite4Vision
- local: model_doc/granitevision
title: GraniteVision
- local: model_doc/grounding-dino
diff --git a/docs/source/en/model_doc/granite4_vision.md b/docs/source/en/model_doc/granite4_vision.md
new file mode 100644
index 000000000000..da5090c251ba
--- /dev/null
+++ b/docs/source/en/model_doc/granite4_vision.md
@@ -0,0 +1,185 @@
+
+*This model was released on 2026-03-27 and added to Hugging Face Transformers on 2026-05-03.*
+
+
+
+# Granite4Vision
+
+[Granite Vision 4.1](https://huggingface.co/ibm-granite/granite-vision-4.1-4b) is a vision-language model from IBM Research designed for enterprise-grade document data extraction. It specializes in chart extraction (Chart2CSV, Chart2Summary, Chart2Code), table extraction (JSON, HTML, OTSL), and semantic key-value pair extraction.
+
+The model builds on [LLaVA-NeXT](llava_next) with several architectural innovations:
+
+1. **SigLIP2 Vision Encoder** ([`google/siglip2-so400m-patch16-384`](https://huggingface.co/google/siglip2-so400m-patch16-384)): images are tiled into 384x384 patches.
+2. **Window Q-Former Projectors**: compress visual features 4x using windowed cross-attention over 4x4 patch windows into 2x2 tokens.
+3. **DeepStack Feature Injection** with 8 vision-to-LLM injection points:
+ - *LayerDeepstack*: features from 4 vision encoder depths are projected into different early LLM layers.
+ - *SpatialDeepstack*: deepest vision features are split into 4 spatial groups and injected at later LLM layers.
+4. **Language Model**: [Granite 4.1](https://huggingface.co/ibm-granite/granite-4.1-4b-base) (4B params) with LoRA adapters (rank 256) across all self-attention and MLP layers.
+
+The model is delivered as a LoRA adapter on top of the base LLM, enabling single deployments to support both multimodal and text-only workloads. Total parameter count is ~4B.
+
+> [!TIP]
+> This model was contributed by the [IBM Granite Vision Team](https://github.com/ibm-granite).
+
+## Usage Tips
+
+- Set `padding_side="left"` during batched generation for more accurate results.
+
+```py
+processor.tokenizer.padding_side = "left"
+```
+
+- The model supports specialized task tags for document extraction: ``, ``, ``, ``, ``, ``. Pass these as the text prompt along with a document image.
+
+- For key-value pair extraction, provide a JSON schema describing the fields to extract. The model returns structured JSON matching the schema.
+
+The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
+
+
+
+
+
+```python
+from transformers import pipeline
+
+pipe = pipeline(
+ task="image-text-to-text",
+ model="ibm-granite/granite-vision-4.1-4b",
+)
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ {"type": "text", "text": "Describe this image."},
+ ],
+ }
+]
+pipe(text=messages, max_new_tokens=100, return_full_text=False)
+```
+
+
+
+
+
+```python
+import torch
+from transformers import AutoProcessor, AutoModelForImageTextToText
+
+model_id = "ibm-granite/granite-vision-4.1-4b"
+
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModelForImageTextToText.from_pretrained(model_id)
+
+conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ {"type": "text", "text": "Describe this image."},
+ ],
+ },
+]
+inputs = processor.apply_chat_template(
+ conversation,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+).to(model.device)
+
+output = model.generate(**inputs, max_new_tokens=100)
+print(processor.decode(output[0], skip_special_tokens=True))
+```
+
+
+
+
+
+Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
+
+The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
+
+```python
+import torch
+from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
+
+quant_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_quant_type="nf4",
+)
+
+model_id = "ibm-granite/granite-vision-4.1-4b"
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModelForImageTextToText.from_pretrained(
+ model_id, quantization_config=quant_config, device_map="auto"
+)
+
+conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ {"type": "text", "text": "Describe this image."},
+ ],
+ },
+]
+inputs = processor.apply_chat_template(
+ conversation,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+).to(model.device)
+
+output = model.generate(**inputs, max_new_tokens=100)
+print(processor.decode(output[0], skip_special_tokens=True))
+```
+
+## Granite4VisionConfig
+
+[[autodoc]] Granite4VisionConfig
+
+## Granite4VisionTextConfig
+
+[[autodoc]] Granite4VisionTextConfig
+
+## Granite4VisionProcessor
+
+[[autodoc]] Granite4VisionProcessor
+ - __call__
+
+## Granite4VisionModel
+
+[[autodoc]] Granite4VisionModel
+
+## Granite4VisionTextModel
+
+[[autodoc]] Granite4VisionTextModel
+
+## Granite4VisionForConditionalGeneration
+
+[[autodoc]] Granite4VisionForConditionalGeneration
+ - forward
+ - get_image_features
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 7fd1bbacf90b..887ad30183c3 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -178,6 +178,7 @@
from .gpt_sw3 import *
from .gptj import *
from .granite import *
+ from .granite4_vision import *
from .granite_speech import *
from .granite_speech_plus import *
from .granitemoe import *
diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py
index 0b2d7dd79167..53755949645b 100644
--- a/src/transformers/models/auto/auto_mappings.py
+++ b/src/transformers/models/auto/auto_mappings.py
@@ -235,6 +235,7 @@
("gpt_oss", "GptOssConfig"),
("gptj", "GPTJConfig"),
("granite", "GraniteConfig"),
+ ("granite4_vision", "Granite4VisionConfig"),
("granite_speech", "GraniteSpeechConfig"),
("granite_speech_encoder", "GraniteSpeechEncoderConfig"),
("granite_speech_plus", "GraniteSpeechPlusConfig"),
@@ -892,6 +893,7 @@
("glm_image", {"pil": "GlmImageImageProcessorPil", "torchvision": "GlmImageImageProcessor"}),
("glpn", {"pil": "GLPNImageProcessorPil", "torchvision": "GLPNImageProcessor"}),
("got_ocr2", {"pil": "GotOcr2ImageProcessorPil", "torchvision": "GotOcr2ImageProcessor"}),
+ ("granite4_vision", {"pil": "LlavaNextImageProcessorPil", "torchvision": "LlavaNextImageProcessor"}),
("grounding-dino", {"pil": "GroundingDinoImageProcessorPil", "torchvision": "GroundingDinoImageProcessor"}),
("idefics", {"pil": "IdeficsImageProcessorPil", "torchvision": "IdeficsImageProcessor"}),
("idefics2", {"pil": "Idefics2ImageProcessorPil", "torchvision": "Idefics2ImageProcessor"}),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 98447b6d1724..c1a46661aecf 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -88,6 +88,7 @@
("focalnet", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}),
("gemma3n", {"torchvision": "SiglipImageProcessor", "pil": "SiglipImageProcessorPil"}),
("git", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}),
+ ("granite4_vision", {"torchvision": "LlavaNextImageProcessor", "pil": "LlavaNextImageProcessorPil"}),
("groupvit", {"torchvision": "CLIPImageProcessor", "pil": "CLIPImageProcessorPil"}),
("hiera", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}),
("ijepa", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 67dc6cd64cce..5fb61fcf53b6 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -211,6 +211,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gpt_oss", "GptOssModel"),
("gptj", "GPTJModel"),
("granite", "GraniteModel"),
+ ("granite4_vision", "Granite4VisionModel"),
("granite_speech", "GraniteSpeechForConditionalGeneration"),
("granitemoe", "GraniteMoeModel"),
("granitemoehybrid", "GraniteMoeHybridModel"),
@@ -995,6 +996,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("glm4v_moe", "Glm4vMoeForConditionalGeneration"),
("glm_ocr", "GlmOcrForConditionalGeneration"),
("got_ocr2", "GotOcr2ForConditionalGeneration"),
+ ("granite4_vision", "Granite4VisionForConditionalGeneration"),
("idefics", "IdeficsForVisionText2Text"),
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 24008bd4263a..c6a7b3fd1296 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -89,6 +89,7 @@
("glm_image", "Glm4vProcessor"),
("glmasr", "GlmAsrProcessor"),
("got_ocr2", "GotOcr2Processor"),
+ ("granite4_vision", "Granite4VisionProcessor"),
("granite_speech", "GraniteSpeechProcessor"),
("granite_speech_plus", "GraniteSpeechProcessor"),
("grounding-dino", "GroundingDinoProcessor"),
diff --git a/src/transformers/models/granite4_vision/__init__.py b/src/transformers/models/granite4_vision/__init__.py
new file mode 100644
index 000000000000..113694a1a26c
--- /dev/null
+++ b/src/transformers/models/granite4_vision/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 IBM. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_granite4_vision import *
+ from .modeling_granite4_vision import *
+ from .processing_granite4_vision import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/granite4_vision/configuration_granite4_vision.py b/src/transformers/models/granite4_vision/configuration_granite4_vision.py
new file mode 100644
index 000000000000..9b98bd64e152
--- /dev/null
+++ b/src/transformers/models/granite4_vision/configuration_granite4_vision.py
@@ -0,0 +1,209 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/granite4_vision/modular_granite4_vision.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_granite4_vision.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 IBM and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Literal
+
+from huggingface_hub.dataclasses import strict
+
+from ...configuration_utils import PreTrainedConfig
+from ...modeling_rope_utils import RopeParameters
+from ...utils import auto_docstring
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+@auto_docstring(checkpoint="ibm-granite4_vision_text/granite4_vision_text-3.0-8b-base")
+@strict
+class Granite4VisionTextConfig(PreTrainedConfig):
+ r"""
+ ```python
+ >>> from transformers import Granite4VisionTextModel, Granite4VisionTextConfig
+
+ >>> # Initializing a Granite4VisionText granite4_vision_text-3b style configuration
+ >>> configuration = Granite4VisionTextConfig()
+
+ >>> # Initializing a model from the granite4_vision_text-7b style configuration
+ >>> model = Granite4VisionTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "granite4_vision_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Granite4VisionTextModel`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ vocab_size: int = 32000
+ hidden_size: int = 4096
+ intermediate_size: int = 11008
+ num_hidden_layers: int = 32
+ num_attention_heads: int = 32
+ num_key_value_heads: int | None = None
+ hidden_act: str = "silu"
+ max_position_embeddings: int = 2048
+ initializer_range: float = 0.02
+ rms_norm_eps: float = 1e-6
+ use_cache: bool = True
+ pad_token_id: int | None = None
+ bos_token_id: int | None = 1
+ eos_token_id: int | list[int] | None = 2
+ tie_word_embeddings: bool = False
+ rope_parameters: RopeParameters | dict | None = None
+ attention_bias: bool = False
+ attention_dropout: float | int = 0.0
+ mlp_bias: bool = False
+ embedding_multiplier: float | int = 1.0
+ logits_scaling: float | int = 1.0
+ residual_multiplier: float | int = 1.0
+ attention_multiplier: float | int = 1.0
+ base_config_key = "text_config"
+
+ def __post_init__(self, **kwargs):
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
+
+ super().__post_init__(**kwargs)
+
+
+@auto_docstring(checkpoint="llava-hf/llava-v1.6-mistral-7b-hf")
+@strict
+class Granite4VisionConfig(PreTrainedConfig):
+ r"""
+ downsample_rate (`str`, *optional*):
+ Fractional downsample rate for the Window Q-Former projector, e.g. `"1/4"` or `"3/8"`.
+ The numerator is the query window side, the denominator is the key window side.
+ deepstack_layer_map (`list`, *optional*):
+ List of `[vision_layer_idx, llm_layer_idx]` pairs. Features from each vision encoder layer
+ are projected and injected at the corresponding LLM decoder layer during forward pass.
+ use_spatial_sampling (`bool`, *optional*, defaults to `False`):
+ Whether to enable spatial offset sampling, which creates 4 groups (TL, TR, BL, BR) from
+ a single vision layer, each injected at a different LLM layer.
+ spatial_vision_layer (`int`, *optional*, defaults to `-1`):
+ Index of the vision encoder layer used for spatial sampling.
+ spatial_target_layers (`list`, *optional*, defaults to `[12, 15, 18, 21]`):
+ Target LLM layers for the 4 spatial offset groups.
+ projector_dropout (`float`, *optional*, defaults to `0.1`):
+ Dropout probability in the Window Q-Former projector.
+ qformer_config (`dict` or `Blip2QFormerConfig`, *optional*):
+ Configuration for the Window Q-Former projector. If `None`, defaults are derived from
+ `vision_config.hidden_size`.
+ image_grid_pinpoints (`list`, *optional*):
+ A list of possible resolutions to use for processing high resolution images. Each item in the list should be a
+ tuple or list of the form `(height, width)`.
+ """
+
+ model_type = "granite4_vision"
+ attribute_map = {"image_token_id": "image_token_index"}
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "qformer_config": AutoConfig}
+
+ vision_config: dict | PreTrainedConfig | None = None
+ text_config: dict | PreTrainedConfig | None = None
+ image_token_index: int = 32000
+ vision_feature_select_strategy: Literal["default", "full"] = "default"
+ vision_feature_layer: int | list[int] = -2
+ tie_word_embeddings: bool = False
+ image_grid_pinpoints: list | None = None
+ image_seq_length: int = 576
+
+ downsample_rate: str | None = None
+ deepstack_layer_map: list | None = None
+ use_spatial_sampling: bool = False
+ spatial_vision_layer: int = -1
+ spatial_target_layers: list | None = None
+ projector_dropout: float = 0.1
+ qformer_config: dict | PreTrainedConfig | None = None
+
+ def __post_init__(self, **kwargs):
+ if self.deepstack_layer_map is not None:
+ self.deepstack_layer_map = [(int(v), int(l)) for v, l in self.deepstack_layer_map]
+
+ if self.spatial_target_layers is None:
+ self.spatial_target_layers = [12, 15, 18, 21]
+
+ # Peek at vision hidden_size before super() to build a fully-specified qformer_config,
+ # avoiding any runtime field patching after super().
+ if isinstance(self.vision_config, dict):
+ vision_hidden_size = self.vision_config.get("hidden_size", 1152)
+ elif self.vision_config is not None:
+ vision_hidden_size = self.vision_config.hidden_size
+ else:
+ vision_hidden_size = 1152
+
+ # Convert qformer_config dict → object before super() so _attn_implementation.setter
+ # (called inside super().__post_init__) sees a config object, not a raw dict.
+ if isinstance(self.qformer_config, dict):
+ model_type = self.qformer_config.get("model_type", "blip_2_qformer")
+ self.qformer_config = CONFIG_MAPPING[model_type](**self.qformer_config)
+ elif self.qformer_config is None:
+ self.qformer_config = CONFIG_MAPPING["blip_2_qformer"](
+ num_hidden_layers=1,
+ intermediate_size=3072,
+ cross_attention_frequency=1,
+ max_position_embeddings=2048,
+ use_qformer_text_input=False,
+ hidden_size=vision_hidden_size,
+ num_attention_heads=vision_hidden_size // 64,
+ encoder_hidden_size=vision_hidden_size,
+ )
+ if isinstance(self.vision_config, dict):
+ self.vision_config["model_type"] = self.vision_config.get("model_type", "clip_vision_model")
+ self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config)
+ elif self.vision_config is None:
+ self.vision_config = CONFIG_MAPPING["clip_vision_model"](
+ intermediate_size=4096,
+ hidden_size=1024,
+ patch_size=14,
+ image_size=336,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ vocab_size=32000,
+ projection_dim=768,
+ )
+
+ if isinstance(self.text_config, dict):
+ self.text_config["model_type"] = self.text_config.get("model_type", "llama")
+ self.text_config = CONFIG_MAPPING[self.text_config["model_type"]](**self.text_config)
+ elif self.text_config is None:
+ self.text_config = CONFIG_MAPPING["llama"]()
+
+ self.image_grid_pinpoints = (
+ self.image_grid_pinpoints
+ if self.image_grid_pinpoints is not None
+ else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
+ )
+
+ super().__post_init__(**kwargs)
+
+
+__all__ = ["Granite4VisionConfig", "Granite4VisionTextConfig"]
diff --git a/src/transformers/models/granite4_vision/modeling_granite4_vision.py b/src/transformers/models/granite4_vision/modeling_granite4_vision.py
new file mode 100644
index 000000000000..2e5919df5588
--- /dev/null
+++ b/src/transformers/models/granite4_vision/modeling_granite4_vision.py
@@ -0,0 +1,1290 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/granite4_vision/modular_granite4_vision.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_granite4_vision.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 IBM and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Callable
+from dataclasses import dataclass
+from fractions import Fraction
+from typing import Optional
+
+import numpy as np
+import torch
+from torch import nn
+
+from ... import initialization as init
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...image_processing_utils import select_best_resolution
+from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
+from ...masking_utils import create_causal_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
+from ...utils.generic import maybe_autocast, merge_with_config_defaults
+from ...utils.output_capturing import capture_outputs
+from ..auto import AutoModel
+from .configuration_granite4_vision import Granite4VisionConfig, Granite4VisionTextConfig
+
+
+@dataclass
+class Granite4VisionModelOutputWithPast(BaseModelOutputWithPast):
+ """
+ Args:
+ deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*):
+ List of `(llm_layer_idx, packed_features)` pairs produced by the deepstack
+ and spatial projectors. Each entry targets one LLM decoder layer; `packed_features`
+ is a per-image list of tensors of shape `(num_image_tokens, hidden_size)`.
+ """
+
+ image_hidden_states: torch.FloatTensor | None = None
+
+ deepstack_features: list | None = None
+
+
+@dataclass
+class Granite4VisionCausalLMOutputWithPast(ModelOutput):
+ """
+ Args:
+ deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*):
+ List of `(llm_layer_idx, packed_features)` pairs. See `Granite4VisionModelOutputWithPast`.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ image_hidden_states: torch.FloatTensor | None = None
+
+ deepstack_features: list | None = None
+
+
+@dataclass
+class Granite4VisionImageFeaturesOutput(BaseModelOutputWithPooling):
+ """
+ Output of `Granite4VisionModel.get_image_features`.
+
+ Args:
+ deepstack_features (`list[tuple[int, list[torch.Tensor]]]`):
+ List of `(llm_layer_idx, packed_features)` pairs. Each entry targets one LLM
+ decoder layer; `packed_features` is a per-image list of tensors of shape
+ `(num_image_tokens, hidden_size)`.
+ """
+
+ deepstack_features: list | None = None
+
+
+# ── Downsampling helpers ─────────────────────────────────────────────────────
+
+
+def interpolate_downsample(image_features: torch.Tensor, orig_side: int, new_side: int) -> torch.Tensor:
+ """Spatial downsampling via area interpolation."""
+ batch, _, channels = image_features.size()
+ spatial = image_features.view(batch, orig_side, orig_side, channels).permute(0, 3, 1, 2)
+ spatial = torch.nn.functional.interpolate(spatial, size=(new_side, new_side), mode="area")
+ return spatial.permute(0, 2, 3, 1).flatten(1, 2)
+
+
+def spatial_offset_downsample(image_features: torch.Tensor, orig_side: int, offset: int = 0) -> torch.Tensor:
+ """Sample one position from each 2x2 block; offset selects which corner (0=TL,1=TR,2=BL,3=BR)."""
+ offset_h, offset_w = [(0, 0), (0, 1), (1, 0), (1, 1)][offset]
+ new_side = orig_side // 2
+ batch, _, channels = image_features.shape
+ grid = image_features.reshape(batch, orig_side, orig_side, channels)
+ grid = grid.reshape(batch, new_side, 2, new_side, 2, channels)
+ return grid[:, :, offset_h, :, offset_w, :].reshape(batch, -1, channels)
+
+
+class Granite4VisionWindowQFormerDownsampler(nn.Module):
+ """Window-based QFormer downsampler that processes image patches in windows."""
+
+ def __init__(self, config, spatial_offset=None):
+ super().__init__()
+ llm_hidden_size = config.text_config.hidden_size
+ vision_hidden_size = config.vision_config.hidden_size
+
+ self.dropout = nn.Dropout(config.projector_dropout)
+ self._spatial_offset = spatial_offset
+ self._downsample_rate = config.downsample_rate
+
+ self.qformer = AutoModel.from_config(config.qformer_config)
+
+ self.image_side = config.vision_config.image_size // config.vision_config.patch_size
+ query_side_str, window_side_str = config.downsample_rate.split("/")
+ self.query_side, self.window_side = int(query_side_str), int(window_side_str)
+ self.query_length = self.query_side**2
+ self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6)
+ self.query = nn.Parameter(torch.empty(1, self.query_length, vision_hidden_size))
+ self.image_positions = nn.Parameter(torch.empty(1, self.window_side**2, vision_hidden_size))
+ self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True)
+
+ def _windowed_raster(self, features, side, window_size):
+ """(B, side*side, C) raster -> (B*num_win*num_win, window_size*window_size, C)"""
+ batch, _, channels = features.shape
+ num_win = side // window_size
+ return (
+ features.view(batch, side, side, channels)
+ .view(batch, num_win, window_size, num_win, window_size, channels)
+ .transpose(2, 3)
+ .flatten(0, 2)
+ .flatten(1, 2)
+ )
+
+ def _unwindowed_raster(self, windowed_features, num_win, window_size):
+ """(B*num_win*num_win, window_size*window_size, C) -> (B, (num_win*window_size)^2, C)"""
+ batch_win, _, channels = windowed_features.shape
+ if batch_win % (num_win * num_win) != 0:
+ raise ValueError(f"Expected batch_win ({batch_win}) to be divisible by num_win^2 ({num_win**2}).")
+ batch = batch_win // (num_win * num_win)
+ side = num_win * window_size
+ return (
+ windowed_features.view(batch, num_win, num_win, window_size, window_size, channels)
+ .transpose(2, 3)
+ .contiguous()
+ .view(batch, side, side, channels)
+ .flatten(1, 2)
+ )
+
+ def forward(self, image_features):
+ batch, hw, channels = image_features.shape
+ if self.image_side * self.image_side != hw:
+ raise ValueError(
+ f"Expected image_features with {self.image_side**2} spatial tokens, got {hw}. "
+ "Check that the vision encoder image_size and patch_size match the config."
+ )
+ num_windows = self.image_side // self.window_side
+ interp_side = int(self.image_side * Fraction(self._downsample_rate))
+ image_features = self.norm(image_features)
+ enc = self._windowed_raster(image_features, self.image_side, self.window_side)
+
+ if self._spatial_offset is not None:
+ downsampled = spatial_offset_downsample(image_features, self.image_side, self._spatial_offset)
+ else:
+ downsampled = interpolate_downsample(image_features, self.image_side, interp_side)
+
+ downsampled_side = num_windows * self.query_side
+ downsampled_w = self._windowed_raster(downsampled, downsampled_side, self.query_side)
+
+ query_embeds = self.query + downsampled_w
+ encoder_embeds = self.dropout(enc + self.image_positions)
+ out_w = self.qformer(
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_embeds,
+ return_dict=True,
+ ).last_hidden_state
+
+ out = self._unwindowed_raster(out_w, num_win=num_windows, window_size=self.query_side)
+ out = self.dropout(out)
+ return self.out_linear(out)
+
+
+class Granite4VisionTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Granite4VisionTextConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: Granite4VisionTextConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+@use_kernel_func_from_hub("rotary_pos_emb")
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+@use_kernelized_func(apply_rotary_pos_emb)
+class Granite4VisionTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Granite4VisionTextConfig, layer_idx: int | None = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = config.attention_multiplier
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ attention_mask: torch.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
+
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Granite4VisionTextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
+ """
+ Granite4VisionTextRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Granite4VisionTextMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class Granite4VisionTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Granite4VisionTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Granite4VisionTextAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Granite4VisionTextMLP(config)
+ self.input_layernorm = Granite4VisionTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Granite4VisionTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.residual_multiplier = config.residual_multiplier
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ use_cache: bool | None = False,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states * self.residual_multiplier
+
+ return hidden_states
+
+
+@auto_docstring
+class Granite4VisionPreTrainedModel(PreTrainedModel):
+ config: Granite4VisionConfig
+ base_model_prefix = "model"
+ input_modalities = ("image", "text")
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Granite4VisionTextDecoderLayer,
+ "attentions": Granite4VisionTextAttention,
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, nn.Linear):
+ init.normal_(module.weight, mean=0.0, std=std)
+ if module.bias is not None:
+ init.zeros_(module.bias)
+ elif isinstance(module, Granite4VisionModel):
+ embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
+ init.normal_(module.image_newline, mean=0.0, std=embed_std)
+ if isinstance(module, nn.Embedding):
+ init.normal_(
+ module.weight,
+ mean=0.0,
+ std=getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range),
+ )
+ if module.padding_idx is not None:
+ init.zeros_(module.weight[module.padding_idx])
+ elif isinstance(module, nn.LayerNorm) or (
+ hasattr(module, "weight") and hasattr(module, "variance_epsilon") and not isinstance(module, nn.Linear)
+ ):
+ init.ones_(module.weight)
+ if hasattr(module, "bias") and module.bias is not None:
+ init.zeros_(module.bias)
+ if isinstance(module, Granite4VisionTextRotaryEmbedding):
+ # Non-persistent buffers (inv_freq, original_inv_freq) are replaced with
+ # torch.empty_like() garbage by _move_missing_keys_from_meta_to_device.
+ # Recompute them here so _initialize_missing_keys restores correct values.
+ rope_type = module.config.rope_parameters.get("rope_type", "default")
+ if rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
+ else:
+ rope_init_fn = module.compute_default_rope_parameters
+ inv_freq, attention_scaling = rope_init_fn(module.config, module.inv_freq.device)
+ init.copy_(module.inv_freq, inv_freq)
+ init.copy_(module.original_inv_freq, inv_freq)
+ module.attention_scaling = attention_scaling
+ if isinstance(module, Granite4VisionWindowQFormerDownsampler):
+ embed_std = 1 / math.sqrt(module.query.shape[-1])
+ init.normal_(module.query, mean=0.0, std=embed_std)
+ init.normal_(module.image_positions, mean=0.0, std=embed_std)
+
+ def _deepstack_inject(
+ self,
+ hidden_states: torch.Tensor,
+ vision_mask: torch.Tensor,
+ features: torch.Tensor,
+ ) -> torch.Tensor:
+ """Add projected vision features into the image-token positions of hidden_states."""
+ vision_mask = vision_mask.to(hidden_states.device)
+ features = features.to(hidden_states.device, hidden_states.dtype)
+ return hidden_states.masked_scatter(
+ vision_mask,
+ (hidden_states[vision_mask] + features.flatten()).view(-1),
+ )
+
+
+@auto_docstring
+class Granite4VisionTextModel(Granite4VisionPreTrainedModel):
+ """Granite LLM backbone with deepstack feature injection support."""
+
+ base_model_prefix = "model"
+ _no_split_modules = ["Granite4VisionTextDecoderLayer"]
+
+ def __init__(self, config: Granite4VisionTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Granite4VisionTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Granite4VisionTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Granite4VisionTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+ self.embedding_multiplier = config.embedding_multiplier
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ vision_mask: torch.BoolTensor | None = None,
+ deepstack_features: dict | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ vision_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Boolean mask marking image token positions. Required when `deepstack_features` is provided.
+ deepstack_features (`dict[int, torch.Tensor]`, *optional*):
+ Mapping from LLM layer index to projected vision features of shape `(num_image_tokens, hidden_size)`.
+ Features are added into image-token positions of hidden states before the corresponding decoder layer.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ inputs_embeds = inputs_embeds * self.embedding_multiplier
+
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ ).unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
+
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ if deepstack_features is not None and layer_idx in deepstack_features:
+ hidden_states = self._deepstack_inject(hidden_states, vision_mask, deepstack_features[layer_idx])
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`tuple`):
+ The size of the input image in the format (width, height).
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ image_size = image_size.tolist()
+
+ height, width = select_best_resolution(image_size, grid_pinpoints)
+ return height // patch_size, width // patch_size
+
+
+def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
+ """
+ Calculate the number of patches after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (`torch.LongTensor` or `np.ndarray` or `tuple[int, int]`):
+ The size of the input image in the format (height, width). ?
+ grid_pinpoints (`List`):
+ A list containing possible resolutions. Each item in the list should be a tuple or list
+ of the form `(height, width)`.
+ patch_size (`int`):
+ The size of each image patch.
+
+ Returns:
+ int: the number of patches
+ """
+ if not isinstance(grid_pinpoints, list):
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
+
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
+ if not isinstance(image_size, (list, tuple)):
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
+ image_size = image_size.tolist()
+
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
+ height, width = best_resolution
+ num_patches = 0
+ # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ num_patches += 1
+ # add the base patch
+ num_patches += 1
+ return num_patches
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (`torch.Tensor`):
+ The image tensor, assumed to be of shape (num_channels, height, width).
+ original_size (`tuple`):
+ The original size of the image (height, width).
+
+ Returns:
+ `torch.Tensor`: The unpadded image tensor.
+ """
+ if not isinstance(original_size, (list, tuple)):
+ if not isinstance(original_size, (torch.Tensor, np.ndarray)):
+ raise TypeError(
+ f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
+ )
+ original_size = original_size.tolist()
+ original_height, original_width = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(round(original_height * scale_factor, 7))
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(round(original_width * scale_factor, 7))
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
+
+ return unpadded_tensor
+
+
+@auto_docstring(
+ custom_intro="""
+ The Llava-Next model which consists of a vision backbone and a language model without language modeling head.
+ """
+)
+class Granite4VisionModel(Granite4VisionPreTrainedModel):
+ base_model_prefix = "model"
+ config_class = Granite4VisionConfig
+
+ def __init__(self, config: Granite4VisionConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+ embed_std = 1 / math.sqrt(config.text_config.hidden_size)
+ self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
+
+ self.vocab_size = config.text_config.vocab_size
+
+ # Replace the inherited LLM backbone with our deepstack-aware subclass
+ self.language_model = Granite4VisionTextModel(config.text_config)
+
+ self.spatial_projectors = None
+ self.downsample_rate = config.downsample_rate
+ self.projector_dropout = config.projector_dropout
+
+ # Deepstack projectors: one per (vision_layer, llm_layer) pair
+ self.layerwise_projectors = nn.ModuleList(
+ [Granite4VisionWindowQFormerDownsampler(config) for _ in range(len(config.deepstack_layer_map))]
+ )
+
+ # Spatial sampling projectors: 4 offset groups (TL, TR, BL, BR)
+ if config.use_spatial_sampling:
+ self.spatial_projectors = nn.ModuleList(
+ [Granite4VisionWindowQFormerDownsampler(config, spatial_offset=i) for i in range(4)]
+ )
+
+ self.pad_token_id = (
+ self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1
+ )
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ """
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
+
+ Overrides the parent to apply downsample_rate to height/width calculations.
+ """
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ if self.layerwise_projectors is not None:
+ ds_rate = Fraction(self.downsample_rate)
+ height = int(height * ds_rate)
+ width = int(width * ds_rate)
+
+ if (
+ np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
+ and vision_feature_select_strategy == "default"
+ ):
+ raise ValueError(
+ "Image feature shape does not line up with the provided patch size. "
+ "You may be using the `default` vision_feature_select_strategy with a "
+ "visual encoder that does not have CLS token."
+ )
+
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None]
+ .expand(*image_feature.shape[:-1], 1)
+ .to(image_feature.device, image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
+ return new_image_features, feature_lens
+
+ @merge_with_config_defaults
+ @can_return_tuple
+ @auto_docstring(
+ custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
+ )
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: int | list[int] | None = None,
+ vision_feature_select_strategy: str | None = None,
+ **kwargs,
+ ) -> Granite4VisionImageFeaturesOutput:
+ """
+ Extract image features via deepstack (multi-layer) and spatial sampling projections.
+
+ Runs the vision tower once, then:
+ 1. Deepstack: for each (vision_layer, llm_layer) in deepstack_layer_map,
+ extracts features from that vision layer, downsamples via interpolation + QFormer,
+ and pairs them with the target LLM layer.
+ 2. Spatial: if enabled, extracts the spatial_vision_layer and creates 4 spatial
+ offset groups (TL, TR, BL, BR), each targeting a different LLM layer.
+ """
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ )
+ for imsize in image_sizes
+ ]
+
+ if pixel_values.dim() == 5:
+ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
+
+ output_hidden_states = kwargs.pop("output_hidden_states", None)
+ if output_hidden_states is None:
+ output_hidden_states = getattr(self.config, "output_hidden_states", False)
+ vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
+
+ # Deepstack features: extract from multiple vision layers, downsample via interpolation
+ all_features = []
+ for projection_idx, (vision_layer, llm_layer) in enumerate(self.config.deepstack_layer_map):
+ selected_feature = vision_outputs.hidden_states[vision_layer]
+
+ if vision_feature_select_strategy == "default":
+ selected_feature = selected_feature[:, 1:]
+
+ projected_features = self.layerwise_projectors[projection_idx](selected_feature)
+ projected_features = torch.split(projected_features, image_num_patches, dim=0)
+
+ packed_features, _ = self.pack_image_features(
+ projected_features,
+ image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ all_features.append((llm_layer, packed_features))
+
+ # Spatial features: extract 4 offset groups from a single vision layer
+ if self.config.use_spatial_sampling:
+ spatial_feature = vision_outputs.hidden_states[self.config.spatial_vision_layer]
+
+ if vision_feature_select_strategy == "default":
+ spatial_feature = spatial_feature[:, 1:]
+
+ for group_idx, llm_layer in enumerate(self.config.spatial_target_layers):
+ projected_group = self.spatial_projectors[group_idx](spatial_feature)
+ projected_group_split = torch.split(projected_group, image_num_patches, dim=0)
+
+ packed_group, _ = self.pack_image_features(
+ projected_group_split,
+ image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ all_features.append((llm_layer, packed_group))
+
+ return Granite4VisionImageFeaturesOutput(
+ deepstack_features=all_features,
+ hidden_states=vision_outputs.hidden_states if output_hidden_states else None,
+ )
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ torch_compilable_check(
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ image_sizes: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ vision_feature_layer: int | list[int] | None = None,
+ vision_feature_select_strategy: str | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Granite4VisionModelOutputWithPast:
+ r"""
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Build deepstack injection map and scatter initial image embeddings
+ deepstack_features = None
+ vision_mask = None
+ image_features = None
+ if pixel_values is not None and pixel_values.size(0) > 0:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ deepstack_features = {}
+ for idx, (llm_layer_idx, packed_features) in enumerate(image_features.deepstack_features):
+ concat_features = torch.cat(packed_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ if idx == 0:
+ vision_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=concat_features
+ )
+ inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0)
+ deepstack_features[llm_layer_idx] = concat_features
+
+ # Bypass nn.Module.__call__ overhead by calling the unbound forward directly.
+ # nn.Module.__call__ has non-trivial per-call overhead that accumulates across 40 layers × N steps.
+ outputs = Granite4VisionTextModel.forward(
+ self.language_model,
+ input_ids=None,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ vision_mask=vision_mask,
+ deepstack_features=deepstack_features,
+ **kwargs,
+ )
+
+ return Granite4VisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ deepstack_features=image_features.deepstack_features if pixel_values is not None else None,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The LLAVA-NeXT model which consists of a vision backbone and a language model.
+ """
+)
+class Granite4VisionForConditionalGeneration(Granite4VisionPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+ config_class = Granite4VisionConfig
+
+ def __init__(self, config: Granite4VisionConfig):
+ super().__init__(config)
+ self.model = Granite4VisionModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ return self.model.pack_image_features(
+ image_features=image_features,
+ image_sizes=image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=image_newline,
+ )
+
+ @merge_with_config_defaults
+ @can_return_tuple
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: int | list[int] | list[int] | None = None,
+ vision_feature_select_strategy: str | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | BaseModelOutputWithPooling:
+ r"""
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
+ The tensors corresponding to the input images.
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
+ Actual image size of each images (H, W).
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ """
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ **kwargs,
+ )
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ image_sizes: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ vision_feature_layer: int | list[int] | None = None,
+ vision_feature_select_strategy: str | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Granite4VisionCausalLMOutputWithPast:
+ r"""
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
+ If `"full"`, the full vision features are used.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import httpx
+ >>> from io import BytesIO
+ >>> from transformers import AutoProcessor, Granite4VisionForConditionalGeneration
+
+ >>> model = Granite4VisionForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
+
+ >>> prompt = "[INST] \nWhat is shown in this image? [/INST]"
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> with httpx.stream("GET", url) as response:
+ ... image = Image.open(BytesIO(response.read()))
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=30)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
+ ```"""
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ outputs = self.model(
+ input_ids,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ return_dict=True,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+
+ loss = None
+ logits = self.lm_head(hidden_states)
+ logits = logits / self.config.text_config.logits_scaling
+ if labels is not None:
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.text_config.vocab_size,
+ **kwargs,
+ )
+
+ if isinstance(logits_to_keep, int) and logits_to_keep > 0:
+ logits = logits[:, -logits_to_keep:, :]
+
+ return Granite4VisionCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ deepstack_features=outputs.deepstack_features,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ image_sizes=None,
+ attention_mask=None,
+ logits_to_keep=None,
+ is_first_iteration=False,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ logits_to_keep=logits_to_keep,
+ is_first_iteration=is_first_iteration,
+ **kwargs,
+ )
+
+ # Pixel values are used only in the first iteration if available
+ # In subsequent iterations, they are already merged with text and cached
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
+ # iteration with a question and cached system prompt (continue generate from cache)
+ if is_first_iteration or not kwargs.get("use_cache", True):
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["image_sizes"] = image_sizes
+
+ return model_inputs
+
+
+__all__ = [
+ "Granite4VisionPreTrainedModel",
+ "Granite4VisionTextModel",
+ "Granite4VisionModel",
+ "Granite4VisionForConditionalGeneration",
+]
diff --git a/src/transformers/models/granite4_vision/modular_granite4_vision.py b/src/transformers/models/granite4_vision/modular_granite4_vision.py
new file mode 100644
index 000000000000..3559a3298ce3
--- /dev/null
+++ b/src/transformers/models/granite4_vision/modular_granite4_vision.py
@@ -0,0 +1,830 @@
+# Copyright 2026 IBM and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from fractions import Fraction
+
+import numpy as np
+import torch
+from torch import nn
+
+from ... import initialization as init
+from ...cache_utils import Cache
+from ...configuration_utils import PreTrainedConfig
+from ...image_processing_utils import select_best_resolution
+from ...masking_utils import create_causal_mask
+from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.output_capturing import capture_outputs
+from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
+from ..granite.configuration_granite import GraniteConfig
+from ..granite.modeling_granite import GraniteAttention, GraniteDecoderLayer, GraniteModel, GraniteRotaryEmbedding
+from ..llava_next.configuration_llava_next import LlavaNextConfig
+from ..llava_next.modeling_llava_next import (
+ LlavaNextCausalLMOutputWithPast,
+ LlavaNextForConditionalGeneration,
+ LlavaNextModel,
+ LlavaNextModelOutputWithPast,
+ LlavaNextPreTrainedModel,
+ get_anyres_image_grid_shape,
+ image_size_to_num_patches,
+ unpad_image,
+)
+from ..llava_next.processing_llava_next import LlavaNextProcessor
+
+
+# ── Output classes ──────────────────────────────────────────────────────────
+
+
+@dataclass
+class Granite4VisionModelOutputWithPast(LlavaNextModelOutputWithPast):
+ """
+ Args:
+ deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*):
+ List of `(llm_layer_idx, packed_features)` pairs produced by the deepstack
+ and spatial projectors. Each entry targets one LLM decoder layer; `packed_features`
+ is a per-image list of tensors of shape `(num_image_tokens, hidden_size)`.
+ """
+
+ deepstack_features: list | None = None
+
+
+@dataclass
+class Granite4VisionCausalLMOutputWithPast(LlavaNextCausalLMOutputWithPast):
+ """
+ Args:
+ deepstack_features (`list[tuple[int, list[torch.Tensor]]]`, *optional*):
+ List of `(llm_layer_idx, packed_features)` pairs. See `Granite4VisionModelOutputWithPast`.
+ """
+
+ deepstack_features: list | None = None
+
+
+@dataclass
+class Granite4VisionImageFeaturesOutput(BaseModelOutputWithPooling):
+ """
+ Output of `Granite4VisionModel.get_image_features`.
+
+ Args:
+ deepstack_features (`list[tuple[int, list[torch.Tensor]]]`):
+ List of `(llm_layer_idx, packed_features)` pairs. Each entry targets one LLM
+ decoder layer; `packed_features` is a per-image list of tensors of shape
+ `(num_image_tokens, hidden_size)`.
+ """
+
+ deepstack_features: list | None = None
+
+
+# ── Config ──────────────────────────────────────────────────────────────────
+
+
+class Granite4VisionTextConfig(GraniteConfig):
+ model_type = "granite4_vision_text"
+ base_config_key = "text_config"
+
+
+class Granite4VisionConfig(LlavaNextConfig):
+ r"""
+ downsample_rate (`str`, *optional*):
+ Fractional downsample rate for the Window Q-Former projector, e.g. `"1/4"` or `"3/8"`.
+ The numerator is the query window side, the denominator is the key window side.
+ deepstack_layer_map (`list`, *optional*):
+ List of `[vision_layer_idx, llm_layer_idx]` pairs. Features from each vision encoder layer
+ are projected and injected at the corresponding LLM decoder layer during forward pass.
+ use_spatial_sampling (`bool`, *optional*, defaults to `False`):
+ Whether to enable spatial offset sampling, which creates 4 groups (TL, TR, BL, BR) from
+ a single vision layer, each injected at a different LLM layer.
+ spatial_vision_layer (`int`, *optional*, defaults to `-1`):
+ Index of the vision encoder layer used for spatial sampling.
+ spatial_target_layers (`list`, *optional*, defaults to `[12, 15, 18, 21]`):
+ Target LLM layers for the 4 spatial offset groups.
+ projector_dropout (`float`, *optional*, defaults to `0.1`):
+ Dropout probability in the Window Q-Former projector.
+ qformer_config (`dict` or `Blip2QFormerConfig`, *optional*):
+ Configuration for the Window Q-Former projector. If `None`, defaults are derived from
+ `vision_config.hidden_size`.
+ image_grid_pinpoints (`list`, *optional*):
+ A list of possible resolutions to use for processing high resolution images. Each item in the list should be a
+ tuple or list of the form `(height, width)`.
+ """
+
+ model_type = "granite4_vision"
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "qformer_config": AutoConfig}
+
+ multimodal_projector_bias = AttributeError()
+ projector_hidden_act = AttributeError()
+
+ downsample_rate: str | None = None
+ deepstack_layer_map: list | None = None
+ use_spatial_sampling: bool = False
+ spatial_vision_layer: int = -1
+ spatial_target_layers: list | None = None
+ projector_dropout: float = 0.1
+ qformer_config: dict | PreTrainedConfig | None = None
+
+ def __post_init__(self, **kwargs):
+ if self.deepstack_layer_map is not None:
+ self.deepstack_layer_map = [(int(v), int(l)) for v, l in self.deepstack_layer_map]
+
+ if self.spatial_target_layers is None:
+ self.spatial_target_layers = [12, 15, 18, 21]
+
+ # Peek at vision hidden_size before super() to build a fully-specified qformer_config,
+ # avoiding any runtime field patching after super().
+ if isinstance(self.vision_config, dict):
+ vision_hidden_size = self.vision_config.get("hidden_size", 1152)
+ elif self.vision_config is not None:
+ vision_hidden_size = self.vision_config.hidden_size
+ else:
+ vision_hidden_size = 1152
+
+ # Convert qformer_config dict → object before super() so _attn_implementation.setter
+ # (called inside super().__post_init__) sees a config object, not a raw dict.
+ if isinstance(self.qformer_config, dict):
+ model_type = self.qformer_config.get("model_type", "blip_2_qformer")
+ self.qformer_config = CONFIG_MAPPING[model_type](**self.qformer_config)
+ elif self.qformer_config is None:
+ self.qformer_config = CONFIG_MAPPING["blip_2_qformer"](
+ num_hidden_layers=1,
+ intermediate_size=3072,
+ cross_attention_frequency=1,
+ max_position_embeddings=2048,
+ use_qformer_text_input=False,
+ hidden_size=vision_hidden_size,
+ num_attention_heads=vision_hidden_size // 64,
+ encoder_hidden_size=vision_hidden_size,
+ )
+
+ super().__post_init__(**kwargs)
+
+
+# ── Processor ───────────────────────────────────────────────────────────────
+
+
+class Granite4VisionProcessor(LlavaNextProcessor):
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ patch_size=None,
+ vision_feature_select_strategy=None,
+ chat_template=None,
+ image_token="",
+ num_additional_image_tokens=0,
+ downsample_rate=None,
+ **kwargs,
+ ):
+ r"""
+ patch_size (`int`, *optional*):
+ Patch size from the vision tower.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Should be same as in model's config.
+ image_token (`str`, *optional*, defaults to `""`):
+ Special token used to denote image location.
+ num_additional_image_tokens (`int`, *optional*, defaults to `0`):
+ Number of additional tokens added to the image embeddings, such as CLS (+1).
+ downsample_rate (`str`, *optional*):
+ Fractional downsample rate (e.g. `"1/4"`), used to adjust the number of image tokens
+ when computing token counts for padding/truncation.
+ """
+ super().__init__(
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ patch_size=patch_size,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ chat_template=chat_template,
+ image_token=image_token,
+ num_additional_image_tokens=num_additional_image_tokens,
+ )
+ self.downsample_rate = downsample_rate
+
+ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
+ image_grid_pinpoints = self.image_processor.image_grid_pinpoints
+
+ height_best_resolution, width_best_resolution = select_best_resolution(
+ [orig_height, orig_width], image_grid_pinpoints
+ )
+ scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
+
+ patches_height = height // self.patch_size
+ patches_width = width // self.patch_size
+ if self.downsample_rate is not None:
+ ds_rate = Fraction(self.downsample_rate)
+ patches_height = int(patches_height * ds_rate)
+ patches_width = int(patches_width * ds_rate)
+
+ unpadded_features, newline_features = self._get_unpadded_features(
+ orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
+ )
+ base_features = patches_height * patches_width + self.num_additional_image_tokens
+ num_image_tokens = unpadded_features + newline_features + base_features
+ return num_image_tokens
+
+
+# ── Downsampling helpers ─────────────────────────────────────────────────────
+
+
+def interpolate_downsample(image_features: torch.Tensor, orig_side: int, new_side: int) -> torch.Tensor:
+ """Spatial downsampling via area interpolation."""
+ batch, _, channels = image_features.size()
+ spatial = image_features.view(batch, orig_side, orig_side, channels).permute(0, 3, 1, 2)
+ spatial = torch.nn.functional.interpolate(spatial, size=(new_side, new_side), mode="area")
+ return spatial.permute(0, 2, 3, 1).flatten(1, 2)
+
+
+def spatial_offset_downsample(image_features: torch.Tensor, orig_side: int, offset: int = 0) -> torch.Tensor:
+ """Sample one position from each 2x2 block; offset selects which corner (0=TL,1=TR,2=BL,3=BR)."""
+ offset_h, offset_w = [(0, 0), (0, 1), (1, 0), (1, 1)][offset]
+ new_side = orig_side // 2
+ batch, _, channels = image_features.shape
+ grid = image_features.reshape(batch, orig_side, orig_side, channels)
+ grid = grid.reshape(batch, new_side, 2, new_side, 2, channels)
+ return grid[:, :, offset_h, :, offset_w, :].reshape(batch, -1, channels)
+
+
+class Granite4VisionWindowQFormerDownsampler(nn.Module):
+ """Window-based QFormer downsampler that processes image patches in windows."""
+
+ def __init__(self, config, spatial_offset=None):
+ super().__init__()
+ llm_hidden_size = config.text_config.hidden_size
+ vision_hidden_size = config.vision_config.hidden_size
+
+ self.dropout = nn.Dropout(config.projector_dropout)
+ self._spatial_offset = spatial_offset
+ self._downsample_rate = config.downsample_rate
+
+ self.qformer = AutoModel.from_config(config.qformer_config)
+
+ self.image_side = config.vision_config.image_size // config.vision_config.patch_size
+ query_side_str, window_side_str = config.downsample_rate.split("/")
+ self.query_side, self.window_side = int(query_side_str), int(window_side_str)
+ self.query_length = self.query_side**2
+ self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6)
+ self.query = nn.Parameter(torch.empty(1, self.query_length, vision_hidden_size))
+ self.image_positions = nn.Parameter(torch.empty(1, self.window_side**2, vision_hidden_size))
+ self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True)
+
+ def _windowed_raster(self, features, side, window_size):
+ """(B, side*side, C) raster -> (B*num_win*num_win, window_size*window_size, C)"""
+ batch, _, channels = features.shape
+ num_win = side // window_size
+ return (
+ features.view(batch, side, side, channels)
+ .view(batch, num_win, window_size, num_win, window_size, channels)
+ .transpose(2, 3)
+ .flatten(0, 2)
+ .flatten(1, 2)
+ )
+
+ def _unwindowed_raster(self, windowed_features, num_win, window_size):
+ """(B*num_win*num_win, window_size*window_size, C) -> (B, (num_win*window_size)^2, C)"""
+ batch_win, _, channels = windowed_features.shape
+ if batch_win % (num_win * num_win) != 0:
+ raise ValueError(f"Expected batch_win ({batch_win}) to be divisible by num_win^2 ({num_win**2}).")
+ batch = batch_win // (num_win * num_win)
+ side = num_win * window_size
+ return (
+ windowed_features.view(batch, num_win, num_win, window_size, window_size, channels)
+ .transpose(2, 3)
+ .contiguous()
+ .view(batch, side, side, channels)
+ .flatten(1, 2)
+ )
+
+ def forward(self, image_features):
+ batch, hw, channels = image_features.shape
+ if self.image_side * self.image_side != hw:
+ raise ValueError(
+ f"Expected image_features with {self.image_side**2} spatial tokens, got {hw}. "
+ "Check that the vision encoder image_size and patch_size match the config."
+ )
+ num_windows = self.image_side // self.window_side
+ interp_side = int(self.image_side * Fraction(self._downsample_rate))
+ image_features = self.norm(image_features)
+ enc = self._windowed_raster(image_features, self.image_side, self.window_side)
+
+ if self._spatial_offset is not None:
+ downsampled = spatial_offset_downsample(image_features, self.image_side, self._spatial_offset)
+ else:
+ downsampled = interpolate_downsample(image_features, self.image_side, interp_side)
+
+ downsampled_side = num_windows * self.query_side
+ downsampled_w = self._windowed_raster(downsampled, downsampled_side, self.query_side)
+
+ query_embeds = self.query + downsampled_w
+ encoder_embeds = self.dropout(enc + self.image_positions)
+ out_w = self.qformer(
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_embeds,
+ return_dict=True,
+ ).last_hidden_state
+
+ out = self._unwindowed_raster(out_w, num_win=num_windows, window_size=self.query_side)
+ out = self.dropout(out)
+ return self.out_linear(out)
+
+
+# ── Model ───────────────────────────────────────────────────────────────────
+
+
+class Granite4VisionTextRotaryEmbedding(GraniteRotaryEmbedding):
+ pass
+
+
+class Granite4VisionTextAttention(GraniteAttention):
+ pass
+
+
+class Granite4VisionTextDecoderLayer(GraniteDecoderLayer):
+ pass
+
+
+class Granite4VisionPreTrainedModel(LlavaNextPreTrainedModel):
+ _can_record_outputs = {
+ "hidden_states": Granite4VisionTextDecoderLayer,
+ "attentions": Granite4VisionTextAttention,
+ }
+
+ def _deepstack_inject(
+ self,
+ hidden_states: torch.Tensor,
+ vision_mask: torch.Tensor,
+ features: torch.Tensor,
+ ) -> torch.Tensor:
+ """Add projected vision features into the image-token positions of hidden_states."""
+ vision_mask = vision_mask.to(hidden_states.device)
+ features = features.to(hidden_states.device, hidden_states.dtype)
+ return hidden_states.masked_scatter(
+ vision_mask,
+ (hidden_states[vision_mask] + features.flatten()).view(-1),
+ )
+
+ def _init_weights(self, module):
+ super()._init_weights(module)
+ if isinstance(module, nn.Embedding):
+ init.normal_(
+ module.weight,
+ mean=0.0,
+ std=getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range),
+ )
+ if module.padding_idx is not None:
+ init.zeros_(module.weight[module.padding_idx])
+ elif isinstance(module, nn.LayerNorm) or (
+ hasattr(module, "weight") and hasattr(module, "variance_epsilon") and not isinstance(module, nn.Linear)
+ ):
+ init.ones_(module.weight)
+ if hasattr(module, "bias") and module.bias is not None:
+ init.zeros_(module.bias)
+ if isinstance(module, Granite4VisionTextRotaryEmbedding):
+ # Non-persistent buffers (inv_freq, original_inv_freq) are replaced with
+ # torch.empty_like() garbage by _move_missing_keys_from_meta_to_device.
+ # Recompute them here so _initialize_missing_keys restores correct values.
+ rope_type = module.config.rope_parameters.get("rope_type", "default")
+ if rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
+ else:
+ rope_init_fn = module.compute_default_rope_parameters
+ inv_freq, attention_scaling = rope_init_fn(module.config, module.inv_freq.device)
+ init.copy_(module.inv_freq, inv_freq)
+ init.copy_(module.original_inv_freq, inv_freq)
+ module.attention_scaling = attention_scaling
+ if isinstance(module, Granite4VisionWindowQFormerDownsampler):
+ embed_std = 1 / math.sqrt(module.query.shape[-1])
+ init.normal_(module.query, mean=0.0, std=embed_std)
+ init.normal_(module.image_positions, mean=0.0, std=embed_std)
+
+
+class Granite4VisionTextModel(Granite4VisionPreTrainedModel, GraniteModel):
+ """Granite LLM backbone with deepstack feature injection support."""
+
+ base_model_prefix = "model"
+ _no_split_modules = ["Granite4VisionTextDecoderLayer"]
+
+ def __init__(self, config: Granite4VisionTextConfig):
+ super().__init__(config)
+
+ @capture_outputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ vision_mask: torch.BoolTensor | None = None,
+ deepstack_features: dict | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ r"""
+ vision_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Boolean mask marking image token positions. Required when `deepstack_features` is provided.
+ deepstack_features (`dict[int, torch.Tensor]`, *optional*):
+ Mapping from LLM layer index to projected vision features of shape `(num_image_tokens, hidden_size)`.
+ Features are added into image-token positions of hidden states before the corresponding decoder layer.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ inputs_embeds = inputs_embeds * self.embedding_multiplier
+
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = (
+ torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
+ ).unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
+
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ if deepstack_features is not None and layer_idx in deepstack_features:
+ hidden_states = self._deepstack_inject(hidden_states, vision_mask, deepstack_features[layer_idx])
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+class Granite4VisionModel(LlavaNextModel):
+ config_class = Granite4VisionConfig
+
+ def __init__(self, config: Granite4VisionConfig):
+ super().__init__(config)
+
+ # Replace parent's single multi_modal_projector with layerwise_projectors
+ del self.multi_modal_projector
+
+ self.spatial_projectors = None
+ self.downsample_rate = config.downsample_rate
+ self.projector_dropout = config.projector_dropout
+
+ # Deepstack projectors: one per (vision_layer, llm_layer) pair
+ self.layerwise_projectors = nn.ModuleList(
+ [Granite4VisionWindowQFormerDownsampler(config) for _ in range(len(config.deepstack_layer_map))]
+ )
+
+ # Spatial sampling projectors: 4 offset groups (TL, TR, BL, BR)
+ if config.use_spatial_sampling:
+ self.spatial_projectors = nn.ModuleList(
+ [Granite4VisionWindowQFormerDownsampler(config, spatial_offset=i) for i in range(4)]
+ )
+
+ self.pad_token_id = (
+ self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1
+ )
+
+ # Replace the inherited LLM backbone with our deepstack-aware subclass
+ self.language_model = Granite4VisionTextModel(config.text_config)
+
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
+ """
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
+
+ Overrides the parent to apply downsample_rate to height/width calculations.
+ """
+ new_image_features = []
+ feature_lens = []
+ for image_idx, image_feature in enumerate(image_features):
+ if image_feature.shape[0] > 1:
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
+
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+ image_sizes[image_idx],
+ self.config.image_grid_pinpoints,
+ self.config.vision_config.image_size,
+ )
+ if self.layerwise_projectors is not None:
+ ds_rate = Fraction(self.downsample_rate)
+ height = int(height * ds_rate)
+ width = int(width * ds_rate)
+
+ if (
+ np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
+ and vision_feature_select_strategy == "default"
+ ):
+ raise ValueError(
+ "Image feature shape does not line up with the provided patch size. "
+ "You may be using the `default` vision_feature_select_strategy with a "
+ "visual encoder that does not have CLS token."
+ )
+
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ if image_newline is not None:
+ image_feature = torch.cat(
+ (
+ image_feature,
+ image_newline[:, None, None]
+ .expand(*image_feature.shape[:-1], 1)
+ .to(image_feature.device, image_feature.dtype),
+ ),
+ dim=-1,
+ )
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else:
+ image_feature = image_feature[0]
+ if image_newline is not None:
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
+ new_image_features.append(image_feature)
+ feature_lens.append(image_feature.size(0))
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
+ return new_image_features, feature_lens
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ image_sizes: torch.Tensor,
+ vision_feature_layer: int | list[int] | None = None,
+ vision_feature_select_strategy: str | None = None,
+ **kwargs,
+ ) -> Granite4VisionImageFeaturesOutput:
+ """
+ Extract image features via deepstack (multi-layer) and spatial sampling projections.
+
+ Runs the vision tower once, then:
+ 1. Deepstack: for each (vision_layer, llm_layer) in deepstack_layer_map,
+ extracts features from that vision layer, downsamples via interpolation + QFormer,
+ and pairs them with the target LLM layer.
+ 2. Spatial: if enabled, extracts the spatial_vision_layer and creates 4 spatial
+ offset groups (TL, TR, BL, BR), each targeting a different LLM layer.
+ """
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ image_num_patches = [
+ image_size_to_num_patches(
+ image_size=imsize,
+ grid_pinpoints=self.config.image_grid_pinpoints,
+ patch_size=self.config.vision_config.image_size,
+ )
+ for imsize in image_sizes
+ ]
+
+ if pixel_values.dim() == 5:
+ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
+ elif pixel_values.dim() != 4:
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
+
+ output_hidden_states = kwargs.pop("output_hidden_states", None)
+ if output_hidden_states is None:
+ output_hidden_states = getattr(self.config, "output_hidden_states", False)
+ vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
+
+ # Deepstack features: extract from multiple vision layers, downsample via interpolation
+ all_features = []
+ for projection_idx, (vision_layer, llm_layer) in enumerate(self.config.deepstack_layer_map):
+ selected_feature = vision_outputs.hidden_states[vision_layer]
+
+ if vision_feature_select_strategy == "default":
+ selected_feature = selected_feature[:, 1:]
+
+ projected_features = self.layerwise_projectors[projection_idx](selected_feature)
+ projected_features = torch.split(projected_features, image_num_patches, dim=0)
+
+ packed_features, _ = self.pack_image_features(
+ projected_features,
+ image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ all_features.append((llm_layer, packed_features))
+
+ # Spatial features: extract 4 offset groups from a single vision layer
+ if self.config.use_spatial_sampling:
+ spatial_feature = vision_outputs.hidden_states[self.config.spatial_vision_layer]
+
+ if vision_feature_select_strategy == "default":
+ spatial_feature = spatial_feature[:, 1:]
+
+ for group_idx, llm_layer in enumerate(self.config.spatial_target_layers):
+ projected_group = self.spatial_projectors[group_idx](spatial_feature)
+ projected_group_split = torch.split(projected_group, image_num_patches, dim=0)
+
+ packed_group, _ = self.pack_image_features(
+ projected_group_split,
+ image_sizes,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_newline=self.image_newline,
+ )
+
+ all_features.append((llm_layer, packed_group))
+
+ return Granite4VisionImageFeaturesOutput(
+ deepstack_features=all_features,
+ hidden_states=vision_outputs.hidden_states if output_hidden_states else None,
+ )
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ image_sizes: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ vision_feature_layer: int | list[int] | None = None,
+ vision_feature_select_strategy: str | None = None,
+ use_cache: bool | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Granite4VisionModelOutputWithPast:
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Build deepstack injection map and scatter initial image embeddings
+ deepstack_features = None
+ vision_mask = None
+ image_features = None
+ if pixel_values is not None and pixel_values.size(0) > 0:
+ image_features = self.get_image_features(
+ pixel_values,
+ image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ )
+
+ deepstack_features = {}
+ for idx, (llm_layer_idx, packed_features) in enumerate(image_features.deepstack_features):
+ concat_features = torch.cat(packed_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ if idx == 0:
+ vision_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=concat_features
+ )
+ inputs_embeds = inputs_embeds.masked_fill(vision_mask, 0.0)
+ deepstack_features[llm_layer_idx] = concat_features
+
+ # Bypass nn.Module.__call__ overhead by calling the unbound forward directly.
+ # nn.Module.__call__ has non-trivial per-call overhead that accumulates across 40 layers × N steps.
+ outputs = Granite4VisionTextModel.forward(
+ self.language_model,
+ input_ids=None,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ vision_mask=vision_mask,
+ deepstack_features=deepstack_features,
+ **kwargs,
+ )
+
+ return Granite4VisionModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ deepstack_features=image_features.deepstack_features if pixel_values is not None else None,
+ )
+
+
+# ── ForConditionalGeneration ────────────────────────────────────────────────
+
+
+class Granite4VisionForConditionalGeneration(LlavaNextForConditionalGeneration):
+ config_class = Granite4VisionConfig
+
+ def __init__(self, config: Granite4VisionConfig):
+ super().__init__(config)
+ self.model = Granite4VisionModel(config)
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ pixel_values: torch.FloatTensor | None = None,
+ image_sizes: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ vision_feature_layer: int | list[int] | None = None,
+ vision_feature_select_strategy: str | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | Granite4VisionCausalLMOutputWithPast:
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ outputs = self.model(
+ input_ids,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ return_dict=True,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+
+ loss = None
+ logits = self.lm_head(hidden_states)
+ logits = logits / self.config.text_config.logits_scaling
+ if labels is not None:
+ loss = self.loss_function(
+ logits,
+ labels,
+ vocab_size=self.config.text_config.vocab_size,
+ **kwargs,
+ )
+
+ if isinstance(logits_to_keep, int) and logits_to_keep > 0:
+ logits = logits[:, -logits_to_keep:, :]
+
+ return Granite4VisionCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ deepstack_features=outputs.deepstack_features,
+ )
+
+
+__all__ = [
+ "Granite4VisionConfig",
+ "Granite4VisionTextConfig",
+ "Granite4VisionProcessor",
+ "Granite4VisionPreTrainedModel",
+ "Granite4VisionTextModel",
+ "Granite4VisionModel",
+ "Granite4VisionForConditionalGeneration",
+]
diff --git a/src/transformers/models/granite4_vision/processing_granite4_vision.py b/src/transformers/models/granite4_vision/processing_granite4_vision.py
new file mode 100644
index 000000000000..f8fa2c2b09ce
--- /dev/null
+++ b/src/transformers/models/granite4_vision/processing_granite4_vision.py
@@ -0,0 +1,237 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/granite4_vision/modular_granite4_vision.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_granite4_vision.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2026 IBM and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from fractions import Fraction
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_processing_utils import select_best_resolution
+from ...image_utils import ImageInput, SizeDict, get_image_size, to_numpy_array
+from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import auto_docstring
+
+
+class Granite4VisionProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "do_pad": True,
+ },
+ }
+
+
+@auto_docstring
+class Granite4VisionProcessor(ProcessorMixin):
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ patch_size=None,
+ vision_feature_select_strategy=None,
+ chat_template=None,
+ image_token="",
+ num_additional_image_tokens=0,
+ downsample_rate=None,
+ **kwargs,
+ ):
+ r"""
+ patch_size (`int`, *optional*):
+ Patch size from the vision tower.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Should be same as in model's config.
+ image_token (`str`, *optional*, defaults to `""`):
+ Special token used to denote image location.
+ num_additional_image_tokens (`int`, *optional*, defaults to `0`):
+ Number of additional tokens added to the image embeddings, such as CLS (+1).
+ downsample_rate (`str`, *optional*):
+ Fractional downsample rate (e.g. `"1/4"`), used to adjust the number of image tokens
+ when computing token counts for padding/truncation.
+ """
+ self.patch_size = patch_size
+ self.num_additional_image_tokens = num_additional_image_tokens
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
+ self.image_token_id = (
+ tokenizer.image_token_id
+ if getattr(tokenizer, "image_token_id", None)
+ else tokenizer.convert_tokens_to_ids(self.image_token)
+ )
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+ self.downsample_rate = downsample_rate
+
+ @auto_docstring
+ def __call__(
+ self,
+ images: ImageInput | None = None,
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
+ **kwargs: Unpack[Granite4VisionProcessorKwargs],
+ ) -> BatchFeature:
+ r"""
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ if images is None and text is None:
+ raise ValueError("You have to specify at least images or text.")
+
+ output_kwargs = self._merge_kwargs(
+ Granite4VisionProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ if images is not None:
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ else:
+ image_inputs = {}
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
+
+ prompt_strings = text
+ if image_inputs:
+ image_sizes = iter(image_inputs["image_sizes"])
+ height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
+ prompt_strings = []
+ for sample in text:
+ while self.image_token in sample:
+ image_size = next(image_sizes)
+ if not isinstance(image_size, (list, tuple)):
+ # cast to list to avoid numerical precision errors when calculating unpadding
+ image_size = image_size.tolist()
+ orig_height, orig_width = image_size
+ num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
+ if self.vision_feature_select_strategy == "default":
+ num_image_tokens -= 1
+ sample = sample.replace(self.image_token, "" * num_image_tokens, 1)
+ prompt_strings.append(sample)
+ prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings]
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
+ image_grid_pinpoints = self.image_processor.image_grid_pinpoints
+
+ height_best_resolution, width_best_resolution = select_best_resolution(
+ [orig_height, orig_width], image_grid_pinpoints
+ )
+ scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
+
+ patches_height = height // self.patch_size
+ patches_width = width // self.patch_size
+ if self.downsample_rate is not None:
+ ds_rate = Fraction(self.downsample_rate)
+ patches_height = int(patches_height * ds_rate)
+ patches_width = int(patches_width * ds_rate)
+
+ unpadded_features, newline_features = self._get_unpadded_features(
+ orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
+ )
+ base_features = patches_height * patches_width + self.num_additional_image_tokens
+ num_image_tokens = unpadded_features + newline_features + base_features
+ return num_image_tokens
+
+ def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
+ """
+ Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
+ because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
+ patches an image is divided into and get the number of features from that.
+ """
+ current_height = patches_height * scale_height
+ current_width = patches_width * scale_width
+
+ original_aspect_ratio = width / height
+ current_aspect_ratio = current_width / current_height
+ if original_aspect_ratio > current_aspect_ratio:
+ new_height = int(round(height * (current_width / width), 7))
+ padding = (current_height - new_height) // 2
+ current_height -= padding * 2
+ else:
+ new_width = int(round(width * (current_height / height), 7))
+ padding = (current_width - new_width) // 2
+ current_width -= padding * 2
+
+ unpadded_features = current_height * current_width
+ newline_features = current_height
+ return (unpadded_features, newline_features)
+
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
+ """
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
+ Args:
+ image_sizes (list[list[str]], *optional*):
+ The input sizes formatted as (height, width) per each image.
+ Returns:
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
+ input modalities, along with other useful data.
+ """
+ vision_data = {}
+ if image_sizes is not None:
+ images_kwargs = Granite4VisionProcessorKwargs._defaults.get("images_kwargs", {})
+ images_kwargs.update(kwargs)
+
+ size = images_kwargs.get("size", None) or self.image_processor.size
+ if isinstance(size, SizeDict):
+ size = (
+ (size.shortest_edge, size.shortest_edge)
+ if size.shortest_edge is not None
+ else (min(size.height, size.width), min(size.height, size.width))
+ )
+ else:
+ size = (
+ (size["shortest_edge"], size["shortest_edge"])
+ if "shortest_edge" in size
+ else (min(size["height"], size["width"]), min(size["height"], size["width"]))
+ )
+ processed_height, processed_width = size
+
+ batch_num_image_tokens = []
+ num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch`
+ for image_size in image_sizes:
+ orig_height, orig_width = image_size
+ num_image_tokens = self._get_number_of_features(
+ orig_height, orig_width, processed_height, processed_width
+ )
+ if self.vision_feature_select_strategy == "default":
+ num_image_tokens -= 1
+ batch_num_image_tokens.append(num_image_tokens)
+ vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
+
+ return MultiModalData(**vision_data)
+
+
+__all__ = ["Granite4VisionProcessor"]
diff --git a/tests/models/granite4_vision/__init__.py b/tests/models/granite4_vision/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/granite4_vision/test_modeling_granite4_vision.py b/tests/models/granite4_vision/test_modeling_granite4_vision.py
new file mode 100644
index 000000000000..10dafd8b3abe
--- /dev/null
+++ b/tests/models/granite4_vision/test_modeling_granite4_vision.py
@@ -0,0 +1,241 @@
+# Copyright 2026 IBM and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Granite4Vision model."""
+
+import unittest
+
+import pytest
+import requests
+
+from transformers import (
+ AutoProcessor,
+ CLIPVisionConfig,
+ Granite4VisionConfig,
+ Granite4VisionForConditionalGeneration,
+ Granite4VisionModel,
+ GraniteConfig,
+ is_torch_available,
+ is_vision_available,
+)
+from transformers.testing_utils import (
+ cleanup,
+ require_torch,
+ slow,
+ torch_device,
+)
+
+from ...test_modeling_common import floats_tensor
+from ...vlm_tester import VLMModelTest, VLMModelTester
+
+
+if is_torch_available():
+ import torch
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+class Granite4VisionModelTester(VLMModelTester):
+ base_model_class = Granite4VisionModel
+ config_class = Granite4VisionConfig
+ conditional_generation_class = Granite4VisionForConditionalGeneration
+ text_config_class = GraniteConfig
+ vision_config_class = CLIPVisionConfig
+
+ def __init__(self, parent, **kwargs):
+ # Vision hidden_size must be divisible by 64 (QFormer num_attention_heads = hidden_size // 64)
+ kwargs.setdefault("hidden_size", 64)
+ kwargs.setdefault("intermediate_size", 64)
+ kwargs.setdefault("num_attention_heads", 2)
+ kwargs.setdefault("num_key_value_heads", 2)
+ kwargs.setdefault("num_hidden_layers", 2)
+ # Image/patch sizes: image_side = image_size // patch_size must be divisible by window_side
+ kwargs.setdefault("image_size", 8)
+ kwargs.setdefault("patch_size", 2)
+ kwargs.setdefault("projection_dim", 64)
+ kwargs.setdefault("num_patches_per_image", 2)
+ # Granite4Vision-specific
+ kwargs.setdefault("downsample_rate", "1/2")
+ kwargs.setdefault("deepstack_layer_map", [[1, 0]])
+ kwargs.setdefault("use_spatial_sampling", False)
+ kwargs.setdefault("projector_dropout", 0.0)
+ kwargs.setdefault("image_token_index", kwargs.get("image_token_id", 3))
+
+ # Compute num_image_tokens after downsampling:
+ # image_side = image_size/patch_size = 4, ds 1/2 -> patches_h = patches_w = 2
+ # pinpoints [[8,8]] -> scale 1x1 -> current_h = current_w = 2
+ # unpadded = 2*2 = 4, newline = 2, base = 2*2 = 4 -> total = 10
+ kwargs.setdefault("num_image_tokens", 10)
+
+ super().__init__(parent, **kwargs)
+
+ def create_pixel_values(self):
+ """Granite4Vision expects 5D pixel_values: (batch_size, num_patches, channels, height, width)"""
+ return floats_tensor(
+ [
+ self.batch_size,
+ self.num_patches_per_image,
+ self.num_channels,
+ self.image_size,
+ self.image_size,
+ ]
+ )
+
+ def get_additional_inputs(self, config, input_ids, pixel_values):
+ """Granite4Vision requires image_sizes tensor"""
+ return {
+ "image_sizes": torch.tensor([[self.image_size, self.image_size]] * self.batch_size),
+ }
+
+ def get_config(self):
+ config = super().get_config()
+ config.image_grid_pinpoints = [[self.image_size, self.image_size]]
+ config.downsample_rate = self.downsample_rate
+ config.deepstack_layer_map = self.deepstack_layer_map
+ config.use_spatial_sampling = self.use_spatial_sampling
+ config.projector_dropout = self.projector_dropout
+ return config
+
+
+@require_torch
+class Granite4VisionModelTest(VLMModelTest, unittest.TestCase):
+ """
+ Model tester for `Granite4VisionForConditionalGeneration`.
+ """
+
+ model_tester_class = Granite4VisionModelTester
+ skip_test_image_features_output_shape = True
+ test_torch_exportable = False
+ # Custom layer-by-layer forward doesn't support output_attentions
+ # (GraniteDecoderLayer discards attention weights internally)
+ test_attention_outputs = False
+ has_attentions = False
+
+ @pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
+ def test_training_gradient_checkpointing(self):
+ super().test_training_gradient_checkpointing()
+
+ @pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
+ def test_training_gradient_checkpointing_use_reentrant_false(self):
+ super().test_training_gradient_checkpointing_use_reentrant_false()
+
+ @pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
+ def test_training_gradient_checkpointing_use_reentrant_true(self):
+ super().test_training_gradient_checkpointing_use_reentrant_true()
+
+ @unittest.skip(
+ "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
+ )
+ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
+ pass
+
+ @unittest.skip(
+ "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
+ )
+ def test_eager_padding_matches_padding_free_with_position_ids(self):
+ pass
+
+ @unittest.skip("Custom layer-by-layer forward has graph breaks incompatible with fullgraph compile")
+ def test_generate_compile_model_forward_fullgraph(self):
+ pass
+
+ @unittest.skip("Blip2QFormerModel in WindowQFormerDownsampler does not support SDPA dispatch")
+ def test_can_set_attention_dynamically_composite_model(self):
+ pass
+
+
+@require_torch
+class Granite4VisionIntegrationTest(unittest.TestCase):
+ model_id = "ibm-granite/granite-vision-4.1-4b"
+
+ def setUp(self):
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ self.image = Image.open(requests.get(url, stream=True).raw)
+
+ def make_prompt(self, question):
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
+ return self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ @slow
+ def test_small_model_integration_test(self):
+ model = Granite4VisionForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(
+ torch_device
+ )
+
+ prompt = self.make_prompt("Describe this image briefly.")
+ inputs = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device)
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ new_tokens = output[:, inputs["input_ids"].shape[1] :]
+
+ EXPECTED_RESPONSE = "The image depicts two cats resting on a pink couch. They are lying in a relaxed, sprawled position, with one cat appearing to be in a" # fmt: skip
+ self.assertEqual(self.processor.decode(new_tokens[0], skip_special_tokens=True), EXPECTED_RESPONSE)
+
+ @slow
+ def test_small_model_integration_test_batch(self):
+ model = Granite4VisionForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(
+ torch_device
+ )
+
+ url2 = "http://images.cocodataset.org/val2017/000000001000.jpg"
+ image2 = Image.open(requests.get(url2, stream=True).raw)
+
+ prompt = self.make_prompt("What do you see in this image?")
+ inputs = self.processor(
+ text=[prompt, prompt],
+ images=[self.image, image2],
+ return_tensors="pt",
+ padding=True,
+ ).to(model.device)
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ new_tokens = output[:, inputs["input_ids"].shape[1] :]
+ responses = self.processor.batch_decode(new_tokens, skip_special_tokens=True)
+
+ self.assertIn("cat", responses[0].lower())
+ self.assertIn("tennis", responses[1].lower())
+
+ @slow
+ def test_small_model_integration_test_batch_matches_single(self):
+ model = Granite4VisionForConditionalGeneration.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(
+ torch_device
+ )
+
+ prompt = self.make_prompt("What do you see in this image?")
+
+ # Single inference
+ inputs_single = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device)
+ output_single = model.generate(**inputs_single, max_new_tokens=30, do_sample=False)
+ decoded_single = self.processor.decode(
+ output_single[0, inputs_single["input_ids"].shape[1] :], skip_special_tokens=True
+ )
+
+ # Batch inference (same image as first in batch)
+ url2 = "http://images.cocodataset.org/val2017/000000001000.jpg"
+ image2 = Image.open(requests.get(url2, stream=True).raw)
+ inputs_batch = self.processor(
+ text=[prompt, prompt],
+ images=[self.image, image2],
+ return_tensors="pt",
+ padding=True,
+ ).to(model.device)
+ output_batch = model.generate(**inputs_batch, max_new_tokens=30, do_sample=False)
+ decoded_batch = self.processor.decode(
+ output_batch[0, inputs_batch["input_ids"].shape[1] :], skip_special_tokens=True
+ )
+
+ self.assertEqual(decoded_single, decoded_batch)
diff --git a/tests/models/granite4_vision/test_processing_granite4_vision.py b/tests/models/granite4_vision/test_processing_granite4_vision.py
new file mode 100644
index 000000000000..8a56aa69b020
--- /dev/null
+++ b/tests/models/granite4_vision/test_processing_granite4_vision.py
@@ -0,0 +1,122 @@
+# Copyright 2025 IBM. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import unittest
+
+import torch
+
+from transformers import Granite4VisionProcessor
+from transformers.testing_utils import (
+ require_vision,
+)
+
+from ...test_processing_common import ProcessorTesterMixin
+
+
+@require_vision
+class Granite4VisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
+ processor_class = Granite4VisionProcessor
+ # Image token expansion with downsample_rate="1/2" produces more tokens than the defaults
+ image_text_kwargs_max_length = 300
+ image_text_kwargs_override_max_length = 280
+ image_unstructured_max_length = 260
+
+ @classmethod
+ def _setup_tokenizer(cls):
+ tokenizer_class = cls._get_component_class_from_processor("tokenizer")
+ tokenizer = tokenizer_class.from_pretrained("huggyllama/llama-7b")
+ tokenizer.add_special_tokens({"additional_special_tokens": [""]})
+ if not tokenizer.pad_token:
+ tokenizer.pad_token = "[PAD]"
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token_id = 0
+ return tokenizer
+
+ @classmethod
+ def _setup_test_attributes(cls, processor):
+ cls.image_token = processor.image_token
+
+ @staticmethod
+ def prepare_processor_dict():
+ return {
+ "chat_template": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}",
+ "patch_size": 14,
+ "vision_feature_select_strategy": "default",
+ "downsample_rate": "1/2",
+ } # fmt: skip
+
+ def test_get_num_vision_tokens(self):
+ """Tests general functionality of the helper used internally in vLLM"""
+ processor = self.get_processor()
+
+ output = processor._get_num_multimodal_tokens(image_sizes=[(100, 100), (300, 100), (500, 30)])
+ self.assertTrue("num_image_tokens" in output)
+ self.assertEqual(len(output["num_image_tokens"]), 3)
+
+ self.assertTrue("num_image_patches" in output)
+ self.assertEqual(len(output["num_image_patches"]), 3)
+
+ def test_chat_template_is_saved(self):
+ processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
+ processor_dict_loaded = json.loads(processor_loaded.to_json_string())
+ # chat templates aren't serialized to json in processors
+ self.assertFalse("chat_template" in processor_dict_loaded)
+
+ # they have to be saved as separate file and loaded back from that file
+ # so we check if the same template is loaded
+ processor_dict = self.prepare_processor_dict()
+ self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
+
+ def test_image_token_filling(self):
+ processor = self.processor_class.from_pretrained(self.tmpdirname)
+ processor.patch_size = 14
+ processor.vision_feature_select_strategy = "default"
+ processor.downsample_rate = "1/2"
+ processor.image_processor.crop_size = {"height": 336, "width": 336}
+ processor.image_processor.size = {"shortest_edge": 336}
+ processor.image_processor.image_grid_pinpoints = [[672, 336]]
+ # Important to check with non square image
+ image = torch.randint(0, 2, (3, 503, 316))
+ image_token_index = processor.image_token_id
+
+ # With downsample_rate="1/2" and patch_size=14:
+ # patches = 336/14 = 24, after ds: 24*1/2 = 12
+ # best resolution for (503, 316): [672, 336]
+ # scale_height=2, scale_width=1
+ # current = 12*2=24 h, 12*1=12 w
+ # aspect: 316/503 = 0.628, 12/24 = 0.5 -> orig > current -> new_height = round(503*(12/316)) = 19
+ # padding = (24-19)//2 = 2, current_height = 24 - 4 = 20
+ # unpadded = 20*12 = 240, newline = 20
+ # base = 12*12 + 0 = 144
+ # total = 240 + 20 + 144 = 404
+ # with "default" strategy: 404 - 1 = 403
+ expected_image_tokens = 403
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ inputs = processor(
+ text=[processor.apply_chat_template(messages)],
+ images=[image],
+ return_tensors="pt",
+ )
+ image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
+ self.assertEqual(expected_image_tokens, image_tokens)
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 0a89b4aa63a9..af97b084a7f6 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -283,6 +283,7 @@
"Gemma4VisionModel", # Building part of a bigger model, tested implicitly
"Gemma4AudioModel", # Building part of a bigger model, tested implicitly
"Sam3LiteTextTextModel", # Building part of a bigger model, tested implicitly through Sam3LiteTextModel
+ "Granite4VisionTextModel", # Building part of bigger (tested) model. Tested implicitly through Granite4VisionModel.
]
)
@@ -509,6 +510,7 @@
"Ernie4_5_VL_MoeModel", # BC Alias
"Ernie4_5_VL_MoeTextModel", # BC Alias
"UVDocBridge", # Building part of a bigger model, tested implicitly through UVDocModel
+ "Granite4VisionTextModel", # Building part of bigger (tested) model.
]