diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index 8a1294611b0d..6ec40209df8b 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -16,24 +16,33 @@ Processor class for LLaVa-NeXT-Video. """ -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Union import numpy as np from ...feature_extraction_utils import BatchFeature from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging -if TYPE_CHECKING: - pass - logger = logging.get_logger(__name__) +class LlavaNextVideoProcessorKwargs(ProcessingKwargs, total=False): + # see processing_utils.ProcessingKwargs documentation for usage. + _defaults = { + "text_kwargs": { + "padding": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + class LlavaNextVideoProcessor(ProcessorMixin): r""" Constructs a LLaVa-NeXT-Video processor which wraps a LLaVa-NeXT image processor, LLaVa-NeXT-Video video processor and @@ -102,13 +111,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, videos: VideoInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + **kwargs: Unpack[LlavaNextVideoProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -130,19 +137,6 @@ def __call__( videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: @@ -160,13 +154,21 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + LlavaNextVideoProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) if images is not None: - image_inputs = self.image_processor(images, return_tensors=return_tensors) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) else: image_inputs = {} if videos is not None: - videos_inputs = self.video_processor(videos, return_tensors=return_tensors) + videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"]) else: videos_inputs = {} @@ -212,13 +214,7 @@ def __call__( prompt_strings.append(sample) text = prompt_strings - text_inputs = self.tokenizer( - text, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) # Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index f4ca90f28c21..aa97799da645 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -41,7 +41,7 @@ class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False): "padding": False, }, "image_kwargs": {}, - "video_kwargs": {}, + "videos_kwargs": {}, } diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 5afba0d7c041..0ce0d3521d3a 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -345,9 +345,9 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro # if batched text inputs, we set padding to True unless specified otherwise if isinstance(text, (list, tuple)) and len(text) > 1: processing_kwargs.setdefault("padding", True) - model_inputs = self.processor( - images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs - ).to(dtype=self.torch_dtype) + model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **processing_kwargs).to( + dtype=self.torch_dtype + ) model_inputs["text"] = inputs_text diff --git a/tests/models/llava_next_video/test_processor_llava_next_video.py b/tests/models/llava_next_video/test_processor_llava_next_video.py new file mode 100644 index 000000000000..764c944bac89 --- /dev/null +++ b/tests/models/llava_next_video/test_processor_llava_next_video.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import shutil +import tempfile +import unittest + +from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor +from transformers.testing_utils import require_av, require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import LlavaNextImageProcessor, LlavaNextVideoImageProcessor + +if is_torch_available: + import torch + + +@require_vision +class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = LlavaNextVideoProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = LlavaNextImageProcessor() + video_processor = LlavaNextVideoImageProcessor() + tokenizer = LlamaTokenizerFast.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf") + processor_kwargs = self.prepare_processor_dict() + + processor = LlavaNextVideoProcessor( + video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs + ) + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def get_video_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor + + def prepare_processor_dict(self): + return { + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '