Skip to content

Commit 7c64aa5

Browse files
authored
Fix encoder depth & output stride on DeeplabV3 & DeeplabV3+ (#991)
* fix encoder depth & output stride * fix ruff style * Revert "fix ruff style" This reverts commit 79d5568. * fix encoder depth & output stride * fix ruff style * update deeplabv3+ doc * restored aux_params
1 parent 589583e commit 7c64aa5

File tree

2 files changed

+44
-35
lines changed

2 files changed

+44
-35
lines changed

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

+14-27
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(
6161
nn.BatchNorm2d(out_channels),
6262
nn.ReLU(),
6363
)
64-
self.out_channels = out_channels
6564

6665
def forward(self, *features):
6766
return super().forward(features[-1])
@@ -79,17 +78,12 @@ def __init__(
7978
aspp_dropout: float,
8079
):
8180
super().__init__()
82-
if encoder_depth not in (3, 4, 5):
81+
if encoder_depth < 3:
8382
raise ValueError(
84-
"Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth)
83+
"Encoder depth for DeepLabV3Plus decoder cannot be less than 3, got {}.".format(
84+
encoder_depth
85+
)
8586
)
86-
if output_stride not in (8, 16):
87-
raise ValueError(
88-
"Output stride should be 8 or 16, got {}.".format(output_stride)
89-
)
90-
91-
self.out_channels = out_channels
92-
self.output_stride = output_stride
9387

9488
self.aspp = nn.Sequential(
9589
ASPP(
@@ -106,17 +100,10 @@ def __init__(
106100
nn.ReLU(),
107101
)
108102

109-
scale_factor = 2 if output_stride == 8 else 4
103+
scale_factor = 4 if output_stride == 16 and encoder_depth > 3 else 2
110104
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
111105

112-
if encoder_depth == 3 and output_stride == 8:
113-
self.highres_input_index = -2
114-
elif encoder_depth == 3 or encoder_depth == 4:
115-
self.highres_input_index = -3
116-
else:
117-
self.highres_input_index = -4
118-
119-
highres_in_channels = encoder_channels[self.highres_input_index]
106+
highres_in_channels = encoder_channels[2]
120107
highres_out_channels = 48 # proposed by authors of paper
121108
self.block1 = nn.Sequential(
122109
nn.Conv2d(
@@ -140,7 +127,7 @@ def __init__(
140127
def forward(self, *features):
141128
aspp_features = self.aspp(features[-1])
142129
aspp_features = self.up(aspp_features)
143-
high_res_features = self.block1(features[self.highres_input_index])
130+
high_res_features = self.block1(features[2])
144131
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
145132
fused_features = self.block2(concat_features)
146133
return fused_features
@@ -240,13 +227,13 @@ def forward(self, x):
240227
class SeparableConv2d(nn.Sequential):
241228
def __init__(
242229
self,
243-
in_channels,
244-
out_channels,
245-
kernel_size,
246-
stride=1,
247-
padding=0,
248-
dilation=1,
249-
bias=True,
230+
in_channels: int,
231+
out_channels: int,
232+
kernel_size: int,
233+
stride: int = 1,
234+
padding: int = 0,
235+
dilation: int = 1,
236+
bias: bool = True,
250237
):
251238
dephtwise_conv = nn.Conv2d(
252239
in_channels,

segmentation_models_pytorch/decoders/deeplabv3/model.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@ class DeepLabV3(SegmentationModel):
3535
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
3636
**callable** and **None**.
3737
Default is **None**
38-
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
38+
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
3939
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
4040
on top of encoder if **aux_params** is not **None** (default). Supported params:
4141
- classes (int): A number of classes
4242
- pooling (str): One of "max", "avg". Default is "avg"
4343
- dropout (float): Dropout factor in [0, 1)
4444
- activation (str): An activation function to apply "sigmoid"/"softmax"
4545
(could be **None** to return logits)
46-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
46+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
47+
Keys with ``None`` values are pruned before passing.
4748
4849
Returns:
4950
``torch.nn.Module``: **DeepLabV3**
@@ -72,6 +73,12 @@ def __init__(
7273
):
7374
super().__init__()
7475

76+
if encoder_output_stride not in [8, 16]:
77+
raise ValueError(
78+
"DeeplabV3 support output stride 8 or 16, got {}.".format(
79+
encoder_output_stride
80+
)
81+
)
7582
self.encoder = get_encoder(
7683
encoder_name,
7784
in_channels=in_channels,
@@ -81,6 +88,14 @@ def __init__(
8188
**kwargs,
8289
)
8390

91+
if upsampling is None:
92+
if encoder_depth <= 3:
93+
scale_factor = 2**encoder_depth
94+
else:
95+
scale_factor = encoder_output_stride
96+
else:
97+
scale_factor = upsampling
98+
8499
self.decoder = DeepLabV3Decoder(
85100
in_channels=self.encoder.out_channels[-1],
86101
out_channels=decoder_channels,
@@ -90,11 +105,11 @@ def __init__(
90105
)
91106

92107
self.segmentation_head = SegmentationHead(
93-
in_channels=self.decoder.out_channels,
108+
in_channels=decoder_channels,
94109
out_channels=classes,
95110
activation=activation,
96111
kernel_size=1,
97-
upsampling=encoder_output_stride if upsampling is None else upsampling,
112+
upsampling=scale_factor,
98113
)
99114

100115
if aux_params is not None:
@@ -129,16 +144,16 @@ class DeepLabV3Plus(SegmentationModel):
129144
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
130145
**callable** and **None**.
131146
Default is **None**
132-
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case
133-
**encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve.
147+
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity.
134148
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
135149
on top of encoder if **aux_params** is not **None** (default). Supported params:
136150
- classes (int): A number of classes
137151
- pooling (str): One of "max", "avg". Default is "avg"
138152
- dropout (float): Dropout factor in [0, 1)
139153
- activation (str): An activation function to apply "sigmoid"/"softmax"
140154
(could be **None** to return logits)
141-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
155+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
156+
Keys with ``None`` values are pruned before passing.
142157
143158
Returns:
144159
``torch.nn.Module``: **DeepLabV3Plus**
@@ -167,6 +182,13 @@ def __init__(
167182
):
168183
super().__init__()
169184

185+
if encoder_output_stride not in [8, 16]:
186+
raise ValueError(
187+
"DeeplabV3Plus support output stride 8 or 16, got {}.".format(
188+
encoder_output_stride
189+
)
190+
)
191+
170192
self.encoder = get_encoder(
171193
encoder_name,
172194
in_channels=in_channels,
@@ -187,7 +209,7 @@ def __init__(
187209
)
188210

189211
self.segmentation_head = SegmentationHead(
190-
in_channels=self.decoder.out_channels,
212+
in_channels=decoder_channels,
191213
out_channels=classes,
192214
activation=activation,
193215
kernel_size=1,

0 commit comments

Comments
 (0)