Skip to content

Commit 4ff7eb2

Browse files
tibor-reissBernardZach
authored andcommitted
🚨🚨🚨 Uniformize kwargs for TrOCR Processor (huggingface#34587)
* Make kwargs uniform for TrOCR * Add tests * Put back current_processor * Remove args * Add todo comment * Code review - breaking change
1 parent 5a6f716 commit 4ff7eb2

File tree

2 files changed

+155
-11
lines changed

2 files changed

+155
-11
lines changed

src/transformers/models/trocr/processing_trocr.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,16 @@
1818

1919
import warnings
2020
from contextlib import contextmanager
21+
from typing import List, Union
2122

22-
from ...processing_utils import ProcessorMixin
23+
from ...image_processing_utils import BatchFeature
24+
from ...image_utils import ImageInput
25+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
26+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
27+
28+
29+
class TrOCRProcessorKwargs(ProcessingKwargs, total=False):
30+
_defaults = {}
2331

2432

2533
class TrOCRProcessor(ProcessorMixin):
@@ -61,7 +69,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
6169
self.current_processor = self.image_processor
6270
self._in_target_context_manager = False
6371

64-
def __call__(self, *args, **kwargs):
72+
def __call__(
73+
self,
74+
images: ImageInput = None,
75+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
76+
audio=None,
77+
videos=None,
78+
**kwargs: Unpack[TrOCRProcessorKwargs],
79+
) -> BatchFeature:
6580
"""
6681
When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
6782
[`~AutoImageProcessor.__call__`] and returns its output. If used in the context
@@ -70,21 +85,21 @@ def __call__(self, *args, **kwargs):
7085
"""
7186
# For backward compatibility
7287
if self._in_target_context_manager:
73-
return self.current_processor(*args, **kwargs)
74-
75-
images = kwargs.pop("images", None)
76-
text = kwargs.pop("text", None)
77-
if len(args) > 0:
78-
images = args[0]
79-
args = args[1:]
88+
return self.current_processor(images, **kwargs)
8089

8190
if images is None and text is None:
8291
raise ValueError("You need to specify either an `images` or `text` input to process.")
8392

93+
output_kwargs = self._merge_kwargs(
94+
TrOCRProcessorKwargs,
95+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
96+
**kwargs,
97+
)
98+
8499
if images is not None:
85-
inputs = self.image_processor(images, *args, **kwargs)
100+
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
86101
if text is not None:
87-
encodings = self.tokenizer(text, **kwargs)
102+
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
88103

89104
if text is None:
90105
return inputs
+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os
2+
import shutil
3+
import tempfile
4+
import unittest
5+
6+
import pytest
7+
8+
from transformers.models.xlm_roberta.tokenization_xlm_roberta import VOCAB_FILES_NAMES
9+
from transformers.testing_utils import (
10+
require_sentencepiece,
11+
require_tokenizers,
12+
require_vision,
13+
)
14+
from transformers.utils import is_vision_available
15+
16+
from ...test_processing_common import ProcessorTesterMixin
17+
18+
19+
if is_vision_available():
20+
from transformers import TrOCRProcessor, ViTImageProcessor, XLMRobertaTokenizerFast
21+
22+
23+
@require_sentencepiece
24+
@require_tokenizers
25+
@require_vision
26+
class TrOCRProcessorTest(ProcessorTesterMixin, unittest.TestCase):
27+
text_input_name = "labels"
28+
processor_class = TrOCRProcessor
29+
30+
def setUp(self):
31+
self.tmpdirname = tempfile.mkdtemp()
32+
33+
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"] # fmt: skip
34+
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
35+
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
36+
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
37+
38+
image_processor = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
39+
tokenizer = XLMRobertaTokenizerFast.from_pretrained("FacebookAI/xlm-roberta-base")
40+
processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
41+
processor.save_pretrained(self.tmpdirname)
42+
43+
def tearDown(self):
44+
shutil.rmtree(self.tmpdirname)
45+
46+
def get_tokenizer(self, **kwargs):
47+
return XLMRobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
48+
49+
def get_image_processor(self, **kwargs):
50+
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
51+
52+
def test_save_load_pretrained_default(self):
53+
image_processor = self.get_image_processor()
54+
tokenizer = self.get_tokenizer()
55+
processor = TrOCRProcessor(image_processor=image_processor, tokenizer=tokenizer)
56+
57+
processor.save_pretrained(self.tmpdirname)
58+
processor = TrOCRProcessor.from_pretrained(self.tmpdirname)
59+
60+
self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast)
61+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
62+
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
63+
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
64+
65+
def test_save_load_pretrained_additional_features(self):
66+
processor = TrOCRProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
67+
processor.save_pretrained(self.tmpdirname)
68+
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
69+
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
70+
71+
processor = TrOCRProcessor.from_pretrained(
72+
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
73+
)
74+
75+
self.assertIsInstance(processor.tokenizer, XLMRobertaTokenizerFast)
76+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
77+
78+
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
79+
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
80+
81+
def test_image_processor(self):
82+
image_processor = self.get_image_processor()
83+
tokenizer = self.get_tokenizer()
84+
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
85+
image_input = self.prepare_image_inputs()
86+
87+
input_feat_extract = image_processor(image_input, return_tensors="np")
88+
input_processor = processor(images=image_input, return_tensors="np")
89+
90+
for key in input_feat_extract.keys():
91+
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
92+
93+
def test_tokenizer(self):
94+
image_processor = self.get_image_processor()
95+
tokenizer = self.get_tokenizer()
96+
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
97+
input_str = "lower newer"
98+
99+
encoded_processor = processor(text=input_str)
100+
encoded_tok = tokenizer(input_str)
101+
102+
for key in encoded_tok.keys():
103+
self.assertListEqual(encoded_tok[key], encoded_processor[key])
104+
105+
def test_processor_text(self):
106+
image_processor = self.get_image_processor()
107+
tokenizer = self.get_tokenizer()
108+
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
109+
input_str = "lower newer"
110+
image_input = self.prepare_image_inputs()
111+
112+
inputs = processor(text=input_str, images=image_input)
113+
114+
self.assertListEqual(list(inputs.keys()), ["pixel_values", "labels"])
115+
116+
# test if it raises when no input is passed
117+
with pytest.raises(ValueError):
118+
processor()
119+
120+
def test_tokenizer_decode(self):
121+
image_processor = self.get_image_processor()
122+
tokenizer = self.get_tokenizer()
123+
processor = TrOCRProcessor(tokenizer=tokenizer, image_processor=image_processor)
124+
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
125+
126+
decoded_processor = processor.batch_decode(predicted_ids)
127+
decoded_tok = tokenizer.batch_decode(predicted_ids)
128+
129+
self.assertListEqual(decoded_tok, decoded_processor)

0 commit comments

Comments
 (0)