Skip to content

Commit e10be82

Browse files
authored
uniformize kwargs for SAM (#34578)
* Make kwargs uniform for SAM * Remove unused attribute * Make point_pad_value part of image_kwargs * Update annotations * Code review - use existing methods * Use ProcessorTesterMixin * Do not add ProcessorTesterMixin everywhere
1 parent 2bb6098 commit e10be82

File tree

2 files changed

+81
-29
lines changed

2 files changed

+81
-29
lines changed

src/transformers/models/sam/processing_sam.py

+60-20
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
"""
1818

1919
from copy import deepcopy
20-
from typing import Optional, Union
20+
from typing import List, Optional, Union
2121

2222
import numpy as np
2323

24-
from ...processing_utils import ProcessorMixin
25-
from ...tokenization_utils_base import BatchEncoding
26-
from ...utils import TensorType, is_tf_available, is_torch_available
24+
from ...image_utils import ImageInput, VideoInput
25+
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
26+
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
27+
from ...utils import is_tf_available, is_torch_available
2728

2829

2930
if is_torch_available():
@@ -33,6 +34,23 @@
3334
import tensorflow as tf
3435

3536

37+
class SamImagesKwargs(ImagesKwargs):
38+
segmentation_maps: Optional[ImageInput]
39+
input_points: Optional[List[List[float]]]
40+
input_labels: Optional[List[List[int]]]
41+
input_boxes: Optional[List[List[List[float]]]]
42+
point_pad_value: Optional[int]
43+
44+
45+
class SamProcessorKwargs(ProcessingKwargs, total=False):
46+
images_kwargs: SamImagesKwargs
47+
_defaults = {
48+
"images_kwargs": {
49+
"point_pad_value": -10,
50+
}
51+
}
52+
53+
3654
class SamProcessor(ProcessorMixin):
3755
r"""
3856
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
@@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin):
4866

4967
attributes = ["image_processor"]
5068
image_processor_class = "SamImageProcessor"
69+
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
70+
optional_call_args = [
71+
"segmentation_maps",
72+
"input_points",
73+
"input_labels",
74+
"input_boxes",
75+
]
5176

5277
def __init__(self, image_processor):
5378
super().__init__(image_processor)
54-
self.current_processor = self.image_processor
55-
self.point_pad_value = -10
5679
self.target_size = self.image_processor.size["longest_edge"]
5780

5881
def __call__(
5982
self,
60-
images=None,
61-
segmentation_maps=None,
62-
input_points=None,
63-
input_labels=None,
64-
input_boxes=None,
65-
return_tensors: Optional[Union[str, TensorType]] = None,
83+
images: Optional[ImageInput] = None,
84+
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
85+
# arguments that may be passed as a positional argument.
86+
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
87+
# or this conversation for more context:
88+
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
89+
# This behavior is only needed for backward compatibility and will be removed in future versions.
90+
*args, # to be deprecated
91+
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
92+
audio: Optional[AudioInput] = None,
93+
video: Optional[VideoInput] = None,
6694
**kwargs,
6795
) -> BatchEncoding:
6896
"""
6997
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
7098
points and bounding boxes for the model if they are provided.
7199
"""
100+
output_kwargs = self._merge_kwargs(
101+
SamProcessorKwargs,
102+
tokenizer_init_kwargs={},
103+
**kwargs,
104+
**self.prepare_and_validate_optional_call_args(*args),
105+
)
106+
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
107+
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
108+
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
109+
72110
encoding_image_processor = self.image_processor(
73111
images,
74-
segmentation_maps=segmentation_maps,
75-
return_tensors=return_tensors,
76-
**kwargs,
112+
**output_kwargs["images_kwargs"],
77113
)
78114

79115
# pop arguments that are not used in the foward but used nevertheless
@@ -94,7 +130,8 @@ def __call__(
94130
input_points=input_points,
95131
input_labels=input_labels,
96132
input_boxes=input_boxes,
97-
return_tensors=return_tensors,
133+
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
134+
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"),
98135
)
99136

