diff --git a/pcdet/datasets/augmentor/augmentor_utils.py b/pcdet/datasets/augmentor/augmentor_utils.py index 3c088e33c..075b44ba4 100644 --- a/pcdet/datasets/augmentor/augmentor_utils.py +++ b/pcdet/datasets/augmentor/augmentor_utils.py @@ -5,109 +5,104 @@ from ...utils import box_utils -def random_flip_along_x(gt_boxes, points, return_flip=False, enable=None): +def random_flip_along(dim, return_flip=False, enable=None): """ Args: - gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] - points: (M, 3 + C) + gt_boxes: (*, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] + points: (*, 3 + C) Returns: """ + assert dim in [0, 1] # corresponds to x-, y-axis respectively + other_dim = 1 - dim if enable is None: enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) + if enable: - gt_boxes[:, 1] = -gt_boxes[:, 1] - gt_boxes[:, 6] = -gt_boxes[:, 6] - points[:, 1] = -points[:, 1] - - if gt_boxes.shape[1] > 7: - gt_boxes[:, 8] = -gt_boxes[:, 8] - if return_flip: - return gt_boxes, points, enable - return gt_boxes, points + def flip_pointlike(points): + points[..., other_dim] = -points[..., other_dim] + return points + def flip_boxlike(boxes): + boxes[..., other_dim] = -boxes[..., other_dim] + boxes[..., 6] = -(boxes[..., 6] + np.pi * dim) -def random_flip_along_y(gt_boxes, points, return_flip=False, enable=None): - """ - Args: - gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] - points: (M, 3 + C) - Returns: - """ - if enable is None: - enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) - if enable: - gt_boxes[:, 0] = -gt_boxes[:, 0] - gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi) - points[:, 0] = -points[:, 0] + if boxes.shape[-1] > 7: + boxes[..., 7 + other_dim] = -boxes[..., 7 + other_dim] + + return boxes + + tfs = dict(point=flip_pointlike, box=flip_boxlike) + else: + tfs = dict() - if gt_boxes.shape[1] > 7: - gt_boxes[:, 7] = -gt_boxes[:, 7] if return_flip: - return gt_boxes, points, enable - return gt_boxes, points + return tfs, enable + return tfs -def global_rotation(gt_boxes, points, rot_range, return_rot=False, noise_rotation=None): +def global_rotation(rot_range, return_rot=False, noise_rotation=None): """ Args: - gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] + gt_boxes: (*, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] points: (M, 3 + C), rot_range: [min, max] Returns: """ if noise_rotation is None: noise_rotation = np.random.uniform(rot_range[0], rot_range[1]) - points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0] - gt_boxes[:, 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]))[0] - gt_boxes[:, 6] += noise_rotation - if gt_boxes.shape[1] > 7: - gt_boxes[:, 7:9] = common_utils.rotate_points_along_z( - np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[np.newaxis, :, :], - np.array([noise_rotation]) - )[0][:, 0:2] + + def rotate_pointlike(points): + points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0] + return points + + def rotate_boxlike(boxes): + boxes[..., 0:3] = common_utils.rotate_points_along_z(boxes[np.newaxis, ..., 0:3], np.array([noise_rotation]))[0] + boxes[..., 6] += noise_rotation + if boxes.shape[-1] > 7: + boxes[..., 7:9] = common_utils.rotate_points_along_z( + np.concatenate((boxes[..., 7:9], np.zeros((*boxes.shape[:-1], 1))), axis=-1)[np.newaxis, ...], + np.array([noise_rotation]) + )[0, ..., 0:2] + return boxes + + tfs = dict(point=rotate_pointlike, box=rotate_boxlike) if return_rot: - return gt_boxes, points, noise_rotation - return gt_boxes, points + return tfs, noise_rotation + return tfs -def global_scaling(gt_boxes, points, scale_range, return_scale=False): +def global_scaling(scale_range, return_scale=False): """ Args: - gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading] + gt_boxes: (*, 7), [x, y, z, dx, dy, dz, heading, [vx], [vy]] points: (M, 3 + C), scale_range: [min, max] Returns: """ if scale_range[1] - scale_range[0] < 1e-3: - return gt_boxes, points + noise_scale = sum(scale_range) / len(scale_range) + assert noise_scale == 1.0, (noise_scale, scale_range) noise_scale = np.random.uniform(scale_range[0], scale_range[1]) - points[:, :3] *= noise_scale - gt_boxes[:, :6] *= noise_scale - if gt_boxes.shape[1] > 7: - gt_boxes[:, 7:] *= noise_scale - - if return_scale: - return gt_boxes, points, noise_scale - return gt_boxes, points -def global_scaling_with_roi_boxes(gt_boxes, roi_boxes, points, scale_range, return_scale=False): - """ - Args: - gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading] - points: (M, 3 + C), - scale_range: [min, max] - Returns: - """ - if scale_range[1] - scale_range[0] < 1e-3: - return gt_boxes, points - noise_scale = np.random.uniform(scale_range[0], scale_range[1]) - points[:, :3] *= noise_scale - gt_boxes[:, :6] *= noise_scale - roi_boxes[:,:, [0,1,2,3,4,5,7,8]] *= noise_scale + def scale_pointlike(points): + points[:, :3] *= noise_scale + return points + + def scale_boxlike(boxes): + boxes[..., :6] *= noise_scale + if boxes.shape[-1] > 7: + boxes[..., 7:9] *= noise_scale + return boxes + + if noise_scale != 1.0: + tfs = dict(point=scale_pointlike, box=scale_boxlike) + else: + tfs = {} + if return_scale: - return gt_boxes,roi_boxes, points, noise_scale - return gt_boxes, roi_boxes, points + return tfs, noise_scale + return tfs def random_image_flip_horizontal(image, depth_map, gt_boxes, calib): diff --git a/pcdet/datasets/augmentor/data_augmentor.py b/pcdet/datasets/augmentor/data_augmentor.py index 56acebc81..2d5f64f18 100644 --- a/pcdet/datasets/augmentor/data_augmentor.py +++ b/pcdet/datasets/augmentor/data_augmentor.py @@ -56,22 +56,15 @@ def __setstate__(self, d): def random_world_flip(self, data_dict=None, config=None): if data_dict is None: return partial(self.random_world_flip, config=config) - gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] for cur_axis in config['ALONG_AXIS_LIST']: assert cur_axis in ['x', 'y'] - gt_boxes, points, enable = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)( - gt_boxes, points, return_flip=True + cur_dim = ['x', 'y'].index(cur_axis) + tfs, enable = augmentor_utils.random_flip_along( + cur_dim, return_flip=True ) + common_utils.apply_data_transform(data_dict, tfs) data_dict['flip_%s'%cur_axis] = enable - if 'roi_boxes' in data_dict.keys(): - num_frame, num_rois,dim = data_dict['roi_boxes'].shape - roi_boxes, _, _ = getattr(augmentor_utils, 'random_flip_along_%s' % cur_axis)( - data_dict['roi_boxes'].reshape(-1,dim), np.zeros([1,3]), return_flip=True, enable=enable - ) - data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim) - data_dict['gt_boxes'] = gt_boxes - data_dict['points'] = points return data_dict def random_world_rotation(self, data_dict=None, config=None): @@ -80,36 +73,18 @@ def random_world_rotation(self, data_dict=None, config=None): rot_range = config['WORLD_ROT_ANGLE'] if not isinstance(rot_range, list): rot_range = [-rot_range, rot_range] - gt_boxes, points, noise_rot = augmentor_utils.global_rotation( - data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True + tfs, noise_rot = augmentor_utils.global_rotation( + rot_range=rot_range, return_rot=True ) - if 'roi_boxes' in data_dict.keys(): - num_frame, num_rois,dim = data_dict['roi_boxes'].shape - roi_boxes, _, _ = augmentor_utils.global_rotation( - data_dict['roi_boxes'].reshape(-1, dim), np.zeros([1, 3]), rot_range=rot_range, return_rot=True, noise_rotation=noise_rot) - data_dict['roi_boxes'] = roi_boxes.reshape(num_frame, num_rois,dim) - - data_dict['gt_boxes'] = gt_boxes - data_dict['points'] = points + common_utils.apply_data_transform(data_dict, tfs) data_dict['noise_rot'] = noise_rot return data_dict def random_world_scaling(self, data_dict=None, config=None): if data_dict is None: return partial(self.random_world_scaling, config=config) - - if 'roi_boxes' in data_dict.keys(): - gt_boxes, roi_boxes, points, noise_scale = augmentor_utils.global_scaling_with_roi_boxes( - data_dict['gt_boxes'], data_dict['roi_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True - ) - data_dict['roi_boxes'] = roi_boxes - else: - gt_boxes, points, noise_scale = augmentor_utils.global_scaling( - data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True - ) - - data_dict['gt_boxes'] = gt_boxes - data_dict['points'] = points + tfs, noise_scale = augmentor_utils.global_scaling(scale_range=config['WORLD_SCALE_RANGE'], return_scale=True) + common_utils.apply_data_transform(data_dict, tfs) data_dict['noise_scale'] = noise_scale return data_dict @@ -143,15 +118,12 @@ def random_world_translation(self, data_dict=None, config=None): np.random.normal(0, noise_translate_std[2], 1), ], dtype=np.float32).T - gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] - points[:, :3] += noise_translate - gt_boxes[:, :3] += noise_translate - - if 'roi_boxes' in data_dict.keys(): - data_dict['roi_boxes'][:, :3] += noise_translate + def translate_locationlike(locations): + locations[..., :3] += noise_translate + return locations - data_dict['gt_boxes'] = gt_boxes - data_dict['points'] = points + tfs = dict(point=translate_locationlike, box=translate_locationlike) + common_utils.apply_data_transform(data_dict, tfs) data_dict['noise_translate'] = noise_translate return data_dict diff --git a/pcdet/datasets/processor/data_processor.py b/pcdet/datasets/processor/data_processor.py index 4f72ab532..0e5660048 100644 --- a/pcdet/datasets/processor/data_processor.py +++ b/pcdet/datasets/processor/data_processor.py @@ -82,7 +82,8 @@ def mask_points_and_boxes_outside_range(self, data_dict=None, config=None): if data_dict.get('points', None) is not None: mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range) - data_dict['points'] = data_dict['points'][mask] + tfs = dict(point=lambda x: x[mask]) + common_utils.apply_data_transform(data_dict, tfs) if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training: mask = box_utils.mask_boxes_outside_range_numpy( @@ -97,10 +98,9 @@ def shuffle_points(self, data_dict=None, config=None): return partial(self.shuffle_points, config=config) if config.SHUFFLE_ENABLED[self.mode]: - points = data_dict['points'] - shuffle_idx = np.random.permutation(points.shape[0]) - points = points[shuffle_idx] - data_dict['points'] = points + shuffle_idx = np.random.permutation(data_dict['points'].shape[0]) + tfs = dict(point=lambda x: x[shuffle_idx]) + common_utils.apply_data_transform(data_dict, tfs) return data_dict @@ -208,7 +208,8 @@ def sample_points(self, data_dict=None, config=None): extra_choice = np.random.choice(choice, num_points - len(points), replace=False) choice = np.concatenate((choice, extra_choice), axis=0) np.random.shuffle(choice) - data_dict['points'] = points[choice] + tfs = dict(point=lambda x: x[choice]) + common_utils.apply_data_transform(data_dict, tfs) return data_dict def calculate_grid_size(self, data_dict=None, config=None): diff --git a/pcdet/utils/common_utils.py b/pcdet/utils/common_utils.py index af70728db..cb573df4c 100644 --- a/pcdet/utils/common_utils.py +++ b/pcdet/utils/common_utils.py @@ -32,10 +32,22 @@ def drop_info_with_name(info, name): return ret_info +def apply_data_transform(data_dict, transforms): + assert set(transforms.keys()).issubset({'point', 'box'}) + data_keys = { + 'point': ['points'], + 'box': ['gt_boxes', 'roi_boxes'] + } + for tf_type, tf in transforms.items(): + for data_key in data_keys[tf_type]: + if data_key in data_dict: + data_dict[data_key] = tf(data_dict[data_key]) + + def rotate_points_along_z(points, angle): """ Args: - points: (B, N, 3 + C) + points: (B, *, 3 + C) angle: (B), angle along z-axis, angle increases x ==> y Returns: @@ -43,6 +55,10 @@ def rotate_points_along_z(points, angle): points, is_numpy = check_numpy_to_torch(points) angle, _ = check_numpy_to_torch(angle) + orig_shape = points.shape + if len(orig_shape) > 3: + points = points.view(orig_shape[0], -1, orig_shape[-1]) + cosa = torch.cos(angle) sina = torch.sin(angle) zeros = angle.new_zeros(points.shape[0]) @@ -54,6 +70,7 @@ def rotate_points_along_z(points, angle): ), dim=1).view(-1, 3, 3).float() points_rot = torch.matmul(points[:, :, 0:3], rot_matrix) points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1) + points_rot = points_rot.view(orig_shape) return points_rot.numpy() if is_numpy else points_rot