Skip to content

[Fast image processors] Improve handling of image-like inputs other than images (segmentation_maps) #39489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def filter_out_unused_kwargs(self, kwargs: dict):
def _prepare_images_structure(
self,
images: ImageInput,
expected_ndims: int = 3,
) -> ImageInput:
"""
Prepare the images structure for processing.
Expand All @@ -464,7 +465,7 @@ def _prepare_images_structure(
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
return make_flat_list_of_images(images, expected_ndims=expected_ndims)

def _process_image(
self,
Expand All @@ -486,6 +487,10 @@ def _process_image(
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
image = torch.from_numpy(image).contiguous()

# If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
if image.ndim == 2:
image = image.unsqueeze(0)

# Infer the channel dimension format if not provided
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
Expand All @@ -500,32 +505,35 @@ def _process_image(

return image

def _prepare_input_images(
def _prepare_image_like_inputs(
self,
images: ImageInput,
do_convert_rgb: Optional[bool] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
expected_ndims: int = 3,
) -> list["torch.Tensor"]:
"""
Prepare the input images for processing.
Prepare image-like inputs for processing.

Args:
images (`ImageInput`):
The input images to process.
The image-like inputs to process.
do_convert_rgb (`bool`, *optional*):
Whether to convert the images to RGB.
input_data_format (`str` or `ChannelDimension`, *optional*):
The input data format of the images.
device (`torch.device`, *optional*):
The device to put the processed images on.
expected_ndims (`int`, *optional*):
The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)

