Skip to content

Commit cd3ee78

Browse files
authored
Merge pull request #1715 from huggingface/convnext_shortcut
Add support to ConvNextBlock for downsample and ch expansion. Fix #1699
2 parents aa8c070 + ad94d73 commit cd3ee78

File tree

1 file changed

+75
-23
lines changed

1 file changed

+75
-23
lines changed

timm/models/convnext.py

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import torch.nn as nn
4646

4747
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
48-
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
48+
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
4949
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
5050
from timm.layers import NormMlpClassifierHead, ClassifierHead
5151
from ._builder import build_model_with_cfg
@@ -56,6 +56,28 @@
5656
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
5757

5858

59+
class Downsample(nn.Module):
60+
61+
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
62+
super().__init__()
63+
avg_stride = stride if dilation == 1 else 1
64+
if stride > 1 or dilation > 1:
65+
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
66+
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
67+
else:
68+
self.pool = nn.Identity()
69+
70+
if in_chs != out_chs:
71+
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
72+
else:
73+
self.conv = nn.Identity()
74+
75+
def forward(self, x):
76+
x = self.pool(x)
77+
x = self.conv(x)
78+
return x
79+
80+
5981
class ConvNeXtBlock(nn.Module):
6082
""" ConvNeXt Block
6183
There are two equivalent implementations:
@@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
6587
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
6688
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
6789
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
68-
69-
Args:
70-
in_chs (int): Number of input channels.
71-
drop_path (float): Stochastic depth rate. Default: 0.0
72-
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
7390
"""
7491

7592
def __init__(
7693
self,
77-
in_chs,
78-
out_chs=None,
79-
kernel_size=7,
80-
stride=1,
81-
dilation=1,
82-
mlp_ratio=4,
83-
conv_mlp=False,
84-
conv_bias=True,
85-
use_grn=False,
86-
ls_init_value=1e-6,
87-
act_layer='gelu',
88-
norm_layer=None,
89-
drop_path=0.,
94+
in_chs: int,
95+
out_chs: Optional[int] = None,
96+
kernel_size: int = 7,
97+
stride: int = 1,
98+
dilation: Union[int, Tuple[int, int]] = (1, 1),
99+
mlp_ratio: float = 4,
100+
conv_mlp: bool = False,
101+
conv_bias: bool = True,
102+
use_grn: bool = False,
103+
ls_init_value: Optional[float] = 1e-6,
104+
act_layer: Union[str, Callable] = 'gelu',
105+
norm_layer: Optional[Callable] = None,
106+
drop_path: float = 0.,
90107
):
108+
"""
109+
110+
Args:
111+
in_chs: Block input channels.
112+
out_chs: Block output channels (same as in_chs if None).
113+
kernel_size: Depthwise convolution kernel size.
114+
stride: Stride of depthwise convolution.
115+
dilation: Tuple specifying input and output dilation of block.
116+
mlp_ratio: MLP expansion ratio.
117+
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
118+
conv_bias: Apply bias for all convolution (linear) layers.
119+
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
120+
ls_init_value: Layer-scale init values, layer-scale applied if not None.
121+
act_layer: Activation layer.
122+
norm_layer: Normalization layer (defaults to LN if not specified).
123+
drop_path: Stochastic depth probability.
124+
"""
91125
super().__init__()
92126
out_chs = out_chs or in_chs
127+
dilation = to_ntuple(2)(dilation)
93128
act_layer = get_act_layer(act_layer)
94129
if not norm_layer:
95130
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
96131
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
97132
self.use_conv_mlp = conv_mlp
98133
self.conv_dw = create_conv2d(
99-
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
134+
in_chs,
135+
out_chs,
136+
kernel_size=kernel_size,
137+
stride=stride,
138+
dilation=dilation[0],
139+
depthwise=True,
140+
bias=conv_bias,
141+
)
100142
self.norm = norm_layer(out_chs)
101143
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
102144
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
145+
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
146+
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
147+
else:
148+
self.shortcut = nn.Identity()
103149
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
104150

105151
def forward(self, x):
@@ -116,7 +162,7 @@ def forward(self, x):
116162
if self.gamma is not None:
117163
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
118164

119-
x = self.drop_path(x) + shortcut
165+
x = self.drop_path(x) + self.shortcut(shortcut)
120166
return x
121167

122168

@@ -148,8 +194,14 @@ def __init__(
148194
self.downsample = nn.Sequential(
149195
norm_layer(in_chs),
150196
create_conv2d(
151-
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
152-
dilation=dilation[0], padding=pad, bias=conv_bias),
197+
in_chs,
198+
out_chs,
199+
kernel_size=ds_ks,
200+
stride=stride,
201+
dilation=dilation[0],
202+
padding=pad,
203+
bias=conv_bias,
204+
),
153205
)
154206
in_chs = out_chs
155207
else:

0 commit comments

Comments
 (0)