1
+ import tensorflow as tf
2
+ from keras .models import Model
3
+ from keras .layers import Input
4
+ from model_blocks import *
5
+
6
+ def AW_Net (input_shape , NUM_CLASSES = 4 , dropout_rate = 0.0 , batch_norm = True ):
7
+ '''
8
+ Attention UNet,
9
+
10
+ '''
11
+ # network structure
12
+ FILTER_NUM = 16 # number of basic filters for the first layer
13
+ FILTER_SIZE = 3 # size of the convolutional filter
14
+ UP_SAMP_SIZE = 2 # size of upsampling filters
15
+
16
+ inputs = layers .Input (input_shape , dtype = tf .float32 )
17
+
18
+ # Downsampling layers
19
+ # DownRes 1, convolution + pooling
20
+ conv_128 = conv_block (inputs , FILTER_SIZE , FILTER_NUM , dropout_rate , 1 , batch_norm )
21
+ pool_64 = layers .MaxPooling2D (pool_size = (2 ,2 ))(conv_128 )
22
+ # DownRes 2
23
+ conv_64 = conv_block (pool_64 , FILTER_SIZE , 2 * FILTER_NUM , dropout_rate , 2 , batch_norm )
24
+ pool_32 = layers .MaxPooling2D (pool_size = (2 ,2 ))(conv_64 )
25
+ # DownRes 3
26
+ conv_32 = conv_block (pool_32 , FILTER_SIZE , 4 * FILTER_NUM , dropout_rate , 3 , batch_norm )
27
+ pool_16 = layers .MaxPooling2D (pool_size = (2 ,2 ))(conv_32 )
28
+ # DownRes 4
29
+ conv_16 = conv_block (pool_16 , FILTER_SIZE , 8 * FILTER_NUM , dropout_rate ,4 , batch_norm )
30
+ pool_8 = layers .MaxPooling2D (pool_size = (2 ,2 ))(conv_16 )
31
+ # DownRes 5, convolution only
32
+ conv_8 = reg_conv_block (pool_8 , FILTER_SIZE , 16 * FILTER_NUM , dropout_rate ,5 , batch_norm )
33
+
34
+ # W-net layers
35
+ gatingw_16 = gating_signal (conv_8 , 8 * FILTER_NUM , batch_norm )
36
+ attw_16 = attention_block (conv_16 , gatingw_16 , 8 * FILTER_NUM )
37
+ upw_16 = layers .UpSampling2D (size = (UP_SAMP_SIZE , UP_SAMP_SIZE ), data_format = "channels_last" )(conv_8 )
38
+ upw_16 = layers .concatenate ([upw_16 , attw_16 ], axis = 3 )
39
+ up_convw_16 = reg_conv_block (upw_16 , FILTER_SIZE , 8 * FILTER_NUM , dropout_rate ,6 , batch_norm )
40
+
41
+ poolw_8 = layers .MaxPooling2D (pool_size = (2 ,2 ))(up_convw_16 )
42
+ ct_16 = layers .concatenate ([conv_8 , poolw_8 ], axis = 3 )
43
+ convw_16 = reg_conv_block (ct_16 , FILTER_SIZE , 16 * FILTER_NUM , dropout_rate ,7 , batch_norm )
44
+
45
+ # UpRes 6, attention gated concatenation + upsampling + double residual convolution
46
+ gating_16 = gating_signal (convw_16 , 8 * FILTER_NUM , batch_norm )
47
+ att_16 = attention_block (up_convw_16 , gating_16 , 8 * FILTER_NUM )
48
+ up_16 = layers .UpSampling2D (size = (UP_SAMP_SIZE , UP_SAMP_SIZE ), data_format = "channels_last" )(convw_16 )
49
+ up_16 = layers .concatenate ([up_16 , att_16 ], axis = 3 )
50
+ up_conv_16 = conv_block (up_16 , FILTER_SIZE , 8 * FILTER_NUM , dropout_rate ,8 , batch_norm )
51
+ # UpRes 7
52
+ gating_32 = gating_signal (up_conv_16 , 4 * FILTER_NUM , batch_norm )
53
+ att_32 = attention_block (conv_32 , gating_32 , 4 * FILTER_NUM )
54
+ up_32 = layers .UpSampling2D (size = (UP_SAMP_SIZE , UP_SAMP_SIZE ), data_format = "channels_last" )(up_conv_16 )
55
+ up_32 = layers .concatenate ([up_32 , att_32 ], axis = 3 )
56
+ up_conv_32 = conv_block (up_32 , FILTER_SIZE , 4 * FILTER_NUM , dropout_rate ,9 , batch_norm )
57
+ # UpRes 8
58
+ gating_64 = gating_signal (up_conv_32 , 2 * FILTER_NUM , batch_norm )
59
+ att_64 = attention_block (conv_64 , gating_64 , 2 * FILTER_NUM )
60
+ up_64 = layers .UpSampling2D (size = (UP_SAMP_SIZE , UP_SAMP_SIZE ), data_format = "channels_last" )(up_conv_32 )
61
+ up_64 = layers .concatenate ([up_64 , att_64 ], axis = 3 )
62
+ up_conv_64 = conv_block (up_64 , FILTER_SIZE , 2 * FILTER_NUM , dropout_rate ,10 , batch_norm )
63
+ # UpRes 9
64
+ gating_128 = gating_signal (up_conv_64 , FILTER_NUM , batch_norm )
65
+ att_128 = attention_block (conv_128 , gating_128 , FILTER_NUM )
66
+ up_128 = layers .UpSampling2D (size = (UP_SAMP_SIZE , UP_SAMP_SIZE ), data_format = "channels_last" )(up_conv_64 )
67
+ up_128 = layers .concatenate ([up_128 , att_128 ], axis = 3 )
68
+ up_conv_128 = conv_block (up_128 , FILTER_SIZE , FILTER_NUM ,dropout_rate ,11 , batch_norm )
69
+
70
+ # 1*1 convolutional layers
71
+ conv_final = layers .Conv2D (NUM_CLASSES , name = 'conv12' , kernel_size = (1 ,1 ))(up_conv_128 )
72
+ conv_final = layers .BatchNormalization (axis = 3 )(conv_final )
73
+ conv_final = layers .Activation ('softmax' )(conv_final ) #Change to softmax for multichannel
74
+
75
+ # Model integration
76
+ model = models .Model (inputs , conv_final , name = "AW-Net" )
77
+ return model
0 commit comments