|
| 1 | +from collections.abc import Sequence |
| 2 | +from typing import Literal |
| 3 | + |
1 | 4 | import torch
|
2 | 5 | import torch.nn as nn
|
3 | 6 | import torch.nn.functional as F
|
@@ -44,7 +47,9 @@ def forward(self, x):
|
44 | 47 |
|
45 | 48 |
|
46 | 49 | class FPABlock(nn.Module):
|
47 |
| - def __init__(self, in_channels, out_channels, upscale_mode="bilinear"): |
| 50 | + def __init__( |
| 51 | + self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" |
| 52 | + ): |
48 | 53 | super(FPABlock, self).__init__()
|
49 | 54 |
|
50 | 55 | self.upscale_mode = upscale_mode
|
@@ -175,34 +180,56 @@ def forward(self, x, y):
|
175 | 180 |
|
176 | 181 | class PANDecoder(nn.Module):
|
177 | 182 | def __init__(
|
178 |
| - self, encoder_channels, decoder_channels, upscale_mode: str = "bilinear" |
| 183 | + self, |
| 184 | + encoder_channels: Sequence[int], |
| 185 | + encoder_depth: Literal[3, 4, 5], |
| 186 | + decoder_channels: int, |
| 187 | + upscale_mode: str = "bilinear", |
179 | 188 | ):
|
180 | 189 | super().__init__()
|
181 | 190 |
|
| 191 | + if encoder_depth < 3: |
| 192 | + raise ValueError( |
| 193 | + "Encoder depth for PAN decoder cannot be less than 3, got {}.".format( |
| 194 | + encoder_depth |
| 195 | + ) |
| 196 | + ) |
| 197 | + |
| 198 | + encoder_channels = encoder_channels[2:] |
| 199 | + |
182 | 200 | self.fpa = FPABlock(
|
183 | 201 | in_channels=encoder_channels[-1], out_channels=decoder_channels
|
184 | 202 | )
|
185 |
| - self.gau3 = GAUBlock( |
186 |
| - in_channels=encoder_channels[-2], |
187 |
| - out_channels=decoder_channels, |
188 |
| - upscale_mode=upscale_mode, |
189 |
| - ) |
190 |
| - self.gau2 = GAUBlock( |
191 |
| - in_channels=encoder_channels[-3], |
192 |
| - out_channels=decoder_channels, |
193 |
| - upscale_mode=upscale_mode, |
194 |
| - ) |
195 |
| - self.gau1 = GAUBlock( |
196 |
| - in_channels=encoder_channels[-4], |
197 |
| - out_channels=decoder_channels, |
198 |
| - upscale_mode=upscale_mode, |
199 |
| - ) |
| 203 | + |
| 204 | + if encoder_depth == 5: |
| 205 | + self.gau3 = GAUBlock( |
| 206 | + in_channels=encoder_channels[2], |
| 207 | + out_channels=decoder_channels, |
| 208 | + upscale_mode=upscale_mode, |
| 209 | + ) |
| 210 | + if encoder_depth >= 4: |
| 211 | + self.gau2 = GAUBlock( |
| 212 | + in_channels=encoder_channels[1], |
| 213 | + out_channels=decoder_channels, |
| 214 | + upscale_mode=upscale_mode, |
| 215 | + ) |
| 216 | + if encoder_depth >= 3: |
| 217 | + self.gau1 = GAUBlock( |
| 218 | + in_channels=encoder_channels[0], |
| 219 | + out_channels=decoder_channels, |
| 220 | + upscale_mode=upscale_mode, |
| 221 | + ) |
200 | 222 |
|
201 | 223 | def forward(self, *features):
|
202 |
| - bottleneck = features[-1] |
203 |
| - x5 = self.fpa(bottleneck) # 1/32 |
204 |
| - x4 = self.gau3(features[-2], x5) # 1/16 |
205 |
| - x3 = self.gau2(features[-3], x4) # 1/8 |
206 |
| - x2 = self.gau1(features[-4], x3) # 1/4 |
| 224 | + features = features[2:] # remove first and second skip |
| 225 | + |
| 226 | + out = self.fpa(features[-1]) # 1/16 or 1/32 |
| 227 | + |
| 228 | + if hasattr(self, "gau3"): |
| 229 | + out = self.gau3(features[2], out) # 1/16 |
| 230 | + if hasattr(self, "gau2"): |
| 231 | + out = self.gau2(features[1], out) # 1/8 |
| 232 | + if hasattr(self, "gau1"): |
| 233 | + out = self.gau1(features[0], out) # 1/4 |
207 | 234 |
|
208 |
| - return x2 |
| 235 | + return out |
0 commit comments