Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions tools/demo.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import argparse
import glob
from pathlib import Path

try:
import open3d
from visual_utils import open3d_vis_utils as V
OPEN3D_FLAG = True
except:
import mayavi.mlab as mlab
from visual_utils import visualize_utils as V
OPEN3D_FLAG = False

import time
import open3d

OPEN3D_FLAG = True
import numpy as np
import torch

from pcdet.config import cfg, cfg_from_yaml_file
from pcdet.datasets import DatasetTemplate
from pcdet.models import build_network, load_data_to_gpu
from pcdet.utils import common_utils
from visual_utils.open3d_vis_utils import draw_scenes


class DemoDataset(DatasetTemplate):
Expand Down Expand Up @@ -90,23 +85,38 @@ def main():
model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=True)
model.cuda()
model.eval()

# Initialize visualizer once and reuse it
vis = open3d.visualization.Visualizer()
vis.create_window()
vis.get_render_option().point_size = 1.0
vis.get_render_option().background_color = np.zeros(3)

with torch.no_grad():
for idx, data_dict in enumerate(demo_dataset):
logger.info(f'Visualized sample index: \t{idx + 1}')
logger.info(f'Visualizing sample index: \t{idx + 1}')
data_dict = demo_dataset.collate_batch([data_dict])
load_data_to_gpu(data_dict)
pred_dicts, _ = model.forward(data_dict)

V.draw_scenes(
points=data_dict['points'][:, 1:], ref_boxes=pred_dicts[0]['pred_boxes'],
ref_scores=pred_dicts[0]['pred_scores'], ref_labels=pred_dicts[0]['pred_labels']
# Update visualization
vis = draw_scenes(
vis=vis,
points=data_dict['points'][:, 1:],
ref_boxes=pred_dicts[0]['pred_boxes'],
ref_scores=pred_dicts[0]['pred_scores'],
ref_labels=pred_dicts[0]['pred_labels']
)

if not OPEN3D_FLAG:
mlab.show(stop=True)
# Add small delay (e.g., 50ms) between frames
time.sleep(0.91)

logger.info('Demo done.')
# Check if window is closed
if not vis.poll_events():
break

vis.destroy_window()
logger.info('Demo done.')

if __name__ == '__main__':
main()
main()
19 changes: 9 additions & 10 deletions tools/visual_utils/open3d_vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,16 @@ def get_coor_colors(obj_labels):
return label_rgba


def draw_scenes(points, gt_boxes=None, ref_boxes=None, ref_labels=None, ref_scores=None, point_colors=None, draw_origin=True):
def draw_scenes(vis, points, gt_boxes=None, ref_boxes=None, ref_labels=None, ref_scores=None, point_colors=None, draw_origin=True):
if isinstance(points, torch.Tensor):
points = points.cpu().numpy()
if isinstance(gt_boxes, torch.Tensor):
gt_boxes = gt_boxes.cpu().numpy()
if isinstance(ref_boxes, torch.Tensor):
ref_boxes = ref_boxes.cpu().numpy()

vis = open3d.visualization.Visualizer()
vis.create_window()

vis.get_render_option().point_size = 1.0
vis.get_render_option().background_color = np.zeros(3)
# Clear all geometries
vis.clear_geometries()

# draw origin
if draw_origin:
Expand All @@ -69,9 +66,11 @@ def draw_scenes(points, gt_boxes=None, ref_boxes=None, ref_labels=None, ref_scor
if ref_boxes is not None:
vis = draw_box(vis, ref_boxes, (0, 1, 0), ref_labels, ref_scores)

vis.run()
vis.destroy_window()

# Update the renderer
vis.update_renderer()
vis.poll_events()

return vis

def translate_boxes_to_open3d_instance(gt_boxes):
"""
Expand Down Expand Up @@ -113,4 +112,4 @@ def draw_box(vis, gt_boxes, color=(0, 1, 0), ref_labels=None, score=None):
# if score is not None:
# corners = box3d.get_box_points()
# vis.add_3d_label(corners[5], '%.2f' % score[i])
return vis
return vis