Skip to content

Commit 7a411e9

Browse files
committed
Uniformize processor kwargs and add tests
1 parent c61fcde commit 7a411e9

File tree

4 files changed

+207
-38
lines changed

4 files changed

+207
-38
lines changed

src/transformers/models/llava_next_video/processing_llava_next_video.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,33 @@
1616
Processor class for LLaVa-NeXT-Video.
1717
"""
1818

19-
from typing import TYPE_CHECKING, List, Optional, Union
19+
from typing import List, Union
2020

2121
from ...feature_extraction_utils import BatchFeature
2222
from ...image_processing_utils import select_best_resolution
2323
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
24-
from ...processing_utils import ProcessorMixin
25-
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
26-
from ...utils import TensorType, logging
24+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
25+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
26+
from ...utils import logging
2727

2828

29-
if TYPE_CHECKING:
30-
pass
31-
3229
logger = logging.get_logger(__name__)
3330

3431

32+
class LlavaNextVideoProcessorKwargs(ProcessingKwargs, total=False):
33+
# see processing_utils.ProcessingKwargs documentation for usage.
34+
_defaults = {
35+
"text_kwargs": {
36+
"padding": False,
37+
},
38+
"image_kwargs": {},
39+
"videos_kwargs": {},
40+
"common_kwargs": {
41+
"return_tensors": "pt",
42+
},
43+
}
44+
45+
3546
class LlavaNextVideoProcessor(ProcessorMixin):
3647
r"""
3748
Constructs a LLaVa-NeXT-Video processor which wraps a LLaVa-NeXT image processor, LLaVa-NeXT-Video video processor and
@@ -100,13 +111,11 @@ def __init__(
100111

101112
def __call__(
102113
self,
103-
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
104114
images: ImageInput = None,
115+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
116+
audio=None,
105117
videos: VideoInput = None,
106-
padding: Union[bool, str, PaddingStrategy] = False,
107-
truncation: Union[bool, str, TruncationStrategy] = None,
108-
max_length: int = None,
109-
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
118+
**kwargs: Unpack[LlavaNextVideoProcessorKwargs],
110119
) -> BatchFeature:
111120
"""
112121
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@@ -128,19 +137,6 @@ def __call__(
128137
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
129138
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
130139
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
131-
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
132-
Select a strategy to pad the returned sequences (according to the model's padding side and padding
133-
index) among:
134-
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
135-
sequence if provided).
136-
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
137-
acceptable input length for the model if that argument is not provided.
138-
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
139-
lengths).
140-
max_length (`int`, *optional*):
141-
Maximum length of the returned list and optionally padding length (see above).
142-
truncation (`bool`, *optional*):
143-
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
144140
return_tensors (`str` or [`~utils.TensorType`], *optional*):
145141
If set, will return tensors of a particular framework. Acceptable values are:
146142
@@ -158,13 +154,21 @@ def __call__(
158154
`None`).
159155
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
160156
"""
157+
# check if images and text inputs are reversed for BC
158+
images, text = _validate_images_text_input_order(images, text)
159+
160+
output_kwargs = self._merge_kwargs(
161+
LlavaNextVideoProcessorKwargs,
162+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
163+
**kwargs,
164+
)
161165
if images is not None:
162-
image_inputs = self.image_processor(images, return_tensors=return_tensors)
166+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
163167
else:
164168
image_inputs = {}
165169

166170
if videos is not None:
167-
videos_inputs = self.video_processor(videos, return_tensors=return_tensors)
171+
videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
168172
else:
169173
videos_inputs = {}
170174

@@ -206,13 +210,7 @@ def __call__(
206210
prompt_strings.append(sample)
207211
text = prompt_strings
208212

209-
text_inputs = self.tokenizer(
210-
text,
211-
return_tensors=return_tensors,
212-
padding=padding,
213-
truncation=truncation,
214-
max_length=max_length,
215-
)
213+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
216214
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
217215

218216
# Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features

src/transformers/models/llava_onevision/processing_llava_onevision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
3939
"padding": False,
4040
},
4141
"image_kwargs": {},
42-
"video_kwargs": {},
42+
"videos_kwargs": {},
4343
}
4444

4545

src/transformers/pipelines/image_text_to_text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,9 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
345345
# if batched text inputs, we set padding to True unless specified otherwise
346346
if isinstance(text, (list, tuple)) and len(text) > 1:
347347
processing_kwargs.setdefault("padding", True)
348-
model_inputs = self.processor(
349-
images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs
350-
).to(dtype=self.torch_dtype)
348+
model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **processing_kwargs).to(
349+
dtype=self.torch_dtype
350+
)
351351

