Skip to content

Commit 7ecc5b8

Browse files
authored
Add image classifier donut & update loss calculation for all swins (#37224)
* add classifier head to donut * add to transformers __init__ * add to auto model * fix typo * add loss for image classification * add checkpoint * remove no needed import * reoder import * format * consistency * add test of classifier * add doc * try ignore * update loss for all swin models
1 parent 5ae9b2c commit 7ecc5b8

File tree

9 files changed

+177
-48
lines changed

9 files changed

+177
-48
lines changed

docs/source/en/model_doc/donut.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,8 @@ print(answer)
226226

227227
[[autodoc]] DonutSwinModel
228228
- forward
229+
230+
## DonutSwinForImageClassification
231+
232+
[[autodoc]] transformers.DonutSwinForImageClassification
233+
- forward

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,7 @@
23042304
)
23052305
_import_structure["models.donut"].extend(
23062306
[
2307+
"DonutSwinForImageClassification",
23072308
"DonutSwinModel",
23082309
"DonutSwinPreTrainedModel",
23092310
]
@@ -7457,6 +7458,7 @@
74577458
DistilBertPreTrainedModel,
74587459
)
74597460
from .models.donut import (
7461+
DonutSwinForImageClassification,
74607462
DonutSwinModel,
74617463
DonutSwinPreTrainedModel,
74627464
)

