Skip to content

Commit d8500cd

Browse files
authored
Uniformize kwargs for Pixtral processor (#33521)
* add uniformized pixtral and kwargs * update doc * fix _validate_images_text_input_order * nit
1 parent c29a869 commit d8500cd

File tree

7 files changed

+255
-62
lines changed

7 files changed

+255
-62
lines changed

docs/source/en/model_doc/pixtral.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The Pixtral model was released by the Mistral AI team on [Vllm](https://github.c
2424
Tips:
2525

2626
- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
27-
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
27+
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
2828
- The format for one or mulitple prompts is the following:
2929
```
3030
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
@@ -35,7 +35,7 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
3535

3636
Here is an example of how to run it:
3737

38-
```python
38+
```python
3939
from transformers import LlavaForConditionalGeneration, AutoProcessor
4040
from PIL import Image
4141

@@ -51,7 +51,7 @@ IMG_URLS = [
5151
]
5252
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
5353

54-
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
54+
inputs = processor(images=IMG_URLS, text=PROMPT, return_tensors="pt").to("cuda")
5555
generate_ids = model.generate(**inputs, max_new_tokens=500)
5656
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
5757

src/transformers/models/pixtral/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343

4444

4545
if TYPE_CHECKING:
46-
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig
46+
from .configuration_pixtral import PixtralVisionConfig
47+
from .processing_pixtral import PixtralProcessor
4748

4849
try:
4950
if not is_torch_available():

src/transformers/models/pixtral/processing_pixtral.py

+40-34
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,36 @@
1616
Processor class for Pixtral.
1717
"""
1818

19-
from typing import List, Optional, Union
19+
import sys
20+
from typing import List, Union
2021

2122
from ...feature_extraction_utils import BatchFeature
2223
from ...image_utils import ImageInput, is_valid_image, load_image
23-
from ...processing_utils import ProcessorMixin
24-
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
25-
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
24+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
25+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
26+
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
2627

2728

29+
if sys.version_info >= (3, 11):
30+
from typing import Unpack
31+
else:
32+
from typing_extensions import Unpack
33+
2834
logger = logging.get_logger(__name__)
2935

3036

37+
class PixtralProcessorKwargs(ProcessingKwargs, total=False):
38+
_defaults = {
39+
"text_kwargs": {
40+
"padding": False,
41+
},
42+
"images_kwargs": {},
43+
"common_kwargs": {
44+
"return_tensors": "pt",
45+
},
46+
}
47+
48+
3149
# Copied from transformers.models.idefics2.processing_idefics2.is_url
3250
def is_url(val) -> bool:
3351
return isinstance(val, str) and val.startswith("http")
@@ -143,12 +161,11 @@ def __init__(
143161

144162
def __call__(
145163
self,
146-
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
147164
images: ImageInput = None,
148-
padding: Union[bool, str, PaddingStrategy] = False,
149-
truncation: Union[bool, str, TruncationStrategy] = None,
150-
max_length=None,
151-
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
165+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
166+
audio=None,
167+
videos=None,
168+
**kwargs: Unpack[PixtralProcessorKwargs],
152169
) -> BatchMixFeature:
153170
"""
154171
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@@ -158,26 +175,13 @@ def __call__(
158175
of the above two methods for more information.
159176
160177
Args:
178+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
179+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
180+
tensor. Both channels-first and channels-last formats are supported.
161181
text (`str`, `List[str]`, `List[List[str]]`):
162182
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
163183
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
164184
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
165-
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
166-
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
167-
tensor. Both channels-first and channels-last formats are supported.
168-
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
169-
Select a strategy to pad the returned sequences (according to the model's padding side and padding
170-
index) among:
171-
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
172-
sequence if provided).
173-
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
174-
acceptable input length for the model if that argument is not provided.
175-
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
176-
lengths).
177-
max_length (`int`, *optional*):
178-
Maximum length of the returned list and optionally padding length (see above).
179-
truncation (`bool`, *optional*):
180-
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
181185
return_tensors (`str` or [`~utils.TensorType`], *optional*):
182186
If set, will return tensors of a particular framework. Acceptable values are:
183187
@@ -195,6 +199,15 @@ def __call__(
195199
`None`).
196200
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
197201
"""
202+
# check if images and text inputs are reversed for BC
203+
images, text = _validate_images_text_input_order(images, text)
204+
205+
output_kwargs = self._merge_kwargs(
206+
PixtralProcessorKwargs,
207+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
208+
**kwargs,
209+
)
210+
198211
if images is not None:
199212
if is_image_or_image_url(images):
200213
images = [[images]]
@@ -209,7 +222,7 @@ def __call__(
209222
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
210223
)
211224
images = [[load_image(im) for im in sample] for sample in images]
212-
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors)
225+
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
213226
else:
214227
image_inputs = {}
215228

@@ -246,16 +259,9 @@ def __call__(
246259
while "<placeholder>" in sample:
247260
replace_str = replace_strings.pop(0)
248261
sample = sample.replace("<placeholder>", replace_str, 1)
249-
250262
prompt_strings.append(sample)
251263

252-
text_inputs = self.tokenizer(
253-
prompt_strings,
254-
return_tensors=return_tensors,
255-
padding=padding,
256-
truncation=truncation,
257-
max_length=max_length,
258-
)
264+
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
259265
return BatchMixFeature(data={**text_inputs, **image_inputs})
260266

261267
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama

src/transformers/processing_utils.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import numpy as np
2828

2929
from .dynamic_module_utils import custom_object_save
30-
from .image_utils import ChannelDimension, is_vision_available, valid_images
30+
from .image_utils import ChannelDimension, is_valid_image, is_vision_available
3131

3232

3333
if is_vision_available():
@@ -1003,6 +1003,20 @@ def _validate_images_text_input_order(images, text):
10031003
in the processor's `__call__` method before calling this method.
10041004
"""
10051005

1006+
def is_url(val) -> bool:
1007+
return isinstance(val, str) and val.startswith("http")
1008+
1009+
def _is_valid_images_input_for_processor(imgs):
1010+
# If we have an list of images, make sure every image is valid
1011+
if isinstance(imgs, (list, tuple)):
1012+
for img in imgs:
1013+
if not _is_valid_images_input_for_processor(img):
1014+
return False
1015+
# If not a list or tuple, we have been given a single image or batched tensor of images
1016+
elif not (is_valid_image(imgs) or is_url(imgs)):
1017+
return False
1018+
return True
1019+
10061020
def _is_valid_text_input_for_processor(t):
10071021
if isinstance(t, str):
10081022
# Strings are fine
@@ -1019,11 +1033,11 @@ def _is_valid_text_input_for_processor(t):
10191033
def _is_valid(input, validator):
10201034
return validator(input) or input is None
10211035

1022-
images_is_valid = _is_valid(images, valid_images)
1023-
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False
1036+
images_is_valid = _is_valid(images, _is_valid_images_input_for_processor)
1037+
images_is_text = _is_valid_text_input_for_processor(images)
10241038

10251039
text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
1026-
text_is_images = valid_images(text) if not text_is_valid else False
1040+
text_is_images = _is_valid_images_input_for_processor(text)
10271041
# Handle cases where both inputs are valid
10281042
if images_is_valid and text_is_valid:
10291043
return images, text

0 commit comments

Comments
 (0)