1717import torch
1818import torch .nn as nn
1919from monai .networks .blocks import Convolution
20- from monai .networks .layers import Act
20+ from monai .networks .layers import Act , get_pool_layer
2121
2222
2323class MultiScalePatchDiscriminator (nn .Sequential ):
@@ -38,6 +38,8 @@ class MultiScalePatchDiscriminator(nn.Sequential):
3838 spatial_dims: number of spatial dimensions (1D, 2D etc.)
3939 num_channels: number of filters in the first convolutional layer (double of the value is taken from then on)
4040 in_channels: number of input channels
41+ pooling_method: pooling method to be applied before each discriminator after the first.
42+ If None, the number of layers is multiplied by the number of discriminators.
4143 out_channels: number of output channels in each discriminator
4244 kernel_size: kernel size of the convolution layers
4345 activation: activation layer type
@@ -52,10 +54,11 @@ class MultiScalePatchDiscriminator(nn.Sequential):
5254 def __init__ (
5355 self ,
5456 num_d : int ,
55- num_layers_d : int ,
57+ num_layers_d : int | list [ int ] ,
5658 spatial_dims : int ,
5759 num_channels : int ,
5860 in_channels : int ,
61+ pooling_method : str = None ,
5962 out_channels : int = 1 ,
6063 kernel_size : int = 4 ,
6164 activation : str | tuple = (Act .LEAKYRELU , {"negative_slope" : 0.2 }),
@@ -67,31 +70,67 @@ def __init__(
6770 ) -> None :
6871 super ().__init__ ()
6972 self .num_d = num_d
73+ if isinstance (num_layers_d , int ) and pooling_method is None :
74+ # if pooling_method is None, calculate the number of layers for each discriminator by multiplying by the number of discriminators
75+ num_layers_d = [num_layers_d * i for i in range (1 , num_d + 1 )]
76+ elif isinstance (num_layers_d , int ) and pooling_method is not None :
77+ # if pooling_method is not None, the number of layers is the same for all discriminators
78+ num_layers_d = [num_layers_d ] * num_d
7079 self .num_layers_d = num_layers_d
71- self .num_channels = num_channels
80+ assert (
81+ len (self .num_layers_d ) == self .num_d
82+ ), f"MultiScalePatchDiscriminator: num_d { num_d } must match the number of num_layers_d. { num_layers_d } "
83+
7284 self .padding = tuple ([int ((kernel_size - 1 ) / 2 )] * spatial_dims )
85+
86+ if pooling_method is None :
87+ pool = None
88+ else :
89+ pool = get_pool_layer (
90+ (pooling_method , {"kernel_size" : kernel_size , "stride" : 2 , 'padding' : self .padding }), spatial_dims = spatial_dims
91+ )
92+ self .num_channels = num_channels
7393 for i_ in range (self .num_d ):
74- num_layers_d_i = self .num_layers_d * ( i_ + 1 )
94+ num_layers_d_i = self .num_layers_d [ i_ ]
7595 output_size = float (minimum_size_im ) / (2 ** num_layers_d_i )
7696 if output_size < 1 :
7797 raise AssertionError (
7898 "Your image size is too small to take in up to %d discriminators with num_layers = %d."
7999 "Please reduce num_layers, reduce num_D or enter bigger images." % (i_ , num_layers_d_i )
80100 )
81- subnet_d = PatchDiscriminator (
82- spatial_dims = spatial_dims ,
83- num_channels = self .num_channels ,
84- in_channels = in_channels ,
85- out_channels = out_channels ,
86- num_layers_d = num_layers_d_i ,
87- kernel_size = kernel_size ,
88- activation = activation ,
89- norm = norm ,
90- bias = bias ,
91- padding = self .padding ,
92- dropout = dropout ,
93- last_conv_kernel_size = last_conv_kernel_size ,
94- )
101+ if i_ == 0 or pool is None :
102+ subnet_d = PatchDiscriminator (
103+ spatial_dims = spatial_dims ,
104+ num_channels = self .num_channels ,
105+ in_channels = in_channels ,
106+ out_channels = out_channels ,
107+ num_layers_d = num_layers_d_i ,
108+ kernel_size = kernel_size ,
109+ activation = activation ,
110+ norm = norm ,
111+ bias = bias ,
112+ padding = self .padding ,
113+ dropout = dropout ,
114+ last_conv_kernel_size = last_conv_kernel_size ,
115+ )
116+ else :
117+ subnet_d = nn .Sequential (
118+ * [pool ] * i_ ,
119+ PatchDiscriminator (
120+ spatial_dims = spatial_dims ,
121+ num_channels = self .num_channels ,
122+ in_channels = in_channels ,
123+ out_channels = out_channels ,
124+ num_layers_d = num_layers_d_i ,
125+ kernel_size = kernel_size ,
126+ activation = activation ,
127+ norm = norm ,
128+ bias = bias ,
129+ padding = self .padding ,
130+ dropout = dropout ,
131+ last_conv_kernel_size = last_conv_kernel_size ,
132+ ),
133+ )
95134
96135 self .add_module ("discriminator_%d" % i_ , subnet_d )
97136
0 commit comments