Skip to content

Commit 480d303

Browse files
authored
Update PAN Decoder support encoder depth (#999)
* update PAN model support encoder depth * update PAN decoder support encoder depth * add typing and fix ruff style * update PAN test sample size * del print * update decoder
1 parent 7c64aa5 commit 480d303

File tree

3 files changed

+61
-30
lines changed

3 files changed

+61
-30
lines changed

segmentation_models_pytorch/decoders/pan/decoder.py

+50-23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from collections.abc import Sequence
2+
from typing import Literal
3+
14
import torch
25
import torch.nn as nn
36
import torch.nn.functional as F
@@ -44,7 +47,9 @@ def forward(self, x):
4447

4548

4649
class FPABlock(nn.Module):
47-
def __init__(self, in_channels, out_channels, upscale_mode="bilinear"):
50+
def __init__(
51+
self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear"
52+
):
4853
super(FPABlock, self).__init__()
4954

5055
self.upscale_mode = upscale_mode
@@ -175,34 +180,56 @@ def forward(self, x, y):
175180

176181
class PANDecoder(nn.Module):
177182
def __init__(
178-
self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear"
183+
self,
184+
encoder_channels: Sequence[int],
185+
encoder_depth: Literal[3, 4, 5],
186+
decoder_channels: int,
187+
upscale_mode: str = "bilinear",
179188
):
180189
super().__init__()
181190

191+
if encoder_depth < 3:
192+
raise ValueError(
193+
"Encoder depth for PAN decoder cannot be less than 3, got {}.".format(
194+
encoder_depth
195+
)
196+
)
197+
198+
encoder_channels = encoder_channels[2:]
199+
182200
self.fpa = FPABlock(
183201
in_channels=encoder_channels[-1], out_channels=decoder_channels
184202
)
185-
self.gau3 = GAUBlock(
186-
in_channels=encoder_channels[-2],
187-
out_channels=decoder_channels,
188-
upscale_mode=upscale_mode,
189-
)
190-
self.gau2 = GAUBlock(
191-
in_channels=encoder_channels[-3],
192-
out_channels=decoder_channels,
193-
upscale_mode=upscale_mode,
194-
)
195-
self.gau1 = GAUBlock(
196-
in_channels=encoder_channels[-4],
197-
out_channels=decoder_channels,
198-
upscale_mode=upscale_mode,
199-
)
203+
204+
if encoder_depth == 5:
205+
self.gau3 = GAUBlock(
206+
in_channels=encoder_channels[2],
207+
out_channels=decoder_channels,
208+
upscale_mode=upscale_mode,
209+
)
210+
if encoder_depth >= 4:
211+
self.gau2 = GAUBlock(
212+
in_channels=encoder_channels[1],
213+
out_channels=decoder_channels,
214+
upscale_mode=upscale_mode,
215+
)
216+
if encoder_depth >= 3:
217+
self.gau1 = GAUBlock(
218+
in_channels=encoder_channels[0],
219+
out_channels=decoder_channels,
220+
upscale_mode=upscale_mode,
221+
)
200222

201223
def forward(self, *features):
202-
bottleneck = features[-1]
203-
x5 = self.fpa(bottleneck) # 1/32
204-
x4 = self.gau3(features[-2], x5) # 1/16
205-
x3 = self.gau2(features[-3], x4) # 1/8
206-
x2 = self.gau1(features[-4], x3) # 1/4
224+
features = features[2:] # remove first and second skip
225+
226+
out = self.fpa(features[-1]) # 1/16 or 1/32
227+
228+
if hasattr(self, "gau3"):
229+
out = self.gau3(features[2], out) # 1/16
230+
if hasattr(self, "gau2"):
231+
out = self.gau2(features[1], out) # 1/8
232+
if hasattr(self, "gau1"):
233+
out = self.gau1(features[0], out) # 1/4
207234

208-
return x2
235+
return out

segmentation_models_pytorch/decoders/pan/model.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Callable, Literal, Optional, Union
22

33
from segmentation_models_pytorch.base import (
44
ClassificationHead,
@@ -20,6 +20,10 @@ class PAN(SegmentationModel):
2020
Args:
2121
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
2222
to extract features of different spatial resolution
23+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
24+
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
25+
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
26+
Default is 5
2327
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
2428
other pretrained weights (see table with available weights for each encoder_name)
2529
encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer.
@@ -52,12 +56,13 @@ class PAN(SegmentationModel):
5256
def __init__(
5357
self,
5458
encoder_name: str = "resnet34",
59+
encoder_depth: Literal[3, 4, 5] = 5,
5560
encoder_weights: Optional[str] = "imagenet",
56-
encoder_output_stride: int = 16,
61+
encoder_output_stride: Literal[16, 32] = 16,
5762
decoder_channels: int = 32,
5863
in_channels: int = 3,
5964
classes: int = 1,
60-
activation: Optional[Union[str, callable]] = None,
65+
activation: Optional[Union[str, Callable]] = None,
6166
upsampling: int = 4,
6267
aux_params: Optional[dict] = None,
6368
**kwargs: dict[str, Any],
@@ -74,14 +79,15 @@ def __init__(
7479
self.encoder = get_encoder(
7580
encoder_name,
7681
in_channels=in_channels,
77-
depth=5,
82+
depth=encoder_depth,
7883
weights=encoder_weights,
7984
output_stride=encoder_output_stride,
8085
**kwargs,
8186
)
8287

8388
self.decoder = PANDecoder(
8489
encoder_channels=self.encoder.out_channels,
90+
encoder_depth=encoder_depth,
8591
decoder_channels=decoder_channels,
8692
)
8793

tests/test_models.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ def get_sample(model_class):
3131
smp.Segformer,
3232
]:
3333
sample = torch.ones([1, 3, 64, 64])
34-
elif model_class == smp.PAN:
35-
sample = torch.ones([2, 3, 256, 256])
36-
elif model_class in [smp.DeepLabV3, smp.DeepLabV3Plus]:
34+
elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]:
3735
sample = torch.ones([2, 3, 128, 128])
3836
elif model_class in [smp.PSPNet, smp.UPerNet]:
3937
# Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input

0 commit comments

Comments
 (0)