Skip to content

Uniformize OwlViT and Owlv2 processors #35700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 13, 2025
Merged
110 changes: 76 additions & 34 deletions src/transformers/models/owlv2/processing_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,40 @@

import numpy as np

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available


if TYPE_CHECKING:
from .modeling_owlv2 import Owlv2ImageGuidedObjectDetectionOutput, Owlv2ObjectDetectionOutput


class Owlv2ImagesKwargs(ImagesKwargs, total=False):
query_images: Optional[ImageInput]


class Owlv2ProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Owlv2ImagesKwargs
_defaults = {
"text_kwargs": {
"padding": "max_length",
},
"images_kwargs": {},
"common_kwargs": {
"return_tensors": "np",
},
}


class Owlv2Processor(ProcessorMixin):
r"""
Constructs an Owlv2 processor which wraps [`Owlv2ImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into
Expand All @@ -46,12 +71,27 @@ class Owlv2Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "Owlv2ImageProcessor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = ["query_images"]

def __init__(self, image_processor, tokenizer, **kwargs):
super().__init__(image_processor, tokenizer)

# Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.__call__ with OwlViT->Owlv2
def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs):
def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
# The following is to capture `query_images` argument that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
#
*args,
audio=None,
videos=None,
**kwargs: Unpack[Owlv2ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
Expand All @@ -60,14 +100,14 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt
of the above two methods for more information.

Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
`List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The query image to be prepared, one query image is expected per target image to be queried. Each image
can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image
Expand All @@ -78,36 +118,49 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Owlv2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
query_images = output_kwargs["images_kwargs"].pop("query_images", None)
return_tensors = output_kwargs["common_kwargs"]["return_tensors"]

if text is None and query_images is None and images is None:
raise ValueError(
"You have to specify at least one text or query image or image. All three cannot be none."
)
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

data = {}
if text is not None:
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])]

elif isinstance(text, List) and isinstance(text[0], List):
encodings = []

# Maximum number of queries across batch
max_num_queries = max([len(t) for t in text])
max_num_queries = max([len(text_single) for text_single in text])

# Pad all batch samples to max number of text queries
for t in text:
if len(t) != max_num_queries:
t = t + [" "] * (max_num_queries - len(t))
for text_single in text:
if len(text_single) != max_num_queries:
text_single = text_single + [" "] * (max_num_queries - len(text_single))

encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"])
encodings.append(encoding)
else:
raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
Expand Down Expand Up @@ -137,30 +190,19 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt
else:
raise ValueError("Target return tensor type could not be returned")

encoding = BatchEncoding()
encoding["input_ids"] = input_ids
encoding["attention_mask"] = attention_mask
data["input_ids"] = input_ids
data["attention_mask"] = attention_mask

if query_images is not None:
encoding = BatchEncoding()
query_pixel_values = self.image_processor(
query_images, return_tensors=return_tensors, **kwargs
).pixel_values
encoding["query_pixel_values"] = query_pixel_values
query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values
# Query images always override the text prompt
data = {"query_pixel_values": query_pixel_values}

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif query_images is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None or query_images is not None:
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data["pixel_values"] = image_features.pixel_values

return BatchFeature(data=data, tensor_type=return_tensors)

# Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_object_detection with OwlViT->Owlv2
def post_process_object_detection(self, *args, **kwargs):
Expand Down
110 changes: 76 additions & 34 deletions src/transformers/models/owlvit/processing_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,40 @@

import numpy as np

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available


if TYPE_CHECKING:
from .modeling_owlvit import OwlViTImageGuidedObjectDetectionOutput, OwlViTObjectDetectionOutput


class OwlViTImagesKwargs(ImagesKwargs, total=False):
query_images: Optional[ImageInput]


class OwlViTProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: OwlViTImagesKwargs
_defaults = {
"text_kwargs": {
"padding": "max_length",
},
"images_kwargs": {},
"common_kwargs": {
"return_tensors": "np",
},
}


class OwlViTProcessor(ProcessorMixin):
r"""
Constructs an OWL-ViT processor which wraps [`OwlViTImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`]
Expand All @@ -46,6 +71,8 @@ class OwlViTProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "OwlViTImageProcessor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = ["query_images"]

def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
Expand All @@ -65,7 +92,20 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs):
def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
# The following is to capture `query_images` argument that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
#
*args,
audio=None,
videos=None,
**kwargs: Unpack[OwlViTProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
Expand All @@ -74,14 +114,14 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt
of the above two methods for more information.

Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
`List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The query image to be prepared, one query image is expected per target image to be queried. Each image
can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image
Expand All @@ -92,36 +132,49 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **query_pixel_values** -- Pixel values of the query images to be fed to a model. Returned when `query_images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
OwlViTProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
query_images = output_kwargs["images_kwargs"].pop("query_images", None)
return_tensors = output_kwargs["common_kwargs"]["return_tensors"]

if text is None and query_images is None and images is None:
raise ValueError(
"You have to specify at least one text or query image or image. All three cannot be none."
)
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

data = {}
if text is not None:
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
encodings = [self.tokenizer(text, **output_kwargs["text_kwargs"])]

elif isinstance(text, List) and isinstance(text[0], List):
encodings = []

# Maximum number of queries across batch
max_num_queries = max([len(t) for t in text])
max_num_queries = max([len(text_single) for text_single in text])

# Pad all batch samples to max number of text queries
for t in text:
if len(t) != max_num_queries:
t = t + [" "] * (max_num_queries - len(t))
for text_single in text:
if len(text_single) != max_num_queries:
text_single = text_single + [" "] * (max_num_queries - len(text_single))

encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
encoding = self.tokenizer(text_single, **output_kwargs["text_kwargs"])
encodings.append(encoding)
else:
raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
Expand Down Expand Up @@ -151,30 +204,19 @@ def __call__(self, text=None, images=None, query_images=None, padding="max_lengt
else:
raise ValueError("Target return tensor type could not be returned")

encoding = BatchEncoding()
encoding["input_ids"] = input_ids
encoding["attention_mask"] = attention_mask
data["input_ids"] = input_ids
data["attention_mask"] = attention_mask

if query_images is not None:
encoding = BatchEncoding()
query_pixel_values = self.image_processor(
query_images, return_tensors=return_tensors, **kwargs
).pixel_values
encoding["query_pixel_values"] = query_pixel_values
query_pixel_values = self.image_processor(query_images, **output_kwargs["images_kwargs"]).pixel_values
# Query images always override the text prompt
data = {"query_pixel_values": query_pixel_values}

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif query_images is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None or query_images is not None:
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data["pixel_values"] = image_features.pixel_values

return BatchFeature(data=data, tensor_type=return_tensors)

def post_process(self, *args, **kwargs):
"""
Expand Down
Loading