Skip to content

Commit a7eba83

Browse files
ariG23498sayakpaulsguggerRocketknight1amyeroberts
authored
TF implementation of RegNets (huggingface#17554)
* chore: initial commit Copied the torch implementation of regnets and porting the code to tf step by step. Also introduced an output layer which was needed for regnets. * chore: porting the rest of the modules to tensorflow did not change the documentation yet, yet to try the playground on the model * Fix initilizations (#1) * fix: code structure in few cases. * fix: code structure to align tf models. * fix: layer naming, bn layer still remains. * chore: change default epsilon and momentum in bn. * chore: styling nits. * fix: cross-loading bn params. * fix: regnet tf model, integration passing. * add: tests for TF regnet. * fix: code quality related issues. * chore: added rest of the files. * minor additions.. * fix: repo consistency. * fix: regnet tf tests. * chore: reorganize dummy_tf_objects for regnet. * chore: remove checkpoint var. * chore: remov unnecessary files. * chore: run make style. * Update docs/source/en/model_doc/regnet.mdx Co-authored-by: Sylvain Gugger <[email protected]> * chore: PR feedback I. * fix: pt test. thanks to @ydshieh. * New adaptive pooler (#3) * feat: new adaptive pooler Co-authored-by: @Rocketknight1 * chore: remove image_size argument. Co-authored-by: matt <[email protected]> Co-authored-by: matt <[email protected]> * Empty-Commit * chore: remove image_size comment. * chore: remove playground_tf.py * chore: minor changes related to spacing. * chore: make style. * Update src/transformers/models/regnet/modeling_tf_regnet.py Co-authored-by: amyeroberts <[email protected]> * Update src/transformers/models/regnet/modeling_tf_regnet.py Co-authored-by: amyeroberts <[email protected]> * chore: refactored __init__. * chore: copied from -> taken from./g * adaptive pool -> global avg pool, channel check. * chore: move channel check to stem. * pr comments - minor refactor and add regnets to doc tests. * Update src/transformers/models/regnet/modeling_tf_regnet.py Co-authored-by: NielsRogge <[email protected]> * minor fix in the xlayer. * Empty-Commit * chore: removed from_pt=True. Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: matt <[email protected]> Co-authored-by: amyeroberts <[email protected]> Co-authored-by: NielsRogge <[email protected]>
1 parent e6d27ca commit a7eba83

File tree

10 files changed

+943
-5
lines changed

10 files changed

+943
-5
lines changed

docs/source/en/index.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ Flax), PyTorch, and/or TensorFlow.
267267
| RAG | | | | | |
268268
| REALM | | | | | |
269269
| Reformer | | | | | |
270-
| RegNet | | | | | |
270+
| RegNet | | | | | |
271271
| RemBERT | | | | | |
272272
| ResNet | | | | | |
273273
| RetriBERT | | | | | |

docs/source/en/model_doc/regnet.mdx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ Tips:
2727
- One can use [`AutoFeatureExtractor`] to prepare images for the model.
2828
- The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer)
2929