src/transformers/loss/loss_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
145145
"ForMaskedLM": ForMaskedLMLoss,
146146
"ForQuestionAnswering": ForQuestionAnsweringLoss,
147147
"ForSequenceClassification": ForSequenceClassificationLoss,
148+
"ForImageClassification": ForSequenceClassificationLoss,
148149
"ForTokenClassification": ForTokenClassification,
149150
"ForSegmentation": ForSegmentationLoss,
150151
"ForObjectDetection": ForObjectDetectionLoss,

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,7 @@
707707
("dinat", "DinatForImageClassification"),
708708
("dinov2", "Dinov2ForImageClassification"),
709709
("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
710+
("donut-swin", "DonutSwinForImageClassification"),
710711
(
711712
"efficientformer",
712713
(

src/transformers/models/donut/modeling_donut_swin.py

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
5050
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
5151

52+
# Image classification docstring
53+
_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder"
54+
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
55+
5256

5357
@dataclass
5458
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
@@ -121,6 +125,43 @@ class DonutSwinModelOutput(ModelOutput):
121125
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
122126

123127

128+
@dataclass
129+
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin
130+
class DonutSwinImageClassifierOutput(ModelOutput):
131+
"""
132+
DonutSwin outputs for image classification.
133+
134+
Args:
135+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
136+
Classification (or regression if config.num_labels==1) loss.
137+
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
138+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
139+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
140+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
141+
shape `(batch_size, sequence_length, hidden_size)`.
142+
143+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
144+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
145+
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
146+
sequence_length)`.
147+
148+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
149+
heads.
150+
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
151+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
152+
shape `(batch_size, hidden_size, height, width)`.
153+
154+
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
155+
include the spatial dimensions.
156+
"""
157+
158+
loss: Optional[torch.FloatTensor] = None
159+
logits: torch.FloatTensor = None
160+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
161+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
162+
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
163+
164+
124165
# Copied from transformers.models.swin.modeling_swin.window_partition
125166
def window_partition(input_feature, window_size):
126167
"""
@@ -845,15 +886,15 @@ def forward(
845886
)
846887

847888

848-
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
889+
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut
849890
class DonutSwinPreTrainedModel(PreTrainedModel):
850891
"""
851892
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
852893
models.
853894
"""
854895

855896
config_class = DonutSwinConfig
856-
base_model_prefix = "swin"
897+
base_model_prefix = "donut"
857898
main_input_name = "pixel_values"
858899
supports_gradient_checkpointing = True
859900
_no_split_modules = ["DonutSwinStage"]
@@ -1015,4 +1056,90 @@ def forward(
10151056
)
10161057

10171058

1018-
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel"]
1059+
@add_start_docstrings(
1060+
"""
1061+
DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
1062+
the [CLS] token) e.g. for ImageNet.
1063+
1064+
<Tip>
1065+
1066+
Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by
1067+
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
1068+
position embeddings to the higher resolution.
1069+
1070+
</Tip>
1071+
""",
1072+
SWIN_START_DOCSTRING,
1073+
)
1074+
# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut
1075+
class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
1076+
def __init__(self, config):
1077+
super().__init__(config)
1078+
1079+
self.num_labels = config.num_labels
1080+
self.donut = DonutSwinModel(config)
1081+
1082+
# Classifier head
1083+
self.classifier = (
1084+
nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
1085+
)
1086+
1087+
# Initialize weights and apply final processing
1088+
self.post_init()
1089+
1090+
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
1091+
@add_code_sample_docstrings(
1092+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
1093+
output_type=DonutSwinImageClassifierOutput,
1094+
config_class=_CONFIG_FOR_DOC,
1095+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1096+
)
1097+
def forward(
1098+
self,
1099+
pixel_values: Optional[torch.FloatTensor] = None,
1100+
head_mask: Optional[torch.FloatTensor] = None,
1101+
labels: Optional[torch.LongTensor] = None,
1102+
output_attentions: Optional[bool] = None,
1103+
output_hidden_states: Optional[bool] = None,
1104+
interpolate_pos_encoding: bool = False,
1105+
return_dict: Optional[bool] = None,
1106+
) -> Union[Tuple, DonutSwinImageClassifierOutput]:
1107+
r"""
1108+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1109+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1110+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1111+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1112+
"""
1113+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1114+
1115+
outputs = self.donut(
1116+
pixel_values,
1117+
head_mask=head_mask,
1118+
output_attentions=output_attentions,
1119+
output_hidden_states=output_hidden_states,
1120+
interpolate_pos_encoding=interpolate_pos_encoding,
1121+
return_dict=return_dict,
1122+
)
1123+
1124+
pooled_output = outputs[1]
1125+
1126+
logits = self.classifier(pooled_output)
1127+
1128+
loss = None
1129+
if labels is not None:
1130+
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
1131+
1132+
if not return_dict:
1133+
output = (logits,) + outputs[2:]
1134+
return ((loss,) + output) if loss is not None else output
1135+
1136+
return DonutSwinImageClassifierOutput(
1137+
loss=loss,
1138+
logits=logits,
1139+
hidden_states=outputs.hidden_states,
1140+
attentions=outputs.attentions,
1141+
reshaped_hidden_states=outputs.reshaped_hidden_states,
1142+
)
1143+
1144+
1145+
__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"]

src/transformers/models/swin/modeling_swin.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import torch
2424
import torch.utils.checkpoint
2525
from torch import nn
26-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2726

2827
from ...activations import ACT2FN
2928
from ...modeling_outputs import BackboneOutput
@@ -1285,26 +1284,7 @@ def forward(
12851284

12861285
loss = None
12871286
if labels is not None:
1288-
if self.config.problem_type is None:
1289-
if self.num_labels == 1:
1290-
self.config.problem_type = "regression"
1291-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1292-
self.config.problem_type = "single_label_classification"
1293-
else:
1294-
self.config.problem_type = "multi_label_classification"
1295-
1296-
if self.config.problem_type == "regression":
1297-
loss_fct = MSELoss()
1298-
if self.num_labels == 1:
1299-
loss = loss_fct(logits.squeeze(), labels.squeeze())
1300-
else:
1301-
loss = loss_fct(logits, labels)
1302-
elif self.config.problem_type == "single_label_classification":
1303-
loss_fct = CrossEntropyLoss()
1304-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1305-
elif self.config.problem_type == "multi_label_classification":
1306-
loss_fct = BCEWithLogitsLoss()
1307-
loss = loss_fct(logits, labels)
1287+
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
13081288

13091289
if not return_dict:
13101290
output = (logits,) + outputs[2:]

src/transformers/models/swinv2/modeling_swinv2.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import torch
2424
import torch.utils.checkpoint
2525
from torch import Tensor, nn
26-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2726

2827
from ...activations import ACT2FN
2928
from ...modeling_outputs import BackboneOutput
@@ -1339,26 +1338,7 @@ def forward(
13391338

13401339
loss = None
13411340
if labels is not None:
1342-
if self.config.problem_type is None:
1343-
if self.num_labels == 1:
1344-
self.config.problem_type = "regression"
1345-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1346-
self.config.problem_type = "single_label_classification"
1347-
else:
1348-
self.config.problem_type = "multi_label_classification"
1349-
1350-
if self.config.problem_type == "regression":
1351-
loss_fct = MSELoss()
1352-
if self.num_labels == 1:
1353-
loss = loss_fct(logits.squeeze(), labels.squeeze())
1354-
else:
1355-
loss = loss_fct(logits, labels)
1356-
elif self.config.problem_type == "single_label_classification":
1357-
loss_fct = CrossEntropyLoss()
1358-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1359-
elif self.config.problem_type == "multi_label_classification":
1360-
loss_fct = BCEWithLogitsLoss()
1361-
loss = loss_fct(logits, labels)
1341+
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
13621342

13631343
if not return_dict:
13641344
output = (logits,) + outputs[2:]

src/transformers/utils/dummy_pt_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3829,6 +3829,13 @@ def __init__(self, *args, **kwargs):
38293829
requires_backends(self, ["torch"])
38303830

38313831

3832+
class DonutSwinForImageClassification(metaclass=DummyObject):
3833+
_backends = ["torch"]
3834+
3835+
def __init__(self, *args, **kwargs):
3836+
requires_backends(self, ["torch"])
3837+
3838+
38323839
class DonutSwinModel(metaclass=DummyObject):
38333840
_backends = ["torch"]
38343841

tests/models/donut/test_modeling_donut_swin.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import torch
3030
from torch import nn
3131

32-
from transformers import DonutSwinModel
32+
from transformers import DonutSwinForImageClassification, DonutSwinModel
3333

3434

3535
class DonutSwinModelTester:
@@ -129,6 +129,24 @@ def create_and_check_model(self, config, pixel_values, labels):
129129

130130
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
131131

132+
def create_and_check_for_image_classification(self, config, pixel_values, labels):
133+
config.num_labels = self.type_sequence_label_size
134+
model = DonutSwinForImageClassification(config)
135+
model.to(torch_device)
136+
model.eval()
137+
result = model(pixel_values, labels=labels)
138+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
139+
140+
# test greyscale images
141+
config.num_channels = 1
142+
model = DonutSwinForImageClassification(config)
143+
model.to(torch_device)
144+
model.eval()
145+
146+
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
147+
result = model(pixel_values)
148+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
149+
132150
def prepare_config_and_inputs_for_common(self):
133151
config_and_inputs = self.prepare_config_and_inputs()
134152
(
@@ -142,8 +160,12 @@ def prepare_config_and_inputs_for_common(self):
142160

143161
@require_torch
144162
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
145-
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
146-
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
163+
all_model_classes = (DonutSwinModel, DonutSwinForImageClassification) if is_torch_available() else ()
164+
pipeline_model_mapping = (
165+
{"image-feature-extraction": DonutSwinModel, "image-classification": DonutSwinForImageClassification}
166+
if is_torch_available()
167+
else {}
168+
)
147169
fx_compatible = True
148170

149171
test_pruning = False
@@ -167,6 +189,10 @@ def test_model(self):
167189
config_and_inputs = self.model_tester.prepare_config_and_inputs()
168190
self.model_tester.create_and_check_model(*config_and_inputs)
169191

192+
def test_for_image_classification(self):
193+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
194+
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
195+
170196
@unittest.skip(reason="DonutSwin does not use inputs_embeds")
171197
def test_inputs_embeds(self):
172198
pass

0 commit comments

Comments
 (0)