Skip to content

Commit 02b3d17

Browse files
authored
Add files via upload
1 parent bab6ee8 commit 02b3d17

File tree

4 files changed

+247
-0
lines changed

4 files changed

+247
-0
lines changed

dataloader.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import keras
2+
import numpy as np
3+
4+
class SpinePTXT(keras.utils.Sequence):
5+
"""Helper to iterate over the data (as Numpy arrays)."""
6+
7+
def __init__(self, batch_size, img_size, input_img_paths, mask_img_paths,num_classes):
8+
self.batch_size = batch_size
9+
self.img_size = img_size
10+
self.input_img_paths = input_img_paths
11+
self.mask_img_paths = mask_img_paths
12+
self.num_classes = num_classes
13+
14+
15+
def __len__(self):
16+
return len(self.mask_img_paths) // self.batch_size
17+
18+
def __getitem__(self, idx):
19+
"""Returns tuple (input, target) correspond to batch #idx."""
20+
i = idx * self.batch_size
21+
batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
22+
batch_mask_img_paths = self.mask_img_paths[i : i + self.batch_size]
23+
x = np.zeros((self.batch_size,) + self.img_size +(3,) , dtype="float32")
24+
for j, path in enumerate(batch_input_img_paths):
25+
img = np.load(path)
26+
x[j]=img
27+
y = np.zeros((self.batch_size,) + self.img_size + (self.num_classes,), dtype="uint8")
28+
for j, path in enumerate(batch_mask_img_paths):
29+
msk = np.load(path)
30+
y[j]=msk
31+
return x, y

main.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from model import*
2+
from dataloader import*
3+
import os
4+
import random
5+
6+
def main(input_dir, mask_dir, image_height, image_width, image_channel, num_classes, batch_size, epochs, val_samples):
7+
8+
img_size = (image_height,image_width)
9+
input_img_paths = sorted(
10+
[
11+
os.path.join(input_dir, fname)
12+
for fname in os.listdir(input_dir)
13+
]
14+
)
15+
mask_img_paths = sorted(
16+
[
17+
os.path.join(mask_dir, fname)
18+
for fname in os.listdir(mask_dir)
19+
]
20+
)
21+
random.Random(1337).shuffle(input_img_paths)
22+
random.Random(1337).shuffle(mask_img_paths)
23+
train_input_img_paths = input_img_paths[:-val_samples]
24+
train_mask_img_paths = mask_img_paths[:-val_samples]
25+
val_input_img_paths = input_img_paths[-val_samples:]
26+
val_mask_img_paths = mask_img_paths[-val_samples:]
27+
28+
# Instantiate data Sequences for each split
29+
train_gen = SpinePTXT(
30+
batch_size, img_size, train_input_img_paths, train_mask_img_paths,num_classes
31+
)
32+
val_gen = SpinePTXT(batch_size, img_size, val_input_img_paths, val_mask_img_paths,num_classes)
33+
34+
model=AW_Net((image_height,image_width,image_channel),num_classes, dropout_rate=0.0, batch_norm=True)
35+
36+
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
37+
38+
history = model.fit(train_gen,validation_data=val_gen,epochs=epochs,)
39+
40+
41+
if __name__ == "__main__":
42+
input_dir = "E:\\BraTS_data\\Image"
43+
mask_dir = "E:\\BraTS_data\\Mask"
44+
image_height=128
45+
image_width=128
46+
image_channel=3
47+
img_size = (image_height,image_width)
48+
num_classes = 4
49+
batch_size = 8
50+
epochs=200
51+
val_samples = 250
52+
main(input_dir, mask_dir, image_height, image_width, image_channel, num_classes, batch_size, epochs, val_samples)
53+