30-
This model was contributed by [Francesco](https://huggingface.co/Francesco).
30+
This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of the model
31+
was contributed by [sayakpaul](https://huggingface.com/sayakpaul) and [ariG23498](https://huggingface.com/ariG23498).
3132
The original code can be found [here](https://github.com/facebookresearch/pycls).
3233

3334

@@ -45,4 +46,15 @@ The original code can be found [here](https://github.com/facebookresearch/pycls)
4546
## RegNetForImageClassification
4647

4748
[[autodoc]] RegNetForImageClassification
48-
- forward
49+
- forward
50+
51+
## TFRegNetModel
52+
53+
[[autodoc]] TFRegNetModel
54+
- call
55+
56+
57+
## TFRegNetForImageClassification
58+
59+
[[autodoc]] TFRegNetForImageClassification
60+
- call

src/transformers/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,6 +2334,14 @@
23342334
"TFRagTokenForGeneration",
23352335
]
23362336
)
2337+
_import_structure["models.regnet"].extend(
2338+
[
2339+
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
2340+
"TFRegNetForImageClassification",
2341+
"TFRegNetModel",
2342+
"TFRegNetPreTrainedModel",
2343+
]
2344+
)
23372345
_import_structure["models.rembert"].extend(
23382346
[
23392347
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -4649,6 +4657,12 @@
46494657
from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
46504658
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
46514659
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
4660+
from .models.regnet import (
4661+
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
4662+
TFRegNetForImageClassification,
4663+
TFRegNetModel,
4664+
TFRegNetPreTrainedModel,
4665+
)
46524666
from .models.rembert import (
46534667
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
46544668
TFRemBertForCausalLM,

src/transformers/modeling_tf_outputs.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,25 @@ class TFBaseModelOutput(ModelOutput):
4646
attentions: Optional[Tuple[tf.Tensor]] = None
4747

4848

49+
@dataclass
50+
class TFBaseModelOutputWithNoAttention(ModelOutput):
51+
"""
52+
Base class for model's outputs, with potential hidden states.
53+
54+
Args:
55+
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
56+
Sequence of hidden-states at the output of the last layer of the model.
57+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
58+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
59+
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
60+
61+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
62+
"""
63+
64+
last_hidden_state: tf.Tensor = None
65+
hidden_states: Optional[Tuple[tf.Tensor]] = None
66+
67+
4968
@dataclass
5069
class TFBaseModelOutputWithPooling(ModelOutput):
5170
"""
@@ -80,6 +99,28 @@ class TFBaseModelOutputWithPooling(ModelOutput):
8099
attentions: Optional[Tuple[tf.Tensor]] = None
81100

82101

102+
@dataclass
103+
class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
104+
"""
105+
Base class for model's outputs that also contains a pooling of the last hidden states.
106+
107+
Args:
108+
last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
109+
Sequence of hidden-states at the output of the last layer of the model.
110+
pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
111+
Last layer hidden-state after a pooling operation on the spatial dimensions.
112+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
113+
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
114+
the output of each layer) of shape `(batch_size, num_channels, height, width)`.
115+
116+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
117+
"""
118+
119+
last_hidden_state: tf.Tensor = None
120+
pooler_output: tf.Tensor = None
121+
hidden_states: Optional[Tuple[tf.Tensor]] = None
122+
123+
83124
@dataclass
84125
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
85126
"""
@@ -825,3 +866,24 @@ class TFSequenceClassifierOutputWithPast(ModelOutput):
825866
past_key_values: Optional[List[tf.Tensor]] = None
826867
hidden_states: Optional[Tuple[tf.Tensor]] = None
827868
attentions: Optional[Tuple[tf.Tensor]] = None
869+
870+
871+
@dataclass
872+
class TFImageClassifierOutputWithNoAttention(ModelOutput):
873+
"""
874+
Base class for outputs of image classification models.
875+
876+
Args:
877+
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
878+
Classification (or regression if config.num_labels==1) loss.
879+
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
880+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
881+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
882+
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
883+
the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called
884+
feature maps) of the model at the output of each stage.
885+
"""
886+
887+
loss: Optional[tf.Tensor] = None
888+
logits: tf.Tensor = None
889+
hidden_states: Optional[Tuple[tf.Tensor]] = None

src/transformers/models/auto/modeling_tf_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
("openai-gpt", "TFOpenAIGPTModel"),
6363
("opt", "TFOPTModel"),
6464
("pegasus", "TFPegasusModel"),
65+
("regnet", "TFRegNetModel"),
6566
("rembert", "TFRemBertModel"),
6667
("roberta", "TFRobertaModel"),
6768
("roformer", "TFRoFormerModel"),
@@ -173,6 +174,7 @@
173174
# Model for Image-classsification
174175
("convnext", "TFConvNextForImageClassification"),
175176
("data2vec-vision", "TFData2VecVisionForImageClassification"),
177+
("regnet", "TFRegNetForImageClassification"),
176178
("swin", "TFSwinForImageClassification"),
177179
("vit", "TFViTForImageClassification"),
178180
]

src/transformers/models/regnet/__init__.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from typing import TYPE_CHECKING
1919

2020
# rely on isort to merge the imports
21-
from ...file_utils import _LazyModule, is_torch_available
22-
from ...utils import OptionalDependencyNotAvailable
21+
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
2322

2423

2524
_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
@@ -37,6 +36,19 @@
3736
"RegNetPreTrainedModel",
3837
]
3938

39+
try:
40+
if not is_tf_available():
41+
raise OptionalDependencyNotAvailable()
42+
except OptionalDependencyNotAvailable:
43+
pass
44+
else:
45+
_import_structure["modeling_tf_regnet"] = [
46+
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
47+
"TFRegNetForImageClassification",
48+
"TFRegNetModel",
49+
"TFRegNetPreTrainedModel",
50+
]
51+
4052

4153
if TYPE_CHECKING:
4254
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
@@ -54,6 +66,19 @@
5466
RegNetPreTrainedModel,
5567
)
5668

69+
try:
70+
if not is_tf_available():
71+
raise OptionalDependencyNotAvailable()
72+
except OptionalDependencyNotAvailable:
73+
pass
74+
else:
75+
from .modeling_tf_regnet import (
76+
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
77+
TFRegNetForImageClassification,
78+
TFRegNetModel,
79+
TFRegNetPreTrainedModel,
80+
)
81+
5782

5883
else:
5984
import sys

0 commit comments

Comments
 (0)