diff --git a/ML/Pytorch/image_segmentation/semantic_segmentation_unet/model.py b/ML/Pytorch/image_segmentation/semantic_segmentation_unet/model.py index 257ae15c..328361b8 100644 --- a/ML/Pytorch/image_segmentation/semantic_segmentation_unet/model.py +++ b/ML/Pytorch/image_segmentation/semantic_segmentation_unet/model.py @@ -52,17 +52,15 @@ def forward(self, x): x = self.pool(x) x = self.bottleneck(x) - skip_connections = skip_connections[::-1] - for idx in range(0, len(self.ups), 2): - x = self.ups[idx](x) - skip_connection = skip_connections[idx//2] + for idx, skip_connection in enumerate(reversed(skip_connections)): + x = self.ups[2*idx](x) if x.shape != skip_connection.shape: x = TF.resize(x, size=skip_connection.shape[2:]) concat_skip = torch.cat((skip_connection, x), dim=1) - x = self.ups[idx+1](concat_skip) + x = self.ups[2*idx + 1](concat_skip) return self.final_conv(x)