Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize LlavaNextVideoProcessor kwargs #35613

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,33 @@
Processor class for LLaVa-NeXT-Video.
"""

from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Union

import numpy as np

from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, logging
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


if TYPE_CHECKING:
pass

logger = logging.get_logger(__name__)


class LlavaNextVideoProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}


class LlavaNextVideoProcessor(ProcessorMixin):
r"""
Constructs a LLaVa-NeXT-Video processor which wraps a LLaVa-NeXT image processor, LLaVa-NeXT-Video video processor and
Expand Down Expand Up @@ -102,13 +111,11 @@ def __init__(

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny concern about squishing in audio before the videos, but I don't think any user passes videos as positional arg so maybe we are oke

I don't want to add more complexity by trying to validate the order of audio and video now, so let's leave as is and just take this noted

videos: VideoInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs: Unpack[LlavaNextVideoProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -130,19 +137,6 @@ def __call__(
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

Expand All @@ -160,13 +154,21 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
LlavaNextVideoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, return_tensors=return_tensors)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
image_inputs = {}

if videos is not None:
videos_inputs = self.video_processor(videos, return_tensors=return_tensors)
videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
else:
videos_inputs = {}

Expand Down Expand Up @@ -212,13 +214,7 @@ def __call__(
prompt_strings.append(sample)
text = prompt_strings

text_inputs = self.tokenizer(
text,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})

# Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
},
"image_kwargs": {},
"video_kwargs": {},
"videos_kwargs": {},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

}


Expand Down
6 changes: 3 additions & 3 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
processing_kwargs.setdefault("padding", True)
model_inputs = self.processor(
images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs
).to(dtype=self.torch_dtype)
model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **processing_kwargs).to(
dtype=self.torch_dtype
)
Comment on lines +348 to +350
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I right that we don't need legacy=False anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up keeping it because I had put v5.0.0 as the deprecation version


model_inputs["text"] = inputs_text

Expand Down
166 changes: 166 additions & 0 deletions tests/models/llava_next_video/test_processor_llava_next_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import shutil
import tempfile
import unittest

from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import LlavaNextImageProcessor, LlavaNextVideoImageProcessor

if is_torch_available:
import torch


@require_vision
class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik ProcessorTesterMixin doesn't test videos_kwargs yet. I think we need to add video tests to make sure that llava-next-video processor works as expected

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! I added tests for video_kwargs, very similar to those on images_kwargs :)

processor_class = LlavaNextVideoProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = LlavaNextImageProcessor()
video_processor = LlavaNextVideoImageProcessor()
tokenizer = LlamaTokenizerFast.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
processor_kwargs = self.prepare_processor_dict()

processor = LlavaNextVideoProcessor(
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs
)
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def get_video_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor

def prepare_processor_dict(self):
return {
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ '\n' + content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"num_additional_image_tokens": 6,
"patch_size": 4,
"vision_feature_select_strategy": "default",
}

def test_processor_to_json_string(self):
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
for key, value in self.prepare_processor_dict().items():
# chat_tempalate are tested as a separate test because they are saved in separate files
if key != "chat_template":
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())

# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"

messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]

formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)

@require_av
def test_chat_template_dict(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": "What is shown in this video?"},
],
},
]

formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
)
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 13, 5618, 338, 4318, 297, 445, 4863, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)

out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])

# add image URL for return dict
messages[0]["content"][0] = {
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
}
out_dict_with_video = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])

@require_torch
@require_av
def test_chat_template_dict_torch(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]

out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
Loading