45
45
import torch .nn as nn
46
46
47
47
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
48
- from timm .layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , Mlp , GlobalResponseNormMlp , \
48
+ from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , Mlp , GlobalResponseNormMlp , \
49
49
LayerNorm2d , LayerNorm , create_conv2d , get_act_layer , make_divisible , to_ntuple
50
50
from timm .layers import NormMlpClassifierHead , ClassifierHead
51
51
from ._builder import build_model_with_cfg
56
56
__all__ = ['ConvNeXt' ] # model_registry will add each entrypoint fn to this
57
57
58
58
59
+ class Downsample (nn .Module ):
60
+
61
+ def __init__ (self , in_chs , out_chs , stride = 1 , dilation = 1 ):
62
+ super ().__init__ ()
63
+ avg_stride = stride if dilation == 1 else 1
64
+ if stride > 1 or dilation > 1 :
65
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn .AvgPool2d
66
+ self .pool = avg_pool_fn (2 , avg_stride , ceil_mode = True , count_include_pad = False )
67
+ else :
68
+ self .pool = nn .Identity ()
69
+
70
+ if in_chs != out_chs :
71
+ self .conv = create_conv2d (in_chs , out_chs , 1 , stride = 1 )
72
+ else :
73
+ self .conv = nn .Identity ()
74
+
75
+ def forward (self , x ):
76
+ x = self .pool (x )
77
+ x = self .conv (x )
78
+ return x
79
+
80
+
59
81
class ConvNeXtBlock (nn .Module ):
60
82
""" ConvNeXt Block
61
83
There are two equivalent implementations:
@@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
65
87
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
66
88
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
67
89
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
68
-
69
- Args:
70
- in_chs (int): Number of input channels.
71
- drop_path (float): Stochastic depth rate. Default: 0.0
72
- ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
73
90
"""
74
91
75
92
def __init__ (
76
93
self ,
77
- in_chs ,
78
- out_chs = None ,
79
- kernel_size = 7 ,
80
- stride = 1 ,
81
- dilation = 1 ,
82
- mlp_ratio = 4 ,
83
- conv_mlp = False ,
84
- conv_bias = True ,
85
- use_grn = False ,
86
- ls_init_value = 1e-6 ,
87
- act_layer = 'gelu' ,
88
- norm_layer = None ,
89
- drop_path = 0. ,
94
+ in_chs : int ,
95
+ out_chs : Optional [ int ] = None ,
96
+ kernel_size : int = 7 ,
97
+ stride : int = 1 ,
98
+ dilation : Union [ int , Tuple [ int , int ]] = ( 1 , 1 ) ,
99
+ mlp_ratio : float = 4 ,
100
+ conv_mlp : bool = False ,
101
+ conv_bias : bool = True ,
102
+ use_grn : bool = False ,
103
+ ls_init_value : Optional [ float ] = 1e-6 ,
104
+ act_layer : Union [ str , Callable ] = 'gelu' ,
105
+ norm_layer : Optional [ Callable ] = None ,
106
+ drop_path : float = 0. ,
90
107
):
108
+ """
109
+
110
+ Args:
111
+ in_chs: Block input channels.
112
+ out_chs: Block output channels (same as in_chs if None).
113
+ kernel_size: Depthwise convolution kernel size.
114
+ stride: Stride of depthwise convolution.
115
+ dilation: Tuple specifying input and output dilation of block.
116
+ mlp_ratio: MLP expansion ratio.
117
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
118
+ conv_bias: Apply bias for all convolution (linear) layers.
119
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
120
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
121
+ act_layer: Activation layer.
122
+ norm_layer: Normalization layer (defaults to LN if not specified).
123
+ drop_path: Stochastic depth probability.
124
+ """
91
125
super ().__init__ ()
92
126
out_chs = out_chs or in_chs
127
+ dilation = to_ntuple (2 )(dilation )
93
128
act_layer = get_act_layer (act_layer )
94
129
if not norm_layer :
95
130
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
96
131
mlp_layer = partial (GlobalResponseNormMlp if use_grn else Mlp , use_conv = conv_mlp )
97
132
self .use_conv_mlp = conv_mlp
98
133
self .conv_dw = create_conv2d (
99
- in_chs , out_chs , kernel_size = kernel_size , stride = stride , dilation = dilation , depthwise = True , bias = conv_bias )
134
+ in_chs ,
135
+ out_chs ,
136
+ kernel_size = kernel_size ,
137
+ stride = stride ,
138
+ dilation = dilation [0 ],
139
+ depthwise = True ,
140
+ bias = conv_bias ,
141
+ )
100
142
self .norm = norm_layer (out_chs )
101
143
self .mlp = mlp_layer (out_chs , int (mlp_ratio * out_chs ), act_layer = act_layer )
102
144
self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs )) if ls_init_value is not None else None
145
+ if in_chs != out_chs or stride != 1 or dilation [0 ] != dilation [1 ]:
146
+ self .shortcut = Downsample (in_chs , out_chs , stride = stride , dilation = dilation [0 ])
147
+ else :
148
+ self .shortcut = nn .Identity ()
103
149
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
104
150
105
151
def forward (self , x ):
@@ -116,7 +162,7 @@ def forward(self, x):
116
162
if self .gamma is not None :
117
163
x = x .mul (self .gamma .reshape (1 , - 1 , 1 , 1 ))
118
164
119
- x = self .drop_path (x ) + shortcut
165
+ x = self .drop_path (x ) + self . shortcut ( shortcut )
120
166
return x
121
167
122
168
@@ -148,8 +194,14 @@ def __init__(
148
194
self .downsample = nn .Sequential (
149
195
norm_layer (in_chs ),
150
196
create_conv2d (
151
- in_chs , out_chs , kernel_size = ds_ks , stride = stride ,
152
- dilation = dilation [0 ], padding = pad , bias = conv_bias ),
197
+ in_chs ,
198
+ out_chs ,
199
+ kernel_size = ds_ks ,
200
+ stride = stride ,
201
+ dilation = dilation [0 ],
202
+ padding = pad ,
203
+ bias = conv_bias ,
204
+ ),
153
205
)
154
206
in_chs = out_chs
155
207
else :
0 commit comments