-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference.py
118 lines (105 loc) · 4.53 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
This module handles the detection and tracking logic
"""
import postprocessing
import utils
import numpy as np
from target import Target
class InferenceModel:
def __init__(self, model, params):
self.model = model
self.params = params
self.frames = []
def predict(self, x):
return self.model.predict(x)
def segment(self, x):
nC = self.params.NUM_CLASSES
nD = self.params.EMBEDDING_DIM
OS = self.params.OUTPUT_SIZE
outputs = self.predict(x)
outputs = np.squeeze(outputs)
combined_class_mask_pred = np.zeros((OS, OS*4, nC))
combined_embedding_pred = np.zeros((OS, OS*4, nD))
for i in range(4):
# channel wise slice copied to horizontal slice
combined_class_mask_pred[:, (OS*i):(OS*(i+1)), :] = \
outputs[:, :, (nC*i):(nC*(i+1))]
combined_embedding_pred[:, (OS*i):(OS*(i+1)), :] = \
outputs[:, :, (nC*4+nD*i):(nC*4+nD*(i+1))]
combined_class_mask_pred_int = np.argmax(combined_class_mask_pred, axis = -1)
cluster_all_class = postprocessing.embedding_to_instance(
combined_embedding_pred,
combined_class_mask_pred_int,
self.params)
return combined_embedding_pred, combined_class_mask_pred_int, cluster_all_class
def get_mask_pair(self, x):
OS = self.params.OUTPUT_SIZE
_, _, cluster_all_class = self.segment(x)
num_instance = int(np.max(cluster_all_class))
amodal_prev_masks = []
amodal_masks = []
for i in range(num_instance):
mask_id = i + 1
instance_mask = cluster_all_class[:, (OS * 0):(OS * 1)]
occ_instance_mask = cluster_all_class[:, (OS * 1):(OS * 2)]
prev_instance_mask = cluster_all_class[:, (OS * 2):(OS * 3)]
occ_prev_instance_mask = cluster_all_class[:, (OS * 3):(OS * 4)]
amodal_mask = np.logical_or(
instance_mask == mask_id,
occ_instance_mask == mask_id)
amodal_prev_mask = np.logical_or(
prev_instance_mask == mask_id,
occ_prev_instance_mask == mask_id)
if (np.sum(amodal_mask) > self.params.MASK_AREA_THRESHOLD and \
np.sum(amodal_prev_mask) > self.params.MASK_AREA_THRESHOLD):
amodal_masks.append(amodal_mask)
amodal_prev_masks.append(amodal_prev_mask)
return amodal_prev_masks, amodal_masks
def update_track(self, x):
masks_0, masks_1 = self.get_mask_pair(x)
# step 1: initialize tracks with every mask in the first frame
if len(self.frames) == 0:
frame_0 = []
frame_1 = []
for i in range(len(masks_0)):
mask_0 = masks_0[i]
mask_1 = masks_1[i]
id = self.highest_id + 1
self.highest_id += 1
frame_0.append(Target(mask_0, id))
frame_1.append(Target(mask_1, id))
self.frames.append(frame_0)
self.frames.append(frame_1)
# step 2: match current frame with previous frame
else:
prev_frame = self.frames[-1]
frame = []
for i in range(len(masks_0)):
mask_0 = masks_0[i]
mask_1 = masks_1[i]
matched = False
for target in prev_frame:
iou = utils.iou(target.mask, mask_0)
if iou > self.params.IOU_THRESHOLD:
# mask_0 is already tracked in the previous frame
# so we only link mask_1
linked_target = Target(mask_1, target.id)
matched = True
frame.append(linked_target)
break
# if there is no match between any previous mask and the
# new detection, then start a new track
if not matched:
id = self.highest_id + 1
self.highest_id += 1
new_target = Target(mask_1, id)
frame.append(new_target)
self.frames.append(frame)
def track_on_sequence(self, sequence):
self.frames = []
self.highest_id = -1
for i in range(len(sequence) - 1):
[prev_image_info, image_info] = sequence[i:i+2]
x, _ = utils.prep_double_frame(prev_image_info, image_info)
self.update_track(x)
return self.frames