Skip to content

Commit 0065697

Browse files
committed
fix tests processors idefics/2
1 parent 6e97b9a commit 0065697

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

tests/models/idefics/test_processor_idefics.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_image_processor_defaults_preserved_by_image_kwargs(self):
229229
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
230230
self.skip_processor_without_typed_kwargs(processor)
231231

232-
input_str = "lower newer"
232+
input_str = self.prepare_text_inputs()
233233
image_input = self.prepare_image_inputs()
234234

235235
inputs = processor(text=input_str, images=image_input)
@@ -246,7 +246,7 @@ def test_kwargs_overrides_default_image_processor_kwargs(self):
246246
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
247247
self.skip_processor_without_typed_kwargs(processor)
248248

249-
input_str = "lower newer"
249+
input_str = self.prepare_text_inputs()
250250
image_input = self.prepare_image_inputs()
251251

252252
inputs = processor(text=input_str, images=image_input, image_size=224)
@@ -263,7 +263,7 @@ def test_unstructured_kwargs(self):
263263
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
264264
self.skip_processor_without_typed_kwargs(processor)
265265

266-
input_str = "lower newer"
266+
input_str = self.prepare_text_inputs()
267267
image_input = self.prepare_image_inputs()
268268
inputs = processor(
269269
text=input_str,
@@ -288,8 +288,8 @@ def test_unstructured_kwargs_batched(self):
288288
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
289289
self.skip_processor_without_typed_kwargs(processor)
290290

291-
input_str = ["lower newer", "upper older longer string"]
292-
image_input = self.prepare_image_inputs() * 2
291+
input_str = self.prepare_text_inputs(batch_size=2)
292+
image_input = self.prepare_image_inputs(batch_size=2)
293293
inputs = processor(
294294
text=input_str,
295295
images=image_input,
@@ -313,7 +313,7 @@ def test_structured_kwargs_nested(self):
313313
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
314314
self.skip_processor_without_typed_kwargs(processor)
315315

316-
input_str = "lower newer"
316+
input_str = self.prepare_text_inputs()
317317
image_input = self.prepare_image_inputs()
318318

319319
# Define the kwargs for each modality
@@ -339,7 +339,7 @@ def test_structured_kwargs_nested_from_dict(self):
339339

340340
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
341341
self.skip_processor_without_typed_kwargs(processor)
342-
input_str = "lower newer"
342+
input_str = self.prepare_text_inputs()
343343
image_input = self.prepare_image_inputs()
344344

345345
# Define the kwargs for each modality

tests/models/idefics2/test_processor_idefics2.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import tempfile
1818
import unittest
1919
from io import BytesIO
20+
from typing import Optional
2021

21-
import numpy as np
2222
import requests
2323

2424
from transformers import Idefics2Processor
@@ -262,16 +262,26 @@ def test_apply_chat_template(self):
262262
)
263263
self.assertEqual(rendered, expected_rendered)
264264

265-
def prepare_text_inputs(self, batched=False):
266-
if batched:
267-
return ["<image> lower newer", "<image> upper older longer string"]
268-
return "<image> lower newer"
265+
# Override as Idefics2Processor needs image tokens in prompts
266+
def prepare_text_inputs(self, batch_size: Optional[int] = None):
267+
if batch_size is None:
268+
return "lower newer <image>"
269269

270+
if batch_size < 1:
271+
raise ValueError("batch_size must be greater than 0")
272+
273+
if batch_size == 1:
274+
return ["lower newer <image>"]
275+
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
276+
batch_size - 2
277+
)
278+
279+
# Override as PixtralProcessor needs nested images to work properly with batched inputs
270280
@require_vision
271-
def prepare_image_inputs(self, batched=False):
281+
def prepare_image_inputs(self, batch_size: Optional[int] = None):
272282
"""This function prepares a list of PIL images for testing"""
273-
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
274-
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
275-
if batched:
276-
return [image_inputs] * 2
277-
return image_inputs
283+
if batch_size is None:
284+
return super().prepare_image_inputs()
285+
if batch_size < 1:
286+
raise ValueError("batch_size must be greater than 0")
287+
return [[super().prepare_image_inputs()]] * batch_size

0 commit comments

Comments
 (0)