-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathsparse4d_detector.py
123 lines (107 loc) · 4.1 KB
/
sparse4d_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) 2024 SparseEnd2End. All rights reserved @author: Thomas Von Wu.
import torch
from inspect import signature
from modules.cnn.base_detector import BaseDetector
from modules.backbone import *
from modules.neck import *
from modules.head import *
from modules.head.sparse4d_blocks.core_blocks import *
from tool.runner.fp16_utils import force_fp32, auto_fp16
try:
from .ops import feature_maps_format
DFA_VALID = True
except:
DFA_VALID = False
__all__ = ["Sparse4D"]
class Sparse4D(BaseDetector):
def __init__(
self,
img_backbone,
img_neck,
head,
depth_branch=None,
use_grid_mask=True,
use_deformable_func=False,
init_cfg=None,
):
super(Sparse4D, self).__init__(init_cfg=init_cfg)
# =========== build modules ===========
def build_module(cfg):
cfg2 = cfg.copy()
type = cfg2.pop("type")
return eval(type)(**cfg2)
self.img_backbone = build_module(img_backbone)
if img_neck is not None:
self.img_neck = build_module(img_neck)
self.head = build_module(head)
self.use_grid_mask = use_grid_mask
if use_deformable_func:
assert DFA_VALID, "deformable_aggregation needs to be set up."
self.use_deformable_func = use_deformable_func
if depth_branch is not None:
self.depth_branch = build_module(depth_branch)
else:
self.depth_branch = None
if use_grid_mask:
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
)
@auto_fp16(apply_to=("img",), out_fp32=True)
def extract_feat(self, img, return_depth=False, metas=None):
bs = img.shape[0] # (1, 6, 3, 256, 704)
if img.dim() == 5: # multi-view
num_cams = img.shape[1]
img = img.flatten(end_dim=1) # [1*6, 3, 256, 704]
else:
num_cams = 1
if self.use_grid_mask:
img = self.grid_mask(img)
if "metas" in signature(self.img_backbone.forward).parameters:
feature_maps = self.img_backbone(img, num_cams, metas=metas)
else:
feature_maps = self.img_backbone(img)
if self.img_neck is not None:
feature_maps = list(self.img_neck(feature_maps))
for i, feat in enumerate(feature_maps):
feature_maps[i] = torch.reshape(feat, (bs, num_cams) + feat.shape[1:])
if return_depth and self.depth_branch is not None:
depths = self.depth_branch(feature_maps, metas.get("focal"))
else:
depths = None
if self.use_deformable_func:
feature_maps = feature_maps_format(feature_maps)
if return_depth:
return feature_maps, depths
return feature_maps
@force_fp32(apply_to=("img",))
def forward(self, img, **data):
if self.training:
return self.forward_train(img, **data)
else:
return self.forward_test(img, **data)
def forward_train(self, img, **data):
feature_maps, depths = self.extract_feat(img, True, data)
model_outs = self.head(feature_maps, data)
output = self.head.loss(model_outs, data)
if depths is not None and "gt_depth" in data:
output["loss_dense_depth"] = self.depth_branch.loss(
depths, data["gt_depth"]
)
return output
def forward_test(self, img, **data):
if isinstance(img, list):
return self.aug_test(img, **data)
else:
return self.simple_test(img, **data)
def simple_test(self, img, **data):
feature_maps = self.extract_feat(img)
model_outs = self.head(feature_maps, data)
results = self.head.post_process(model_outs)
output = [{"img_bbox": result} for result in results]
return output
def aug_test(self, img, **data):
# fake test time augmentation
for key in data.keys():
if isinstance(data[key], list):
data[key] = data[key][0]
return self.simple_test(img[0], **data)