Skip to content

Commit 12aea6a

Browse files
committed
code
1 parent 64fe0bf commit 12aea6a

15 files changed

+1125
-152
lines changed

.idea/CASED-Tensorflow.iml

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/Project_Default.xml

Lines changed: 23 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

Lines changed: 379 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

CASED_test.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
import time, pickle
2+
from ops import *
3+
from utils import *
4+
from collections import defaultdict
5+
import matplotlib.pyplot as plt
6+
from skimage.util.shape import view_as_blocks as patch_blocks
7+
from math import ceil
8+
class CASED(object) :
9+
def __init__(self, sess, batch_size, checkpoint_dir, result_dir, log_dir):
10+
self.sess = sess
11+
self.dataset_name = 'LUNA16'
12+
self.checkpoint_dir = checkpoint_dir
13+
self.result_dir = result_dir
14+
self.log_dir = log_dir
15+
self.batch_size = batch_size
16+
self.model_name = "CASED" # name for checkpoint
17+
18+
self.c_dim = 1
19+
self.y_dim = 2 # nodule ? or non_nodule ?
20+
self.block_size = 72
21+
22+
def cased_network(self, x, reuse=False, scope='CASED_NETWORK'):
23+
with tf.variable_scope(scope, reuse=reuse) :
24+
x = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='conv1')
25+
up_conv1 = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='up_conv1')
26+
27+
x = max_pooling(up_conv1)
28+
29+
x = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='conv2')
30+
up_conv2 = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='up_conv2')
31+
32+
x = max_pooling(up_conv2)
33+
34+
x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv3')
35+
up_conv3 = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='up_conv3')
36+
37+
x = max_pooling(up_conv3)
38+
39+
x = conv_layer(x, channels=256, kernel=3, stride=1, layer_name='conv4')
40+
x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv5')
41+
42+
x = deconv_layer(x, channels=256, kernel=4, stride=2,layer_name='deconv1')
43+
x = copy_crop(crop_layer=up_conv3, in_layer=x)
44+
45+
x = conv_layer(x, channels=128, kernel=1, stride=1, layer_name='conv6')
46+
x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv7')
47+
48+
x = deconv_layer(x, channels=128, kernel=4, stride=2, layer_name='deconv2')
49+
x = copy_crop(crop_layer=up_conv2, in_layer=x)
50+
51+
x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv8')
52+
x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv9')
53+
54+
x = deconv_layer(x, channels=64, kernel=4, stride=2, layer_name='deconv3')
55+
x = copy_crop(crop_layer=up_conv1, in_layer=x)
56+
57+
x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv10')
58+
x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv11')
59+
60+
logits = conv_layer(x, channels=2, kernel=1, stride=1, activation=None, layer_name='conv12')
61+
62+
x = softmax(logits)
63+
64+
return logits, x
65+
66+
def build_model(self):
67+
68+
bs = self.batch_size
69+
scan_dims = [None, None, None, self.c_dim]
70+
scan_y_dims = [None, None, None, self.y_dim]
71+
72+
""" Graph Input """
73+
# images
74+
self.inputs = tf.placeholder(tf.float32, [bs] + scan_dims, name='patch')
75+
76+
# labels
77+
self.y = tf.placeholder(tf.float32, [bs] + scan_y_dims, name='y') # for loss
78+
79+
self.logits, self.softmax_logits = self.cased_network(self.inputs)
80+
81+
""" Loss function """
82+
self.correct_prediction = tf.equal(tf.argmax(self.softmax_logits, -1), tf.argmax(self.y, -1))
83+
self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
84+
self.sensitivity, self.fp_rate = sensitivity(labels=self.y, logits=self.softmax_logits)
85+
86+
87+
""" Summary """
88+
89+
c_acc = tf.summary.scalar('acc', self.accuracy)
90+
c_recall = tf.summary.scalar('sensitivity', self.sensitivity)
91+
c_fp = tf.summary.scalar('false_positive', self.fp_rate)
92+
self.c_sum = tf.summary.merge([c_acc, c_recall, c_fp])
93+
94+
def test(self):
95+
block_size = self.block_size
96+
# initialize all variables
97+
tf.global_variables_initializer().run()
98+
99+
# saver to save model
100+
self.saver = tf.train.Saver()
101+
102+
# summary writer
103+
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
104+
105+
# restore check-point if it exits
106+
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
107+
108+
if could_load:
109+
print(" [*] Load SUCCESS")
110+
else:
111+
print(" [!] Load failed...")
112+
113+
validation_sub_n = 0
114+
subset_name = 'subset' + str(validation_sub_n)
115+
image_paths = glob.glob("/data/jhkim/LUNA16/original/subset" + str(validation_sub_n) + '/*.mhd')
116+
all_scan_num = len(image_paths)
117+
sens_list = None
118+
nan_num = 0
119+
for scan in image_paths :
120+
121+
scan_name = os.path.split(scan)[1].replace('.mhd', '')
122+
scan_npy = '/data2/jhkim/npydata/' + subset_name + '/' + scan_name + '.npy'
123+
label_npy = '/data2/jhkim/npydata/' + subset_name + '/' + scan_name + '.label.npy'
124+
125+
image = np.transpose(np.load(scan_npy))
126+
label = np.transpose(np.load(label_npy))
127+
128+
if np.count_nonzero(label) == 0 :
129+
nan_num += 1
130+
continue
131+
132+
print(np.shape(image))
133+
print(np.shape(label))
134+
135+
pad_list = []
136+
for i in range(3) :
137+
if np.shape(image)[i] % block_size == 0:
138+
pad_l = 0
139+
pad_r = pad_l
140+
else:
141+
q = (ceil(np.shape(image)[i] / block_size) * block_size) - np.shape(image)[i]
142+
143+
if q % 2 == 0:
144+
pad_l = q // 2
145+
pad_r = pad_l
146+
else:
147+
pad_l = q // 2
148+
pad_r = pad_l + 1
149+
150+
pad_list.append(pad_l)
151+
pad_list.append(pad_r)
152+
153+
154+
image = np.pad(image, pad_width=[ [pad_list[0], pad_list[1]], [pad_list[2], pad_list[3]], [pad_list[4], pad_list[5]] ],
155+
mode='constant', constant_values=0)
156+
157+
label = np.pad(label, pad_width=[ [pad_list[0], pad_list[1]], [pad_list[2], pad_list[3]], [pad_list[4], pad_list[5]] ],
158+
mode='constant', constant_values=0)
159+
160+
print('padding !')
161+
print(np.shape(image))
162+
print(np.shape(label))
163+
164+
image_blocks = patch_blocks(image, block_shape=(block_size, block_size, block_size))
165+
len_x = len(image_blocks)
166+
len_y = len(image_blocks[0])
167+
len_z = len(image_blocks[0, 0])
168+
169+
result_scan = None
170+
for x_i in range(len_x):
171+
x = None
172+
for y_i in range(len_y):
173+
y = None
174+
for z_i in range(len_z):
175+
scan = np.expand_dims(np.expand_dims(image_blocks[x_i, y_i, z_i], axis=-1), axis=0) # 1 72 72 72 1
176+
scan = np.pad(scan, pad_width=[[0, 0], [30, 30], [30, 30], [30, 30], [0, 0]], mode='constant', constant_values=0)
177+
test_feed_dict = {
178+
self.inputs: scan
179+
}
180+
181+
logits = self.sess.run(
182+
self.softmax_logits, feed_dict=test_feed_dict
183+
)
184+
logits = np.squeeze(np.argmax(logits, axis=-1), axis=0) # 72 72 72
185+
186+
y = logits if y is None else np.concatenate((y, logits), axis=2)
187+
188+
x = y if x is None else np.concatenate((x, y), axis=1)
189+
190+
result_scan = x if result_scan is None else np.concatenate((result_scan, x), axis=0)
191+
# print(result) # 3d original size
192+
193+
with open('include.pkl', 'rb') as f:
194+
coords_dict = pickle.load(f, encoding='bytes')
195+
196+
with open('exclude.pkl', 'rb') as f:
197+
exclude_dict = pickle.load(f, encoding='bytes')
198+
199+
exclude_coords = exclude_dict[scan_name]
200+
201+
for ex in exclude_coords :
202+
ex[0] = ex[0] + (pad_list[0] + pad_list[1]) // 2
203+
ex[1] = ex[1] + (pad_list[2] + pad_list[3]) // 2
204+
ex[2] = ex[2] + (pad_list[4] + pad_list[5]) // 2
205+
ex_diameter = ex[3]
206+
if ex_diameter < 0.0 :
207+
ex_diameter = 10.0
208+
exclude_position = (ex[0], ex[1], ex[2])
209+
exclude_mask = create_exclude_mask(result_scan.shape, exclude_position, ex_diameter )
210+
result_scan = np.logical_and(result_scan, exclude_mask)
211+
212+
"""
213+
coords = coords_dict[scan_name]
214+
cnt = 0
215+
for c in coords :
216+
print('******** result ********')
217+
print(np.shape(result_scan))
218+
print(np.shape(label))
219+
print(c)
220+
x_coords = c[0] + (pad_list[0] + pad_list[1]) // 2
221+
y_coords = c[1] + (pad_list[2] + pad_list[3]) // 2
222+
z_coords = c[2] + (pad_list[4] + pad_list[5]) // 2
223+
offset = 34
224+
result_scan_img = result_scan[int(x_coords - offset): int(x_coords + offset), int(y_coords - offset):int(y_coords + offset), z_coords]
225+
label_scan = label[int(x_coords - offset): int(x_coords + offset), int(y_coords - offset):int(y_coords + offset), z_coords]
226+
227+
plt.imsave('./image/test_{}_{}.png'.format(scan_num,cnt), result_scan_img, cmap=plt.cm.gray)
228+
plt.imsave('./image/label_{}_{}.png'.format(scan_num,cnt), label_scan, cmap=plt.cm.gray)
229+
cnt += 1
230+
scan_num += 1
231+
"""
232+
233+
if sens_list is None :
234+
sens_list = fp_per_scan(result_scan, label)
235+
else :
236+
sens_list += fp_per_scan(result_scan, label)
237+
238+
fp_list = [0.125, 0.25, 0.5, 1, 2, 4, 8]
239+
sens_list /= (all_scan_num - nan_num)
240+
241+
for i in range(len(fp_list)) :
242+
print('{} : {}'.format(fp_list[i], sens_list[i]))
243+
244+
print('Average sensitivity : {}'.format(np.mean(sens_list)))
245+
246+
247+
248+
249+
250+
@property
251+
def model_dir(self):
252+
return "{}_{}".format(
253+
self.model_name, self.dataset_name)
254+
255+
def save(self, checkpoint_dir, step):
256+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
257+
258+
if not os.path.exists(checkpoint_dir):
259+
os.makedirs(checkpoint_dir)
260+
261+
self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
262+
263+
def load(self, checkpoint_dir):
264+
import re
265+
print(" [*] Reading checkpoints...")
266+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
267+
268+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
269+
if ckpt and ckpt.model_checkpoint_path:
270+
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
271+
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
272+
counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
273+
print(" [*] Success to read {}".format(ckpt_name))
274+
return True, counter
275+
else:
276+
print(" [*] Failed to find a checkpoint")
277+
return False, 0

0 commit comments

Comments
 (0)