352352
model_inputs["text"] = inputs_text
353353

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import shutil
18+
import tempfile
19+
import unittest
20+
21+
from transformers.testing_utils import require_av, require_torch, require_vision
22+
from transformers.utils import is_torch_available, is_vision_available
23+
24+
from ...test_processing_common import ProcessorTesterMixin
25+
26+
27+
if is_vision_available():
28+
from transformers import (
29+
AutoProcessor,
30+
LlamaTokenizerFast,
31+
LlavaNextImageProcessor,
32+
LlavaNextVideoImageProcessor,
33+
LlavaNextVideoProcessor,
34+
)
35+
36+
if is_torch_available:
37+
import torch
38+
39+
40+
@require_vision
41+
class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
42+
processor_class = LlavaNextVideoProcessor
43+
44+
def setUp(self):
45+
self.tmpdirname = tempfile.mkdtemp()
46+
image_processor = LlavaNextImageProcessor()
47+
video_processor = LlavaNextVideoImageProcessor()
48+
tokenizer = LlamaTokenizerFast.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
49+
processor_kwargs = self.prepare_processor_dict()
50+
51+
processor = LlavaNextVideoProcessor(
52+
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs
53+
)
54+
processor.save_pretrained(self.tmpdirname)
55+
56+
def get_tokenizer(self, **kwargs):
57+
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
58+
59+
def get_image_processor(self, **kwargs):
60+
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
61+
62+
def get_video_processor(self, **kwargs):
63+
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
64+
65+
def prepare_processor_dict(self):
66+
return {
67+
"chat_template": "dummy_template",
68+
"num_additional_image_tokens": 6,
69+
"vision_feature_select_strategy": "default",
70+
}
71+
72+
def test_processor_to_json_string(self):
73+
processor = self.get_processor()
74+
obj = json.loads(processor.to_json_string())
75+
print(processor)
76+
for key, value in self.prepare_processor_dict().items():
77+
# chat_tempalate are tested as a separate test because they are saved in separate files
78+
if key != "chat_template":
79+
self.assertEqual(obj[key], value)
80+
self.assertEqual(getattr(processor, key, None), value)
81+
82+
# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved
83+
def test_chat_template_is_saved(self):
84+
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
85+
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
86+
# chat templates aren't serialized to json in processors
87+
self.assertFalse("chat_template" in processor_dict_loaded.keys())
88+
89+
# they have to be saved as separate file and loaded back from that file
90+
# so we check if the same template is loaded
91+
processor_dict = self.prepare_processor_dict()
92+
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
93+
94+
def tearDown(self):
95+
shutil.rmtree(self.tmpdirname)
96+
97+
def test_chat_template(self):
98+
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
99+
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"
100+
101+
messages = [
102+
{
103+
"role": "user",
104+
"content": [
105+
{"type": "image"},
106+
{"type": "text", "text": "What is shown in this image?"},
107+
],
108+
},
109+
]
110+
111+
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
112+
self.assertEqual(expected_prompt, formatted_prompt)
113+
114+
@require_av
115+
def test_chat_template_dict(self):
116+
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
117+
messages = [
118+
{
119+
"role": "user",
120+
"content": [
121+
{"type": "video"},
122+
{"type": "text", "text": "What is shown in this video?"},
123+
],
124+
},
125+
]
126+
127+
formatted_prompt_tokenized = processor.apply_chat_template(
128+
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
129+
)
130+
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 13, 5618, 338, 4318, 297, 445, 4863, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
131+
self.assertListEqual(expected_output, formatted_prompt_tokenized)
132+
133+
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
134+
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
135+
136+
# add image URL for return dict
137+
messages[0]["content"][0] = {
138+
"type": "video",
139+
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
140+
}
141+
out_dict_with_video = processor.apply_chat_template(
142+
messages, add_generation_prompt=True, tokenize=True, return_dict=True
143+
)
144+
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
145+
146+
@require_torch
147+
@require_av
148+
def test_chat_template_dict_torch(self):
149+
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
150+
messages = [
151+
{
152+
"role": "user",
153+
"content": [
154+
{
155+
"type": "video",
156+
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
157+
},
158+
{"type": "text", "text": "What is shown in this video?"},
159+
],
160+
},
161+
]
162+
163+
out_dict_tensors = processor.apply_chat_template(
164+
messages,
165+
add_generation_prompt=True,
166+
tokenize=True,
167+
return_dict=True,
168+
return_tensors="pt",
169+
)
170+
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
171+
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))

0 commit comments

Comments
 (0)