Skip to content

Commit 1456120

Browse files
authored
Uniformize kwargs for Udop processor and update docs (#33628)
* Add optional kwargs and uniformize udop * cleanup Unpack * nit Udop
1 parent be9cf07 commit 1456120

File tree

3 files changed

+110
-88
lines changed

3 files changed

+110
-88
lines changed

src/transformers/models/udop/modeling_udop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1790,7 +1790,7 @@ def forward(
17901790
>>> # one can use the various task prefixes (prompts) used during pre-training
17911791
>>> # e.g. the task prefix for DocVQA is "Question answering. "
17921792
>>> question = "Question answering. What is the date on the form?"
1793-
>>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
1793+
>>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt")
17941794
17951795
>>> # autoregressive generation
17961796
>>> predicted_ids = model.generate(**encoding)

src/transformers/models/udop/processing_udop.py

Lines changed: 91 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,38 @@
1818

1919
from typing import List, Optional, Union
2020

21+
from transformers import logging
22+
23+
from ...image_processing_utils import BatchFeature
2124
from ...image_utils import ImageInput
22-
from ...processing_utils import ProcessorMixin
23-
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
24-
from ...utils import TensorType
25+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
26+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
27+
28+
29+
logger = logging.get_logger(__name__)
30+
31+
32+
class UdopTextKwargs(TextKwargs, total=False):
33+
word_labels: Optional[Union[List[int], List[List[int]]]]
34+
boxes: Union[List[List[int]], List[List[List[int]]]]
35+
36+
37+
class UdopProcessorKwargs(ProcessingKwargs, total=False):
38+
text_kwargs: UdopTextKwargs
39+
_defaults = {
40+
"text_kwargs": {
41+
"add_special_tokens": True,
42+
"padding": False,
43+
"truncation": False,
44+
"stride": 0,
45+
"return_overflowing_tokens": False,
46+
"return_special_tokens_mask": False,
47+
"return_offsets_mapping": False,
48+
"return_length": False,
49+
"verbose": True,
50+
},
51+
"images_kwargs": {},
52+
}
2553

2654

2755
class UdopProcessor(ProcessorMixin):
@@ -49,6 +77,8 @@ class UdopProcessor(ProcessorMixin):
4977
attributes = ["image_processor", "tokenizer"]
5078
image_processor_class = "LayoutLMv3ImageProcessor"
5179
tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast")
80+
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
81+
optional_call_args = ["text_pair"]
5282

5383
def __init__(self, image_processor, tokenizer):
5484
super().__init__(image_processor, tokenizer)
@@ -57,28 +87,16 @@ def __call__(
5787
self,
5888
images: Optional[ImageInput] = None,
5989
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
60-
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
61-
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
62-
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
63-
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
64-
text_pair_target: Optional[
65-
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
66-
] = None,
67-
add_special_tokens: bool = True,
68-
padding: Union[bool, str, PaddingStrategy] = False,
69-
truncation: Union[bool, str, TruncationStrategy] = False,
70-
max_length: Optional[int] = None,
71-
stride: int = 0,
72-
pad_to_multiple_of: Optional[int] = None,
73-
return_token_type_ids: Optional[bool] = None,
74-
return_attention_mask: Optional[bool] = None,
75-
return_overflowing_tokens: bool = False,
76-
return_special_tokens_mask: bool = False,
77-
return_offsets_mapping: bool = False,
78-
return_length: bool = False,
79-
verbose: bool = True,
80-
return_tensors: Optional[Union[str, TensorType]] = None,
81-
) -> BatchEncoding:
90+
# The following is to capture `text_pair` argument that may be passed as a positional argument.
91+
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
92+
# or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
93+
# This behavior is only needed for backward compatibility and will be removed in future versions.
94+
#
95+
*args,
96+
audio=None,
97+
videos=None,
98+
**kwargs: Unpack[UdopProcessorKwargs],
99+
) -> BatchFeature:
82100
"""
83101
This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case
84102
[`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
@@ -93,6 +111,20 @@ def __call__(
93111
Please refer to the docstring of the above two methods for more information.
94112
"""
95113
# verify input
114+
output_kwargs = self._merge_kwargs(
115+
UdopProcessorKwargs,
116+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
117+
**kwargs,
118+
**self.prepare_and_validate_optional_call_args(*args),
119+
)
120+
121+
boxes = output_kwargs["text_kwargs"].pop("boxes", None)
122+
word_labels = output_kwargs["text_kwargs"].pop("word_labels", None)
123+
text_pair = output_kwargs["text_kwargs"].pop("text_pair", None)
124+
return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False)
125+
return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False)
126+
text_target = output_kwargs["text_kwargs"].get("text_target", None)
127+
96128
if self.image_processor.apply_ocr and (boxes is not None):
97129
raise ValueError(
98130
"You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True."
@@ -103,69 +135,47 @@ def __call__(
103135
"You cannot provide word labels if you initialized the image processor with apply_ocr set to True."
104136
)
105137

106-
if return_overflowing_tokens is True and return_offsets_mapping is False:
138+
if return_overflowing_tokens and not return_offsets_mapping:
107139
raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")
108140

109141
if text_target is not None:
110142
# use the processor to prepare the targets of UDOP
111143
return self.tokenizer(
112-
text_target=text_target,
113-
text_pair_target=text_pair_target,
114-
add_special_tokens=add_special_tokens,
115-
padding=padding,
116-
truncation=truncation,
117-
max_length=max_length,
118-
stride=stride,
119-
pad_to_multiple_of=pad_to_multiple_of,
120-
return_token_type_ids=return_token_type_ids,
121-
return_attention_mask=return_attention_mask,
122-
return_overflowing_tokens=return_overflowing_tokens,
123-
return_special_tokens_mask=return_special_tokens_mask,
124-
return_offsets_mapping=return_offsets_mapping,
125-
return_length=return_length,
126-
verbose=verbose,
127-
return_tensors=return_tensors,
144+
**output_kwargs["text_kwargs"],
128145
)
129146

130147
else:
131148
# use the processor to prepare the inputs of UDOP
132149
# first, apply the image processor
133-
features = self.image_processor(images=images, return_tensors=return_tensors)
150+
features = self.image_processor(images=images, **output_kwargs["images_kwargs"])
151+
features_words = features.pop("words", None)
152+
features_boxes = features.pop("boxes", None)
153+
154+
output_kwargs["text_kwargs"].pop("text_target", None)
155+
output_kwargs["text_kwargs"].pop("text_pair_target", None)
156+
output_kwargs["text_kwargs"]["text_pair"] = text_pair
157+
output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes
158+
output_kwargs["text_kwargs"]["word_labels"] = word_labels
134159

135160
# second, apply the tokenizer
136161
if text is not None and self.image_processor.apply_ocr and text_pair is None:
137162
if isinstance(text, str):
138163
text = [text] # add batch dimension (as the image processor always adds a batch dimension)
139-
text_pair = features["words"]
164+
output_kwargs["text_kwargs"]["text_pair"] = features_words
140165

141166
encoded_inputs = self.tokenizer(
142-
text=text if text is not None else features["words"],
143-
text_pair=text_pair if text_pair is not None else None,
144-
boxes=boxes if boxes is not None else features["boxes"],
145-
word_labels=word_labels,
146-
add_special_tokens=add_special_tokens,
147-
padding=padding,
148-
truncation=truncation,
149-
max_length=max_length,
150-
stride=stride,
151-
pad_to_multiple_of=pad_to_multiple_of,
152-
return_token_type_ids=return_token_type_ids,
153-
return_attention_mask=return_attention_mask,
154-
return_overflowing_tokens=return_overflowing_tokens,
155-
return_special_tokens_mask=return_special_tokens_mask,
156-
return_offsets_mapping=return_offsets_mapping,
157-
return_length=return_length,
158-
verbose=verbose,
159-
return_tensors=return_tensors,
167+
text=text if text is not None else features_words,
168+
**output_kwargs["text_kwargs"],
160169
)
161170

162171
# add pixel values
163-
pixel_values = features.pop("pixel_values")
164172
if return_overflowing_tokens is True:
165-
pixel_values = self.get_overflowing_images(pixel_values, encoded_inputs["overflow_to_sample_mapping"])
166-
encoded_inputs["pixel_values"] = pixel_values
173+
features["pixel_values"] = self.get_overflowing_images(
174+
features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"]
175+
)
176+
features.update(encoded_inputs)
167177

168-
return encoded_inputs
178+
return features
169179

170180
# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images
171181
def get_overflowing_images(self, images, overflow_to_sample_mapping):
@@ -198,7 +208,20 @@ def decode(self, *args, **kwargs):
198208
"""
199209
return self.tokenizer.decode(*args, **kwargs)
200210

211+
def post_process_image_text_to_text(self, generated_outputs):
212+
"""
213+
Post-process the output of the model to decode the text.
214+
215+
Args:
216+
generated_outputs (`torch.Tensor` or `np.ndarray`):
217+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
218+
or `(sequence_length,)`.
219+
220+
Returns:
221+
`List[str]`: The decoded text.
222+
"""
223+
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
224+
201225
@property
202-
# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names
203226
def model_input_names(self):
204-
return ["input_ids", "bbox", "attention_mask", "pixel_values"]
227+
return ["pixel_values", "input_ids", "bbox", "attention_mask"]

tests/models/udop/test_processor_udop.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import json
16-
import os
1715
import shutil
1816
import tempfile
1917
import unittest
@@ -34,7 +32,7 @@
3432
require_torch,
3533
slow,
3634
)
37-
from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available, is_torch_available
35+
from transformers.utils import cached_property, is_pytesseract_available, is_torch_available
3836

3937
from ...test_processing_common import ProcessorTesterMixin
4038

@@ -55,20 +53,19 @@
5553
class UdopProcessorTest(ProcessorTesterMixin, unittest.TestCase):
5654
tokenizer_class = UdopTokenizer
5755
rust_tokenizer_class = UdopTokenizerFast
58-
maxDiff = None
5956
processor_class = UdopProcessor
57+
maxDiff = None
6058

6159
def setUp(self):
62-
image_processor_map = {
63-
"do_resize": True,
64-
"size": 224,
65-
"apply_ocr": True,
66-
}
67-
6860
self.tmpdirname = tempfile.mkdtemp()
69-
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
70-
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
71-
fp.write(json.dumps(image_processor_map) + "\n")
61+
image_processor = LayoutLMv3ImageProcessor(
62+
do_resize=True,
63+
size=224,
64+
apply_ocr=True,
65+
)
66+
tokenizer = UdopTokenizer.from_pretrained("microsoft/udop-large")
67+
processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer)
68+
processor.save_pretrained(self.tmpdirname)
7269

7370
self.tokenizer_pretrained_name = "microsoft/udop-large"
7471

@@ -80,15 +77,15 @@ def setUp(self):
8077
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
8178
return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
8279

80+
def get_image_processor(self, **kwargs):
81+
return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
82+
8383
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
8484
return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
8585

8686
def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
8787
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
8888

89-
def get_image_processor(self, **kwargs):
90-
return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
91-
9289
def tearDown(self):
9390
shutil.rmtree(self.tmpdirname)
9491

@@ -153,7 +150,7 @@ def test_model_input_names(self):
153150
input_str = "lower newer"
154151
image_input = self.prepare_image_inputs()
155152

156-
inputs = processor(text=input_str, images=image_input)
153+
inputs = processor(images=image_input, text=input_str)
157154

158155
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
159156

@@ -472,7 +469,7 @@ def test_processor_case_5(self):
472469
question = "What's his name?"
473470
words = ["hello", "world"]
474471
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
475-
input_processor = processor(images[0], question, words, boxes, return_tensors="pt")
472+
input_processor = processor(images[0], question, text_pair=words, boxes=boxes, return_tensors="pt")
476473

477474
# verify keys
478475
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
@@ -488,7 +485,9 @@ def test_processor_case_5(self):
488485
questions = ["How old is he?", "what's the time"]
489486
words = [["hello", "world"], ["my", "name", "is", "niels"]]
490487
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
491-
input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt")
488+
input_processor = processor(
489+
images, questions, text_pair=words, boxes=boxes, padding=True, return_tensors="pt"
490+
)
492491

493492
# verify keys
494493
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]

0 commit comments

Comments
 (0)