|
17 | 17 | """
|
18 | 18 |
|
19 | 19 | import warnings
|
| 20 | +from typing import List, Optional, Union |
20 | 21 |
|
21 |
| -from ...processing_utils import ProcessorMixin |
22 |
| -from ...tokenization_utils_base import BatchEncoding |
| 22 | +from ...image_utils import ImageInput |
| 23 | +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack |
| 24 | +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput |
| 25 | + |
| 26 | + |
| 27 | +class VisionTextDualEncoderProcessorKwargs(ProcessingKwargs, total=False): |
| 28 | + _defaults = {} |
23 | 29 |
|
24 | 30 |
|
25 | 31 | class VisionTextDualEncoderProcessor(ProcessorMixin):
|
@@ -61,7 +67,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
61 | 67 | super().__init__(image_processor, tokenizer)
|
62 | 68 | self.current_processor = self.image_processor
|
63 | 69 |
|
64 |
| - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): |
| 70 | + def __call__( |
| 71 | + self, |
| 72 | + images: Optional[ImageInput] = None, |
| 73 | + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
| 74 | + audio=None, |
| 75 | + videos=None, |
| 76 | + **kwargs: Unpack[VisionTextDualEncoderProcessorKwargs], |
| 77 | + ) -> BatchEncoding: |
65 | 78 | """
|
66 | 79 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
67 | 80 | and `kwargs` arguments to VisionTextDualEncoderTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not
|
@@ -99,19 +112,28 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
|
99 | 112 | if text is None and images is None:
|
100 | 113 | raise ValueError("You have to specify either text or images. Both cannot be none.")
|
101 | 114 |
|
| 115 | + output_kwargs = self._merge_kwargs( |
| 116 | + VisionTextDualEncoderProcessorKwargs, |
| 117 | + tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| 118 | + **kwargs, |
| 119 | + ) |
| 120 | + |
102 | 121 | if text is not None:
|
103 |
| - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) |
| 122 | + encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
104 | 123 |
|
105 | 124 | if images is not None:
|
106 |
| - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) |
| 125 | + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) |
107 | 126 |
|
108 | 127 | if text is not None and images is not None:
|
109 | 128 | encoding["pixel_values"] = image_features.pixel_values
|
110 | 129 | return encoding
|
111 | 130 | elif text is not None:
|
112 | 131 | return encoding
|
113 | 132 | else:
|
114 |
| - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) |
| 133 | + return BatchEncoding( |
| 134 | + data=dict(**image_features), |
| 135 | + tensor_type=output_kwargs["common_kwargs"].get("return_tensors"), |
| 136 | + ) |
115 | 137 |
|
116 | 138 | def batch_decode(self, *args, **kwargs):
|
117 | 139 | """
|
|
0 commit comments