Skip to content

Commit 203e270

Browse files
authored
Add image text to text pipeline (#34170)
* Standardize image-text-to-text-models-output add post_process_image_text_to_text to chameleon and cleanup Fix legacy kwarg behavior and deprecation warning add post_process_image_text_to_text to qwen2_vl and llava_onevision Add post_process_image_text_to_text to idefics3, mllama, pixtral processor * nit var name post_process_image_text_to_text udop * nit fix deprecation warnings * Add image-text-to-text pipeline * add support for image url in chat template for pipeline * Reformat to be fully compatible with chat templates * Add tests chat template * Fix imports and tests * Add pipeline tag * change logic handling of single prompt ans multiple images * add pipeline mapping to models * fix batched inference * fix tests * Add manual batching for preprocessing * Fix outputs with nested images * Add support for all common processing kwargs * Add default padding when multiple text inputs (batch size>1) * nit change version deprecation warning * Add support for text only inference * add chat_template warnings * Add pipeline tests and add copied from post process function * Fix batched pipeline tests * nit * Fix pipeline tests blip2 * remove unnecessary max_new_tokens * revert processing kosmos2 and remove unnecessary max_new_tokens * fix pipeline tests idefics * Force try loading processor if pipeline supports it * revert load_processor change * hardcode loading only processor * remove unnecessary try except * skip imagetexttotext tests for kosmos2 as tiny model causes problems * Make code clearer * Address review comments * remove preprocessing logic from pipeline * fix fuyu * add BC resize fuyu * Move post_process_image_text_to_text to ProcessorMixin * add guard in post_process * fix zero shot object detection pipeline * add support for generator input in pipeline * nit * change default image-text-to-text model to llava onevision * fix owlv2 size dict * Change legacy deprecation warning to only show when True
1 parent c443d8d commit 203e270

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+988
-33
lines changed

docs/source/en/main_classes/pipelines.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following.
478478
- __call__
479479
- all
480480

481+
### ImageTextToTextPipeline
482+
483+
[[autodoc]] ImageTextToTextPipeline
484+
- __call__
485+
- all
486+
481487
### MaskGenerationPipeline
482488

483489
[[autodoc]] MaskGenerationPipeline

docs/source/ja/main_classes/pipelines.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
481481
- __call__
482482
- all
483483

484+
### ImageTextToTextPipeline
485+
486+
[[autodoc]] ImageTextToTextPipeline
487+
- __call__
488+
- all
489+
484490
### VisualQuestionAnsweringPipeline
485491

486492
[[autodoc]] VisualQuestionAnsweringPipeline

docs/source/zh/main_classes/pipelines.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details.
455455
- __call__
456456
- all
457457

458+
### ImageTextToTextPipeline
459+
460+
[[autodoc]] ImageTextToTextPipeline
461+
- __call__
462+
- all
463+
458464
### MaskGenerationPipeline
459465

460466
[[autodoc]] MaskGenerationPipeline

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,7 @@
868868
"ImageClassificationPipeline",
869869
"ImageFeatureExtractionPipeline",
870870
"ImageSegmentationPipeline",
871+
"ImageTextToTextPipeline",
871872
"ImageToImagePipeline",
872873
"ImageToTextPipeline",
873874
"JsonPipelineDataFormat",
@@ -5794,6 +5795,7 @@
57945795
ImageClassificationPipeline,
57955796
ImageFeatureExtractionPipeline,
57965797
ImageSegmentationPipeline,
5798+
ImageTextToTextPipeline,
57975799
ImageToImagePipeline,
57985800
ImageToTextPipeline,
57995801
JsonPipelineDataFormat,

src/transformers/image_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,27 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
385385
return image
386386

387387

388+
def load_images(
389+
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
390+
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
391+
"""Loads images, handling different levels of nesting.
392+
393+
Args:
394+
images: A single image, a list of images, or a list of lists of images to load.
395+
timeout: Timeout for loading images.
396+
397+
Returns:
398+
A single image, a list of images, a list of lists of images.
399+
"""
400+
if isinstance(images, (list, tuple)):
401+
if len(images) and isinstance(images[0], (list, tuple)):
402+
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
403+
else:
404+
return [load_image(image, timeout=timeout) for image in images]
405+
else:
406+
return load_image(images, timeout=timeout)
407+
408+
388409
def validate_preprocess_arguments(
389410
do_rescale: Optional[bool] = None,
390411
rescale_factor: Optional[float] = None,

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
("oneformer", ("OneFormerImageProcessor",)),
115115
("owlv2", ("Owlv2ImageProcessor",)),
116116
("owlvit", ("OwlViTImageProcessor",)),
117+
("paligemma", ("SiglipImageProcessor",)),
117118
("perceiver", ("PerceiverImageProcessor",)),
118119
("pix2struct", ("Pix2StructImageProcessor",)),
119120
("pixtral", ("PixtralImageProcessor",)),

src/transformers/models/donut/processing_donut.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@
2424
from ...image_utils import ImageInput
2525
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
2626
from ...tokenization_utils_base import PreTokenizedInput, TextInput
27+
from ...utils import logging
2728

2829

2930
class DonutProcessorKwargs(ProcessingKwargs, total=False):
3031
_defaults = {}
3132

3233

34+
logger = logging.get_logger(__name__)
35+
36+
3337
class DonutProcessor(ProcessorMixin):
3438
r"""
3539
Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single
@@ -85,6 +89,16 @@ def __call__(
8589
[`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
8690
"""
8791
# For backward compatibility
92+
legacy = kwargs.pop("legacy", True)
93+
if legacy:
94+
# With `add_special_tokens=True`, the performance of donut are degraded when working with both images and text.
95+
logger.warning_once(
96+
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
97+
"In the new behavior, if both images and text are provided, the default value of `add_special_tokens` "
98+
"will be changed to `False` when calling the tokenizer if `add_special_tokens` is unset. "
99+
"To test the new behavior, set `legacy=False`as a processor call argument."
100+
)
101+
88102
if self._in_target_context_manager:
89103
return self.current_processor(images, text, **kwargs)
90104

@@ -100,6 +114,8 @@ def __call__(
100114
if images is not None:
101115
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
102116
if text is not None:
117+
if not legacy and images is not None:
118+
output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
103119
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
104120

105121
if text is None:

src/transformers/models/fuyu/image_processing_fuyu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121

22-
from ...image_processing_utils import BaseImageProcessor, BatchFeature
22+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
2323
from ...image_transforms import (
2424
pad,
2525
resize,
@@ -475,6 +475,7 @@ def preprocess(
475475
input_data_format = infer_channel_dimension_format(batch_images[0][0])
476476

477477
original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
478+
size = get_size_dict(size) # for BC
478479

479480
if do_resize:
480481
batch_images = [

src/transformers/models/fuyu/processing_fuyu.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ def _tokenize_prompts_with_image_and_batch(
264264
bos_token = tokenizer.vocab["|ENDOFTEXT|"]
265265
prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
266266
if add_beginning_of_answer_token:
267-
boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
267+
beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
268268
# Only add bbox open token to the last subsequence since that is what will be completed
269269
for token_seq in prompts_tokens:
270-
token_seq[-1].append(boa)
270+
token_seq[-1].append(beginning_of_answer)
271271

272272
# Now we have a list of list of tokens which each list has a different
273273
# size. We want to extend this list to:
@@ -682,6 +682,32 @@ def tokens_to_points(tokens, original_size):
682682

683683
return results
684684

685+
def post_process_image_text_to_text(self, generated_outputs):
686+
"""
687+
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
688+
689+
Args:
690+
generated_outputs (`torch.Tensor` or `np.ndarray`):
691+
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
692+
containing the token ids of the generated sequences.
693+
694+
Returns:
695+
`List[str]`: The decoded text output.
696+
"""
697+
beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
698+
# get boa index for each outputted sequence tensor
699+
# start all generated sequences from the beginning of the answer token, pad to have consistent length
700+
unpadded_output_sequences = [
701+
seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
702+
]
703+
max_len = max(len(seq) for seq in unpadded_output_sequences)
704+
# convert to torch and pad sequences
705+
padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
706+
for i, seq in enumerate(unpadded_output_sequences):
707+
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)
708+
709+
return self.batch_decode(padded_output_sequences, skip_special_tokens=True)
710+
685711
def batch_decode(self, *args, **kwargs):
686712
"""
687713
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

src/transformers/models/git/processing_git.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@
2222
from ...image_utils import ImageInput
2323
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
2424
from ...tokenization_utils_base import PreTokenizedInput, TextInput
25+
from ...utils import logging
2526

2627

2728
class GitProcessorKwargs(ProcessingKwargs, total=False):
2829
_defaults = {}
2930

3031

32+
logger = logging.get_logger(__name__)
33+
34+
3135
class GitProcessor(ProcessorMixin):
3236
r"""
3337
Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
@@ -91,6 +95,15 @@ def __call__(
9195
`None`).
9296
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
9397
"""
98+
legacy = kwargs.pop("legacy", True)
99+
if legacy:
100+
logger.warning_once(
101+
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
102+
"In the new behavior, if both images and text are provided, the last token (EOS token) "
103+
"of the input_ids and attention_mask tensors will be removed. "
104+
"To test the new behavior, set `legacy=False`as a processor call argument."
105+
)
106+
94107
if text is None and images is None:
95108
raise ValueError("You have to specify either text or images. Both cannot be none.")
96109

@@ -110,6 +123,10 @@ def __call__(
110123
if images is not None:
111124
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
112125
data.update(image_features)
126+
if not legacy:
127+
data["input_ids"] = data["input_ids"][:, :-1]
128+
data["attention_mask"] = data["attention_mask"][:, :-1]
129+
113130
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))
114131

115132
def batch_decode(self, *args, **kwargs):

src/transformers/models/kosmos2/processing_kosmos2.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,21 @@ def post_process_generation(self, text, cleanup_and_extract=True):
428428
return clean_text_and_extract_entities_with_bboxes(caption)
429429
return caption
430430

431+
def post_process_image_text_to_text(self, generated_outputs):
432+
"""
433+
Post-process the output of the model to decode the text.
434+
435+
Args:
436+
generated_outputs (`torch.Tensor` or `np.ndarray`):
437+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
438+
or `(sequence_length,)`.
439+
440+
Returns:
441+
`List[str]`: The decoded text.
442+
"""
443+
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
444+
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]
445+
431446
@property
432447
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
433448
def model_input_names(self):

src/transformers/models/mllama/processing_mllama.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,22 @@ def decode(self, *args, **kwargs):
342342
"""
343343
return self.tokenizer.decode(*args, **kwargs)
344344

345+
def post_process_image_text_to_text(self, generated_outputs):
346+
"""
347+
Post-process the output of the model to decode the text.
348+
349+
Args:
350+
generated_outputs (`torch.Tensor` or `np.ndarray`):
351+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
352+
or `(sequence_length,)`.
353+
354+
Returns:
355+
`List[str]`: The decoded text.
356+
"""
357+
return self.tokenizer.batch_decode(
358+
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
359+
)
360+
345361
@property
346362
def model_input_names(self):
347363
tokenizer_input_names = self.tokenizer.model_input_names

src/transformers/models/owlv2/image_processing_owlv2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121

22-
from ...image_processing_utils import BaseImageProcessor, BatchFeature
22+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
2323
from ...image_transforms import (
2424
center_to_corners_format,
2525
pad,
@@ -399,6 +399,7 @@ def preprocess(
399399
image_std = image_std if image_std is not None else self.image_std
400400

401401
size = size if size is not None else self.size
402+
size = get_size_dict(size) # for BC
402403

403404
images = make_list_of_images(images)
404405

src/transformers/models/pix2struct/processing_pix2struct.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ...feature_extraction_utils import BatchFeature
2222
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
2323
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
24+
from ...utils import logging
2425

2526

2627
class Pix2StructImagesKwargs(ImagesKwargs, total=False):
@@ -48,6 +49,9 @@ class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
4849
}
4950

5051

52+
logger = logging.get_logger(__name__)
53+
54+
5155
class Pix2StructProcessor(ProcessorMixin):
5256
r"""
5357
Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single
@@ -85,6 +89,15 @@ def __call__(
8589
8690
Please refer to the docstring of the above two methods for more information.
8791
"""
92+
legacy = kwargs.pop("legacy", True)
93+
if legacy:
94+
logger.warning_once(
95+
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
96+
"In the new behavior, If both images and text are provided, image_processor is not a VQA processor, and `add_special_tokens` is unset, "
97+
"the default value of `add_special_tokens` will be changed to `False` when calling the tokenizer. "
98+
"To test the new behavior, set `legacy=False`as a processor call argument."
99+
)
100+
88101
if images is None and text is None:
89102
raise ValueError("You have to specify either images or text.")
90103

@@ -93,8 +106,12 @@ def __call__(
93106
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
94107
**kwargs,
95108
)
109+
add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None)
96110
# Get only text
97111
if images is None and not self.image_processor.is_vqa:
112+
output_kwargs["text_kwargs"]["add_special_tokens"] = (
113+
add_special_tokens if add_special_tokens is not None else True
114+
)
98115
self.current_processor = self.tokenizer
99116
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
100117
return text_encoding
@@ -108,6 +125,9 @@ def __call__(
108125
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
109126

110127
if text is not None and not self.image_processor.is_vqa:
128+
output_kwargs["text_kwargs"]["add_special_tokens"] = (
129+
add_special_tokens if add_special_tokens is not None else legacy
130+
)
111131
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
112132

113133
if "attention_mask" in text_encoding:

src/transformers/models/qwen2_vl/processing_qwen2_vl.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,22 @@ def decode(self, *args, **kwargs):
168168
"""
169169
return self.tokenizer.decode(*args, **kwargs)
170170

171+
def post_process_image_text_to_text(self, generated_outputs):
172+
"""
173+
Post-process the output of the model to decode the text.
174+
175+
Args:
176+
generated_outputs (`torch.Tensor` or `np.ndarray`):
177+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
178+
or `(sequence_length,)`.
179+
180+
Returns:
181+
`List[str]`: The decoded text.
182+
"""
183+
return self.tokenizer.batch_decode(
184+
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
185+
)
186+
171187
@property
172188
def model_input_names(self):
173189
tokenizer_input_names = self.tokenizer.model_input_names

0 commit comments

Comments
 (0)