Skip to content

Commit 4436d06

Browse files
David Josef Emmerichsdemmerichs
David Josef Emmerichs
authored andcommitted
refactor application of data_dict transforms
1 parent b4dd915 commit 4436d06

File tree

4 files changed

+90
-99
lines changed

4 files changed

+90
-99
lines changed

pcdet/datasets/augmentor/augmentor_utils.py

+58-48
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ...utils import box_utils
66

77

8-
def random_flip_along(dim, gt_boxes, points, return_flip=False, enable=None):
8+
def random_flip_along(dim, return_flip=False, enable=None):
99
"""
1010
Args:
1111
gt_boxes: (*, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
@@ -16,19 +16,31 @@ def random_flip_along(dim, gt_boxes, points, return_flip=False, enable=None):
1616
other_dim = 1 - dim
1717
if enable is None:
1818
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5])
19+
1920
if enable:
20-
gt_boxes[..., other_dim] = -gt_boxes[..., other_dim]
21-
gt_boxes[..., 6] = -(gt_boxes[..., 6] + np.pi * dim)
22-
points[..., other_dim] = -points[..., other_dim]
21+
def flip_pointlike(points):
22+
points[..., other_dim] = -points[..., other_dim]
23+
return points
24+
25+
def flip_boxlike(boxes):
26+
boxes[..., other_dim] = -boxes[..., other_dim]
27+
boxes[..., 6] = -(boxes[..., 6] + np.pi * dim)
28+
29+
if boxes.shape[-1] > 7:
30+
boxes[..., 7 + other_dim] = -boxes[..., 7 + other_dim]
31+
32+
return boxes
33+
34+
tfs = dict(point=flip_pointlike, box=flip_boxlike)
35+
else:
36+
tfs = dict()
2337

24-
if gt_boxes.shape[-1] > 7:
25-
gt_boxes[..., 7 + other_dim] = -gt_boxes[..., 7 + other_dim]
2638
if return_flip:
27-
return gt_boxes, points, enable
28-
return gt_boxes, points
39+
return tfs, enable
40+
return tfs
2941

3042

31-
def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotation=None):
43+
def global_rotation(rot_range, return_rot=False, noise_rotation=None):
3244
"""
3345
Args:
3446
gt_boxes: (*, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
@@ -38,61 +50,59 @@ def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotatio
3850
"""
3951
if noise_rotation is None:
4052
noise_rotation = np.random.uniform(rot_range[0], rot_range[1])
41-
points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
42-
gt_boxes[..., 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, ..., 0:3], np.array([noise_rotation]))[0]
43-
gt_boxes[..., 6] += noise_rotation
44-
if gt_boxes.shape[-1] > 7:
45-
gt_boxes[..., 7:9] = common_utils.rotate_points_along_z(
46-
np.concatenate((gt_boxes[..., 7:9], np.zeros((*gt_boxes.shape[:-1], 1))), axis=-1)[np.newaxis, ...],
47-
np.array([noise_rotation])
48-
)[0, ..., 0:2]
53+
54+
def rotate_pointlike(points):
55+
points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0]
56+
return points
57+
58+
def rotate_boxlike(boxes):
59+
boxes[..., 0:3] = common_utils.rotate_points_along_z(boxes[np.newaxis, ..., 0:3], np.array([noise_rotation]))[0]
60+
boxes[..., 6] += noise_rotation
61+
if boxes.shape[-1] > 7:
62+
boxes[..., 7:9] = common_utils.rotate_points_along_z(
63+
np.concatenate((boxes[..., 7:9], np.zeros((*boxes.shape[:-1], 1))), axis=-1)[np.newaxis, ...],
64+
np.array([noise_rotation])
65+
)[0, ..., 0:2]
66+
return boxes
67+
68+
tfs = dict(point=rotate_pointlike, box=rotate_boxlike)
4969

5070
if return_rot:
51-
return gt_boxes, points, noise_rotation
52-
return gt_boxes, points
71+
return tfs, noise_rotation
72+
return tfs
5373

5474

55-
def global_scaling(gt_boxes, points, scale_range, return_scale=False):
75+
def global_scaling(scale_range, return_scale=False):
5676
"""
5777
Args:
58-
gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
78+
gt_boxes: (*, 7), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
5979
points: (M, 3 + C),
6080
scale_range: [min, max]
6181
Returns:
6282
"""
6383
if scale_range[1] - scale_range[0] < 1e-3:
64-
return gt_boxes, points
84+
noise_scale = sum(scale_range) / len(scale_range)
85+
assert noise_scale == 1.0, (noise_scale, scale_range)
6586
noise_scale = np.random.uniform(scale_range[0], scale_range[1])
66-
points[:, :3] *= noise_scale
67-
gt_boxes[:, :6] *= noise_scale
68-
if gt_boxes.shape[1] > 7:
69-
gt_boxes[:, 7:9] *= noise_scale
70-
71-
if return_scale:
72-
return gt_boxes, points, noise_scale
73-
return gt_boxes, points
7487

75-
def global_scaling_with_roi_boxes(gt_boxes, roi_boxes, points, scale_range, return_scale=False):
76-
"""
77-
Args:
78-
gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading, [vx], [vy]]
79-
points: (M, 3 + C),
80-
scale_range: [min, max]
81-
Returns:
82-
"""
83-
if scale_range[1] - scale_range[0] < 1e-3:
84-
return gt_boxes, points
85-
noise_scale = np.random.uniform(scale_range[0], scale_range[1])
86-
points[:, :3] *= noise_scale
87-
gt_boxes[:, :6] *= noise_scale
88-
if gt_boxes.shape[1] > 7:
89-
gt_boxes[:, 7:9] *= noise_scale
88+
def scale_pointlike(points):
89+
points[:, :3] *= noise_scale
90+
return points
91+
92+
def scale_boxlike(boxes):
93+
boxes[..., :6] *= noise_scale
94+
if boxes.shape[-1] > 7:
95+
boxes[..., 7:9] *= noise_scale
96+
return boxes
9097

91-
roi_boxes[:,:, [0,1,2,3,4,5,7,8]] *= noise_scale
98+
if noise_scale != 1.0:
99+
tfs = dict(point=scale_pointlike, box=scale_boxlike)
100+
else:
101+
tfs = {}
92102

93103
if return_scale:
94-
return gt_boxes,roi_boxes, points, noise_scale
95-
return gt_boxes, roi_boxes, points
104+
return tfs, noise_scale
105+
return tfs
96106

97107

98108
def random_image_flip_horizontal(image, depth_map, gt_boxes, calib):

pcdet/datasets/augmentor/data_augmentor.py

+13-45
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,15 @@ def __setstate__(self, d):
5656
def random_world_flip(self, data_dict=None, config=None):
5757
if data_dict is None:
5858
return partial(self.random_world_flip, config=config)
59-
gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
6059
for cur_axis in config['ALONG_AXIS_LIST']:
6160
assert cur_axis in ['x', 'y']
6261
cur_dim = ['x', 'y'].index(cur_axis)
63-
gt_boxes, points, enable = augmentor_utils.random_flip_along(
64-
cur_dim, gt_boxes, points, return_flip=True
62+
tfs, enable = augmentor_utils.random_flip_along(
63+
cur_dim, return_flip=True
6564
)
65+
common_utils.apply_data_transform(data_dict, tfs)
6666
data_dict['flip_%s'%cur_axis] = enable
67-
if 'roi_boxes' in data_dict.keys():
68-
data_dict['roi_boxes'], _ = augmentor_utils.random_flip_along(
69-
cur_dim,
70-
data_dict['roi_boxes'],
71-
np.zeros([0,3]),
72-
enable=enable,
73-
)
7467

75-
data_dict['gt_boxes'] = gt_boxes
76-
data_dict['points'] = points
7768
return data_dict
7869

7970
def random_world_rotation(self, data_dict=None, config=None):
@@ -82,38 +73,18 @@ def random_world_rotation(self, data_dict=None, config=None):
8273
rot_range = config['WORLD_ROT_ANGLE']
8374
if not isinstance(rot_range, list):
8475
rot_range = [-rot_range, rot_range]
85-
gt_boxes, points, noise_rot = augmentor_utils.global_rotation(
86-
data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True
76+
tfs, noise_rot = augmentor_utils.global_rotation(
77+
rot_range=rot_range, return_rot=True
8778
)
88-
if 'roi_boxes' in data_dict.keys():
89-
data_dict['roi_boxes'], _ = augmentor_utils.global_rotation(
90-
data_dict['roi_boxes'],
91-
np.zeros([0, 3]),
92-
rot_range=rot_range,
93-
noise_rotation=noise_rot,
94-
)
95-
96-
data_dict['gt_boxes'] = gt_boxes
97-
data_dict['points'] = points
79+
common_utils.apply_data_transform(data_dict, tfs)
9880
data_dict['noise_rot'] = noise_rot
9981
return data_dict
10082

10183
def random_world_scaling(self, data_dict=None, config=None):
10284
if data_dict is None:
10385
return partial(self.random_world_scaling, config=config)
104-
105-
if 'roi_boxes' in data_dict.keys():
106-
gt_boxes, roi_boxes, points, noise_scale = augmentor_utils.global_scaling_with_roi_boxes(
107-
data_dict['gt_boxes'], data_dict['roi_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
108-
)
109-
data_dict['roi_boxes'] = roi_boxes
110-
else:
111-
gt_boxes, points, noise_scale = augmentor_utils.global_scaling(
112-
data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
113-
)
114-
115-
data_dict['gt_boxes'] = gt_boxes
116-
data_dict['points'] = points
86+
tfs, noise_scale = augmentor_utils.global_scaling(scale_range=config['WORLD_SCALE_RANGE'], return_scale=True)
87+
common_utils.apply_data_transform(data_dict, tfs)
11788
data_dict['noise_scale'] = noise_scale
11889
return data_dict
11990

@@ -147,15 +118,12 @@ def random_world_translation(self, data_dict=None, config=None):
147118
np.random.normal(0, noise_translate_std[2], 1),
148119
], dtype=np.float32).T
149120

150-
gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
151-
points[:, :3] += noise_translate
152-
gt_boxes[:, :3] += noise_translate
153-
154-
if 'roi_boxes' in data_dict.keys():
155-
data_dict['roi_boxes'][:, :, :3] += noise_translate
121+
def translate_locationlike(locations):
122+
locations[..., :3] += noise_translate
123+
return locations
156124

157-
data_dict['gt_boxes'] = gt_boxes
158-
data_dict['points'] = points
125+
tfs = dict(point=translate_locationlike, box=translate_locationlike)
126+
common_utils.apply_data_transform(data_dict, tfs)
159127
data_dict['noise_translate'] = noise_translate
160128
return data_dict
161129

pcdet/datasets/processor/data_processor.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def mask_points_and_boxes_outside_range(self, data_dict=None, config=None):
8282

8383
if data_dict.get('points', None) is not None:
8484
mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range)
85-
data_dict['points'] = data_dict['points'][mask]
85+
tfs = dict(point=lambda x: x[mask])
86+
common_utils.apply_data_transform(data_dict, tfs)
8687

8788
if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training:
8889
mask = box_utils.mask_boxes_outside_range_numpy(
@@ -97,10 +98,9 @@ def shuffle_points(self, data_dict=None, config=None):
9798
return partial(self.shuffle_points, config=config)
9899

99100
if config.SHUFFLE_ENABLED[self.mode]:
100-
points = data_dict['points']
101-
shuffle_idx = np.random.permutation(points.shape[0])
102-
points = points[shuffle_idx]
103-
data_dict['points'] = points
101+
shuffle_idx = np.random.permutation(data_dict['points'].shape[0])
102+
tfs = dict(point=lambda x: x[shuffle_idx])
103+
common_utils.apply_data_transform(data_dict, tfs)
104104

105105
return data_dict
106106

@@ -208,7 +208,8 @@ def sample_points(self, data_dict=None, config=None):
208208
extra_choice = np.random.choice(choice, num_points - len(points), replace=False)
209209
choice = np.concatenate((choice, extra_choice), axis=0)
210210
np.random.shuffle(choice)
211-
data_dict['points'] = points[choice]
211+
tfs = dict(point=lambda x: x[choice])
212+
common_utils.apply_data_transform(data_dict, tfs)
212213
return data_dict
213214

214215
def calculate_grid_size(self, data_dict=None, config=None):

pcdet/utils/common_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ def drop_info_with_name(info, name):
3232
return ret_info
3333

3434

35+
def apply_data_transform(data_dict, transforms):
36+
assert set(transforms.keys()).issubset({'point', 'box'})
37+
data_keys = {
38+
'point': ['points'],
39+
'box': ['gt_boxes', 'roi_boxes']
40+
}
41+
for tf_type, tf in transforms.items():
42+
for data_key in data_keys[tf_type]:
43+
if data_key in data_dict:
44+
data_dict[data_key] = tf(data_dict[data_key])
45+
46+
3547
def rotate_points_along_z(points, angle):
3648
"""
3749
Args:

0 commit comments

Comments
 (0)