-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
77 lines (68 loc) · 3.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
from keras.models import Model
from keras.layers import Input
from model_blocks import*
def AW_Net(input_shape, NUM_CLASSES=4, dropout_rate=0.0, batch_norm=True):
'''
Attention UNet,
'''
# network structure
FILTER_NUM = 16 # number of basic filters for the first layer
FILTER_SIZE = 3 # size of the convolutional filter
UP_SAMP_SIZE = 2 # size of upsampling filters
inputs = layers.Input(input_shape, dtype=tf.float32)
# Downsampling layers
# DownRes 1, convolution + pooling
conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, 1, batch_norm)
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
# DownRes 2
conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, 2, batch_norm)
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
# DownRes 3
conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, 3, batch_norm)
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32)
# DownRes 4
conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate,4, batch_norm)
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16)
# DownRes 5, convolution only
conv_8 = reg_conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, dropout_rate,5, batch_norm)
# W-net layers
gatingw_16 = gating_signal(conv_8, 8*FILTER_NUM, batch_norm)
attw_16 = attention_block(conv_16, gatingw_16, 8*FILTER_NUM)
upw_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8)
upw_16 = layers.concatenate([upw_16, attw_16], axis=3)
up_convw_16 = reg_conv_block(upw_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate,6, batch_norm)
poolw_8 = layers.MaxPooling2D(pool_size=(2,2))(up_convw_16)
ct_16 = layers.concatenate([conv_8, poolw_8], axis=3)
convw_16 = reg_conv_block(ct_16, FILTER_SIZE, 16*FILTER_NUM, dropout_rate,7, batch_norm)
# UpRes 6, attention gated concatenation + upsampling + double residual convolution
gating_16 = gating_signal(convw_16, 8*FILTER_NUM, batch_norm)
att_16 = attention_block(up_convw_16, gating_16, 8*FILTER_NUM)
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(convw_16)
up_16 = layers.concatenate([up_16, att_16], axis=3)
up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate,8, batch_norm)
# UpRes 7
gating_32 = gating_signal(up_conv_16, 4*FILTER_NUM, batch_norm)
att_32 = attention_block(conv_32, gating_32, 4*FILTER_NUM)
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16)
up_32 = layers.concatenate([up_32, att_32], axis=3)
up_conv_32 = conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate,9, batch_norm)
# UpRes 8
gating_64 = gating_signal(up_conv_32, 2*FILTER_NUM, batch_norm)
att_64 = attention_block(conv_64, gating_64, 2*FILTER_NUM)
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32)
up_64 = layers.concatenate([up_64, att_64], axis=3)
up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate,10, batch_norm)
# UpRes 9
gating_128 = gating_signal(up_conv_64, FILTER_NUM, batch_norm)
att_128 = attention_block(conv_128, gating_128, FILTER_NUM)
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64)
up_128 = layers.concatenate([up_128, att_128], axis=3)
up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM,dropout_rate,11, batch_norm)
# 1*1 convolutional layers
conv_final = layers.Conv2D(NUM_CLASSES, name='conv12', kernel_size=(1,1))(up_conv_128)
conv_final = layers.BatchNormalization(axis=3)(conv_final)
conv_final = layers.Activation('softmax')(conv_final) #Change to softmax for multichannel
# Model integration
model = models.Model(inputs, conv_final, name="AW-Net")
return model