Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 100 additions & 105 deletions src/mjlab/sensor/raycast_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,25 +560,6 @@ def initialize(
if self._use_cuda_graph:
self._create_graph()

def _create_graph(self) -> None:
"""Capture CUDA graph for raycast operation."""
assert self._wp_device is not None and self._wp_device.is_cuda
with wp.ScopedDevice(self._wp_device):
with wp.ScopedCapture() as capture:
rays(
m=self._model.struct, # type: ignore[attr-defined]
d=self._data.struct, # type: ignore[attr-defined]
pnt=self._ray_pnt,
vec=self._ray_vec,
geomgroup=self._geomgroup,
flg_static=True,
bodyexclude=self._ray_bodyexclude,
dist=self._ray_dist,
geomid=self._ray_geomid,
normal=self._ray_normal,
)
self._raycast_graph = capture.graph

@property
def data(self) -> RayCastData:
self._perform_raycast()
Expand All @@ -597,6 +578,105 @@ def data(self) -> RayCastData:
def num_rays(self) -> int:
return self._num_rays

def debug_vis(self, visualizer: DebugVisualizer) -> None:
if not self.cfg.debug_vis:
return
assert self._data is not None
assert self._local_offsets is not None
assert self._local_directions is not None

env_idx = visualizer.env_idx
data = self.data

if self._frame_type == "body":
frame_pos = self._data.xpos[env_idx, self._frame_body_id].cpu().numpy()
frame_mat_tensor = self._data.xmat[env_idx, self._frame_body_id].view(3, 3)
elif self._frame_type == "site":
frame_pos = self._data.site_xpos[env_idx, self._frame_site_id].cpu().numpy()
frame_mat_tensor = self._data.site_xmat[env_idx, self._frame_site_id].view(3, 3)
else: # geom
frame_pos = self._data.geom_xpos[env_idx, self._frame_geom_id].cpu().numpy()
frame_mat_tensor = self._data.geom_xmat[env_idx, self._frame_geom_id].view(3, 3)

# Apply ray alignment for visualization.
rot_mat_tensor = self._compute_alignment_rotation(frame_mat_tensor.unsqueeze(0))[0]
rot_mat = rot_mat_tensor.cpu().numpy()

local_offsets_np = self._local_offsets.cpu().numpy()
local_dirs_np = self._local_directions.cpu().numpy()
hit_positions_np = data.hit_pos_w[env_idx].cpu().numpy()
distances_np = data.distances[env_idx].cpu().numpy()
normals_np = data.normals_w[env_idx].cpu().numpy()

meansize = visualizer.meansize
ray_width = 0.1 * meansize
sphere_radius = self.cfg.viz.hit_sphere_radius * meansize
normal_length = self.cfg.viz.normal_length * meansize
normal_width = 0.1 * meansize

for i in range(self._num_rays):
origin = frame_pos + rot_mat @ local_offsets_np[i]
hit = distances_np[i] >= 0

if hit:
end = hit_positions_np[i]
color = self.cfg.viz.hit_color
else:
direction = rot_mat @ local_dirs_np[i]
end = origin + direction * min(0.5, self.cfg.max_distance * 0.05)
color = self.cfg.viz.miss_color

if self.cfg.viz.show_rays:
visualizer.add_arrow(
start=origin,
end=end,
color=color,
width=ray_width,
label=f"{self.cfg.name}_ray_{i}",
)

if hit:
visualizer.add_sphere(
center=end,
radius=sphere_radius,
color=self.cfg.viz.hit_sphere_color,
label=f"{self.cfg.name}_hit_{i}",
)
if self.cfg.viz.show_normals:
normal_end = end + normals_np[i] * normal_length
visualizer.add_arrow(
start=end,
end=normal_end,
color=self.cfg.viz.normal_color,
width=normal_width,
label=f"{self.cfg.name}_normal_{i}",
)

# Private methods.

def _create_graph(self) -> None:
"""Capture CUDA graph for raycast operation."""
assert self._wp_device is not None and self._wp_device.is_cuda
with wp.ScopedDevice(self._wp_device):
with wp.ScopedCapture() as capture:
self._raycast_direct()
self._raycast_graph = capture.graph

def _raycast_direct(self) -> None:
"""Execute raycast kernel directly."""
rays(
m=self._model.struct, # type: ignore[attr-defined]
d=self._data.struct, # type: ignore[attr-defined]
pnt=self._ray_pnt,
vec=self._ray_vec,
geomgroup=self._geomgroup,
flg_static=True,
bodyexclude=self._ray_bodyexclude,
dist=self._ray_dist,
geomid=self._ray_geomid,
normal=self._ray_normal,
)

def _perform_raycast(self) -> None:
assert self._data is not None and self._model is not None
assert self._local_offsets is not None and self._local_directions is not None
Expand Down Expand Up @@ -632,18 +712,7 @@ def _perform_raycast(self) -> None:
with wp.ScopedDevice(self._wp_device):
wp.capture_launch(self._raycast_graph)
else:
rays(
m=self._model.struct, # type: ignore[attr-defined]
d=self._data.struct, # type: ignore[attr-defined]
pnt=self._ray_pnt,
vec=self._ray_vec,
geomgroup=self._geomgroup,
flg_static=True,
bodyexclude=self._ray_bodyexclude,
dist=self._ray_dist,
geomid=self._ray_geomid,
normal=self._ray_normal,
)
self._raycast_direct()

self._distances = wp.to_torch(self._ray_dist)
self._normals_w = wp.to_torch(self._ray_normal).view(num_envs, self._num_rays, 3)
Expand Down Expand Up @@ -728,77 +797,3 @@ def _extract_yaw_rotation(self, rot_mat: torch.Tensor) -> torch.Tensor:
yaw_mat[:, 1, 1] = x_proj[:, 0]
yaw_mat[:, 2, 2] = 1
return yaw_mat

def debug_vis(self, visualizer: DebugVisualizer) -> None:
if not self.cfg.debug_vis:
return
assert self._data is not None
assert self._local_offsets is not None
assert self._local_directions is not None

env_idx = visualizer.env_idx
data = self.data

if self._frame_type == "body":
frame_pos = self._data.xpos[env_idx, self._frame_body_id].cpu().numpy()
frame_mat_tensor = self._data.xmat[env_idx, self._frame_body_id].view(3, 3)
elif self._frame_type == "site":
frame_pos = self._data.site_xpos[env_idx, self._frame_site_id].cpu().numpy()
frame_mat_tensor = self._data.site_xmat[env_idx, self._frame_site_id].view(3, 3)
else: # geom
frame_pos = self._data.geom_xpos[env_idx, self._frame_geom_id].cpu().numpy()
frame_mat_tensor = self._data.geom_xmat[env_idx, self._frame_geom_id].view(3, 3)

# Apply ray alignment for visualization.
rot_mat_tensor = self._compute_alignment_rotation(frame_mat_tensor.unsqueeze(0))[0]
rot_mat = rot_mat_tensor.cpu().numpy()

local_offsets_np = self._local_offsets.cpu().numpy()
local_dirs_np = self._local_directions.cpu().numpy()
hit_positions_np = data.hit_pos_w[env_idx].cpu().numpy()
distances_np = data.distances[env_idx].cpu().numpy()
normals_np = data.normals_w[env_idx].cpu().numpy()

meansize = visualizer.meansize
ray_width = 0.1 * meansize
sphere_radius = self.cfg.viz.hit_sphere_radius * meansize
normal_length = self.cfg.viz.normal_length * meansize
normal_width = 0.1 * meansize

for i in range(self._num_rays):
origin = frame_pos + rot_mat @ local_offsets_np[i]
hit = distances_np[i] >= 0

if hit:
end = hit_positions_np[i]
color = self.cfg.viz.hit_color
else:
direction = rot_mat @ local_dirs_np[i]
end = origin + direction * min(0.5, self.cfg.max_distance * 0.05)
color = self.cfg.viz.miss_color

if self.cfg.viz.show_rays:
visualizer.add_arrow(
start=origin,
end=end,
color=color,
width=ray_width,
label=f"{self.cfg.name}_ray_{i}",
)

if hit:
visualizer.add_sphere(
center=end,
radius=sphere_radius,
color=self.cfg.viz.hit_sphere_color,
label=f"{self.cfg.name}_hit_{i}",
)
if self.cfg.viz.show_normals:
normal_end = end + normals_np[i] * normal_length
visualizer.add_arrow(
start=end,
end=normal_end,
color=self.cfg.viz.normal_color,
width=normal_width,
label=f"{self.cfg.name}_normal_{i}",
)