17
17
"""
18
18
19
19
from copy import deepcopy
20
- from typing import Optional , Union
20
+ from typing import List , Optional , Union
21
21
22
22
import numpy as np
23
23
24
- from ...processing_utils import ProcessorMixin
25
- from ...tokenization_utils_base import BatchEncoding
26
- from ...utils import TensorType , is_tf_available , is_torch_available
24
+ from ...image_utils import ImageInput , VideoInput
25
+ from ...processing_utils import ImagesKwargs , ProcessingKwargs , ProcessorMixin
26
+ from ...tokenization_utils_base import AudioInput , BatchEncoding , PreTokenizedInput , TextInput
27
+ from ...utils import is_tf_available , is_torch_available
27
28
28
29
29
30
if is_torch_available ():
33
34
import tensorflow as tf
34
35
35
36
37
+ class SamImagesKwargs (ImagesKwargs ):
38
+ segmentation_maps : Optional [ImageInput ]
39
+ input_points : Optional [List [List [float ]]]
40
+ input_labels : Optional [List [List [int ]]]
41
+ input_boxes : Optional [List [List [List [float ]]]]
42
+ point_pad_value : Optional [int ]
43
+
44
+
45
+ class SamProcessorKwargs (ProcessingKwargs , total = False ):
46
+ images_kwargs : SamImagesKwargs
47
+ _defaults = {
48
+ "images_kwargs" : {
49
+ "point_pad_value" : - 10 ,
50
+ }
51
+ }
52
+
53
+
36
54
class SamProcessor (ProcessorMixin ):
37
55
r"""
38
56
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
@@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin):
48
66
49
67
attributes = ["image_processor" ]
50
68
image_processor_class = "SamImageProcessor"
69
+ # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
70
+ optional_call_args = [
71
+ "segmentation_maps" ,
72
+ "input_points" ,
73
+ "input_labels" ,
74
+ "input_boxes" ,
75
+ ]
51
76
52
77
def __init__ (self , image_processor ):
53
78
super ().__init__ (image_processor )
54
- self .current_processor = self .image_processor
55
- self .point_pad_value = - 10
56
79
self .target_size = self .image_processor .size ["longest_edge" ]
57
80
58
81
def __call__ (
59
82
self ,
60
- images = None ,
61
- segmentation_maps = None ,
62
- input_points = None ,
63
- input_labels = None ,
64
- input_boxes = None ,
65
- return_tensors : Optional [Union [str , TensorType ]] = None ,
83
+ images : Optional [ImageInput ] = None ,
84
+ # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
85
+ # arguments that may be passed as a positional argument.
86
+ # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
87
+ # or this conversation for more context:
88
+ # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
89
+ # This behavior is only needed for backward compatibility and will be removed in future versions.
90
+ * args , # to be deprecated
91
+ text : Optional [Union [TextInput , PreTokenizedInput , List [TextInput ], List [PreTokenizedInput ]]] = None ,
92
+ audio : Optional [AudioInput ] = None ,
93
+ video : Optional [VideoInput ] = None ,
66
94
** kwargs ,
67
95
) -> BatchEncoding :
68
96
"""
69
97
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
70
98
points and bounding boxes for the model if they are provided.
71
99
"""
100
+ output_kwargs = self ._merge_kwargs (
101
+ SamProcessorKwargs ,
102
+ tokenizer_init_kwargs = {},
103
+ ** kwargs ,
104
+ ** self .prepare_and_validate_optional_call_args (* args ),
105
+ )
106
+ input_points = output_kwargs ["images_kwargs" ].pop ("input_points" , None )
107
+ input_labels = output_kwargs ["images_kwargs" ].pop ("input_labels" , None )
108
+ input_boxes = output_kwargs ["images_kwargs" ].pop ("input_boxes" , None )
109
+
72
110
encoding_image_processor = self .image_processor (
73
111
images ,
74
- segmentation_maps = segmentation_maps ,
75
- return_tensors = return_tensors ,
76
- ** kwargs ,
112
+ ** output_kwargs ["images_kwargs" ],
77
113
)
78
114
79
115
# pop arguments that are not used in the foward but used nevertheless
@@ -94,7 +130,8 @@ def __call__(
94
130
input_points = input_points ,
95
131
input_labels = input_labels ,
96
132
input_boxes = input_boxes ,
97
- return_tensors = return_tensors ,
133
+ return_tensors = output_kwargs ["common_kwargs" ].get ("return_tensors" ),
134
+ point_pad_value = output_kwargs ["images_kwargs" ].get ("point_pad_value" ),
98
135
)
99
136
100
137
return encoding_image_processor
@@ -107,6 +144,7 @@ def _normalize_and_convert(
107
144
input_labels = None ,
108
145
input_boxes = None ,
109
146
return_tensors = "pt" ,
147
+ point_pad_value = - 10 ,
110
148
):
111
149
if input_points is not None :
112
150
if len (original_sizes ) != len (input_points ):
@@ -121,7 +159,9 @@ def _normalize_and_convert(
121
159
# check that all arrays have the same shape
122
160
if not all (point .shape == input_points [0 ].shape for point in input_points ):
123
161
if input_labels is not None :
124
- input_points , input_labels = self ._pad_points_and_labels (input_points , input_labels )
162
+ input_points , input_labels = self ._pad_points_and_labels (
163
+ input_points , input_labels , point_pad_value
164
+ )
125
165
126
166
input_points = np .array (input_points )
127
167
@@ -174,7 +214,7 @@ def _normalize_and_convert(
174
214
175
215
return encoding_image_processor
176
216
177
- def _pad_points_and_labels (self , input_points , input_labels ):
217
+ def _pad_points_and_labels (self , input_points , input_labels , point_pad_value ):
178
218
r"""
179
219
The method pads the 2D points and labels to the maximum number of points in the batch.
180
220
"""
@@ -183,9 +223,9 @@ def _pad_points_and_labels(self, input_points, input_labels):
183
223
for i , point in enumerate (input_points ):
184
224
if point .shape [0 ] != expected_nb_points :
185
225
point = np .concatenate (
186
- [point , np .zeros ((expected_nb_points - point .shape [0 ], 2 )) + self . point_pad_value ], axis = 0
226
+ [point , np .zeros ((expected_nb_points - point .shape [0 ], 2 )) + point_pad_value ], axis = 0
187
227
)
188
- input_labels [i ] = np .append (input_labels [i ], [self . point_pad_value ])
228
+ input_labels [i ] = np .append (input_labels [i ], [point_pad_value ])
189
229
processed_input_points .append (point )
190
230
input_points = processed_input_points
191
231
return input_points , input_labels
0 commit comments