13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import shutil
17
+ import tempfile
16
18
import unittest
19
+ from typing import Optional
17
20
18
21
import numpy as np
19
22
20
23
from transformers import MllamaProcessor
21
24
from transformers .testing_utils import require_torch , require_vision
22
25
from transformers .utils import is_vision_available
23
26
27
+ from ...test_processing_common import ProcessorTesterMixin
28
+
24
29
25
30
if is_vision_available ():
26
31
from PIL import Image
27
32
28
33
29
34
@require_torch
30
35
@require_vision
31
- class MllamaProcessorTest (unittest .TestCase ):
36
+ class MllamaProcessorTest (ProcessorTesterMixin , unittest .TestCase ):
37
+ processor_class = MllamaProcessor
38
+
32
39
def setUp (self ):
33
- self .checkpoint = "hf-internal-testing/mllama-11b" # TODO: change
34
- self . processor = MllamaProcessor .from_pretrained (self .checkpoint )
40
+ self .checkpoint = "hf-internal-testing/mllama-11b"
41
+ processor = MllamaProcessor .from_pretrained (self .checkpoint )
35
42
self .image1 = Image .new ("RGB" , (224 , 220 ))
36
43
self .image2 = Image .new ("RGB" , (512 , 128 ))
37
- self .image_token = self .processor .image_token
38
- self .image_token_id = self .processor .image_token_id
39
- self .pad_token_id = self .processor .tokenizer .pad_token_id
40
- self .bos_token = self .processor .bos_token
41
- self .bos_token_id = self .processor .tokenizer .bos_token_id
44
+ self .image_token = processor .image_token
45
+ self .image_token_id = processor .image_token_id
46
+ self .pad_token_id = processor .tokenizer .pad_token_id
47
+ self .bos_token = processor .bos_token
48
+ self .bos_token_id = processor .tokenizer .bos_token_id
49
+ self .tmpdirname = tempfile .mkdtemp ()
50
+ processor .save_pretrained (self .tmpdirname )
51
+
52
+ def tearDown (self ):
53
+ shutil .rmtree (self .tmpdirname )
42
54
43
55
def test_apply_chat_template (self ):
44
56
# Message contains content which a mix of lists with images and image urls and string
@@ -64,8 +76,8 @@ def test_apply_chat_template(self):
64
76
],
65
77
},
66
78
]
67
-
68
- rendered = self . processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = False )
79
+ processor = MllamaProcessor . from_pretrained ( self . tmpdirname )
80
+ rendered = processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = False )
69
81
70
82
expected_rendered = (
71
83
"<|begin_of_text|>"
@@ -96,7 +108,7 @@ def test_apply_chat_template(self):
96
108
],
97
109
},
98
110
]
99
- input_ids = self . processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = True )
111
+ input_ids = processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = True )
100
112
expected_ids = [
101
113
128000 , # <|begin_of_text|>
102
114
128006 , # <|start_header_id|>
@@ -142,15 +154,15 @@ def test_apply_chat_template(self):
142
154
}
143
155
]
144
156
145
- rendered = self . processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = False )
157
+ rendered = processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = False )
146
158
expected_rendered = (
147
159
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n \n "
148
160
"Describe this image in two sentences<|image|> Test sentence <|image|>ok\n <|eot_id|>"
149
161
"<|start_header_id|>assistant<|end_header_id|>\n \n "
150
162
)
151
163
self .assertEqual (rendered , expected_rendered )
152
164
153
- input_ids = self . processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = True )
165
+ input_ids = processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = True )
154
166
# fmt: off
155
167
expected_ids = [
156
168
128000 , 128006 , 882 , 128007 , 271 , 75885 , 420 , 2217 , 304 , 1403 , 23719 , 128256 ,
@@ -176,18 +188,19 @@ def test_apply_chat_template(self):
176
188
}
177
189
]
178
190
179
- rendered_list = self . processor .apply_chat_template (messages_list , add_generation_prompt = True , tokenize = False )
180
- rendered_str = self . processor .apply_chat_template (messages_str , add_generation_prompt = True , tokenize = False )
191
+ rendered_list = processor .apply_chat_template (messages_list , add_generation_prompt = True , tokenize = False )
192
+ rendered_str = processor .apply_chat_template (messages_str , add_generation_prompt = True , tokenize = False )
181
193
self .assertEqual (rendered_list , rendered_str )
182
194
183
195
def test_process_interleaved_images_prompts_image_splitting (self ):
196
+ processor = MllamaProcessor .from_pretrained (self .tmpdirname )
184
197
# Test that a single image is processed correctly
185
- inputs = self . processor (images = self .image2 , size = {"width" : 224 , "height" : 224 })
198
+ inputs = processor (images = self .image2 , size = {"width" : 224 , "height" : 224 })
186
199
self .assertEqual (inputs ["pixel_values" ].shape , (1 , 1 , 4 , 3 , 224 , 224 ))
187
200
188
201
# Test that text is processed correctly
189
202
text = "<|begin_of_text|>This is a test sentence.<|end_of_text|>"
190
- inputs = self . processor (text = text )
203
+ inputs = processor (text = text )
191
204
expected_ids = [128000 , 2028 , 374 , 264 , 1296 , 11914 , 13 , 128001 ]
192
205
self .assertEqual (inputs ["input_ids" ][0 ], expected_ids )
193
206
self .assertEqual (inputs ["attention_mask" ][0 ], [1 ] * len (expected_ids ))
@@ -197,7 +210,7 @@ def test_process_interleaved_images_prompts_image_splitting(self):
197
210
image_str = "<|image|>"
198
211
text_str = "This is a test sentence."
199
212
text = image_str + text_str
200
- inputs = self . processor (
213
+ inputs = processor (
201
214
text = text ,
202
215
images = self .image1 ,
203
216
size = {"width" : 128 , "height" : 128 },
@@ -225,7 +238,7 @@ def test_process_interleaved_images_prompts_image_splitting(self):
225
238
]
226
239
# fmt: onn
227
240
images = [[self .image1 ], [self .image1 , self .image2 ]]
228
- inputs = self . processor (text = text , images = images , padding = True , size = {"width" : 256 , "height" : 256 })
241
+ inputs = processor (text = text , images = images , padding = True , size = {"width" : 256 , "height" : 256 })
229
242
230
243
self .assertEqual (inputs ["pixel_values" ].shape , (2 , 2 , 4 , 3 , 256 , 256 ))
231
244
for input_ids_i , attention_mask_i , expected_ids_i in zip (inputs ["input_ids" ], inputs ["attention_mask" ], expected_ids ):
@@ -264,34 +277,49 @@ def test_process_interleaved_images_prompts_image_error(self):
264
277
"This is a test sentence." ,
265
278
"In this other sentence we try some good things" ,
266
279
]
267
- inputs = self .processor (text = text , images = None , padding = True )
280
+ processor = MllamaProcessor .from_pretrained (self .tmpdirname )
281
+ inputs = processor (text = text , images = None , padding = True )
268
282
self .assertIsNotNone (inputs ["input_ids" ])
269
283
270
284
text = [
271
285
"This is a test sentence.<|image|>" ,
272
286
"In this other sentence we try some good things" ,
273
287
]
274
288
with self .assertRaises (ValueError ):
275
- self . processor (text = text , images = None , padding = True )
289
+ processor (text = text , images = None , padding = True )
276
290
277
291
images = [[self .image1 ], []]
278
292
with self .assertRaises (ValueError ):
279
- self . processor (text = text , images = images , padding = True )
293
+ processor (text = text , images = images , padding = True )
280
294
281
295
text = [
282
296
"This is a test sentence.<|image|>" ,
283
297
"In this other sentence we try some good things<|image|>" ,
284
298
]
285
299
with self .assertRaises (ValueError ):
286
- self . processor (text = text , images = None , padding = True )
300
+ processor (text = text , images = None , padding = True )
287
301
288
302
text = [
289
303
"This is a test sentence.<|image|>" ,
290
304
"In this other sentence we try some good things<|image|>" ,
291
305
]
292
306
images = [[self .image1 ], [self .image2 ]]
293
- inputs = self . processor (text = text , images = images , padding = True )
307
+ inputs = processor (text = text , images = images , padding = True )
294
308
295
309
images = [[self .image1 , self .image2 ], []]
296
310
with self .assertRaises (ValueError ):
297
- self .processor (text = text , images = None , padding = True )
311
+ processor (text = text , images = None , padding = True )
312
+
313
+ # Override as MllamaProcessor needs image tokens in prompts
314
+ def prepare_text_inputs (self , batch_size : Optional [int ] = None ):
315
+ if batch_size is None :
316
+ return "lower newer <|image|>"
317
+
318
+ if batch_size < 1 :
319
+ raise ValueError ("batch_size must be greater than 0" )
320
+
321
+ if batch_size == 1 :
322
+ return ["lower newer <|image|>" ]
323
+ return ["lower newer <|image|>" , "<|image|> upper older longer string" ] + ["<|image|> lower newer" ] * (
324
+ batch_size - 2
325
+ )
0 commit comments