diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 7ccf5ec96cec..58b8de9edbc6 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -298,6 +298,14 @@ def job_name(self): exotic_models_job = CircleCIJob( "exotic_models", docker_image=[{"image":"huggingface/transformers-exotic-models"}], + tests_to_run=[ + *glob.glob("tests/models/layoutlm*/*.py", recursive=True), + *glob.glob("tests/models/layoutxlm/*.py", recursive=True), + *glob.glob("tests/models/*nat/*.py", recursive=True), + *glob.glob("tests/models/deta/*.py", recursive=True), + *glob.glob("tests/models/udop/*.py", recursive=True), + *glob.glob("tests/models/nougat/*.py", recursive=True), + ], pytest_num_workers=12, parallelism=4, pytest_options={"durations": 100}, diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py index c47d58c30c01..f67b4caef070 100644 --- a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py @@ -127,7 +127,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor): def __init__( self, do_resize: bool = True, - size: Dict[str, int] = None, + size: Optional[Dict[str, int]] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, apply_ocr: bool = True, ocr_lang: Optional[str] = None, @@ -198,10 +198,10 @@ def resize( def preprocess( self, images: ImageInput, - do_resize: bool = None, - size: Dict[str, int] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, resample: PILImageResampling = None, - apply_ocr: bool = None, + apply_ocr: Optional[bool] = None, ocr_lang: Optional[str] = None, tesseract_config: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, diff --git a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py index 1edf87465bbf..6a9c851078dd 100644 --- a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py @@ -16,12 +16,51 @@ Processor class for LayoutLMv2. """ +import sys import warnings from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class LayoutLMv2TextKwargs(TextKwargs, total=False): + boxes: Optional[Union[List[List[int]], List[List[List[int]]]]] + word_labels: Optional[Union[List[int], List[List[int]]]] + + +class LayoutLMv2ImagesKwargs(ImagesKwargs, total=False): + apply_ocr: bool + ocr_lang: Optional[str] + tesseract_config: Optional[str] + + +class LayoutLMv2ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: LayoutLMv2TextKwargs + images_kwargs: LayoutLMv2ImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "apply_ocr": True, + }, + } class LayoutLMv2Processor(ProcessorMixin): @@ -47,6 +86,7 @@ class LayoutLMv2Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "LayoutLMv2ImageProcessor" tokenizer_class = ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast") + optional_call_args = ["text_pair", "boxes", "word_labels"] def __init__(self, image_processor=None, tokenizer=None, **kwargs): feature_extractor = None @@ -68,27 +108,16 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): def __call__( self, - images, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, - boxes: Union[List[List[int]], List[List[List[int]]]] = None, - word_labels: Optional[Union[List[int], List[List[int]]]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ) -> BatchEncoding: + images: ImageInput, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + # The following is to capture `text_pair`, `boxes`, `word_labels` arguments that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + # This behavior is only needed for backward compatibility and will be removed in future versions. + *args, + audio=None, + videos=None, + **kwargs: Unpack[LayoutLMv2ProcessorKwargs], + ) -> BatchFeature: """ This method first forwards the `images` argument to [`~LayoutLMv2ImageProcessor.__call__`]. In case [`LayoutLMv2ImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and @@ -98,59 +127,90 @@ def __call__( arguments to [`~LayoutLMv2Tokenizer.__call__`] and returns the output, together with resized `images``. Please refer to the docstring of the above two methods for more information. + + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **image** -- Pixel values to be fed to a model. + - **bbox** -- Bounding boxes of the words in the image. """ + output_kwargs = self._merge_kwargs( + LayoutLMv2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) + boxes = output_kwargs["text_kwargs"].pop("boxes", None) + word_labels = output_kwargs["text_kwargs"].pop("word_labels", None) + apply_ocr = output_kwargs["images_kwargs"].get("apply_ocr", self.image_processor.apply_ocr) + # verify input - if self.image_processor.apply_ocr and (boxes is not None): + if apply_ocr and (boxes is not None): raise ValueError( "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." ) - if self.image_processor.apply_ocr and (word_labels is not None): + if apply_ocr and (word_labels is not None): raise ValueError( "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." ) - if return_overflowing_tokens is True and return_offsets_mapping is False: + if ( + output_kwargs["text_kwargs"]["return_overflowing_tokens"] + and not output_kwargs["text_kwargs"]["return_offsets_mapping"] + ): raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") # first, apply the image processor - features = self.image_processor(images=images, return_tensors=return_tensors) + features = self.image_processor(images=images, **output_kwargs["images_kwargs"]) # second, apply the tokenizer - if text is not None and self.image_processor.apply_ocr and text_pair is None: + if text is not None and apply_ocr and text_pair is None: if isinstance(text, str): text = [text] # add batch dimension (as the image processor always adds a batch dimension) text_pair = features["words"] + if text is None: + if not hasattr(features, "words"): + raise ValueError("You need to provide `text` or set `apply_ocr` to `True`") + text = features["words"] + if boxes is None: + if not hasattr(features, "boxes"): + raise ValueError("You need to provide `boxes` or set `apply_ocr` to `True`") + boxes = features["boxes"] + encoded_inputs = self.tokenizer( - text=text if text is not None else features["words"], - text_pair=text_pair if text_pair is not None else None, - boxes=boxes if boxes is not None else features["boxes"], + text=text, + text_pair=text_pair, + boxes=boxes, word_labels=word_labels, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, + **output_kwargs["text_kwargs"], ) # add pixel values images = features.pop("pixel_values") - if return_overflowing_tokens is True: + if output_kwargs["text_kwargs"]["return_overflowing_tokens"] is True: images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"]) encoded_inputs["image"] = images - return encoded_inputs + return BatchFeature( + data=dict(**encoded_inputs), tensor_type=output_kwargs["common_kwargs"].get("return_tensors") + ) def get_overflowing_images(self, images, overflow_to_sample_mapping): # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py index 6f16435c14dd..d8059fd883d1 100644 --- a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py @@ -144,13 +144,13 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): def __init__( self, do_resize: bool = True, - size: Dict[str, int] = None, + size: Optional[Dict[str, int]] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, do_rescale: bool = True, rescale_value: float = 1 / 255, do_normalize: bool = True, - image_mean: Union[float, Iterable[float]] = None, - image_std: Union[float, Iterable[float]] = None, + image_mean: Optional[Union[float, Iterable[float]]] = None, + image_std: Optional[Union[float, Iterable[float]]] = None, apply_ocr: bool = True, ocr_lang: Optional[str] = None, tesseract_config: Optional[str] = "", @@ -225,15 +225,15 @@ def resize( def preprocess( self, images: ImageInput, - do_resize: bool = None, - size: Dict[str, int] = None, - resample=None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Union[float, Iterable[float]] = None, - image_std: Union[float, Iterable[float]] = None, - apply_ocr: bool = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional[PILImageResampling] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, Iterable[float]]] = None, + image_std: Optional[Union[float, Iterable[float]]] = None, + apply_ocr: Optional[bool] = None, ocr_lang: Optional[str] = None, tesseract_config: Optional[str] = None, return_tensors: Optional[Union[str, TensorType]] = None, @@ -251,7 +251,7 @@ def preprocess( Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Desired size of the output image after applying `resize`. - resample (`int`, *optional*, defaults to `self.resample`): + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` filters. Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): diff --git a/src/transformers/models/layoutlmv3/processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py index 369bd51bec28..ad15c1b76313 100644 --- a/src/transformers/models/layoutlmv3/processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py @@ -16,12 +16,51 @@ Processor class for LayoutLMv3. """ +import sys import warnings from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class LayoutLMv3TextKwargs(TextKwargs, total=False): + boxes: Optional[Union[List[List[int]], List[List[List[int]]]]] + word_labels: Optional[Union[List[int], List[List[int]]]] + + +class LayoutLMv3ImagesKwargs(ImagesKwargs, total=False): + apply_ocr: bool + ocr_lang: Optional[str] + tesseract_config: Optional[str] + + +class LayoutLMv3ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: LayoutLMv3TextKwargs + images_kwargs: LayoutLMv3ImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "apply_ocr": True, + }, + } class LayoutLMv3Processor(ProcessorMixin): @@ -47,6 +86,7 @@ class LayoutLMv3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "LayoutLMv3ImageProcessor" tokenizer_class = ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast") + optional_call_args = ["text_pair", "boxes", "word_labels"] def __init__(self, image_processor=None, tokenizer=None, **kwargs): feature_extractor = None @@ -68,27 +108,16 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): def __call__( self, - images, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, - boxes: Union[List[List[int]], List[List[List[int]]]] = None, - word_labels: Optional[Union[List[int], List[List[int]]]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ) -> BatchEncoding: + images: ImageInput, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + # The following is to capture `text_pair`, `boxes`, `word_labels` arguments that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + # This behavior is only needed for backward compatibility and will be removed in future versions. + *args, + audio=None, + videos=None, + **kwargs: Unpack[LayoutLMv3ProcessorKwargs], + ) -> BatchFeature: """ This method first forwards the `images` argument to [`~LayoutLMv3ImageProcessor.__call__`]. In case [`LayoutLMv3ImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and @@ -99,56 +128,84 @@ def __call__( resized and normalized `pixel_values`. Please refer to the docstring of the above two methods for more information. + + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **image** -- Pixel values to be fed to a model. + - **bbox** -- Bounding boxes of the words in the image. """ + output_kwargs = self._merge_kwargs( + LayoutLMv3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) + boxes = output_kwargs["text_kwargs"].pop("boxes", None) + word_labels = output_kwargs["text_kwargs"].pop("word_labels", None) + apply_ocr = output_kwargs["images_kwargs"].get("apply_ocr", self.image_processor.apply_ocr) + # verify input - if self.image_processor.apply_ocr and (boxes is not None): + if apply_ocr and (boxes is not None): raise ValueError( "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." ) - if self.image_processor.apply_ocr and (word_labels is not None): + if apply_ocr and (word_labels is not None): raise ValueError( "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." ) # first, apply the image processor - features = self.image_processor(images=images, return_tensors=return_tensors) + features = self.image_processor(images=images, **output_kwargs["images_kwargs"]) # second, apply the tokenizer - if text is not None and self.image_processor.apply_ocr and text_pair is None: + if text is not None and apply_ocr and text_pair is None: if isinstance(text, str): text = [text] # add batch dimension (as the image processor always adds a batch dimension) text_pair = features["words"] + if text is None: + if not hasattr(features, "words"): + raise ValueError("You need to provide `text` or set `apply_ocr` to `True`") + text = features["words"] + if boxes is None: + if not hasattr(features, "boxes"): + raise ValueError("You need to provide `boxes` or set `apply_ocr` to `True`") + boxes = features["boxes"] + encoded_inputs = self.tokenizer( - text=text if text is not None else features["words"], - text_pair=text_pair if text_pair is not None else None, - boxes=boxes if boxes is not None else features["boxes"], + text=text, + text_pair=text_pair, + boxes=boxes, word_labels=word_labels, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, + **output_kwargs["text_kwargs"], ) # add pixel values images = features.pop("pixel_values") - if return_overflowing_tokens is True: + if output_kwargs["text_kwargs"]["return_overflowing_tokens"]: images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"]) encoded_inputs["pixel_values"] = images - return encoded_inputs + return BatchFeature( + data=dict(**encoded_inputs), tensor_type=output_kwargs["common_kwargs"].get("return_tensors") + ) def get_overflowing_images(self, images, overflow_to_sample_mapping): # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image diff --git a/src/transformers/models/layoutxlm/processing_layoutxlm.py b/src/transformers/models/layoutxlm/processing_layoutxlm.py index 1cbd3f20c2fa..3f58c561cdee 100644 --- a/src/transformers/models/layoutxlm/processing_layoutxlm.py +++ b/src/transformers/models/layoutxlm/processing_layoutxlm.py @@ -16,12 +16,51 @@ Processor class for LayoutXLM. """ +import sys import warnings from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class LayoutXLMTextKwargs(TextKwargs, total=False): + boxes: Optional[Union[List[List[int]], List[List[List[int]]]]] + word_labels: Optional[Union[List[int], List[List[int]]]] + + +class LayoutXLMImagesKwargs(ImagesKwargs, total=False): + apply_ocr: bool + ocr_lang: Optional[str] + tesseract_config: Optional[str] + + +class LayoutXLMProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: LayoutXLMTextKwargs + images_kwargs: LayoutXLMImagesKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "apply_ocr": True, + }, + } class LayoutXLMProcessor(ProcessorMixin): @@ -47,6 +86,7 @@ class LayoutXLMProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "LayoutLMv2ImageProcessor" tokenizer_class = ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast") + optional_call_args = ["text_pair", "boxes", "word_labels"] def __init__(self, image_processor=None, tokenizer=None, **kwargs): if "feature_extractor" in kwargs: @@ -67,90 +107,110 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): def __call__( self, - images, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, - boxes: Union[List[List[int]], List[List[List[int]]]] = None, - word_labels: Optional[Union[List[int], List[List[int]]]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, - ) -> BatchEncoding: + images: ImageInput, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + # The following is to capture `text_pair`, `boxes`, `word_labels` arguments that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + # This behavior is only needed for backward compatibility and will be removed in future versions. + *args, + audio=None, + videos=None, + **kwargs: Unpack[LayoutXLMProcessorKwargs], + ) -> BatchFeature: """ This method first forwards the `images` argument to [`~LayoutLMv2ImagePrpcessor.__call__`]. In case - [`LayoutLMv2ImagePrpcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and + [`LayoutLMv2ImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and bounding boxes along with the additional arguments to [`~LayoutXLMTokenizer.__call__`] and returns the output, together with resized `images`. In case [`LayoutLMv2ImagePrpcessor`] was initialized with `apply_ocr` set to `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user along with the additional arguments to [`~LayoutXLMTokenizer.__call__`] and returns the output, together with resized `images``. Please refer to the docstring of the above two methods for more information. + + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **image** -- Pixel values to be fed to a model. + - **bbox** -- Bounding boxes of the words in the image. """ + output_kwargs = self._merge_kwargs( + LayoutXLMProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) + boxes = output_kwargs["text_kwargs"].pop("boxes", None) + word_labels = output_kwargs["text_kwargs"].pop("word_labels", None) + apply_ocr = output_kwargs["images_kwargs"].get("apply_ocr", self.image_processor.apply_ocr) + # verify input - if self.image_processor.apply_ocr and (boxes is not None): + if apply_ocr and (boxes is not None): raise ValueError( "You cannot provide bounding boxes " "if you initialized the image processor with apply_ocr set to True." ) - if self.image_processor.apply_ocr and (word_labels is not None): + if apply_ocr and (word_labels is not None): raise ValueError( "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." ) - if return_overflowing_tokens is True and return_offsets_mapping is False: + if ( + output_kwargs["text_kwargs"]["return_overflowing_tokens"] + and not output_kwargs["text_kwargs"]["return_offsets_mapping"] + ): raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") # first, apply the image processor - features = self.image_processor(images=images, return_tensors=return_tensors) + features = self.image_processor(images=images, **output_kwargs["images_kwargs"]) # second, apply the tokenizer - if text is not None and self.image_processor.apply_ocr and text_pair is None: + if text is not None and apply_ocr and text_pair is None: if isinstance(text, str): text = [text] # add batch dimension (as the image processor always adds a batch dimension) text_pair = features["words"] + if text is None: + if not hasattr(features, "words"): + raise ValueError("You need to provide `text` or set `apply_ocr` to `True`") + text = features["words"] + if boxes is None: + if not hasattr(features, "boxes"): + raise ValueError("You need to provide `boxes` or set `apply_ocr` to `True`") + boxes = features["boxes"] + encoded_inputs = self.tokenizer( - text=text if text is not None else features["words"], - text_pair=text_pair if text_pair is not None else None, - boxes=boxes if boxes is not None else features["boxes"], + text=text, + text_pair=text_pair, + boxes=boxes, word_labels=word_labels, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, + **output_kwargs["text_kwargs"], ) # add pixel values images = features.pop("pixel_values") - if return_overflowing_tokens is True: + if output_kwargs["text_kwargs"]["return_overflowing_tokens"]: images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"]) encoded_inputs["image"] = images - return encoded_inputs + return BatchFeature( + data=dict(**encoded_inputs), tensor_type=output_kwargs["common_kwargs"].get("return_tensors") + ) def get_overflowing_images(self, images, overflow_to_sample_mapping): # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image diff --git a/tests/models/layoutlmv2/test_processor_layoutlmv2.py b/tests/models/layoutlmv2/test_processor_layoutlmv2.py index a2676195ffd3..261106ecf193 100644 --- a/tests/models/layoutlmv2/test_processor_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_processor_layoutlmv2.py @@ -37,6 +37,7 @@ @require_pytesseract @require_tokenizers class LayoutLMv2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + images_input_name = "image" tokenizer_class = LayoutLMv2Tokenizer rust_tokenizer_class = LayoutLMv2TokenizerFast processor_class = LayoutLMv2Processor @@ -183,6 +184,155 @@ def preprocess_data(examples): self.assertEqual(len(train_data["image"]), len(train_data["input_ids"])) + def test_model_specific_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", apply_ocr=True) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + + image_input = self.prepare_image_inputs() + with self.assertRaises(ValueError): + # LayoutLMv2's processor expects `text` to be provided when `apply_ocr` is set to False + processor( + images=image_input, + return_tensors="pt", + apply_ocr=False, + ) + + def test_image_processor_defaults_preserved_by_image_kwargs(self): + """ + We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor. + We then check that the mean of the pixel_values is less than or equal to 0 after processing. + Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied. + """ + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", do_rescale=True, rescale_factor=-1 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", do_rescale=True, rescale_factor=1 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt") + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + max_length = 76 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + do_rescale=True, + rescale_factor=-1, + padding="max_length", + max_length=max_length, + ) + + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + do_rescale=True, + rescale_factor=-1, + padding="longest", + max_length=76, + ) + + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertTrue( + len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1]) + and len(inputs[self.text_input_name][1]) < 76 + ) + + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"do_rescale": True, "rescale_factor": -1}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"do_rescale": True, "rescale_factor": -1}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + # different use cases tests @require_torch diff --git a/tests/models/layoutlmv3/test_processor_layoutlmv3.py b/tests/models/layoutlmv3/test_processor_layoutlmv3.py index e55b19ea44b0..f5f895697ebc 100644 --- a/tests/models/layoutlmv3/test_processor_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_processor_layoutlmv3.py @@ -37,6 +37,7 @@ @require_pytesseract @require_tokenizers class LayoutLMv3ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "microsoft/layoutlmv3-base" tokenizer_class = LayoutLMv3Tokenizer rust_tokenizer_class = LayoutLMv3TokenizerFast processor_class = LayoutLMv3Processor @@ -87,6 +88,9 @@ def setUp(self): with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(image_processor_map) + "\n") + tokenizer = LayoutLMv3Tokenizer.from_pretrained(self.from_pretrained_id) + tokenizer.save_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) @@ -163,6 +167,23 @@ def test_model_input_names(self): self.assertListEqual(list(inputs.keys()), processor.model_input_names) + def test_model_specific_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", apply_ocr=True) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + + image_input = self.prepare_image_inputs() + with self.assertRaises(ValueError): + # LayoutLMv3's processor expects `text` to be provided when `apply_ocr` is set to False + processor( + images=image_input, + return_tensors="pt", + apply_ocr=False, + ) + # different use cases tests @require_torch diff --git a/tests/models/layoutxlm/test_processor_layoutxlm.py b/tests/models/layoutxlm/test_processor_layoutxlm.py index b970a3e52683..164f0830bba8 100644 --- a/tests/models/layoutxlm/test_processor_layoutxlm.py +++ b/tests/models/layoutxlm/test_processor_layoutxlm.py @@ -43,6 +43,7 @@ @require_sentencepiece @require_tokenizers class LayoutXLMProcessorTest(ProcessorTesterMixin, unittest.TestCase): + images_input_name = "image" tokenizer_class = LayoutXLMTokenizer rust_tokenizer_class = LayoutXLMTokenizerFast processor_class = LayoutXLMProcessor @@ -61,6 +62,8 @@ def setUp(self): # taken from `test_tokenization_layoutxlm.LayoutXLMTokenizationTest.test_save_pretrained` self.tokenizer_pretrained_name = "hf-internal-testing/tiny-random-layoutxlm" + tokenizer = LayoutXLMTokenizer.from_pretrained(self.tokenizer_pretrained_name) + tokenizer.save_pretrained(self.tmpdirname) tokenizer = self.get_tokenizer() image_processor = self.get_image_processor() @@ -182,6 +185,155 @@ def preprocess_data(examples): self.assertEqual(len(train_data["image"]), len(train_data["input_ids"])) + def test_model_specific_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", apply_ocr=True) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + + image_input = self.prepare_image_inputs() + with self.assertRaises(ValueError): + # LayoutXLM's processor expects `text` to be provided when `apply_ocr` is set to False + processor( + images=image_input, + return_tensors="pt", + apply_ocr=False, + ) + + def test_image_processor_defaults_preserved_by_image_kwargs(self): + """ + We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor. + We then check that the mean of the pixel_values is less than or equal to 0 after processing. + Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied. + """ + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", do_rescale=True, rescale_factor=-1 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", do_rescale=True, rescale_factor=1 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt") + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + max_length = 76 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + do_rescale=True, + rescale_factor=-1, + padding="max_length", + max_length=max_length, + ) + + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + do_rescale=True, + rescale_factor=-1, + padding="longest", + max_length=76, + ) + + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertTrue( + len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1]) + and len(inputs[self.text_input_name][1]) < 76 + ) + + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"do_rescale": True, "rescale_factor": -1}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"do_rescale": True, "rescale_factor": -1}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertLessEqual(inputs[self.images_input_name].shape[-1], 224) + self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + # different use cases tests @require_sentencepiece diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 8cc71147c220..715fa84e3075 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -218,6 +218,7 @@ def test_unstructured_kwargs(self): input_str = "lower newer" image_input = self.prepare_image_inputs() + max_length = 76 inputs = processor( text=input_str, images=image_input, @@ -225,7 +226,7 @@ def test_unstructured_kwargs(self): do_rescale=True, rescale_factor=-1, padding="max_length", - max_length=76, + max_length=max_length, ) self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)