Skip to content

Commit b184e46

Browse files
committed
uniformize kwargs for chameleon
1 parent dba8d08 commit b184e46

File tree

1 file changed

+39
-38
lines changed

1 file changed

+39
-38
lines changed

src/transformers/models/chameleon/processing_chameleon.py

+39-38
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,44 @@
1616
Processor class for Chameleon.
1717
"""
1818

19-
from typing import List, Optional, Union
19+
import sys
20+
from typing import List, Union
2021

2122
import numpy as np
2223

2324
from ...feature_extraction_utils import BatchFeature
2425
from ...image_utils import ImageInput
25-
from ...processing_utils import ProcessorMixin
26-
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
26+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
27+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
2728
from ...utils import TensorType, is_vision_available
2829

30+
if sys.version_info >= (3, 11):
31+
from typing import Unpack
32+
else:
33+
from typing_extensions import Unpack
2934

3035
if is_vision_available():
3136
import PIL
3237

3338

39+
class ChameleonTextKwargs(TextKwargs, total=False):
40+
return_for_text_completion: bool
41+
42+
43+
class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
44+
text_kwargs: ChameleonTextKwargs
45+
_defaults = {
46+
"text_kwargs": {
47+
"padding": False,
48+
"stride": 0,
49+
"return_for_text_completion": False,
50+
},
51+
"common_kwargs": {
52+
"return_tensors": TensorType.PYTORCH,
53+
},
54+
}
55+
56+
3457
class ChameleonProcessor(ProcessorMixin):
3558
r"""
3659
Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
@@ -65,11 +88,7 @@ def __call__(
6588
self,
6689
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
6790
images: ImageInput = None,
68-
padding: Union[bool, str, PaddingStrategy] = False,
69-
truncation: Union[bool, str, TruncationStrategy] = None,
70-
max_length: int = None,
71-
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
72-
return_for_text_completion: bool = False,
91+
**kwargs: Unpack[ChameleonProcessorKwargs],
7392
) -> BatchFeature:
7493
"""
7594
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@@ -86,26 +105,6 @@ def __call__(
86105
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
87106
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
88107
tensor. Both channels-first and channels-last formats are supported.
89-
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
90-
Select a strategy to pad the returned sequences (according to the model's padding side and padding
91-
index) among:
92-
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
93-
sequence if provided).
94-
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
95-
acceptable input length for the model if that argument is not provided.
96-
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
97-
lengths).
98-
max_length (`int`, *optional*):
99-
Maximum length of the returned list and optionally padding length (see above).
100-
truncation (`bool`, *optional*):
101-
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
102-
return_tensors (`str` or [`~utils.TensorType`], *optional*):
103-
If set, will return tensors of a particular framework. Acceptable values are:
104-
105-
- `'tf'`: Return TensorFlow `tf.constant` objects.
106-
- `'pt'`: Return PyTorch `torch.Tensor` objects.
107-
- `'np'`: Return NumPy `np.ndarray` objects.
108-
- `'jax'`: Return JAX `jnp.ndarray` objects.
109108
110109
Returns:
111110
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
@@ -120,6 +119,15 @@ def __call__(
120119
text = [text]
121120
elif not isinstance(text, list) and not isinstance(text[0], str):
122121
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
122+
if text is None and images is None:
123+
raise ValueError("You must provide either text or images")
124+
125+
output_kwargs = self._merge_kwargs(
126+
ChameleonProcessorKwargs,
127+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
128+
**kwargs,
129+
)
130+
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
123131

124132
# Replace the image token with the expanded image token sequence
125133
prompt_strings = []
@@ -130,19 +138,12 @@ def __call__(
130138
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
131139
prompt_strings.append(sample)
132140

133-
data = self.tokenizer(
134-
prompt_strings,
135-
return_tensors=return_tensors,
136-
padding=padding,
137-
truncation=truncation,
138-
max_length=max_length,
139-
)
141+
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
140142

141143
if images is not None:
142-
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
143-
data["pixel_values"] = pixel_values
144+
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
144145

145-
return BatchFeature(data=data, tensor_type=return_tensors)
146+
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])
146147

147148
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
148149
def batch_decode(self, *args, **kwargs):

0 commit comments

Comments
 (0)