Skip to content

Commit 0635650

Browse files
committed
p3ch15 Add missing cpp, etc. files, update p2ch14 code
1 parent 5d8b081 commit 0635650

File tree

12 files changed

+476
-51
lines changed

12 files changed

+476
-51
lines changed

p2ch11/model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@ def _init_weights(self):
3636
nn.ConvTranspose2d,
3737
nn.ConvTranspose3d,
3838
}:
39-
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
39+
nn.init.kaiming_normal_(
40+
m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
41+
)
4042
if m.bias is not None:
41-
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
43+
fan_in, fan_out = \
44+
nn.init._calculate_fan_in_and_fan_out(m.weight.data)
4245
bound = 1 / math.sqrt(fan_out)
4346
nn.init.normal_(m.bias, -bound, bound)
4447

4548

49+
4650
def forward(self, input_batch):
4751
bn_output = self.tail_batchnorm(input_batch)
4852

@@ -64,9 +68,13 @@ class LunaBlock(nn.Module):
6468
def __init__(self, in_channels, conv_channels):
6569
super().__init__()
6670

67-
self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
71+
self.conv1 = nn.Conv3d(
72+
in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
73+
)
6874
self.relu1 = nn.ReLU(inplace=True)
69-
self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
75+
self.conv2 = nn.Conv3d(
76+
conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
77+
)
7078
self.relu2 = nn.ReLU(inplace=True)
7179

7280
self.maxpool = nn.MaxPool3d(2, 2)

p2ch14/check_nodule_fp_rate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,15 @@ def clusterSegmentationOutput(self, series_uid, ct, clean_g):
413413
clean_a = clean_g.cpu().numpy()
414414
candidateLabel_a, candidate_count = measure.label(clean_a)
415415
centerIrc_list = measure.center_of_mass(
416-
ct.hu_a + 1001,
416+
ct.hu_a.clip(-1000, 1000) + 1001,
417417
labels=candidateLabel_a,
418418
index=list(range(1, candidate_count+1)),
419419
)
420420

