Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fd69d5c

Browse files
committedJan 14, 2025·
add videos_kwargs tests
1 parent 7a411e9 commit fd69d5c

File tree

3 files changed

+194
-12
lines changed

3 files changed

+194
-12
lines changed
 

‎src/transformers/models/llava_next_video/image_processing_llava_next_video.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def make_batched_videos(videos) -> List[VideoInput]:
5454
return videos
5555

5656
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
57-
if isinstance(videos[0], Image.Image):
57+
if isinstance(videos[0], Image.Image) or len(videos[0].shape) == 3:
5858
return [videos]
5959
elif len(videos[0].shape) == 4:
6060
return [list(video) for video in videos]

‎tests/models/llava_next_video/test_processor_llava_next_video.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,15 @@
1818
import tempfile
1919
import unittest
2020

21+
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
2122
from transformers.testing_utils import require_av, require_torch, require_vision
2223
from transformers.utils import is_torch_available, is_vision_available
2324

2425
from ...test_processing_common import ProcessorTesterMixin
2526

2627

2728
if is_vision_available():
28-
from transformers import (
29-
AutoProcessor,
30-
LlamaTokenizerFast,
31-
LlavaNextImageProcessor,
32-
LlavaNextVideoImageProcessor,
33-
LlavaNextVideoProcessor,
34-
)
29+
from transformers import LlavaNextImageProcessor, LlavaNextVideoImageProcessor
3530

3631
if is_torch_available:
3732
import torch
@@ -66,13 +61,13 @@ def prepare_processor_dict(self):
6661
return {
6762
"chat_template": "dummy_template",
6863
"num_additional_image_tokens": 6,
64+
"patch_size": 4,
6965
"vision_feature_select_strategy": "default",
7066
}
7167

7268
def test_processor_to_json_string(self):
7369
processor = self.get_processor()
7470
obj = json.loads(processor.to_json_string())
75-
print(processor)
7671
for key, value in self.prepare_processor_dict().items():
7772
# chat_tempalate are tested as a separate test because they are saved in separate files
7873
if key != "chat_template":

‎tests/test_processing_common.py

+190-3
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,12 @@ def prepare_image_inputs(self, batch_size: Optional[int] = None):
121121
return prepare_image_inputs() * batch_size
122122

123123
@require_vision
124-
def prepare_video_inputs(self):
124+
def prepare_video_inputs(self, batch_size: Optional[int] = None):
125125
"""This function prepares a list of numpy videos."""
126126
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
127-
image_inputs = [video_input] * 3 # batch-size=3
128-
return image_inputs
127+
if batch_size is None:
128+
return video_input
129+
return [video_input] * batch_size
129130

130131
def test_processor_to_json_string(self):
131132
processor = self.get_processor()
@@ -484,6 +485,192 @@ def test_structured_kwargs_audio_nested(self):
484485
elif "labels" in inputs:
485486
self.assertEqual(len(inputs["labels"][0]), 76)
486487

