Skip to content

Commit 0decb11

Browse files
committed
Make kwargs uniform for VisionTextDualEncoder
1 parent 203e270 commit 0decb11

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@
1717
"""
1818

1919
import warnings
20+
from typing import List, Optional, Union
2021

21-
from ...processing_utils import ProcessorMixin
22-
from ...tokenization_utils_base import BatchEncoding
22+
from ...image_utils import ImageInput
23+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
24+
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
25+
26+
27+
class VisionTextDualEncoderProcessorKwargs(ProcessingKwargs, total=False):
28+
_defaults = {}
2329

2430

2531
class VisionTextDualEncoderProcessor(ProcessorMixin):
@@ -61,7 +67,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
6167
super().__init__(image_processor, tokenizer)
6268
self.current_processor = self.image_processor
6369

64-
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
70+
def __call__(
71+
self,
72+
images: Optional[ImageInput] = None,
73+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
74+
audio=None,
75+
videos=None,
76+
**kwargs: Unpack[VisionTextDualEncoderProcessorKwargs],
77+
) -> BatchEncoding:
6578
"""
6679
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
6780
and `kwargs` arguments to VisionTextDualEncoderTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not
@@ -99,19 +112,28 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
99112
if text is None and images is None:
100113
raise ValueError("You have to specify either text or images. Both cannot be none.")
101114

115+
output_kwargs = self._merge_kwargs(
116+
VisionTextDualEncoderProcessorKwargs,
117+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
118+
**kwargs,
119+
)
120+
102121
if text is not None:
103-
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
122+
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
104123

105124
if images is not None:
106-
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
125+
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
107126

108127
if text is not None and images is not None:
109128
encoding["pixel_values"] = image_features.pixel_values
110129
return encoding
111130
elif text is not None:
112131
return encoding
113132
else:
114-
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
133+
return BatchEncoding(
134+
data=dict(**image_features),
135+
tensor_type=output_kwargs["common_kwargs"].get("return_tensors"),
136+
)
115137

116138
def batch_decode(self, *args, **kwargs):
117139
"""

0 commit comments

Comments
 (0)