-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSS_eval.py
119 lines (103 loc) · 4.74 KB
/
SS_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import cv2
import argparse
import numpy as np
import torch
import torch.nn as nn
from SS_config1 import config
from utils.pyt_utils import ensure_dir, link_file, load_model, parse_devices
from utils.visualize import print_iou, show_img
from engine.evaluator import Evaluator
from engine.logger import get_logger
from utils.metric import hist_info, compute_score
from dataloader.SS_Dataset import RGBXDataset
from models.builder import EncoderDecoder as segmodel
from dataloader.SS_dataloader import ValPre
logger = get_logger()
class SegEvaluator(Evaluator):
def func_per_iteration(self, data, device):
List_Img = list(data.values())
####
name = List_Img[-2]
label = List_Img[1]
del List_Img[1]
del List_Img[-1]
del List_Img[-1] ####List里面包含rgb,model_x,view_img
pred = self.sliding_eval_rgbX(List_Img, config.eval_crop_size, config.eval_stride_rate, device)
hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label)
results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp}
if self.save_path is not None:
ensure_dir(self.save_path)
ensure_dir(self.save_path+'_color')
fn = name + '.png'
# save colored result
result_img = Image.fromarray(pred.astype(np.uint8), mode='P')
class_colors = get_class_colors()
palette_list = list(np.array(class_colors).flat)
if len(palette_list) < 768:
palette_list += [0] * (768 - len(palette_list))
result_img.putpalette(palette_list)
result_img.save(os.path.join(self.save_path+'_color', fn))
# save raw result
cv2.imwrite(os.path.join(self.save_path, fn), pred)
logger.info('Save the image ' + fn)
if self.show_image:
colors = self.dataset.get_class_colors
image = img
clean = np.zeros(label.shape)
comp_img = show_img(colors, config.background, image, clean,
label,
pred)
cv2.imshow('comp_image', comp_img)
cv2.waitKey(0)
return results_dict
def compute_metric(self, results):
hist = np.zeros((config.num_classes, config.num_classes))
correct = 0
labeled = 0
count = 0
for d in results:
hist += d['hist']
correct += d['correct']
labeled += d['labeled']
count += 1
iou, mean_IoU, _, freq_IoU, mean_pixel_acc, pixel_acc = compute_score(hist, correct, labeled)
result_line = print_iou(iou, freq_IoU, mean_pixel_acc, pixel_acc,
dataset.class_names, show_no_back=False)
return result_line
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', default='last', type=str)
parser.add_argument('-d', '--devices', default='0', type=str)
parser.add_argument('-v', '--verbose', default=False, action='store_true')
parser.add_argument('--show_image', '-s', default=False,
action='store_true')
parser.add_argument('--save_path', '-p', default=None)
args = parser.parse_args()
all_dev = parse_devices(args.devices)
network = segmodel(cfg=config, criterion=None, norm_layer=nn.BatchNorm2d)
data_setting = {'rgb_root': config.rgb_root_folder,
'rgb_format': config.rgb_format,
'gt_root': config.gt_root_folder,
'gt_format': config.gt_format,
'transform_gt': config.gt_transform,
'x_root':config.x_root_folder,
'x_format': config.x_format,
'x_single_channel': config.x_is_single_channel,
'class_names': config.class_names,
'train_source': config.train_source,
'eval_source': config.eval_source,
'class_names': config.class_names,
'view_path':config.rgb_view_path,
'View_path_format':config.rgb_view_format,
'view_list':config.view_list}
val_pre = ValPre()
dataset = RGBXDataset(data_setting, 'val', val_pre)
with torch.no_grad():
segmentor = SegEvaluator(dataset, config.num_classes, config.norm_mean,
config.norm_std, network,
config.eval_scale_array, config.eval_flip,
all_dev, args.verbose, args.save_path,
args.show_image)
segmentor.run(config.checkpoint_dir, args.epochs, config.val_log_file,
config.link_val_log_file)