From 09770385d9802c659053225757f65b03e69eb714 Mon Sep 17 00:00:00 2001
From: FlyingQianMM <245467267@qq.com>
Date: Fri, 19 Aug 2022 19:44:57 +0800
Subject: [PATCH 1/5] add --pretrained for tools/training
---
configs/centerpoint/centerpoint_pillars_016voxel_kitti.yml | 1 -
docs/models/centerpoint/README.md | 2 +-
paddle3d/apis/trainer.py | 5 +++++
tools/train.py | 7 +++++++
4 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/configs/centerpoint/centerpoint_pillars_016voxel_kitti.yml b/configs/centerpoint/centerpoint_pillars_016voxel_kitti.yml
index 5984eae4..42960457 100644
--- a/configs/centerpoint/centerpoint_pillars_016voxel_kitti.yml
+++ b/configs/centerpoint/centerpoint_pillars_016voxel_kitti.yml
@@ -117,7 +117,6 @@ model:
out_channels: [128, 128, 128]
upsample_strides: [0.5, 1, 2]
use_conv_for_no_stride: True
- use_spatial_attn_before_concat: True
bbox_head:
type: CenterHead
in_channels: 384 # sum([128, 128, 128])
diff --git a/docs/models/centerpoint/README.md b/docs/models/centerpoint/README.md
index 348ec423..1acb41ae 100644
--- a/docs/models/centerpoint/README.md
+++ b/docs/models/centerpoint/README.md
@@ -358,7 +358,7 @@ python infer.py --model_file /path/to/centerpoint.pdmodel --params_file /path/to
##
自定义数据集
-请参考文档[自定义数据集格式说明](../../../datasets/custom)准备自定义数据集。
+请参考文档[自定义数据集格式说明](../../../datasets/custom.md)准备自定义数据集。
## Apollo使用教程
基于Paddle3D训练完成的CenterPoint模型可以直接部署到Apollo架构中使用,请参考[教程](https://github.com/ApolloAuto/apollo/blob/master/modules/perception/README_paddle3D_CN.md)
diff --git a/paddle3d/apis/trainer.py b/paddle3d/apis/trainer.py
index 84358932..4e1e6ce3 100644
--- a/paddle3d/apis/trainer.py
+++ b/paddle3d/apis/trainer.py
@@ -22,6 +22,7 @@
from paddle3d.apis.checkpoint import Checkpoint, CheckpointABC
from paddle3d.apis.pipeline import training_step, validation_step
from paddle3d.apis.scheduler import Scheduler, SchedulerABC
+from paddle3d.utils.checkpoint import load_pretrained_model
from paddle3d.utils.logger import logger
from paddle3d.utils.timer import Timer
@@ -101,12 +102,16 @@ def __init__(
train_dataset: Optional[paddle.io.Dataset] = None,
val_dataset: Optional[paddle.io.Dataset] = None,
resume: bool = False,
+ pretrained: str = None,
# TODO: Default parameters should not use mutable objects, there is a risk
checkpoint: Union[dict, CheckpointABC] = dict(),
scheduler: Union[dict, SchedulerABC] = dict(),
dataloader_fn: Union[dict, Callable] = dict()):
self.model = model
+ if pretrained is not None:
+ load_pretrained_model(self.model, pretrained)
+
self.optimizer = optimizer
_dataloader_build_fn = default_dataloader_build_fn(
diff --git a/tools/train.py b/tools/train.py
index 8c3a4fa5..0a9dcb8b 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -84,6 +84,12 @@ def parse_args():
dest='resume',
help='Whether to resume training from checkpoint',
action='store_true')
+ parser.add_argument(
+ '--pretrained',
+ dest='pretrained',
+ help='pretrained weights for model',
+ type=str,
+ default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
@@ -147,6 +153,7 @@ def main(args):
batch_size = dic.pop('batch_size')
dic.update({
'resume': args.resume,
+ 'pretrained': args.pretrained,
'checkpoint': {
'keep_checkpoint_max': args.keep_checkpoint_max,
'save_dir': args.save_dir
From 146bbc34e9ae0b23436c7d21c6914cc55b27ba65 Mon Sep 17 00:00:00 2001
From: FlyingQianMM <245467267@qq.com>
Date: Mon, 22 Aug 2022 14:21:06 +0800
Subject: [PATCH 2/5] load pretrained for tools/train.py
---
paddle3d/apis/trainer.py | 4 ----
tools/train.py | 11 +++++++----
2 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/paddle3d/apis/trainer.py b/paddle3d/apis/trainer.py
index 4e1e6ce3..70df3f96 100644
--- a/paddle3d/apis/trainer.py
+++ b/paddle3d/apis/trainer.py
@@ -22,7 +22,6 @@
from paddle3d.apis.checkpoint import Checkpoint, CheckpointABC
from paddle3d.apis.pipeline import training_step, validation_step
from paddle3d.apis.scheduler import Scheduler, SchedulerABC
-from paddle3d.utils.checkpoint import load_pretrained_model
from paddle3d.utils.logger import logger
from paddle3d.utils.timer import Timer
@@ -102,15 +101,12 @@ def __init__(
train_dataset: Optional[paddle.io.Dataset] = None,
val_dataset: Optional[paddle.io.Dataset] = None,
resume: bool = False,
- pretrained: str = None,
# TODO: Default parameters should not use mutable objects, there is a risk
checkpoint: Union[dict, CheckpointABC] = dict(),
scheduler: Union[dict, SchedulerABC] = dict(),
dataloader_fn: Union[dict, Callable] = dict()):
self.model = model
- if pretrained is not None:
- load_pretrained_model(self.model, pretrained)
self.optimizer = optimizer
diff --git a/tools/train.py b/tools/train.py
index 0a9dcb8b..282d9345 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -22,6 +22,7 @@
import paddle3d.env as paddle3d_env
from paddle3d.apis.config import Config
from paddle3d.apis.trainer import Trainer
+from paddle3d.utils.checkpoint import load_pretrained_model
from paddle3d.utils.logger import logger
@@ -85,9 +86,9 @@ def parse_args():
help='Whether to resume training from checkpoint',
action='store_true')
parser.add_argument(
- '--pretrained',
- dest='pretrained',
- help='pretrained weights for model',
+ '--model',
+ dest='model',
+ help='pretrained parameters of the model',
type=str,
default=None)
parser.add_argument(
@@ -153,7 +154,6 @@ def main(args):
batch_size = dic.pop('batch_size')
dic.update({
'resume': args.resume,
- 'pretrained': args.pretrained,
'checkpoint': {
'keep_checkpoint_max': args.keep_checkpoint_max,
'save_dir': args.save_dir
@@ -169,6 +169,9 @@ def main(args):
}
})
+ if args.model is not None:
+ load_pretrained_model(cfg.model, args.model)
+
trainer = Trainer(**dic)
trainer.train()
From b1583bfcd8bf8815f1cd1425364a40c8374b34a0 Mon Sep 17 00:00:00 2001
From: FlyingQianMM <245467267@qq.com>
Date: Mon, 22 Aug 2022 14:21:57 +0800
Subject: [PATCH 3/5] delete needless line
---
paddle3d/apis/trainer.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/paddle3d/apis/trainer.py b/paddle3d/apis/trainer.py
index 70df3f96..84358932 100644
--- a/paddle3d/apis/trainer.py
+++ b/paddle3d/apis/trainer.py
@@ -107,7 +107,6 @@ def __init__(
dataloader_fn: Union[dict, Callable] = dict()):
self.model = model
-
self.optimizer = optimizer
_dataloader_build_fn = default_dataloader_build_fn(
From 3934d8c0dd9eb6a9ed1caf54e437820355d78f77 Mon Sep 17 00:00:00 2001
From: FlyingQianMM <245467267@qq.com>
Date: Wed, 24 Aug 2022 11:44:07 +0800
Subject: [PATCH 4/5] train dataset would not been initializer during
evaluation
---
paddle3d/apis/trainer.py | 125 +++++++++----------
show_lidar_pred_on_image.py | 236 ++++++++++++++++++++++++++++++++++++
tools/evaluate.py | 3 +-
3 files changed, 302 insertions(+), 62 deletions(-)
create mode 100644 show_lidar_pred_on_image.py
diff --git a/paddle3d/apis/trainer.py b/paddle3d/apis/trainer.py
index 84358932..bfbb25a6 100644
--- a/paddle3d/apis/trainer.py
+++ b/paddle3d/apis/trainer.py
@@ -113,76 +113,79 @@ def __init__(
**dataloader_fn) if isinstance(dataloader_fn,
dict) else dataloader_fn
- self.train_dataloader = _dataloader_build_fn(train_dataset, self.model)
+ self.train_dataloader = _dataloader_build_fn(
+ train_dataset, self.model) if train_dataset else None
self.eval_dataloader = _dataloader_build_fn(
val_dataset, self.model) if val_dataset else None
self.val_dataset = val_dataset
- self.resume = resume
- vdl_file_name = None
- self.iters_per_epoch = len(self.train_dataloader)
+ if train_dataset:
+ self.resume = resume
+ vdl_file_name = None
+ self.iters_per_epoch = len(self.train_dataloader)
- if iters is None:
- self.epochs = epochs
- self.iters = epochs * self.iters_per_epoch
- self.train_by_epoch = True
- else:
- self.iters = iters
- self.epochs = (iters - 1) // self.iters_per_epoch + 1
- self.train_by_epoch = False
-
- self.cur_iter = 0
- self.cur_epoch = 0
+ if iters is None:
+ self.epochs = epochs
+ self.iters = epochs * self.iters_per_epoch
+ self.train_by_epoch = True
+ else:
+ self.iters = iters
+ self.epochs = (iters - 1) // self.iters_per_epoch + 1
+ self.train_by_epoch = False
- if self.optimizer.__class__.__name__ == 'OneCycleAdam':
- self.optimizer.before_run(max_iters=self.iters)
+ self.cur_iter = 0
+ self.cur_epoch = 0
- self.checkpoint = default_checkpoint_build_fn(
- **checkpoint) if isinstance(checkpoint, dict) else checkpoint
+ if self.optimizer.__class__.__name__ == 'OneCycleAdam':
+ self.optimizer.before_run(max_iters=self.iters)
- if isinstance(scheduler, dict):
- scheduler.setdefault('train_by_epoch', self.train_by_epoch)
- scheduler.setdefault('iters_per_epoch', self.iters_per_epoch)
- self.scheduler = default_scheduler_build_fn(**scheduler)
- else:
- self.scheduler = scheduler
-
- if self.checkpoint is None:
- return
-
- if not self.checkpoint.empty:
- if not resume:
- raise RuntimeError(
- 'The checkpoint {} is not emtpy! Set `resume=True` to continue training or use another dir as checkpoint'
- .format(self.checkpoint.rootdir))
-
- if self.checkpoint.meta.get(
- 'train_by_epoch') != self.train_by_epoch:
- raise RuntimeError(
- 'Unable to resume training since the train_by_epoch is inconsistent with that saved in the checkpoint'
- )
-
- params_dict, opt_dict = self.checkpoint.get()
- self.model.set_dict(params_dict)
- self.optimizer.set_state_dict(opt_dict)
- self.cur_iter = self.checkpoint.meta.get('iters')
- self.cur_epoch = self.checkpoint.meta.get('epochs')
- self.scheduler.step(self.cur_iter)
-
- logger.info(
- 'Resume model from checkpoint {}, current iter set to {}'.
- format(self.checkpoint.rootdir, self.cur_iter))
- vdl_file_name = self.checkpoint.meta['vdl_file_name']
- elif resume:
- logger.warning(
- "Attempt to restore parameters from an empty checkpoint")
+ self.checkpoint = default_checkpoint_build_fn(
+ **checkpoint) if isinstance(checkpoint, dict) else checkpoint
- if env.local_rank == 0:
- self.log_writer = LogWriter(
- logdir=self.checkpoint.rootdir, file_name=vdl_file_name)
- self.checkpoint.record('vdl_file_name',
- os.path.basename(self.log_writer.file_name))
- self.checkpoint.record('train_by_epoch', self.train_by_epoch)
+ if isinstance(scheduler, dict):
+ scheduler.setdefault('train_by_epoch', self.train_by_epoch)
+ scheduler.setdefault('iters_per_epoch', self.iters_per_epoch)
+ self.scheduler = default_scheduler_build_fn(**scheduler)
+ else:
+ self.scheduler = scheduler
+
+ if self.checkpoint is None:
+ return
+
+ if not self.checkpoint.empty:
+ if not resume:
+ raise RuntimeError(
+ 'The checkpoint {} is not emtpy! Set `resume=True` to continue training or use another dir as checkpoint'
+ .format(self.checkpoint.rootdir))
+
+ if self.checkpoint.meta.get(
+ 'train_by_epoch') != self.train_by_epoch:
+ raise RuntimeError(
+ 'Unable to resume training since the train_by_epoch is inconsistent with that saved in the checkpoint'
+ )
+
+ params_dict, opt_dict = self.checkpoint.get()
+ self.model.set_dict(params_dict)
+ self.optimizer.set_state_dict(opt_dict)
+ self.cur_iter = self.checkpoint.meta.get('iters')
+ self.cur_epoch = self.checkpoint.meta.get('epochs')
+ self.scheduler.step(self.cur_iter)
+
+ logger.info(
+ 'Resume model from checkpoint {}, current iter set to {}'.
+ format(self.checkpoint.rootdir, self.cur_iter))
+ vdl_file_name = self.checkpoint.meta['vdl_file_name']
+ elif resume:
+ logger.warning(
+ "Attempt to restore parameters from an empty checkpoint")
+
+ if env.local_rank == 0:
+ self.log_writer = LogWriter(
+ logdir=self.checkpoint.rootdir, file_name=vdl_file_name)
+ self.checkpoint.record(
+ 'vdl_file_name',
+ os.path.basename(self.log_writer.file_name))
+ self.checkpoint.record('train_by_epoch', self.train_by_epoch)
def train(self):
"""
diff --git a/show_lidar_pred_on_image.py b/show_lidar_pred_on_image.py
new file mode 100644
index 00000000..eb3688a3
--- /dev/null
+++ b/show_lidar_pred_on_image.py
@@ -0,0 +1,236 @@
+import argparse
+import os
+import os.path as osp
+
+import cv2
+import numpy as np
+
+from paddle3d.datasets.kitti.kitti_utils import box_lidar_to_camera
+from paddle3d.geometries import BBoxes3D, CoordMode
+from paddle3d.sample import Sample
+
+classmap = {0: 'Car', 1: 'Cyclist', 2: 'Pedestrain'}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--calib_file', dest='calib_file', help='calibration file', type=str)
+ parser.add_argument(
+ '--image_file', dest='image_file', help='image file', type=str)
+ parser.add_argument(
+ '--label_file', dest='label_file', help='label file', type=str)
+ parser.add_argument(
+ '--pred_file',
+ dest='pred_file',
+ help='prediction results file',
+ type=str)
+ parser.add_argument(
+ '--save_dir',
+ dest='save_dir',
+ help='the path to save visualized result',
+ type=str)
+ parser.add_argument(
+ '--draw_threshold',
+ dest='draw_threshold',
+ help=
+ 'prediction whose confidence is lower than threshold would not been shown',
+ type=float)
+ return parser.parse_args()
+
+
+class Calib:
+ def __init__(self, dict_calib):
+ super(Calib, self).__init__()
+ self.P0 = dict_calib['P0'].reshape(3, 4)
+ self.P1 = dict_calib['P1'].reshape(3, 4)
+ self.P2 = dict_calib['P2'].reshape(3, 4)
+ self.P3 = dict_calib['P3'].reshape(3, 4)
+ self.R0_rect = dict_calib['R0_rect'].reshape(3, 3)
+ self.Tr_velo_to_cam = dict_calib['Tr_velo_to_cam'].reshape(3, 4)
+ self.Tr_imu_to_velo = dict_calib['Tr_imu_to_velo'].reshape(3, 4)
+
+
+class Object3d:
+ def __init__(self, content):
+ super(Object3d, self).__init__()
+ lines = content.split()
+ lines = list(filter(lambda x: len(x), lines))
+ self.name, self.truncated, self.occluded, self.alpha = lines[0], float(
+ lines[1]), float(lines[2]), float(lines[3])
+ self.bbox = [lines[4], lines[5], lines[6], lines[7]]
+ self.bbox = np.array([float(x) for x in self.bbox])
+ self.dimensions = [lines[8], lines[9], lines[10]]
+ self.dimensions = np.array([float(x) for x in self.dimensions])
+ self.location = [lines[11], lines[12], lines[13]]
+ self.location = np.array([float(x) for x in self.location])
+ self.rotation_y = float(lines[14])
+ if len(lines) == 16:
+ self.score = float(lines[15])
+
+
+def rot_y(rotation_y):
+ cos = np.cos(rotation_y)
+ sin = np.sin(rotation_y)
+ R = np.array([[cos, 0, sin], [0, 1, 0], [-sin, 0, cos]])
+ return R
+
+
+def parse_gt_info(calib_path, label_path):
+
+ with open(calib_path) as f:
+ lines = f.readlines()
+ lines = list(filter(lambda x: len(x) and x != '\n', lines))
+ dict_calib = {}
+ for line in lines:
+ key, value = line.split(":")
+ dict_calib[key] = np.array([float(x) for x in value.split()])
+ calib = Calib(dict_calib)
+
+ with open(label_path, 'r') as f:
+ lines = f.readlines()
+ lines = list(filter(lambda x: len(x) and x != '\n', lines))
+ obj = [Object3d(x) for x in lines]
+ return calib, obj
+
+
+def predictions_to_kitti_format(pred):
+ num_boxes = pred.bboxes_3d.shape[0]
+ names = np.array([classmap[label] for label in pred.labels])
+ calibs = pred.calibs
+ if pred.bboxes_3d.coordmode != CoordMode.KittiCamera:
+ bboxes_3d = box_lidar_to_camera(pred.bboxes_3d, calibs)
+ else:
+ bboxes_3d = pred.bboxes_3d
+
+ if bboxes_3d.origin != [.5, 1., .5]:
+ bboxes_3d[:, :3] += bboxes_3d[:, 3:6] * (
+ np.array([.5, 1., .5]) - np.array(bboxes_3d.origin))
+ bboxes_3d.origin = [.5, 1., .5]
+
+ loc = bboxes_3d[:, :3]
+ dim = bboxes_3d[:, 3:6]
+
+ contents = []
+ for i in range(num_boxes):
+ # In kitti records, dimensions order is hwl format
+ content = "{} 0 0 0 0 0 0 0 {} {} {} {} {} {} {} {}".format(
+ names[i], dim[i, 2], dim[i, 1], dim[i, 0], loc[i, 0], loc[i, 1],
+ loc[i, 2], bboxes_3d[i, 6], pred.confidences[i])
+ contents.append(content)
+
+ obj = [Object3d(x) for x in contents]
+ return obj
+
+
+def parse_pred_info(pred_path, calib):
+ with open(pred_path, 'r') as f:
+ lines = f.readlines()
+ lines = list(filter(lambda x: len(x) and x != '\n', lines))
+
+ scores = []
+ labels = []
+ boxes_3d = []
+ for res in lines:
+ score = float(res.split("Score: ")[-1].split(" ")[0])
+ label = int(res.split("Label: ")[-1].split(" ")[0])
+ box_3d = res.split("Box(x_c, y_c, z_c, w, l, h, -rot): ")[-1].split(" ")
+ box_3d = [float(b) for b in box_3d]
+ scores.append(score)
+ labels.append(label)
+ boxes_3d.append(box_3d)
+ scores = np.stack(scores)
+ labels = np.stack(labels)
+ boxes_3d = np.stack(boxes_3d)
+ data = Sample(pred_path, 'lidar')
+ data.bboxes_3d = BBoxes3D(boxes_3d)
+ data.bboxes_3d.coordmode = 'Lidar'
+ data.bboxes_3d.origin = [0.5, 0.5, 0.5]
+ data.bboxes_3d.rot_axis = 2
+ data.labels = labels
+ data.confidences = scores
+ data.calibs = calib
+
+ return data
+
+
+def visualize(image_path, calib, obj, title, draw_threshold=None):
+ img = cv2.imread(image_path)
+ for i in range(len(obj)):
+ if obj[i].name in ['Car', 'Pedestrian', 'Cyclist']:
+ if draw_threshold is not None and hasattr(obj[i], 'score'):
+ if obj[i].score < draw_threshold:
+ continue
+ R = rot_y(obj[i].rotation_y)
+ h, w, l = obj[i].dimensions[0], obj[i].dimensions[1], obj[
+ i].dimensions[2]
+ x = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2]
+ y = [0, 0, 0, 0, -h, -h, -h, -h]
+ z = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2]
+ corner_3d = np.vstack([x, y, z])
+ corner_3d = np.dot(R, corner_3d)
+
+ corner_3d[0, :] += obj[i].location[0]
+ corner_3d[1, :] += obj[i].location[1]
+ corner_3d[2, :] += obj[i].location[2]
+
+ corner_3d = np.vstack((corner_3d, np.zeros((1,
+ corner_3d.shape[-1]))))
+ corner_2d = np.dot(calib.P2, corner_3d)
+ corner_2d[0, :] /= corner_2d[2, :]
+ corner_2d[1, :] /= corner_2d[2, :]
+
+ if obj[i].name == 'Car':
+ color = [20, 20, 255]
+ elif obj[i].name == 'Pedestrian':
+ color = [0, 255, 255]
+ else:
+ color = [255, 0, 255]
+
+ thickness = 1
+ for corner_i in range(0, 4):
+ ii, ij = corner_i, (corner_i + 1) % 4
+ corner_2d = corner_2d.astype('int32')
+ cv2.line(img, (corner_2d[0, ii], corner_2d[1, ii]),
+ (corner_2d[0, ij], corner_2d[1, ij]), color, thickness)
+ ii, ij = corner_i + 4, (corner_i + 1) % 4 + 4
+ cv2.line(img, (corner_2d[0, ii], corner_2d[1, ii]),
+ (corner_2d[0, ij], corner_2d[1, ij]), color, thickness)
+ ii, ij = corner_i, corner_i + 4
+ cv2.line(img, (corner_2d[0, ii], corner_2d[1, ii]),
+ (corner_2d[0, ij], corner_2d[1, ij]), color, thickness)
+ box_text = obj[i].name
+ if hasattr(obj[i], 'score'):
+ box_text += ': {:.2}'.format(obj[i].score)
+ cv2.putText(img, box_text,
+ (min(corner_2d[0, :]), min(corner_2d[1, :]) - 2),
+ cv2.FONT_HERSHEY_COMPLEX_SMALL, 0.5, color, 1)
+ cv2.putText(img, title, (int(img.shape[1] / 2), 20),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 100, 0), 2)
+
+ return img
+
+
+def main(args):
+ calib, gt_obj = parse_gt_info(args.calib_file, args.label_file)
+ gt_image = visualize(args.image_file, calib, gt_obj, title='GroundTruth')
+ pred = parse_pred_info(args.pred_file, [
+ calib.P0, calib.P1, calib.P2, calib.P3, calib.R0_rect,
+ calib.Tr_velo_to_cam, calib.Tr_imu_to_velo
+ ])
+ preds = predictions_to_kitti_format(pred)
+ pred_image = visualize(
+ args.image_file,
+ calib,
+ preds,
+ title='Prediction',
+ draw_threshold=args.draw_threshold)
+ show_image = np.vstack([gt_image, pred_image])
+ cv2.imwrite(
+ osp.join(args.save_dir,
+ osp.split(args.image_file)[-1]), show_image)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
diff --git a/tools/evaluate.py b/tools/evaluate.py
index 692dfa1b..e5a124f3 100644
--- a/tools/evaluate.py
+++ b/tools/evaluate.py
@@ -80,7 +80,8 @@ def main(args):
'dataloader_fn': {
'batch_size': batch_size,
'num_workers': args.num_workers
- }
+ },
+ 'train_dataset': None
})
if args.model is not None:
From e514330060260b3c603edaca4992d6638c0777f1 Mon Sep 17 00:00:00 2001
From: FlyingQianMM <245467267@qq.com>
Date: Wed, 24 Aug 2022 11:46:18 +0800
Subject: [PATCH 5/5] delete needless file
---
show_lidar_pred_on_image.py | 236 ------------------------------------
1 file changed, 236 deletions(-)
delete mode 100644 show_lidar_pred_on_image.py
diff --git a/show_lidar_pred_on_image.py b/show_lidar_pred_on_image.py
deleted file mode 100644
index eb3688a3..00000000
--- a/show_lidar_pred_on_image.py
+++ /dev/null
@@ -1,236 +0,0 @@
-import argparse
-import os
-import os.path as osp
-
-import cv2
-import numpy as np
-
-from paddle3d.datasets.kitti.kitti_utils import box_lidar_to_camera
-from paddle3d.geometries import BBoxes3D, CoordMode
-from paddle3d.sample import Sample
-
-classmap = {0: 'Car', 1: 'Cyclist', 2: 'Pedestrain'}
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--calib_file', dest='calib_file', help='calibration file', type=str)
- parser.add_argument(
- '--image_file', dest='image_file', help='image file', type=str)
- parser.add_argument(
- '--label_file', dest='label_file', help='label file', type=str)
- parser.add_argument(
- '--pred_file',
- dest='pred_file',
- help='prediction results file',
- type=str)
- parser.add_argument(
- '--save_dir',
- dest='save_dir',
- help='the path to save visualized result',
- type=str)
- parser.add_argument(
- '--draw_threshold',
- dest='draw_threshold',
- help=
- 'prediction whose confidence is lower than threshold would not been shown',
- type=float)
- return parser.parse_args()
-
-
-class Calib:
- def __init__(self, dict_calib):
- super(Calib, self).__init__()
- self.P0 = dict_calib['P0'].reshape(3, 4)
- self.P1 = dict_calib['P1'].reshape(3, 4)
- self.P2 = dict_calib['P2'].reshape(3, 4)
- self.P3 = dict_calib['P3'].reshape(3, 4)
- self.R0_rect = dict_calib['R0_rect'].reshape(3, 3)
- self.Tr_velo_to_cam = dict_calib['Tr_velo_to_cam'].reshape(3, 4)
- self.Tr_imu_to_velo = dict_calib['Tr_imu_to_velo'].reshape(3, 4)
-
-
-class Object3d:
- def __init__(self, content):
- super(Object3d, self).__init__()
- lines = content.split()
- lines = list(filter(lambda x: len(x), lines))
- self.name, self.truncated, self.occluded, self.alpha = lines[0], float(
- lines[1]), float(lines[2]), float(lines[3])
- self.bbox = [lines[4], lines[5], lines[6], lines[7]]
- self.bbox = np.array([float(x) for x in self.bbox])
- self.dimensions = [lines[8], lines[9], lines[10]]
- self.dimensions = np.array([float(x) for x in self.dimensions])
- self.location = [lines[11], lines[12], lines[13]]
- self.location = np.array([float(x) for x in self.location])
- self.rotation_y = float(lines[14])
- if len(lines) == 16:
- self.score = float(lines[15])
-
-
-def rot_y(rotation_y):
- cos = np.cos(rotation_y)
- sin = np.sin(rotation_y)
- R = np.array([[cos, 0, sin], [0, 1, 0], [-sin, 0, cos]])
- return R
-
-
-def parse_gt_info(calib_path, label_path):
-
- with open(calib_path) as f:
- lines = f.readlines()
- lines = list(filter(lambda x: len(x) and x != '\n', lines))
- dict_calib = {}
- for line in lines:
- key, value = line.split(":")
- dict_calib[key] = np.array([float(x) for x in value.split()])
- calib = Calib(dict_calib)
-
- with open(label_path, 'r') as f:
- lines = f.readlines()
- lines = list(filter(lambda x: len(x) and x != '\n', lines))
- obj = [Object3d(x) for x in lines]
- return calib, obj
-
-
-def predictions_to_kitti_format(pred):
- num_boxes = pred.bboxes_3d.shape[0]
- names = np.array([classmap[label] for label in pred.labels])
- calibs = pred.calibs
- if pred.bboxes_3d.coordmode != CoordMode.KittiCamera:
- bboxes_3d = box_lidar_to_camera(pred.bboxes_3d, calibs)
- else:
- bboxes_3d = pred.bboxes_3d
-
- if bboxes_3d.origin != [.5, 1., .5]:
- bboxes_3d[:, :3] += bboxes_3d[:, 3:6] * (
- np.array([.5, 1., .5]) - np.array(bboxes_3d.origin))
- bboxes_3d.origin = [.5, 1., .5]
-
- loc = bboxes_3d[:, :3]
- dim = bboxes_3d[:, 3:6]
-
- contents = []
- for i in range(num_boxes):
- # In kitti records, dimensions order is hwl format
- content = "{} 0 0 0 0 0 0 0 {} {} {} {} {} {} {} {}".format(
- names[i], dim[i, 2], dim[i, 1], dim[i, 0], loc[i, 0], loc[i, 1],
- loc[i, 2], bboxes_3d[i, 6], pred.confidences[i])
- contents.append(content)
-
- obj = [Object3d(x) for x in contents]
- return obj
-
-
-def parse_pred_info(pred_path, calib):
- with open(pred_path, 'r') as f:
- lines = f.readlines()
- lines = list(filter(lambda x: len(x) and x != '\n', lines))
-
- scores = []
- labels = []
- boxes_3d = []
- for res in lines:
- score = float(res.split("Score: ")[-1].split(" ")[0])
- label = int(res.split("Label: ")[-1].split(" ")[0])
- box_3d = res.split("Box(x_c, y_c, z_c, w, l, h, -rot): ")[-1].split(" ")
- box_3d = [float(b) for b in box_3d]
- scores.append(score)
- labels.append(label)
- boxes_3d.append(box_3d)
- scores = np.stack(scores)
- labels = np.stack(labels)
- boxes_3d = np.stack(boxes_3d)
- data = Sample(pred_path, 'lidar')
- data.bboxes_3d = BBoxes3D(boxes_3d)
- data.bboxes_3d.coordmode = 'Lidar'
- data.bboxes_3d.origin = [0.5, 0.5, 0.5]
- data.bboxes_3d.rot_axis = 2
- data.labels = labels
- data.confidences = scores
- data.calibs = calib
-
- return data
-
-
-def visualize(image_path, calib, obj, title, draw_threshold=None):
- img = cv2.imread(image_path)
- for i in range(len(obj)):
- if obj[i].name in ['Car', 'Pedestrian', 'Cyclist']:
- if draw_threshold is not None and hasattr(obj[i], 'score'):
- if obj[i].score < draw_threshold:
- continue
- R = rot_y(obj[i].rotation_y)
- h, w, l = obj[i].dimensions[0], obj[i].dimensions[1], obj[
- i].dimensions[2]
- x = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2]
- y = [0, 0, 0, 0, -h, -h, -h, -h]
- z = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2]
- corner_3d = np.vstack([x, y, z])
- corner_3d = np.dot(R, corner_3d)
-
- corner_3d[0, :] += obj[i].location[0]
- corner_3d[1, :] += obj[i].location[1]
- corner_3d[2, :] += obj[i].location[2]
-
- corner_3d = np.vstack((corner_3d, np.zeros((1,
- corner_3d.shape[-1]))))
- corner_2d = np.dot(calib.P2, corner_3d)
- corner_2d[0, :] /= corner_2d[2, :]
- corner_2d[1, :] /= corner_2d[2, :]
-
- if obj[i].name == 'Car':
- color = [20, 20, 255]
- elif obj[i].name == 'Pedestrian':
- color = [0, 255, 255]
- else:
- color = [255, 0, 255]
-
- thickness = 1
- for corner_i in range(0, 4):
- ii, ij = corner_i, (corner_i + 1) % 4
- corner_2d = corner_2d.astype('int32')
- cv2.line(img, (corner_2d[0, ii], corner_2d[1, ii]),
- (corner_2d[0, ij], corner_2d[1, ij]), color, thickness)
- ii, ij = corner_i + 4, (corner_i + 1) % 4 + 4
- cv2.line(img, (corner_2d[0, ii], corner_2d[1, ii]),
- (corner_2d[0, ij], corner_2d[1, ij]), color, thickness)
- ii, ij = corner_i, corner_i + 4
- cv2.line(img, (corner_2d[0, ii], corner_2d[1, ii]),
- (corner_2d[0, ij], corner_2d[1, ij]), color, thickness)
- box_text = obj[i].name
- if hasattr(obj[i], 'score'):
- box_text += ': {:.2}'.format(obj[i].score)
- cv2.putText(img, box_text,
- (min(corner_2d[0, :]), min(corner_2d[1, :]) - 2),
- cv2.FONT_HERSHEY_COMPLEX_SMALL, 0.5, color, 1)
- cv2.putText(img, title, (int(img.shape[1] / 2), 20),
- cv2.FONT_HERSHEY_SIMPLEX, 0.75, (255, 100, 0), 2)
-
- return img
-
-
-def main(args):
- calib, gt_obj = parse_gt_info(args.calib_file, args.label_file)
- gt_image = visualize(args.image_file, calib, gt_obj, title='GroundTruth')
- pred = parse_pred_info(args.pred_file, [
- calib.P0, calib.P1, calib.P2, calib.P3, calib.R0_rect,
- calib.Tr_velo_to_cam, calib.Tr_imu_to_velo
- ])
- preds = predictions_to_kitti_format(pred)
- pred_image = visualize(
- args.image_file,
- calib,
- preds,
- title='Prediction',
- draw_threshold=args.draw_threshold)
- show_image = np.vstack([gt_image, pred_image])
- cv2.imwrite(
- osp.join(args.save_dir,
- osp.split(args.image_file)[-1]), show_image)
-
-
-if __name__ == '__main__':
- args = parse_args()
- main(args)