model.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import tensorflow as tf
2+
from keras.models import Model
3+
from keras.layers import Input
4+
from model_blocks import*
5+
6+
def AW_Net(input_shape, NUM_CLASSES=4, dropout_rate=0.0, batch_norm=True):
7+
'''
8+
Attention UNet,
9+
10+
'''
11+
# network structure
12+
FILTER_NUM = 16 # number of basic filters for the first layer
13+
FILTER_SIZE = 3 # size of the convolutional filter
14+
UP_SAMP_SIZE = 2 # size of upsampling filters
15+
16+
inputs = layers.Input(input_shape, dtype=tf.float32)
17+
18+
# Downsampling layers
19+
# DownRes 1, convolution + pooling
20+
conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, 1, batch_norm)
21+
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128)
22+
# DownRes 2
23+
conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, 2, batch_norm)
24+
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64)
25+
# DownRes 3
26+
conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, 3, batch_norm)
27+
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32)
28+
# DownRes 4
29+
conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate,4, batch_norm)
30+
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16)
31+
# DownRes 5, convolution only
32+
conv_8 = reg_conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, dropout_rate,5, batch_norm)
33+
34+
# W-net layers
35+
gatingw_16 = gating_signal(conv_8, 8*FILTER_NUM, batch_norm)
36+
attw_16 = attention_block(conv_16, gatingw_16, 8*FILTER_NUM)
37+
upw_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8)
38+
upw_16 = layers.concatenate([upw_16, attw_16], axis=3)
39+
up_convw_16 = reg_conv_block(upw_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate,6, batch_norm)
40+
41+
poolw_8 = layers.MaxPooling2D(pool_size=(2,2))(up_convw_16)
42+
ct_16 = layers.concatenate([conv_8, poolw_8], axis=3)
43+
convw_16 = reg_conv_block(ct_16, FILTER_SIZE, 16*FILTER_NUM, dropout_rate,7, batch_norm)
44+
45+
# UpRes 6, attention gated concatenation + upsampling + double residual convolution
46+
gating_16 = gating_signal(convw_16, 8*FILTER_NUM, batch_norm)
47+
att_16 = attention_block(up_convw_16, gating_16, 8*FILTER_NUM)
48+
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(convw_16)
49+
up_16 = layers.concatenate([up_16, att_16], axis=3)
50+
up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate,8, batch_norm)
51+
# UpRes 7
52+
gating_32 = gating_signal(up_conv_16, 4*FILTER_NUM, batch_norm)
53+
att_32 = attention_block(conv_32, gating_32, 4*FILTER_NUM)
54+
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16)
55+
up_32 = layers.concatenate([up_32, att_32], axis=3)
56+
up_conv_32 = conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate,9, batch_norm)
57+
# UpRes 8
58+
gating_64 = gating_signal(up_conv_32, 2*FILTER_NUM, batch_norm)
59+
att_64 = attention_block(conv_64, gating_64, 2*FILTER_NUM)
60+
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32)
61+
up_64 = layers.concatenate([up_64, att_64], axis=3)
62+
up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate,10, batch_norm)
63+
# UpRes 9
64+
gating_128 = gating_signal(up_conv_64, FILTER_NUM, batch_norm)
65+
att_128 = attention_block(conv_128, gating_128, FILTER_NUM)
66+
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64)
67+
up_128 = layers.concatenate([up_128, att_128], axis=3)
68+
up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM,dropout_rate,11, batch_norm)
69+
70+
# 1*1 convolutional layers
71+
conv_final = layers.Conv2D(NUM_CLASSES, name='conv12', kernel_size=(1,1))(up_conv_128)
72+
conv_final = layers.BatchNormalization(axis=3)(conv_final)
73+
conv_final = layers.Activation('softmax')(conv_final) #Change to softmax for multichannel
74+
75+
# Model integration
76+
model = models.Model(inputs, conv_final, name="AW-Net")
77+
return model

model_blocks.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from tensorflow.keras import models, layers, regularizers
2+
from tensorflow.keras import backend as K
3+
4+
def conv_block(x, filter_size, size, dropout,num, batch_norm=False):
5+
6+
conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)
7+
if batch_norm is True:
8+
conv = layers.BatchNormalization(axis=3)(conv)
9+
conv = layers.Activation("relu")(conv)
10+
11+
conv = layers.Conv2D(size, (filter_size, filter_size), padding="same",name="conv"+str(num))(conv)
12+
if batch_norm is True:
13+
conv = layers.BatchNormalization(axis=3)(conv)
14+
conv = layers.Activation("relu")(conv)
15+
16+
if dropout > 0:
17+
conv = layers.Dropout(dropout)(conv)
18+
19+
return conv
20+
21+
def reg_conv_block(x, filter_size, size, dropout,num, batch_norm=False):
22+
23+
conv = layers.Conv2D(size, (filter_size, filter_size),kernel_regularizer='l1', padding="same")(x)
24+
if batch_norm is True:
25+
conv = layers.BatchNormalization(axis=3)(conv)
26+
conv = layers.Activation("relu")(conv)
27+
28+
conv = layers.Conv2D(size, (filter_size, filter_size),kernel_regularizer='l1', padding="same",name="conv"+str(num))(conv)
29+
if batch_norm is True:
30+
conv = layers.BatchNormalization(axis=3)(conv)
31+
conv = layers.Activation("relu")(conv)
32+
33+
if dropout > 0:
34+
conv = layers.Dropout(dropout)(conv)
35+
36+
return conv
37+
38+
def gating_signal(input, out_size, batch_norm=False):
39+
"""
40+
resize the down layer feature map into the same dimension as the up layer feature map
41+
using 1x1 conv
42+
:return: the gating feature map with the same dimension of the up layer feature map
43+
"""
44+
x = layers.Conv2D(out_size, (1, 1), padding='same')(input)
45+
if batch_norm:
46+
x = layers.BatchNormalization()(x)
47+
x = layers.Activation('relu')(x)
48+
return x
49+
50+
def attention_block(x, gating, inter_shape):
51+
shape_x = K.int_shape(x)
52+
shape_g = K.int_shape(gating)
53+
54+
# Getting the x signal to the same shape as the gating signal
55+
theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x) # 16
56+
shape_theta_x = K.int_shape(theta_x)
57+
58+
# Getting the gating signal to the same number of filters as the inter_shape
59+
phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
60+
upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3),
61+
strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
62+
padding='same')(phi_g) # 16
63+
64+
concat_xg = layers.add([upsample_g, theta_x])
65+
act_xg = layers.Activation('relu')(concat_xg)
66+
psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)
67+
sigmoid_xg = layers.Activation('sigmoid')(psi)
68+
shape_sigmoid = K.int_shape(sigmoid_xg)
69+
upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg) # 32
70+
71+
upsample_psi = repeat_elem(upsample_psi, shape_x[3])
72+
73+
y = layers.multiply([upsample_psi, x])
74+
75+
result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)
76+
result_bn = layers.BatchNormalization()(result)
77+
return result_bn
78+
79+
def repeat_elem(tensor, rep):
80+
# lambda function to repeat Repeats the elements of a tensor along an axis
81+
#by a factor of rep.
82+
# If tensor has shape (None, 256,256,3), lambda will return a tensor of shape
83+
#(None, 256,256,6), if specified axis=3 and rep=2.
84+
85+
return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
86+
arguments={'repnum': rep})(tensor)

0 commit comments

Comments
 (0)