Skip to content

Commit 293ff45

Browse files
andimarafiotiamyeroberts
authored andcommitted
Add Idefics 3! (huggingface#32473)
* Add Idefics 3! * fixes to make both pipelines identical * fix for quantized models * First pass at the review * remove vocab size from the main config (it's still in the text_config) * hot fix for merve * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * re-add model_type for text_config * remove support for old_cache * remove hidden_size from main config * rename idefics3 HF repo * few changes suggested in the PR * fix to input_data_format computation * remove overwrite of _autoset_attn_implementation following @zucchini-nlp suggestion * improve example * few improvements from amy's review * big change to enable processing input images as numpy arrays * Changes to the code to uniformize processor kwargs * image processing tests * image processing tests fixes and some bugs they discovered * addressed review comments from Yoni * fix modeling tests * remove special tokens that are not special * fixes tests * skip failing tests - they also fail for idefics2 * added paper and readded the tests with multi gpu, who knows * Update docs/source/en/model_doc/idefics3.md Co-authored-by: amyeroberts <[email protected]> * Apply suggestions from code review Co-authored-by: amyeroberts <[email protected]> * review amy until image_processing_idefics3 * last comments from Amy * review amy * Update src/transformers/models/idefics3/image_processing_idefics3.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/idefics3/modeling_idefics3.py Co-authored-by: amyeroberts <[email protected]> * Update docs/source/en/model_doc/idefics3.md Co-authored-by: amyeroberts <[email protected]> * doc improvement - amy review * fix runtime error during fine-tuning * amy's review * Update src/transformers/models/idefics3/image_processing_idefics3.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/idefics3/image_processing_idefics3.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/idefics3/modeling_idefics3.py Co-authored-by: amyeroberts <[email protected]> * ruff * amy's comment on the order * ruff ruff * fix copies * square images when they are not splitted * ruff :( * Update src/transformers/models/idefics3/image_processing_idefics3.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/idefics3/test_processing_idefics3.py Co-authored-by: amyeroberts <[email protected]> * fix small bug introduced in refactor * amy's image processing changes * fixes peft tests and ruff * modify to_pil_image from transformers. and review from emanuele. * add modified to_pil_image --------- Co-authored-by: amyeroberts <[email protected]>
1 parent 2407a62 commit 293ff45

27 files changed

+4482
-2
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,8 @@
832832
title: IDEFICS
833833
- local: model_doc/idefics2
834834
title: Idefics2
835+
- local: model_doc/idefics3
836+
title: Idefics3
835837
- local: model_doc/instructblip
836838
title: InstructBLIP
837839
- local: model_doc/instructblipvideo

docs/source/en/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ Flax), PyTorch, and/or TensorFlow.
170170
| [I-BERT](model_doc/ibert) ||||
171171
| [IDEFICS](model_doc/idefics) ||||
172172
| [Idefics2](model_doc/idefics2) ||||
173+
| [Idefics3](model_doc/idefics3) ||||
173174
| [ImageGPT](model_doc/imagegpt) ||||
174175
| [Informer](model_doc/informer) ||||
175176
| [InstructBLIP](model_doc/instructblip) ||||

docs/source/en/model_doc/idefics3.md

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Idefics3
18+
19+
## Overview
20+
21+
The Idefics3 model was proposed in [Building and better understanding vision-language models: insights and future directions](https://huggingface.co/papers/2408.12637) by Hugo Laurençon, Andrés Marafioti, Victor Sanh, and Léo Tronchon.
22+
23+
Idefics3 is an adaptation of the Idefics2 model with three main differences:
24+
25+
- It uses Llama3 for the text model.
26+
- It uses an updated processing logic for the images.
27+
- It removes the perceiver.
28+
29+
The abstract from the paper is the following:
30+
31+
*The field of vision-language models (VLMs), which take images and texts as inputs and output texts, is rapidly evolving and has yet to reach consensus on several key aspects of the development pipeline, including data, architecture, and training methods. This paper can be seen as a tutorial for building a VLM. We begin by providing a comprehensive overview of the current state-of-the-art approaches, highlighting the strengths and weaknesses of each, addressing the major challenges in the field, and suggesting promising research directions for underexplored areas. We then walk through the practical steps to build Idefics3-8B, a powerful VLM that significantly outperforms its predecessor Idefics2-8B, while being trained efficiently, exclusively on open datasets, and using a straightforward pipeline. These steps include the creation of Docmatix, a dataset for improving document understanding capabilities, which is 240 times larger than previously available datasets. We release the model along with the datasets created for its training.*
32+
33+
## Usage tips
34+
35+
Input images are processed either by upsampling (if resizing is enabled) or at their original resolution. The resizing behavior depends on two parameters: do_resize and size.
36+
37+
If `do_resize` is set to `True`, the model resizes images so that the longest edge is 4*364 pixels by default.
38+
The default resizing behavior can be customized by passing a dictionary to the `size` parameter. For example, `{"longest_edge": 4 * 364}` is the default, but you can change it to a different value if needed.
39+
40+
Here’s how to control resizing and set a custom size:
41+
```python
42+
image_processor = Idefics3ImageProcessor(do_resize=True, size={"longest_edge": 2 * 364}, max_image_size=364)
43+
```
44+
45+
Additionally, the `max_image_size` parameter, which controls the size of each square patch the image is decomposed into, is set to 364 by default but can be adjusted as needed. After resizing (if applicable), the image processor decomposes the images into square patches based on the `max_image_size` parameter.
46+
47+
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [andimarafioti](https://huggingface.co/andito).
48+
49+
50+
## Idefics3Config
51+
52+
[[autodoc]] Idefics3Config
53+
54+
55+
## Idefics3Model
56+
57+
[[autodoc]] Idefics3Model
58+
- forward
59+
60+
## Idefics3ForConditionalGeneration
61+
62+
[[autodoc]] Idefics3ForConditionalGeneration
63+
- forward
64+
65+
66+
## Idefics3ImageProcessor
67+
[[autodoc]] Idefics3ImageProcessor
68+
- preprocess
69+
70+
71+
## Idefics3Processor
72+
[[autodoc]] Idefics3Processor
73+
- __call__

docs/source/en/perf_infer_gpu_one.md

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ FlashAttention-2 is currently supported for the following architectures:
5454
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
5555
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
5656
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
57+
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
5758
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
5859
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
5960
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)

src/transformers/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@
482482
"models.ibert": ["IBertConfig"],
483483
"models.idefics": ["IdeficsConfig"],
484484
"models.idefics2": ["Idefics2Config"],
485+
"models.idefics3": ["Idefics3Config"],
485486
"models.imagegpt": ["ImageGPTConfig"],
486487
"models.informer": ["InformerConfig"],
487488
"models.instructblip": [
@@ -1192,6 +1193,7 @@
11921193
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
11931194
_import_structure["models.idefics"].extend(["IdeficsImageProcessor"])
11941195
_import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
1196+
_import_structure["models.idefics3"].extend(["Idefics3ImageProcessor"])
11951197
_import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"])
11961198
_import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"])
11971199
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
@@ -2429,6 +2431,14 @@
24292431
"Idefics2Processor",
24302432
]
24312433
)
2434+
_import_structure["models.idefics3"].extend(
2435+
[
2436+
"Idefics3ForConditionalGeneration",
2437+
"Idefics3Model",
2438+
"Idefics3PreTrainedModel",
2439+
"Idefics3Processor",
2440+
]
2441+
)
24322442
_import_structure["models.imagegpt"].extend(
24332443
[
24342444
"ImageGPTForCausalImageModeling",
@@ -5299,6 +5309,7 @@
52995309
IdeficsConfig,
53005310
)
53015311
from .models.idefics2 import Idefics2Config
5312+
from .models.idefics3 import Idefics3Config
53025313
from .models.imagegpt import ImageGPTConfig
53035314
from .models.informer import InformerConfig
53045315
from .models.instructblip import (
@@ -6047,6 +6058,7 @@
60476058
from .models.grounding_dino import GroundingDinoImageProcessor
60486059
from .models.idefics import IdeficsImageProcessor
60496060
from .models.idefics2 import Idefics2ImageProcessor
6061+
from .models.idefics3 import Idefics3ImageProcessor
60506062
from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor
60516063
from .models.instructblipvideo import InstructBlipVideoImageProcessor
60526064
from .models.layoutlmv2 import (
@@ -7087,6 +7099,12 @@
70877099
Idefics2PreTrainedModel,
70887100
Idefics2Processor,
70897101
)
7102+
from .models.idefics3 import (
7103+
Idefics3ForConditionalGeneration,
7104+
Idefics3Model,
7105+
Idefics3PreTrainedModel,
7106+
Idefics3Processor,
7107+
)
70907108
from .models.imagegpt import (
70917109
ImageGPTForCausalImageModeling,
70927110
ImageGPTForImageClassification,

src/transformers/image_transforms.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def _rescale_for_pil_conversion(image):
162162
def to_pil_image(
163163
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
164164
do_rescale: Optional[bool] = None,
165+
image_mode: Optional[str] = None,
165166
input_data_format: Optional[Union[str, ChannelDimension]] = None,
166167
) -> "PIL.Image.Image":
167168
"""
@@ -175,6 +176,8 @@ def to_pil_image(
175176
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
176177
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
177178
and `False` otherwise.
179+
image_mode (`str`, *optional*):
180+
The mode to use for the PIL image. If unset, will use the default mode for the input image type.
178181
input_data_format (`ChannelDimension`, *optional*):
179182
The channel dimension format of the input image. If unset, will use the inferred format from the input.
180183
@@ -207,7 +210,7 @@ def to_pil_image(
207210
image = rescale(image, 255)
208211

209212
image = image.astype(np.uint8)
210-
return PIL.Image.fromarray(image)
213+
return PIL.Image.fromarray(image, mode=image_mode)
211214

212215

213216
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366

src/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
ibert,
117117
idefics,
118118
idefics2,
119+
idefics3,
119120
imagegpt,
120121
informer,
121122
instructblip,

src/transformers/models/auto/configuration_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
("ibert", "IBertConfig"),
135135
("idefics", "IdeficsConfig"),
136136
("idefics2", "Idefics2Config"),
137+
("idefics3", "Idefics3Config"),
137138
("imagegpt", "ImageGPTConfig"),
138139
("informer", "InformerConfig"),
139140
("instructblip", "InstructBlipConfig"),
@@ -434,6 +435,7 @@
434435
("ibert", "I-BERT"),
435436
("idefics", "IDEFICS"),
436437
("idefics2", "Idefics2"),
438+
("idefics3", "Idefics3"),
437439
("imagegpt", "ImageGPT"),
438440
("informer", "Informer"),
439441
("instructblip", "InstructBLIP"),

src/transformers/models/auto/image_processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
("hiera", ("BitImageProcessor",)),
9090
("idefics", ("IdeficsImageProcessor",)),
9191
("idefics2", ("Idefics2ImageProcessor",)),
92+
("idefics3", ("Idefics3ImageProcessor",)),
9293
("imagegpt", ("ImageGPTImageProcessor",)),
9394
("instructblip", ("BlipImageProcessor",)),
9495
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),

src/transformers/models/auto/modeling_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
("ibert", "IBertModel"),
132132
("idefics", "IdeficsModel"),
133133
("idefics2", "Idefics2Model"),
134+
("idefics3", "Idefics3Model"),
134135
("imagegpt", "ImageGPTModel"),
135136
("informer", "InformerModel"),
136137
("jamba", "JambaModel"),
@@ -316,6 +317,7 @@
316317
("ibert", "IBertForMaskedLM"),
317318
("idefics", "IdeficsForVisionText2Text"),
318319
("idefics2", "Idefics2ForConditionalGeneration"),
320+
("idefics3", "Idefics3ForConditionalGeneration"),
319321
("layoutlm", "LayoutLMForMaskedLM"),
320322
("llava", "LlavaForConditionalGeneration"),
321323
("llava_next", "LlavaNextForConditionalGeneration"),
@@ -736,6 +738,7 @@
736738
("chameleon", "ChameleonForConditionalGeneration"),
737739
("git", "GitForCausalLM"),
738740
("idefics2", "Idefics2ForConditionalGeneration"),
741+
("idefics3", "Idefics3ForConditionalGeneration"),
739742
("instructblip", "InstructBlipForConditionalGeneration"),
740743
("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
741744
("kosmos-2", "Kosmos2ForConditionalGeneration"),

src/transformers/models/auto/processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
("hubert", "Wav2Vec2Processor"),
6666
("idefics", "IdeficsProcessor"),
6767
("idefics2", "Idefics2Processor"),
68+
("idefics3", "Idefics3Processor"),
6869
("instructblip", "InstructBlipProcessor"),
6970
("instructblipvideo", "InstructBlipVideoProcessor"),
7071
("kosmos-2", "Kosmos2Processor"),

src/transformers/models/auto/tokenization_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@
219219
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
220220
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
221221
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
222+
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
222223
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
223224
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
224225
(

src/transformers/models/idefics2/modeling_idefics2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
10981098

10991099
def _init_weights(self, module):
11001100
std = (
1101-
self.config.text_config.initializer_range
1101+
self.config.initializer_range
11021102
if hasattr(self.config, "initializer_range")
11031103
else self.config.text_config.initializer_range
11041104
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
17+
18+
19+
_import_structure = {"configuration_idefics3": ["Idefics3Config"]}
20+
21+
22+
try:
23+
if not is_vision_available():
24+
raise OptionalDependencyNotAvailable()
25+
except OptionalDependencyNotAvailable:
26+
pass
27+
else:
28+
_import_structure["image_processing_idefics3"] = ["Idefics3ImageProcessor"]
29+
30+
31+
try:
32+
if not is_torch_available():
33+
raise OptionalDependencyNotAvailable()
34+
except OptionalDependencyNotAvailable:
35+
pass
36+
else:
37+
_import_structure["modeling_idefics3"] = [
38+
"Idefics3ForConditionalGeneration",
39+
"Idefics3PreTrainedModel",
40+
"Idefics3Model",
41+
]
42+
_import_structure["processing_idefics3"] = ["Idefics3Processor"]
43+
44+
if TYPE_CHECKING:
45+
from .configuration_idefics3 import Idefics3Config
46+
47+
try:
48+
if not is_vision_available():
49+
raise OptionalDependencyNotAvailable()
50+
except OptionalDependencyNotAvailable:
51+
pass
52+
else:
53+
from .image_processing_idefics3 import Idefics3ImageProcessor
54+
55+
try:
56+
if not is_torch_available():
57+
raise OptionalDependencyNotAvailable()
58+
except OptionalDependencyNotAvailable:
59+
pass
60+
else:
61+
from .modeling_idefics3 import (
62+
Idefics3ForConditionalGeneration,
63+
Idefics3Model,
64+
Idefics3PreTrainedModel,
65+
)
66+
from .processing_idefics3 import Idefics3Processor
67+
68+
69+
else:
70+
import sys
71+
72+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

0 commit comments

Comments
 (0)