100137
return encoding_image_processor
@@ -107,6 +144,7 @@ def _normalize_and_convert(
107144
input_labels=None,
108145
input_boxes=None,
109146
return_tensors="pt",
147+
point_pad_value=-10,
110148
):
111149
if input_points is not None:
112150
if len(original_sizes) != len(input_points):
@@ -121,7 +159,9 @@ def _normalize_and_convert(
121159
# check that all arrays have the same shape
122160
if not all(point.shape == input_points[0].shape for point in input_points):
123161
if input_labels is not None:
124-
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels)
162+
input_points, input_labels = self._pad_points_and_labels(
163+
input_points, input_labels, point_pad_value
164+
)
125165

126166
input_points = np.array(input_points)
127167

@@ -174,7 +214,7 @@ def _normalize_and_convert(
174214

175215
return encoding_image_processor
176216

177-
def _pad_points_and_labels(self, input_points, input_labels):
217+
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
178218
r"""
179219
The method pads the 2D points and labels to the maximum number of points in the batch.
180220
"""
@@ -183,9 +223,9 @@ def _pad_points_and_labels(self, input_points, input_labels):
183223
for i, point in enumerate(input_points):
184224
if point.shape[0] != expected_nb_points:
185225
point = np.concatenate(
186-
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
226+
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
187227
)
188-
input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
228+
input_labels[i] = np.append(input_labels[i], [point_pad_value])
189229
processed_input_points.append(point)
190230
input_points = processed_input_points
191231
return input_points, input_labels

tests/models/sam/test_processor_sam.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
2828

29-
from ...test_processing_common import prepare_image_inputs
29+
from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs
3030

3131

3232
if is_vision_available():
@@ -43,7 +43,9 @@
4343

4444
@require_vision
4545
@require_torchvision
46-
class SamProcessorTest(unittest.TestCase):
46+
class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase):
47+
processor_class = SamProcessor
48+
4749
def setUp(self):
4850
self.tmpdirname = tempfile.mkdtemp()
4951
image_processor = SamImageProcessor()
@@ -56,11 +58,6 @@ def get_image_processor(self, **kwargs):
5658
def tearDown(self):
5759
shutil.rmtree(self.tmpdirname)
5860

59-
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
60-
def prepare_image_inputs(self):
61-
"""This function prepares a list of PIL images."""
62-
return prepare_image_inputs()
63-
6461
def prepare_mask_inputs(self):
6562
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
6663
or a list of PyTorch tensors if one specifies torchify=True.
@@ -69,6 +66,21 @@ def prepare_mask_inputs(self):
6966
mask_inputs = [Image.fromarray(x) for x in mask_inputs]
7067
return mask_inputs
7168

69+
def test_chat_template_save_loading(self):
70+
self.skipTest("SamProcessor does not have a tokenizer")
71+
72+
def test_image_processor_defaults_preserved_by_image_kwargs(self):
73+
self.skipTest("SamProcessor does not have a tokenizer")
74+
75+
def test_kwargs_overrides_default_image_processor_kwargs(self):
76+
self.skipTest("SamProcessor does not have a tokenizer")
77+
78+
def test_kwargs_overrides_default_tokenizer_kwargs(self):
79+
self.skipTest("SamProcessor does not have a tokenizer")
80+
81+
def test_tokenizer_defaults_preserved_by_kwargs(self):
82+
self.skipTest("SamProcessor does not have a tokenizer")
83+
7284
def test_save_load_pretrained_additional_features(self):
7385
processor = SamProcessor(image_processor=self.get_image_processor())
7486
processor.save_pretrained(self.tmpdirname)
@@ -165,7 +177,7 @@ def get_image_processor(self, **kwargs):
165177
def tearDown(self):
166178
shutil.rmtree(self.tmpdirname)
167179

168-
# Processor tester class can't use ProcessorTesterMixin as processor is atypical e.g. only contains an image processor and it assumes torch
180+
# This is to avoid repeating the skipping of the common tests
169181
def prepare_image_inputs(self):
170182
"""This function prepares a list of PIL images."""
171183
return prepare_image_inputs()
@@ -248,7 +260,7 @@ def get_image_processor(self, **kwargs):
248260
def tearDown(self):
249261
shutil.rmtree(self.tmpdirname)
250262

251-
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
263+
# This is to avoid repeating the skipping of the common tests
252264
def prepare_image_inputs(self):
253265
"""This function prepares a list of PIL images."""
254266
return prepare_image_inputs()

0 commit comments

Comments
 (0)