Skip to content

Commit 25e096f

Browse files
authored
added evaluation script for PPHumanSeg model (#130)
* added evaluation script for PPHumanSeg * added quantized model, renamed dataset * minor spacing changes * moved _all variables outside loop and updated accuracy * removed printing for class accuracy and IoU * added 2 transforms * evaluation done on same size tensor as input size with mIoU 0.9085 * final changes * added mIoU and reference
1 parent 8de3653 commit 25e096f

File tree

6 files changed

+262
-5
lines changed

6 files changed

+262
-5
lines changed

models/human_segmentation_pphumanseg/README.md

+12
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ python demo.py --help
2222

2323
![messi](./examples/messi.jpg)
2424

25+
---
26+
Results of accuracy evaluation with [tools/eval](../../tools/eval).
27+
28+
| Models | Accuracy | mIoU |
29+
| ------------------ | -------------- | ------------- |
30+
| PPHumanSeg | 0.9581 | 0.8996 |
31+
| PPHumanSeg quant | 0.4365 | 0.2788 |
32+
33+
34+
\*: 'quant' stands for 'quantized'.
35+
36+
---
2537
## License
2638

2739
All files in this directory are licensed under [Apache 2.0 License](./LICENSE).

models/human_segmentation_pphumanseg/pphumanseg.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, modelPath, backendId=0, targetId=0):
1919

2020
self._inputNames = ''
2121
self._outputNames = ['save_infer_model/scale_0.tmp_1']
22+
self._currentInputSize = None
2223
self._inputSize = [192, 192]
2324
self._mean = np.array([0.5, 0.5, 0.5])[np.newaxis, np.newaxis, :]
2425
self._std = np.array([0.5, 0.5, 0.5])[np.newaxis, np.newaxis, :]
@@ -36,28 +37,36 @@ def setTarget(self, target_id):
3637
self._model.setPreferableTarget(self._targetId)
3738

3839
def _preprocess(self, image):
40+
41+
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
42+
43+
self._currentInputSize = image.shape
44+
image = cv.resize(image, (192, 192))
45+
3946
image = image.astype(np.float32, copy=False) / 255.0
4047
image -= self._mean
4148
image /= self._std
4249
return cv.dnn.blobFromImage(image)
4350

4451
def infer(self, image):
45-
assert image.shape[0] == self._inputSize[1], '{} (height of input image) != {} (preset height)'.format(image.shape[0], self._inputSize[1])
46-
assert image.shape[1] == self._inputSize[0], '{} (width of input image) != {} (preset width)'.format(image.shape[1], self._inputSize[0])
4752

4853
# Preprocess
4954
inputBlob = self._preprocess(image)
5055

5156
# Forward
5257
self._model.setInput(inputBlob, self._inputNames)
53-
outputBlob = self._model.forward(self._outputNames)
58+
outputBlob = self._model.forward()
5459

5560
# Postprocess
5661
results = self._postprocess(outputBlob)
5762

5863
return results
5964

6065
def _postprocess(self, outputBlob):
61-
result = np.argmax(outputBlob[0], axis=1).astype(np.uint8)
66+
67+
outputBlob = outputBlob[0]
68+
outputBlob = cv.resize(outputBlob.transpose(1,2,0), (self._currentInputSize[1], self._currentInputSize[0]), interpolation=cv.INTER_LINEAR).transpose(2,0,1)[np.newaxis, ...]
69+
70+
result = np.argmax(outputBlob, axis=1).astype(np.uint8)
6271
return result
6372

tools/eval/README.md