Returns:
List[`torch.Tensor`]: The processed images.
"""

# Get structured images (potentially nested)
images = self._prepare_images_structure(images)
images = self._prepare_images_structure(images, expected_ndims=expected_ndims)

process_image_partial = partial(
self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
Expand Down Expand Up @@ -627,10 +635,6 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)

# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
Expand All @@ -652,6 +656,28 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag
kwargs.pop("default_to_square")
kwargs.pop("data_format")

return self._preprocess_image_like_inputs(
images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, **kwargs
)

def _preprocess_image_like_inputs(
self,
images: ImageInput,
*args,
do_convert_rgb: bool,
input_data_format: ChannelDimension,
device: Optional[Union[str, "torch.device"]] = None,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> BatchFeature:
"""
Preprocess image-like inputs.
To be overriden by subclasses when image-like inputs other than images should be processed.
It can be used for segmentation maps, depth maps, etc.
"""
# Prepare input images
images = self._prepare_image_like_inputs(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
return self._preprocess(images, *args, **kwargs)

def _preprocess(
Expand Down
22 changes: 14 additions & 8 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,16 @@ def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:

def make_flat_list_of_images(
images: Union[list[ImageInput], ImageInput],
expected_ndims: int = 3,
) -> ImageInput:
"""
Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
If the input is a nested list of images, it is converted to a flat list of images.
Args:
images (`Union[list[ImageInput], ImageInput]`):
The input image.
expected_ndims (`int`, *optional*, defaults to 3):
The expected number of dimensions for a single input image.
Returns:
list: A list of images or a 4d array of images.
"""
Expand All @@ -232,28 +235,31 @@ def make_flat_list_of_images(
return [img for img_list in images for img in img_list]

if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
return images
if images[0].ndim == 4:
if images[0].ndim == expected_ndims + 1:
return [img for img_list in images for img in img_list]

if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
if is_pil_image(images) or images.ndim == expected_ndims:
return [images]
if images.ndim == 4:
if images.ndim == expected_ndims + 1:
return list(images)

raise ValueError(f"Could not make a flat list of images from {images}")


def make_nested_list_of_images(
images: Union[list[ImageInput], ImageInput],
expected_ndims: int = 3,
) -> ImageInput:
"""
Ensure that the output is a nested list of images.
Args:
images (`Union[list[ImageInput], ImageInput]`):
The input image.
expected_ndims (`int`, *optional*, defaults to 3):
The expected number of dimensions for a single input image.
Returns:
list: A list of list of images or a list of 4d array of images.
"""
Expand All @@ -267,16 +273,16 @@ def make_nested_list_of_images(

# If it's a list of images, it's a single batch, so convert it to a list of lists
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == 3:
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
return [images]
if images[0].ndim == 4:
if images[0].ndim == expected_ndims + 1:
return [list(image) for image in images]

# If it's a single image, convert it to a list of lists
if is_valid_image(images):
if is_pil_image(images) or images.ndim == 3:
if is_pil_image(images) or images.ndim == expected_ndims:
return [[images]]
if images.ndim == 4:
if images.ndim == expected_ndims + 1:
return [list(images)]

raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
Expand Down
155 changes: 54 additions & 101 deletions src/transformers/models/beit/image_processing_beit_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@
PILImageResampling,
SizeDict,
is_torch_tensor,
make_list_of_images,
pil_torch_interpolation_mapping,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring
Expand Down Expand Up @@ -91,6 +88,59 @@ def reduce_label(self, labels: list["torch.Tensor"]):

return label

@auto_docstring
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[BeitFastImageProcessorKwargs],
) -> BatchFeature:
r"""
segmentation_maps (`ImageInput`, *optional*):
The segmentation maps to preprocess.
"""
return super().preprocess(images, segmentation_maps, **kwargs)

def _preprocess_image_like_inputs(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput],
do_convert_rgb: bool,
input_data_format: ChannelDimension,
device: Optional[Union[str, "torch.device"]] = None,
**kwargs: Unpack[BeitFastImageProcessorKwargs],
) -> BatchFeature:
"""
Preprocess image-like inputs.
"""
images = self._prepare_image_like_inputs(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
images_kwargs = kwargs.copy()
images_kwargs["do_reduce_labels"] = False
batch_feature = self._preprocess(images, **images_kwargs)

if segmentation_maps is not None:
processed_segmentation_maps = self._prepare_image_like_inputs(
images=segmentation_maps,
expected_ndims=2,
do_convert_rgb=False,
input_data_format=ChannelDimension.FIRST,
)

segmentation_maps_kwargs = kwargs.copy()
segmentation_maps_kwargs["do_normalize"] = False
segmentation_maps_kwargs["do_rescale"] = False
segmentation_maps_kwargs["input_data_format"] = ChannelDimension.FIRST
processed_segmentation_maps = self._preprocess(
images=processed_segmentation_maps, **segmentation_maps_kwargs
)
processed_segmentation_maps = processed_segmentation_maps.pixel_values.squeeze(1)
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
batch_feature["labels"] = processed_segmentation_maps

return batch_feature

def _preprocess(
self,
images: list["torch.Tensor"],
Expand Down Expand Up @@ -136,105 +186,8 @@ def _preprocess(

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return processed_images

def _preprocess_images(
self,
images,
**kwargs,
):
"""Preprocesses images."""
kwargs["do_reduce_labels"] = False
processed_images = self._preprocess(images=images, **kwargs)
return processed_images

def _preprocess_segmentation_maps(
self,
segmentation_maps,
**kwargs,
):
"""Preprocesses segmentation maps."""
processed_segmentation_maps = []
for segmentation_map in segmentation_maps:
segmentation_map = self._process_image(
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
)

if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]

processed_segmentation_maps.append(segmentation_map)

kwargs["do_normalize"] = False
kwargs["do_rescale"] = False
kwargs["input_data_format"] = ChannelDimension.FIRST
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)

processed_segmentation_maps = processed_segmentation_maps.squeeze(1)

processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
return processed_segmentation_maps

@auto_docstring
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[BeitFastImageProcessorKwargs],
) -> BatchFeature:
r"""
segmentation_maps (`ImageInput`, *optional*):
The segmentation maps to preprocess.
"""
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))

# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)

# Prepare segmentation maps
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)

# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)

# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)

# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)

# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")

images = self._preprocess_images(
images=images,
**kwargs,
)
data = {"pixel_values": images}

if segmentation_maps is not None:
segmentation_maps = self._preprocess_segmentation_maps(
segmentation_maps=segmentation_maps,
**kwargs,
)
data["labels"] = segmentation_maps

return BatchFeature(data=data)
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
"""
Expand Down
Loading