|
49 | 49 | _CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base" |
50 | 50 | _EXPECTED_OUTPUT_SHAPE = [1, 49, 768] |
51 | 51 |
|
| 52 | +# Image classification docstring |
| 53 | +_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder" |
| 54 | +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" |
| 55 | + |
52 | 56 |
|
53 | 57 | @dataclass |
54 | 58 | # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin |
@@ -121,6 +125,43 @@ class DonutSwinModelOutput(ModelOutput): |
121 | 125 | reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
122 | 126 |
|
123 | 127 |
|
| 128 | +@dataclass |
| 129 | +# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin |
| 130 | +class DonutSwinImageClassifierOutput(ModelOutput): |
| 131 | + """ |
| 132 | + DonutSwin outputs for image classification. |
| 133 | +
|
| 134 | + Args: |
| 135 | + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| 136 | + Classification (or regression if config.num_labels==1) loss. |
| 137 | + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): |
| 138 | + Classification (or regression if config.num_labels==1) scores (before SoftMax). |
| 139 | + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| 140 | + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of |
| 141 | + shape `(batch_size, sequence_length, hidden_size)`. |
| 142 | +
|
| 143 | + Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| 144 | + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| 145 | + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, |
| 146 | + sequence_length)`. |
| 147 | +
|
| 148 | + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| 149 | + heads. |
| 150 | + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| 151 | + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of |
| 152 | + shape `(batch_size, hidden_size, height, width)`. |
| 153 | +
|
| 154 | + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to |
| 155 | + include the spatial dimensions. |
| 156 | + """ |
| 157 | + |
| 158 | + loss: Optional[torch.FloatTensor] = None |
| 159 | + logits: torch.FloatTensor = None |
| 160 | + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| 161 | + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
| 162 | + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| 163 | + |
| 164 | + |
124 | 165 | # Copied from transformers.models.swin.modeling_swin.window_partition |
125 | 166 | def window_partition(input_feature, window_size): |
126 | 167 | """ |
@@ -845,15 +886,15 @@ def forward( |
845 | 886 | ) |
846 | 887 |
|
847 | 888 |
|
848 | | -# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin |
| 889 | +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut |
849 | 890 | class DonutSwinPreTrainedModel(PreTrainedModel): |
850 | 891 | """ |
851 | 892 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
852 | 893 | models. |
853 | 894 | """ |
854 | 895 |
|
855 | 896 | config_class = DonutSwinConfig |
856 | | - base_model_prefix = "swin" |
| 897 | + base_model_prefix = "donut" |
857 | 898 | main_input_name = "pixel_values" |
858 | 899 | supports_gradient_checkpointing = True |
859 | 900 | _no_split_modules = ["DonutSwinStage"] |
@@ -1015,4 +1056,90 @@ def forward( |
1015 | 1056 | ) |
1016 | 1057 |
|
1017 | 1058 |
|
1018 | | -__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel"] |
| 1059 | +@add_start_docstrings( |
| 1060 | + """ |
| 1061 | + DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of |
| 1062 | + the [CLS] token) e.g. for ImageNet. |
| 1063 | +
|
| 1064 | + <Tip> |
| 1065 | +
|
| 1066 | + Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by |
| 1067 | + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained |
| 1068 | + position embeddings to the higher resolution. |
| 1069 | +
|
| 1070 | + </Tip> |
| 1071 | + """, |
| 1072 | + SWIN_START_DOCSTRING, |
| 1073 | +) |
| 1074 | +# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut |
| 1075 | +class DonutSwinForImageClassification(DonutSwinPreTrainedModel): |
| 1076 | + def __init__(self, config): |
| 1077 | + super().__init__(config) |
| 1078 | + |
| 1079 | + self.num_labels = config.num_labels |
| 1080 | + self.donut = DonutSwinModel(config) |
| 1081 | + |
| 1082 | + # Classifier head |
| 1083 | + self.classifier = ( |
| 1084 | + nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() |
| 1085 | + ) |
| 1086 | + |
| 1087 | + # Initialize weights and apply final processing |
| 1088 | + self.post_init() |
| 1089 | + |
| 1090 | + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) |
| 1091 | + @add_code_sample_docstrings( |
| 1092 | + checkpoint=_IMAGE_CLASS_CHECKPOINT, |
| 1093 | + output_type=DonutSwinImageClassifierOutput, |
| 1094 | + config_class=_CONFIG_FOR_DOC, |
| 1095 | + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, |
| 1096 | + ) |
| 1097 | + def forward( |
| 1098 | + self, |
| 1099 | + pixel_values: Optional[torch.FloatTensor] = None, |
| 1100 | + head_mask: Optional[torch.FloatTensor] = None, |
| 1101 | + labels: Optional[torch.LongTensor] = None, |
| 1102 | + output_attentions: Optional[bool] = None, |
| 1103 | + output_hidden_states: Optional[bool] = None, |
| 1104 | + interpolate_pos_encoding: bool = False, |
| 1105 | + return_dict: Optional[bool] = None, |
| 1106 | + ) -> Union[Tuple, DonutSwinImageClassifierOutput]: |
| 1107 | + r""" |
| 1108 | + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| 1109 | + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., |
| 1110 | + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| 1111 | + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| 1112 | + """ |
| 1113 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 1114 | + |
| 1115 | + outputs = self.donut( |
| 1116 | + pixel_values, |
| 1117 | + head_mask=head_mask, |
| 1118 | + output_attentions=output_attentions, |
| 1119 | + output_hidden_states=output_hidden_states, |
| 1120 | + interpolate_pos_encoding=interpolate_pos_encoding, |
| 1121 | + return_dict=return_dict, |
| 1122 | + ) |
| 1123 | + |
| 1124 | + pooled_output = outputs[1] |
| 1125 | + |
| 1126 | + logits = self.classifier(pooled_output) |
| 1127 | + |
| 1128 | + loss = None |
| 1129 | + if labels is not None: |
| 1130 | + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config) |
| 1131 | + |
| 1132 | + if not return_dict: |
| 1133 | + output = (logits,) + outputs[2:] |
| 1134 | + return ((loss,) + output) if loss is not None else output |
| 1135 | + |
| 1136 | + return DonutSwinImageClassifierOutput( |
| 1137 | + loss=loss, |
| 1138 | + logits=logits, |
| 1139 | + hidden_states=outputs.hidden_states, |
| 1140 | + attentions=outputs.attentions, |
| 1141 | + reshaped_hidden_states=outputs.reshaped_hidden_states, |
| 1142 | + ) |
| 1143 | + |
| 1144 | + |
| 1145 | +__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"] |
0 commit comments