+21
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Supported datasets:
2121
- [LFW](#lfw)
2222
- [ICDAR](#ICDAR2003)
2323
- [IIIT5K](#iiit5k)
24+
- [Mini Supervisely](#mini_supervisely)
2425

2526
## ImageNet
2627

@@ -190,4 +191,24 @@ Run evaluation with the following command:
190191

191192
```shell
192193
python eval.py -m crnn -d iiit5k -dr /path/to/iiit5k
194+
```
195+
196+
197+
## Mini Supervisely
198+
199+
### Prepare data
200+
Please download the mini_supervisely data from [here](https://paddleseg.bj.bcebos.com/humanseg/data/mini_supervisely.zip) which includes the validation dataset and unzip it.
201+
202+
### Evaluation
203+
204+
Run evaluation with the following command :
205+
206+
```shell
207+
python eval.py -m pphumanseg -d mini_supervisely -dr /path/to/pphumanseg
208+
```
209+
210+
Run evaluation on quantized model with the following command :
211+
212+
```shell
213+
python eval.py -m pphumanseg_q -d mini_supervisely -dr /path/to/pphumanseg
193214
```

tools/eval/datasets/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .lfw import LFW
44
from .icdar import ICDAR
55
from .iiit5k import IIIT5K
6+
from .minisupervisely import MiniSupervisely
67

78
class Registery:
89
def __init__(self, name):
@@ -20,4 +21,5 @@ def register(self, item):
2021
DATASETS.register(WIDERFace)
2122
DATASETS.register(LFW)
2223
DATASETS.register(ICDAR)
23-
DATASETS.register(IIIT5K)
24+
DATASETS.register(IIIT5K)
25+
DATASETS.register(MiniSupervisely)
+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import os
2+
import cv2 as cv
3+
import numpy as np
4+
from tqdm import tqdm
5+
6+
7+
class MiniSupervisely :
8+
9+
'''
10+
Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/paddleseg/core/val.py
11+
for official evaluation implementation.
12+
'''
13+
14+
def __init__(self, root) :
15+
self.root = root
16+
self.val_path = os.path.join(root, 'val.txt')
17+
self.image_set = self.load_data(self.val_path)
18+
self.num_classes = 2
19+
self.miou = -1
20+
self.class_miou = -1
21+
self.acc = -1
22+
self.class_acc = -1
23+
24+
25+
@property
26+
def name(self):
27+
return self.__class__.__name__
28+
29+
30+
def load_data(self, val_path) :
31+
"""
32+
Load validation image set from val.txt file
33+
Args :
34+
val_path (str) : path to val.txt file
35+
Returns :
36+
image_set (list) : list of image path of input and expected image
37+
"""
38+
39+
image_set = []
40+
with open(val_path, 'r') as f :
41+
for line in f.readlines() :
42+
image_set.append(line.strip().split())
43+
44+
return image_set
45+
46+
47+
def eval(self, model) :
48+
"""
49+
Evaluate model on validation set
50+
Args :
51+
model (object) : PP_HumanSeg model object
52+
"""
53+
54+
intersect_area_all = np.zeros([1], dtype=np.int64)
55+
pred_area_all = np.zeros([1], dtype=np.int64)
56+
label_area_all = np.zeros([1], dtype=np.int64)
57+
58+
pbar = tqdm(self.image_set)
59+
60+
pbar.set_description(
61+
"Evaluating {} with {} val set".format(model.name, self.name))
62+
63+
for input_image, expected_image in pbar :
64+
65+
input_image = cv.imread(os.path.join(self.root, input_image)).astype('float32')
66+
67+
expected_image = cv.imread(os.path.join(self.root, expected_image), cv.IMREAD_GRAYSCALE)[np.newaxis, :, :]
68+
69+
output_image = model.infer(input_image)
70+
71+
intersect_area, pred_area, label_area = self.calculate_area(
72+
output_image.astype('uint32'),
73+
expected_image.astype('uint32'),
74+
self.num_classes)
75+
76+
intersect_area_all = intersect_area_all + intersect_area
77+
pred_area_all = pred_area_all + pred_area
78+
label_area_all = label_area_all + label_area
79+
80+
self.class_iou, self.miou = self.mean_iou(intersect_area_all, pred_area_all,
81+
label_area_all)
82+
self.class_acc, self.acc = self.accuracy(intersect_area_all, pred_area_all)
83+
84+
85+
def get_results(self) :
86+
"""
87+
Get evaluation results
88+
Returns :
89+
miou (float) : mean iou
90+
class_miou (list) : iou on all classes
91+
acc (float) : mean accuracy
92+
class_acc (list) : accuracy on all classes
93+
"""
94+
return self.miou, self.class_miou, self.acc, self.class_acc
95+
96+
97+
def print_result(self) :
98+
"""
99+
Print evaluation results
100+
"""
101+
print("Mean IoU : ", self.miou)
102+
print("Mean Accuracy : ", self.acc)
103+
print("Class IoU : ", self.class_iou)
104+
print("Class Accuracy : ", self.class_acc)
105+
106+
107+
def calculate_area(self,pred, label, num_classes, ignore_index=255):
108+
"""
109+
Calculate intersect, prediction and label area
110+
Args:
111+
pred (Tensor): The prediction by model.
112+
label (Tensor): The ground truth of image.
113+
num_classes (int): The unique number of target classes.
114+
ignore_index (int): Specifies a target value that is ignored. Default: 255.
115+
Returns:
116+
Tensor: The intersection area of prediction and the ground on all class.
117+
Tensor: The prediction area on all class.
118+
Tensor: The ground truth area on all class
119+
"""
120+
121+
122+
if len(pred.shape) == 4:
123+
pred = np.squeeze(pred, axis=1)
124+
if len(label.shape) == 4:
125+
label = np.squeeze(label, axis=1)
126+
if not pred.shape == label.shape:
127+
raise ValueError('Shape of `pred` and `label should be equal, '
128+
'but there are {} and {}.'.format(pred.shape,
129+
label.shape))
130+
131+
mask = label != ignore_index
132+
pred_area = []
133+
label_area = []
134+
intersect_area = []
135+
136+
#iterate over all classes and calculate their respective areas
137+
for i in range(num_classes):
138+
pred_i = np.logical_and(pred == i, mask)
139+
label_i = label == i
140+
intersect_i = np.logical_and(pred_i, label_i)
141+
pred_area.append(np.sum(pred_i.astype('int32')))
142+
label_area.append(np.sum(label_i.astype('int32')))
143+
intersect_area.append(np.sum(intersect_i.astype('int32')))
144+
145+
return intersect_area, pred_area, label_area
146+
147+
148+
def mean_iou(self,intersect_area, pred_area, label_area):
149+
"""
150+
Calculate iou.
151+
Args:
152+
intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
153+
pred_area (Tensor): The prediction area on all classes.
154+
label_area (Tensor): The ground truth area on all classes.
155+
Returns:
156+
np.ndarray: iou on all classes.
157+
float: mean iou of all classes.
158+
"""
159+
intersect_area = np.array(intersect_area)
160+
pred_area = np.array(pred_area)
161+
label_area = np.array(label_area)
162+
163+
union = pred_area + label_area - intersect_area
164+
165+
class_iou = []
166+
for i in range(len(intersect_area)):
167+
if union[i] == 0:
168+
iou = 0
169+
else:
170+
iou = intersect_area[i] / union[i]
171+
class_iou.append(iou)
172+
173+
miou = np.mean(class_iou)
174+
175+
return np.array(class_iou), miou
176+
177+
178+
def accuracy(self,intersect_area, pred_area):
179+
"""
180+
Calculate accuracy
181+
Args:
182+
intersect_area (Tensor): The intersection area of prediction and ground truth on all classes..
183+
pred_area (Tensor): The prediction area on all classes.
184+
Returns:
185+
np.ndarray: accuracy on all classes.
186+
float: mean accuracy.
187+
"""
188+
189+
intersect_area = np.array(intersect_area)
190+
pred_area = np.array(pred_area)
191+
192+
class_acc = []
193+
for i in range(len(intersect_area)):
194+
if pred_area[i] == 0:
195+
acc = 0
196+
else:
197+
acc = intersect_area[i] / pred_area[i]
198+
class_acc.append(acc)
199+
200+
macc = np.sum(intersect_area) / np.sum(pred_area)
201+
202+
return np.array(class_acc), macc

tools/eval/eval.py

+11
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@
7777
name="CRNN",
7878
topic="text_recognition",
7979
modelPath=os.path.join(root_dir, "models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx")),
80+
pphumanseg=dict(
81+
name="PPHumanSeg",
82+
topic="human_segmentation",
83+
modelPath=os.path.join(root_dir, "models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct.onnx")),
84+
pphumanseg_q=dict(
85+
name="PPHumanSeg",
86+
topic="human_segmentation",
87+
modelPath=os.path.join(root_dir, "models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct-act_int8-wt_int8-quantized.onnx")),
8088
)
8189

8290
datasets = dict(
@@ -97,6 +105,9 @@
97105
iiit5k=dict(
98106
name="IIIT5K",
99107
topic="text_recognition"),
108+
mini_supervisely=dict(
109+
name="MiniSupervisely",
110+
topic="human_segmentation"),
100111
)
101112

102113
def main(args):

0 commit comments

Comments
 (0)