|
| 1 | +import time, pickle |
| 2 | +from ops import * |
| 3 | +from utils import * |
| 4 | +from collections import defaultdict |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +from skimage.util.shape import view_as_blocks as patch_blocks |
| 7 | +from math import ceil |
| 8 | +class CASED(object) : |
| 9 | + def __init__(self, sess, batch_size, checkpoint_dir, result_dir, log_dir): |
| 10 | + self.sess = sess |
| 11 | + self.dataset_name = 'LUNA16' |
| 12 | + self.checkpoint_dir = checkpoint_dir |
| 13 | + self.result_dir = result_dir |
| 14 | + self.log_dir = log_dir |
| 15 | + self.batch_size = batch_size |
| 16 | + self.model_name = "CASED" # name for checkpoint |
| 17 | + |
| 18 | + self.c_dim = 1 |
| 19 | + self.y_dim = 2 # nodule ? or non_nodule ? |
| 20 | + self.block_size = 72 |
| 21 | + |
| 22 | + def cased_network(self, x, reuse=False, scope='CASED_NETWORK'): |
| 23 | + with tf.variable_scope(scope, reuse=reuse) : |
| 24 | + x = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='conv1') |
| 25 | + up_conv1 = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='up_conv1') |
| 26 | + |
| 27 | + x = max_pooling(up_conv1) |
| 28 | + |
| 29 | + x = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='conv2') |
| 30 | + up_conv2 = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='up_conv2') |
| 31 | + |
| 32 | + x = max_pooling(up_conv2) |
| 33 | + |
| 34 | + x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv3') |
| 35 | + up_conv3 = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='up_conv3') |
| 36 | + |
| 37 | + x = max_pooling(up_conv3) |
| 38 | + |
| 39 | + x = conv_layer(x, channels=256, kernel=3, stride=1, layer_name='conv4') |
| 40 | + x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv5') |
| 41 | + |
| 42 | + x = deconv_layer(x, channels=256, kernel=4, stride=2,layer_name='deconv1') |
| 43 | + x = copy_crop(crop_layer=up_conv3, in_layer=x) |
| 44 | + |
| 45 | + x = conv_layer(x, channels=128, kernel=1, stride=1, layer_name='conv6') |
| 46 | + x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv7') |
| 47 | + |
| 48 | + x = deconv_layer(x, channels=128, kernel=4, stride=2, layer_name='deconv2') |
| 49 | + x = copy_crop(crop_layer=up_conv2, in_layer=x) |
| 50 | + |
| 51 | + x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv8') |
| 52 | + x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv9') |
| 53 | + |
| 54 | + x = deconv_layer(x, channels=64, kernel=4, stride=2, layer_name='deconv3') |
| 55 | + x = copy_crop(crop_layer=up_conv1, in_layer=x) |
| 56 | + |
| 57 | + x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv10') |
| 58 | + x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv11') |
| 59 | + |
| 60 | + logits = conv_layer(x, channels=2, kernel=1, stride=1, activation=None, layer_name='conv12') |
| 61 | + |
| 62 | + x = softmax(logits) |
| 63 | + |
| 64 | + return logits, x |
| 65 | + |
| 66 | + def build_model(self): |
| 67 | + |
| 68 | + bs = self.batch_size |
| 69 | + scan_dims = [None, None, None, self.c_dim] |
| 70 | + scan_y_dims = [None, None, None, self.y_dim] |
| 71 | + |
| 72 | + """ Graph Input """ |
| 73 | + # images |
| 74 | + self.inputs = tf.placeholder(tf.float32, [bs] + scan_dims, name='patch') |
| 75 | + |
| 76 | + # labels |
| 77 | + self.y = tf.placeholder(tf.float32, [bs] + scan_y_dims, name='y') # for loss |
| 78 | + |
| 79 | + self.logits, self.softmax_logits = self.cased_network(self.inputs) |
| 80 | + |
| 81 | + """ Loss function """ |
| 82 | + self.correct_prediction = tf.equal(tf.argmax(self.softmax_logits, -1), tf.argmax(self.y, -1)) |
| 83 | + self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) |
| 84 | + self.sensitivity, self.fp_rate = sensitivity(labels=self.y, logits=self.softmax_logits) |
| 85 | + |
| 86 | + |
| 87 | + """ Summary """ |
| 88 | + |
| 89 | + c_acc = tf.summary.scalar('acc', self.accuracy) |
| 90 | + c_recall = tf.summary.scalar('sensitivity', self.sensitivity) |
| 91 | + c_fp = tf.summary.scalar('false_positive', self.fp_rate) |
| 92 | + self.c_sum = tf.summary.merge([c_acc, c_recall, c_fp]) |
| 93 | + |
| 94 | + def test(self): |
| 95 | + block_size = self.block_size |
| 96 | + # initialize all variables |
| 97 | + tf.global_variables_initializer().run() |
| 98 | + |
| 99 | + # saver to save model |
| 100 | + self.saver = tf.train.Saver() |
| 101 | + |
| 102 | + # summary writer |
| 103 | + self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph) |
| 104 | + |
| 105 | + # restore check-point if it exits |
| 106 | + could_load, checkpoint_counter = self.load(self.checkpoint_dir) |
| 107 | + |
| 108 | + if could_load: |
| 109 | + print(" [*] Load SUCCESS") |
| 110 | + else: |
| 111 | + print(" [!] Load failed...") |
| 112 | + |
| 113 | + validation_sub_n = 0 |
| 114 | + subset_name = 'subset' + str(validation_sub_n) |
| 115 | + image_paths = glob.glob("/data/jhkim/LUNA16/original/subset" + str(validation_sub_n) + '/*.mhd') |
| 116 | + all_scan_num = len(image_paths) |
| 117 | + sens_list = None |
| 118 | + nan_num = 0 |
| 119 | + for scan in image_paths : |
| 120 | + |
| 121 | + scan_name = os.path.split(scan)[1].replace('.mhd', '') |
| 122 | + scan_npy = '/data2/jhkim/npydata/' + subset_name + '/' + scan_name + '.npy' |
| 123 | + label_npy = '/data2/jhkim/npydata/' + subset_name + '/' + scan_name + '.label.npy' |
| 124 | + |
| 125 | + image = np.transpose(np.load(scan_npy)) |
| 126 | + label = np.transpose(np.load(label_npy)) |
| 127 | + |
| 128 | + if np.count_nonzero(label) == 0 : |
| 129 | + nan_num += 1 |
| 130 | + continue |
| 131 | + |
| 132 | + print(np.shape(image)) |
| 133 | + print(np.shape(label)) |
| 134 | + |
| 135 | + pad_list = [] |
| 136 | + for i in range(3) : |
| 137 | + if np.shape(image)[i] % block_size == 0: |
| 138 | + pad_l = 0 |
| 139 | + pad_r = pad_l |
| 140 | + else: |
| 141 | + q = (ceil(np.shape(image)[i] / block_size) * block_size) - np.shape(image)[i] |
| 142 | + |
| 143 | + if q % 2 == 0: |
| 144 | + pad_l = q // 2 |
| 145 | + pad_r = pad_l |
| 146 | + else: |
| 147 | + pad_l = q // 2 |
| 148 | + pad_r = pad_l + 1 |
| 149 | + |
| 150 | + pad_list.append(pad_l) |
| 151 | + pad_list.append(pad_r) |
| 152 | + |
| 153 | + |
| 154 | + image = np.pad(image, pad_width=[ [pad_list[0], pad_list[1]], [pad_list[2], pad_list[3]], [pad_list[4], pad_list[5]] ], |
| 155 | + mode='constant', constant_values=0) |
| 156 | + |
| 157 | + label = np.pad(label, pad_width=[ [pad_list[0], pad_list[1]], [pad_list[2], pad_list[3]], [pad_list[4], pad_list[5]] ], |
| 158 | + mode='constant', constant_values=0) |
| 159 | + |
| 160 | + print('padding !') |
| 161 | + print(np.shape(image)) |
| 162 | + print(np.shape(label)) |
| 163 | + |
| 164 | + image_blocks = patch_blocks(image, block_shape=(block_size, block_size, block_size)) |
| 165 | + len_x = len(image_blocks) |
| 166 | + len_y = len(image_blocks[0]) |
| 167 | + len_z = len(image_blocks[0, 0]) |
| 168 | + |
| 169 | + result_scan = None |
| 170 | + for x_i in range(len_x): |
| 171 | + x = None |
| 172 | + for y_i in range(len_y): |
| 173 | + y = None |
| 174 | + for z_i in range(len_z): |
| 175 | + scan = np.expand_dims(np.expand_dims(image_blocks[x_i, y_i, z_i], axis=-1), axis=0) # 1 72 72 72 1 |
| 176 | + scan = np.pad(scan, pad_width=[[0, 0], [30, 30], [30, 30], [30, 30], [0, 0]], mode='constant', constant_values=0) |
| 177 | + test_feed_dict = { |
| 178 | + self.inputs: scan |
| 179 | + } |
| 180 | + |
| 181 | + logits = self.sess.run( |
| 182 | + self.softmax_logits, feed_dict=test_feed_dict |
| 183 | + ) |
| 184 | + logits = np.squeeze(np.argmax(logits, axis=-1), axis=0) # 72 72 72 |
| 185 | + |
| 186 | + y = logits if y is None else np.concatenate((y, logits), axis=2) |
| 187 | + |
| 188 | + x = y if x is None else np.concatenate((x, y), axis=1) |
| 189 | + |
| 190 | + result_scan = x if result_scan is None else np.concatenate((result_scan, x), axis=0) |
| 191 | + # print(result) # 3d original size |
| 192 | + |
| 193 | + with open('include.pkl', 'rb') as f: |
| 194 | + coords_dict = pickle.load(f, encoding='bytes') |
| 195 | + |
| 196 | + with open('exclude.pkl', 'rb') as f: |
| 197 | + exclude_dict = pickle.load(f, encoding='bytes') |
| 198 | + |
| 199 | + exclude_coords = exclude_dict[scan_name] |
| 200 | + |
| 201 | + for ex in exclude_coords : |
| 202 | + ex[0] = ex[0] + (pad_list[0] + pad_list[1]) // 2 |
| 203 | + ex[1] = ex[1] + (pad_list[2] + pad_list[3]) // 2 |
| 204 | + ex[2] = ex[2] + (pad_list[4] + pad_list[5]) // 2 |
| 205 | + ex_diameter = ex[3] |
| 206 | + if ex_diameter < 0.0 : |
| 207 | + ex_diameter = 10.0 |
| 208 | + exclude_position = (ex[0], ex[1], ex[2]) |
| 209 | + exclude_mask = create_exclude_mask(result_scan.shape, exclude_position, ex_diameter ) |
| 210 | + result_scan = np.logical_and(result_scan, exclude_mask) |
| 211 | + |
| 212 | + """ |
| 213 | + coords = coords_dict[scan_name] |
| 214 | + cnt = 0 |
| 215 | + for c in coords : |
| 216 | + print('******** result ********') |
| 217 | + print(np.shape(result_scan)) |
| 218 | + print(np.shape(label)) |
| 219 | + print(c) |
| 220 | + x_coords = c[0] + (pad_list[0] + pad_list[1]) // 2 |
| 221 | + y_coords = c[1] + (pad_list[2] + pad_list[3]) // 2 |
| 222 | + z_coords = c[2] + (pad_list[4] + pad_list[5]) // 2 |
| 223 | + offset = 34 |
| 224 | + result_scan_img = result_scan[int(x_coords - offset): int(x_coords + offset), int(y_coords - offset):int(y_coords + offset), z_coords] |
| 225 | + label_scan = label[int(x_coords - offset): int(x_coords + offset), int(y_coords - offset):int(y_coords + offset), z_coords] |
| 226 | +
|
| 227 | + plt.imsave('./image/test_{}_{}.png'.format(scan_num,cnt), result_scan_img, cmap=plt.cm.gray) |
| 228 | + plt.imsave('./image/label_{}_{}.png'.format(scan_num,cnt), label_scan, cmap=plt.cm.gray) |
| 229 | + cnt += 1 |
| 230 | + scan_num += 1 |
| 231 | + """ |
| 232 | + |
| 233 | + if sens_list is None : |
| 234 | + sens_list = fp_per_scan(result_scan, label) |
| 235 | + else : |
| 236 | + sens_list += fp_per_scan(result_scan, label) |
| 237 | + |
| 238 | + fp_list = [0.125, 0.25, 0.5, 1, 2, 4, 8] |
| 239 | + sens_list /= (all_scan_num - nan_num) |
| 240 | + |
| 241 | + for i in range(len(fp_list)) : |
| 242 | + print('{} : {}'.format(fp_list[i], sens_list[i])) |
| 243 | + |
| 244 | + print('Average sensitivity : {}'.format(np.mean(sens_list))) |
| 245 | + |
| 246 | + |
| 247 | + |
| 248 | + |
| 249 | + |
| 250 | + @property |
| 251 | + def model_dir(self): |
| 252 | + return "{}_{}".format( |
| 253 | + self.model_name, self.dataset_name) |
| 254 | + |
| 255 | + def save(self, checkpoint_dir, step): |
| 256 | + checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) |
| 257 | + |
| 258 | + if not os.path.exists(checkpoint_dir): |
| 259 | + os.makedirs(checkpoint_dir) |
| 260 | + |
| 261 | + self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) |
| 262 | + |
| 263 | + def load(self, checkpoint_dir): |
| 264 | + import re |
| 265 | + print(" [*] Reading checkpoints...") |
| 266 | + checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) |
| 267 | + |
| 268 | + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) |
| 269 | + if ckpt and ckpt.model_checkpoint_path: |
| 270 | + ckpt_name = os.path.basename(ckpt.model_checkpoint_path) |
| 271 | + self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) |
| 272 | + counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) |
| 273 | + print(" [*] Success to read {}".format(ckpt_name)) |
| 274 | + return True, counter |
| 275 | + else: |
| 276 | + print(" [*] Failed to find a checkpoint") |
| 277 | + return False, 0 |
0 commit comments