421+
421422
candidateInfo_list = []
422423
for i, center_irc in enumerate(centerIrc_list):
424+
assert np.isfinite(center_irc).all(), repr([series_uid, i, candidate_count, (ct.hu_a[candidateLabel_a == i+1]).sum(), center_irc])
423425
center_xyz = irc2xyz(
424426
center_irc,
425427
ct.origin_xyz,

p2ch14/dsets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,12 @@ def __len__(self):
381381

382382
def __getitem__(self, ndx):
383383
if self.ratio_int:
384-
if ndx % 4 < 2:
385-
candidateInfo_tup = self.mal_list[(ndx // 3) % len(self.mal_list)]
386-
elif ndx % 4 == 2:
387-
candidateInfo_tup = self.ben_list[(ndx // 3) % len(self.ben_list)]
384+
if ndx % 2 != 0:
385+
candidateInfo_tup = self.mal_list[(ndx // 2) % len(self.mal_list)]
386+
elif ndx % 4 == 0:
387+
candidateInfo_tup = self.ben_list[(ndx // 4) % len(self.ben_list)]
388388
else:
389-
candidateInfo_tup = self.neg_list[(ndx // 3) % len(self.neg_list)]
389+
candidateInfo_tup = self.neg_list[(ndx // 4) % len(self.neg_list)]
390390
else:
391391
if ndx >= len(self.ben_list):
392392
candidateInfo_tup = self.mal_list[ndx - len(self.ben_list)]

p2ch14/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def augment3d(inp):
2020
if random.random() > 0.5:
2121
transform_t[i,i] *= -1
2222
if True: #'offset' in augmentation_dict:
23-
offset_float = 0.1 # 8 # augmentation_dict['offset']
23+
offset_float = 0.1
2424
random_float = (random.random() * 2 - 1)
2525
transform_t[3,i] = offset_float * random_float
2626
if True:

p2ch14/nodule_analysis.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,15 @@
3131
logging.getLogger("p2ch14.dsets").setLevel(logging.WARNING)
3232

3333
def print_confusion(label, confusions, do_mal):
34+
row_labels = ['Non-Nodules', 'Benign', 'Malignant']
35+
3436
if do_mal:
35-
col_labels = ['', 'Complete Miss', 'Filtered', 'Benign', 'Malignant']
36-
row_labels = ['Non-Nodules', 'Benign', 'Malignant']
37+
col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Benign', 'Pred. Malignant']
3738
else:
38-
col_labels = ['', 'Complete Miss', 'Filtered', 'Detected']
39-
row_labels = ['Non-Nodules', 'Nodules']
40-
confusions[-2] += confusions[-1]
39+
col_labels = ['', 'Complete Miss', 'Filtered Out', 'Pred. Nodule']
4140
confusions[:, -2] += confusions[:, -1]
42-
confusions = confusions[:-1, :-1]
43-
cell_width = 14
41+
confusions = confusions[:, :-1]
42+
cell_width = 16
4443
f = '{:>' + str(cell_width) + '}'
4544
print(label)
4645
print(' | '.join([f.format(s) for s in col_labels]))
@@ -72,7 +71,7 @@ def match_and_score(detections, truth, threshold=0.5, threshold_mal=0.5):
7271
confusion = np.zeros((3, 4), dtype=np.int)
7372
if len(detected_xyz) == 0:
7473
for tn in true_nodules:
75-
confusiion[2 if tn.isMal_bool else 1, 0] += 1
74+
confusion[2 if tn.isMal_bool else 1, 0] += 1
7675
elif len(truth_xyz) == 0:
7776
for dc in detected_classes:
7877
confusion[0, dc] += 1
@@ -124,7 +123,7 @@ def __init__(self, sys_argv=None):
124123
parser.add_argument('--segmentation-path',
125124
help="Path to the saved segmentation model",
126125
nargs='?',
127-
default=None,
126+
default='data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state',
128127
)
129128

130129
parser.add_argument('--cls-model',
@@ -135,13 +134,14 @@ def __init__(self, sys_argv=None):
135134
parser.add_argument('--classification-path',
136135
help="Path to the saved classification model",
137136
nargs='?',
138-
default=None,
137+
default='data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state',
139138
)
140139

141140
parser.add_argument('--malignancy-model',
142141
help="What to model class name to use for the malignancy classifier.",
143142
action='store',
144-
default='ModifiedLunaModel',
143+
default='LunaModel',
144+
# default='ModifiedLunaModel',
145145
)
146146
parser.add_argument('--malignancy-path',
147147
help="Path to the saved malignancy classification model",
@@ -303,7 +303,6 @@ def main(self):
303303
val_list = sorted(series_set & val_set)
304304

305305

306-
candidateInfo_list = []
307306
candidateInfo_dict = getCandidateInfoDict()
308307
series_iter = enumerateWithEstimate(
309308
val_list + train_list,
@@ -314,10 +313,8 @@ def main(self):
314313
ct = getCt(series_uid)
315314
mask_a = self.segmentCt(ct, series_uid)
316315

317-
candidateInfo_list = self.clusterSegmentationOutput(
318-
series_uid,
319-
ct,
320-
mask_a,
316+
candidateInfo_list = self.groupSegmentationOutput(
317+
series_uid, ct, mask_a
321318
)
322319
classifications_list = self.classifyCandidates(ct, candidateInfo_list)
323320

@@ -339,7 +336,6 @@ def main(self):
339336
print_confusion("Total", all_confusion, self.malignancy_model is not None)
340337

341338

342-
343339
def classifyCandidates(self, ct, candidateInfo_list):
344340
cls_dl = self.initClassificationDl(candidateInfo_list)
345341
classifications_list = []
@@ -348,49 +344,50 @@ def classifyCandidates(self, ct, candidateInfo_list):
348344

349345
input_g = input_t.to(self.device)
350346
with torch.no_grad():
351-
_, probability_g = self.cls_model(input_g)
347+
_, probability_nodule_g = self.cls_model(input_g)
352348
if self.malignancy_model is not None:
353349
_, probability_mal_g = self.malignancy_model(input_g)
354350
else:
355-
probability_mal_g = torch.zeros_like(probability_g)
351+
probability_mal_g = torch.zeros_like(probability_nodule_g)
356352

357-
for center_irc, prob, prob_mal in zip(center_list,
358-
probability_g[:,1].tolist(),
359-
probability_mal_g[:,1].tolist()
360-
):
353+
zip_iter = zip(
354+
center_list,
355+
probability_nodule_g[:,1].tolist(),
356+
probability_mal_g[:,1].tolist(),
357+
)
358+
for center_irc, prob_nodule, prob_mal in zip_iter:
361359
center_xyz = irc2xyz(
362360
center_irc,
363361
direction_a=ct.direction_a,
364362
origin_xyz=ct.origin_xyz,
365-
vxSize_xyz=ct.vxSize_xyz)
366-
classifications_list.append((prob, prob_mal, center_xyz, center_irc))
363+
vxSize_xyz=ct.vxSize_xyz,
364+
)
365+
cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
366+
classifications_list.append(cls_tup)
367367
return classifications_list
368368

369369
def segmentCt(self, ct, series_uid):
370370
with torch.no_grad():
371371
output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
372372
seg_dl = self.initSegmentationDl(series_uid)
373373
for batch_tup in seg_dl:
374-
input_t = batch_tup[0]
375-
ndx_list = batch_tup[4]
374+
input_t, label_t, series_list, slice_ndx_list = batch_tup
376375

377376
input_g = input_t.to(self.device)
378377
prediction_g = self.seg_model(input_g)
379378

380-
for i, sample_ndx in enumerate(ndx_list):
381-
output_a[sample_ndx] = prediction_g[i].cpu().numpy()
379+
for i, slice_ndx in enumerate(slice_ndx_list):
380+
output_a[slice_ndx] = prediction_g[i].cpu().numpy()
382381

383-
# mask_a = output_a > 0.25
384382
mask_a = output_a > 0.5
385-
# mask_a = morphology.binary_erosion(mask_a, iterations=1)
386-
# mask_a = morphology.binary_dilation(mask_a, iterations=2)
383+
mask_a = morphology.binary_erosion(mask_a, iterations=1)
387384

388385
return mask_a
389386

390-
def clusterSegmentationOutput(self, series_uid, ct, clean_a):
387+
def groupSegmentationOutput(self, series_uid, ct, clean_a):
391388
candidateLabel_a, candidate_count = measurements.label(clean_a)
392389
centerIrc_list = measurements.center_of_mass(
393-
ct.hu_a + 1001,
390+
ct.hu_a.clip(-1000, 1000) + 1001,
394391
labels=candidateLabel_a,
395392
index=np.arange(1, candidate_count+1),
396393
)

p2ch14/training.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77

88
import numpy as np
9+
from matplotlib import pyplot
910

1011
from torch.utils.tensorboard import SummaryWriter
1112

@@ -121,11 +122,19 @@ def initModel(self):
121122

122123
if self.cli_args.finetune:
123124
d = torch.load(self.cli_args.finetune, map_location='cpu')
124-
model_blocks = [n for n, subm in model.named_children()
125-
if len(list(subm.parameters())) > 0]
125+
model_blocks = [
126+
n for n, subm in model.named_children()
127+
if len(list(subm.parameters())) > 0
128+
]
126129
finetune_blocks = model_blocks[-self.cli_args.finetune_depth:]
127130
log.info(f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}")
128-
model.load_state_dict(d['model_state'])
131+
model.load_state_dict(
132+
{
133+
k: v for k,v in d['model_state'].items()
134+
if k.split('.')[0] not in model_blocks[-1]
135+
},
136+
strict=False,
137+
)
129138
for n, p in model.named_parameters():
130139
if n.split('.')[0] not in finetune_blocks:
131140
p.requires_grad_(False)
@@ -138,7 +147,7 @@ def initModel(self):
138147

139148
def initOptimizer(self):
140149
lr = 0.003 if self.cli_args.finetune else 0.001
141-
return SGD(self.model.parameters(), weight_decay=1e-4, lr=lr)
150+
return SGD(self.model.parameters(), lr=lr, weight_decay=1e-4)
142151
#return Adam(self.model.parameters(), lr=3e-4)
143152

144153
def initTrainDl(self):
@@ -398,12 +407,21 @@ def logMetrics(
398407
metrics_dict['pr/f1_score'] = \
399408
2 * (precision * recall) / (precision + recall)
400409

410+
threshold = torch.linspace(1, 0)
411+
tpr = (metrics_t[None, METRICS_PRED_P_NDX, posLabel_mask] >= threshold[:, None]).sum(1).float() / pos_count
412+
fpr = (metrics_t[None, METRICS_PRED_P_NDX, negLabel_mask] >= threshold[:, None]).sum(1).float() / neg_count
413+
fp_diff = fpr[1:]-fpr[:-1]
414+
tp_avg = (tpr[1:]+tpr[:-1])/2
415+
auc = (fp_diff * tp_avg).sum()
416+
metrics_dict['auc'] = auc
417+
401418
log.info(
402419
("E{} {:8} {loss/all:.4f} loss, "
403420
+ "{correct/all:-5.1f}% correct, "
404421
+ "{pr/precision:.4f} precision, "
405422
+ "{pr/recall:.4f} recall, "
406-
+ "{pr/f1_score:.4f} f1 score"
423+
+ "{pr/f1_score:.4f} f1 score, "
424+
+ "{auc:.4f} auc"
407425
).format(
408426
epoch_ndx,
409427
mode_str,
@@ -461,6 +479,11 @@ def logMetrics(
461479
key = key.replace('neg', neg)
462480
writer.add_scalar(key, value, self.totalTrainingSamples_count)
463481

482+
fig = pyplot.figure()
483+
pyplot.plot(fpr, tpr)
484+
writer.add_figure('roc', fig, self.totalTrainingSamples_count)
485+
486+
writer.add_scalar('auc', auc, self.totalTrainingSamples_count)
464487
# # tag::logMetrics_writer_prcurve[]
465488
# writer.add_pr_curve(
466489
# 'pr',
@@ -485,7 +508,10 @@ def logMetrics(
485508
bins=bins
486509
)
487510

488-
score = metrics_dict['pr/f1_score']
511+
if not self.cli_args.malignant:
512+
score = metrics_dict['pr/f1_score']
513+
else:
514+
score = metrics_dict['auc']
489515

490516
return score
491517

p3ch15/CMakeLists.txt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
project(cyclegan-jit)
3+
4+
find_package(Torch REQUIRED)
5+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
6+
7+
add_executable(cyclegan-jit cyclegan_jit.cpp)
8+
target_link_libraries(cyclegan-jit pthread jpeg X11)
9+
target_link_libraries(cyclegan-jit "${TORCH_LIBRARIES}")
10+
set_property(TARGET cyclegan-jit PROPERTY CXX_STANDARD 14)
11+
12+
add_executable(cyclegan-cpp-api cyclegan_cpp_api.cpp)
13+
target_link_libraries(cyclegan-cpp-api pthread jpeg X11)
14+
target_link_libraries(cyclegan-cpp-api "${TORCH_LIBRARIES}")
15+
set_property(TARGET cyclegan-cpp-api PROPERTY CXX_STANDARD 14)
16+
17+
# The following code block is suggested to be used on Windows.
18+
# According to https://github.com/pytorch/pytorch/issues/25457,
19+
# the DLLs need to be copied to avoid memory errors.
20+
if (MSVC)
21+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
22+
add_custom_command(TARGET cyclegan-jit
23+
POST_BUILD
24+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
25+
${TORCH_DLLS}
26+
$<TARGET_FILE_DIR:example-app>)
27+
endif (MSVC)

0 commit comments

Comments
 (0)