488+
def test_tokenizer_defaults_preserved_by_kwargs_video(self):
489+
if "video_processor" not in self.processor_class.attributes:
490+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
491+
processor_components = self.prepare_components()
492+
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
493+
processor_kwargs = self.prepare_processor_dict()
494+
495+
processor = self.processor_class(**processor_components, **processor_kwargs)
496+
self.skip_processor_without_typed_kwargs(processor)
497+
input_str = self.prepare_text_inputs()
498+
video_input = self.prepare_video_inputs()
499+
inputs = processor(text=input_str, videos=video_input, return_tensors="pt")
500+
self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
501+
502+
def test_video_processor_defaults_preserved_by_video_kwargs(self):
503+
"""
504+
We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor.
505+
We then check that the mean of the pixel_values is less than or equal to 0 after processing.
506+
Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied.
507+
"""
508+
if "video_processor" not in self.processor_class.attributes:
509+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
510+
processor_components = self.prepare_components()
511+
processor_components["video_processor"] = self.get_component(
512+
"video_processor", do_rescale=True, rescale_factor=-1
513+
)
514+
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
515+
processor_kwargs = self.prepare_processor_dict()
516+
517+
processor = self.processor_class(**processor_components, **processor_kwargs)
518+
self.skip_processor_without_typed_kwargs(processor)
519+
520+
input_str = self.prepare_text_inputs()
521+
video_input = self.prepare_video_inputs()
522+
523+
inputs = processor(text=input_str, videos=video_input, return_tensors="pt")
524+
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
525+
526+
def test_kwargs_overrides_default_tokenizer_kwargs_video(self):
527+
if "video_processor" not in self.processor_class.attributes:
528+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
529+
processor_components = self.prepare_components()
530+
processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest")
531+
processor_kwargs = self.prepare_processor_dict()
532+
533+
processor = self.processor_class(**processor_components, **processor_kwargs)
534+
self.skip_processor_without_typed_kwargs(processor)
535+
input_str = self.prepare_text_inputs()
536+
video_input = self.prepare_video_inputs()
537+
inputs = processor(
538+
text=input_str, videos=video_input, return_tensors="pt", max_length=112, padding="max_length"
539+
)
540+
self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
541+
542+
def test_kwargs_overrides_default_video_processor_kwargs(self):
543+
if "video_processor" not in self.processor_class.attributes:
544+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
545+
processor_components = self.prepare_components()
546+
processor_components["video_processor"] = self.get_component(
547+
"video_processor", do_rescale=True, rescale_factor=1
548+
)
549+
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
550+
processor_kwargs = self.prepare_processor_dict()
551+
552+
processor = self.processor_class(**processor_components, **processor_kwargs)
553+
self.skip_processor_without_typed_kwargs(processor)
554+
555+
input_str = self.prepare_text_inputs()
556+
video_input = self.prepare_video_inputs()
557+
558+
inputs = processor(text=input_str, videos=video_input, do_rescale=True, rescale_factor=-1, return_tensors="pt")
559+
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
560+
561+
def test_unstructured_kwargs_video(self):
562+
if "video_processor" not in self.processor_class.attributes:
563+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
564+
processor_components = self.prepare_components()
565+
processor_kwargs = self.prepare_processor_dict()
566+
processor = self.processor_class(**processor_components, **processor_kwargs)
567+
self.skip_processor_without_typed_kwargs(processor)
568+
569+
input_str = self.prepare_text_inputs()
570+
video_input = self.prepare_video_inputs()
571+
inputs = processor(
572+
text=input_str,
573+
videos=video_input,
574+
return_tensors="pt",
575+
do_rescale=True,
576+
rescale_factor=-1,
577+
padding="max_length",
578+
max_length=76,
579+
)
580+
581+
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
582+
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
583+
584+
def test_unstructured_kwargs_batched_video(self):
585+
if "video_processor" not in self.processor_class.attributes:
586+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
587+
processor_components = self.prepare_components()
588+
processor_kwargs = self.prepare_processor_dict()
589+
processor = self.processor_class(**processor_components, **processor_kwargs)
590+
self.skip_processor_without_typed_kwargs(processor)
591+
592+
input_str = self.prepare_text_inputs(batch_size=2)
593+
video_input = self.prepare_video_inputs(batch_size=2)
594+
inputs = processor(
595+
text=input_str,
596+
videos=video_input,
597+
return_tensors="pt",
598+
do_rescale=True,
599+
rescale_factor=-1,
600+
padding="longest",
601+
max_length=76,
602+
)
603+
604+
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
605+
self.assertTrue(
606+
len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1])
607+
and len(inputs[self.text_input_name][1]) < 76
608+
)
609+
610+
def test_doubly_passed_kwargs_video(self):
611+
if "video_processor" not in self.processor_class.attributes:
612+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
613+
processor_components = self.prepare_components()
614+
processor_kwargs = self.prepare_processor_dict()
615+
processor = self.processor_class(**processor_components, **processor_kwargs)
616+
self.skip_processor_without_typed_kwargs(processor)
617+
618+
input_str = [self.prepare_text_inputs()]
619+
video_input = self.prepare_video_inputs()
620+
with self.assertRaises(ValueError):
621+
_ = processor(
622+
text=input_str,
623+
videos=video_input,
624+
videos_kwargs={"do_rescale": True, "rescale_factor": -1},
625+
do_rescale=True,
626+
return_tensors="pt",
627+
)
628+
629+
def test_structured_kwargs_nested_video(self):
630+
if "video_processor" not in self.processor_class.attributes:
631+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
632+
processor_components = self.prepare_components()
633+
processor_kwargs = self.prepare_processor_dict()
634+
processor = self.processor_class(**processor_components, **processor_kwargs)
635+
self.skip_processor_without_typed_kwargs(processor)
636+
637+
input_str = self.prepare_text_inputs()
638+
video_input = self.prepare_video_inputs()
639+
640+
# Define the kwargs for each modality
641+
all_kwargs = {
642+
"common_kwargs": {"return_tensors": "pt"},
643+
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
644+
"text_kwargs": {"padding": "max_length", "max_length": 76},
645+
}
646+
647+
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
648+
self.skip_processor_without_typed_kwargs(processor)
649+
650+
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
651+
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
652+
653+
def test_structured_kwargs_nested_from_dict_video(self):
654+
if "video_processor" not in self.processor_class.attributes:
655+
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
656+
processor_components = self.prepare_components()
657+
processor_kwargs = self.prepare_processor_dict()
658+
processor = self.processor_class(**processor_components, **processor_kwargs)
659+
self.skip_processor_without_typed_kwargs(processor)
660+
input_str = self.prepare_text_inputs()
661+
video_input = self.prepare_video_inputs()
662+
663+
# Define the kwargs for each modality
664+
all_kwargs = {
665+
"common_kwargs": {"return_tensors": "pt"},
666+
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
667+
"text_kwargs": {"padding": "max_length", "max_length": 76},
668+
}
669+
670+
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
671+
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
672+
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
673+
487674
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
488675
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
489676
def test_overlapping_text_kwargs_handling(self):

0 commit comments

Comments
 (0)
Please sign in to comment.