diff --git a/src/mjlab/sensor/raycast_sensor.py b/src/mjlab/sensor/raycast_sensor.py index 8b6a835a8..58b609575 100644 --- a/src/mjlab/sensor/raycast_sensor.py +++ b/src/mjlab/sensor/raycast_sensor.py @@ -560,25 +560,6 @@ def initialize( if self._use_cuda_graph: self._create_graph() - def _create_graph(self) -> None: - """Capture CUDA graph for raycast operation.""" - assert self._wp_device is not None and self._wp_device.is_cuda - with wp.ScopedDevice(self._wp_device): - with wp.ScopedCapture() as capture: - rays( - m=self._model.struct, # type: ignore[attr-defined] - d=self._data.struct, # type: ignore[attr-defined] - pnt=self._ray_pnt, - vec=self._ray_vec, - geomgroup=self._geomgroup, - flg_static=True, - bodyexclude=self._ray_bodyexclude, - dist=self._ray_dist, - geomid=self._ray_geomid, - normal=self._ray_normal, - ) - self._raycast_graph = capture.graph - @property def data(self) -> RayCastData: self._perform_raycast() @@ -597,6 +578,105 @@ def data(self) -> RayCastData: def num_rays(self) -> int: return self._num_rays + def debug_vis(self, visualizer: DebugVisualizer) -> None: + if not self.cfg.debug_vis: + return + assert self._data is not None + assert self._local_offsets is not None + assert self._local_directions is not None + + env_idx = visualizer.env_idx + data = self.data + + if self._frame_type == "body": + frame_pos = self._data.xpos[env_idx, self._frame_body_id].cpu().numpy() + frame_mat_tensor = self._data.xmat[env_idx, self._frame_body_id].view(3, 3) + elif self._frame_type == "site": + frame_pos = self._data.site_xpos[env_idx, self._frame_site_id].cpu().numpy() + frame_mat_tensor = self._data.site_xmat[env_idx, self._frame_site_id].view(3, 3) + else: # geom + frame_pos = self._data.geom_xpos[env_idx, self._frame_geom_id].cpu().numpy() + frame_mat_tensor = self._data.geom_xmat[env_idx, self._frame_geom_id].view(3, 3) + + # Apply ray alignment for visualization. + rot_mat_tensor = self._compute_alignment_rotation(frame_mat_tensor.unsqueeze(0))[0] + rot_mat = rot_mat_tensor.cpu().numpy() + + local_offsets_np = self._local_offsets.cpu().numpy() + local_dirs_np = self._local_directions.cpu().numpy() + hit_positions_np = data.hit_pos_w[env_idx].cpu().numpy() + distances_np = data.distances[env_idx].cpu().numpy() + normals_np = data.normals_w[env_idx].cpu().numpy() + + meansize = visualizer.meansize + ray_width = 0.1 * meansize + sphere_radius = self.cfg.viz.hit_sphere_radius * meansize + normal_length = self.cfg.viz.normal_length * meansize + normal_width = 0.1 * meansize + + for i in range(self._num_rays): + origin = frame_pos + rot_mat @ local_offsets_np[i] + hit = distances_np[i] >= 0 + + if hit: + end = hit_positions_np[i] + color = self.cfg.viz.hit_color + else: + direction = rot_mat @ local_dirs_np[i] + end = origin + direction * min(0.5, self.cfg.max_distance * 0.05) + color = self.cfg.viz.miss_color + + if self.cfg.viz.show_rays: + visualizer.add_arrow( + start=origin, + end=end, + color=color, + width=ray_width, + label=f"{self.cfg.name}_ray_{i}", + ) + + if hit: + visualizer.add_sphere( + center=end, + radius=sphere_radius, + color=self.cfg.viz.hit_sphere_color, + label=f"{self.cfg.name}_hit_{i}", + ) + if self.cfg.viz.show_normals: + normal_end = end + normals_np[i] * normal_length + visualizer.add_arrow( + start=end, + end=normal_end, + color=self.cfg.viz.normal_color, + width=normal_width, + label=f"{self.cfg.name}_normal_{i}", + ) + + # Private methods. + + def _create_graph(self) -> None: + """Capture CUDA graph for raycast operation.""" + assert self._wp_device is not None and self._wp_device.is_cuda + with wp.ScopedDevice(self._wp_device): + with wp.ScopedCapture() as capture: + self._raycast_direct() + self._raycast_graph = capture.graph + + def _raycast_direct(self) -> None: + """Execute raycast kernel directly.""" + rays( + m=self._model.struct, # type: ignore[attr-defined] + d=self._data.struct, # type: ignore[attr-defined] + pnt=self._ray_pnt, + vec=self._ray_vec, + geomgroup=self._geomgroup, + flg_static=True, + bodyexclude=self._ray_bodyexclude, + dist=self._ray_dist, + geomid=self._ray_geomid, + normal=self._ray_normal, + ) + def _perform_raycast(self) -> None: assert self._data is not None and self._model is not None assert self._local_offsets is not None and self._local_directions is not None @@ -632,18 +712,7 @@ def _perform_raycast(self) -> None: with wp.ScopedDevice(self._wp_device): wp.capture_launch(self._raycast_graph) else: - rays( - m=self._model.struct, # type: ignore[attr-defined] - d=self._data.struct, # type: ignore[attr-defined] - pnt=self._ray_pnt, - vec=self._ray_vec, - geomgroup=self._geomgroup, - flg_static=True, - bodyexclude=self._ray_bodyexclude, - dist=self._ray_dist, - geomid=self._ray_geomid, - normal=self._ray_normal, - ) + self._raycast_direct() self._distances = wp.to_torch(self._ray_dist) self._normals_w = wp.to_torch(self._ray_normal).view(num_envs, self._num_rays, 3) @@ -728,77 +797,3 @@ def _extract_yaw_rotation(self, rot_mat: torch.Tensor) -> torch.Tensor: yaw_mat[:, 1, 1] = x_proj[:, 0] yaw_mat[:, 2, 2] = 1 return yaw_mat - - def debug_vis(self, visualizer: DebugVisualizer) -> None: - if not self.cfg.debug_vis: - return - assert self._data is not None - assert self._local_offsets is not None - assert self._local_directions is not None - - env_idx = visualizer.env_idx - data = self.data - - if self._frame_type == "body": - frame_pos = self._data.xpos[env_idx, self._frame_body_id].cpu().numpy() - frame_mat_tensor = self._data.xmat[env_idx, self._frame_body_id].view(3, 3) - elif self._frame_type == "site": - frame_pos = self._data.site_xpos[env_idx, self._frame_site_id].cpu().numpy() - frame_mat_tensor = self._data.site_xmat[env_idx, self._frame_site_id].view(3, 3) - else: # geom - frame_pos = self._data.geom_xpos[env_idx, self._frame_geom_id].cpu().numpy() - frame_mat_tensor = self._data.geom_xmat[env_idx, self._frame_geom_id].view(3, 3) - - # Apply ray alignment for visualization. - rot_mat_tensor = self._compute_alignment_rotation(frame_mat_tensor.unsqueeze(0))[0] - rot_mat = rot_mat_tensor.cpu().numpy() - - local_offsets_np = self._local_offsets.cpu().numpy() - local_dirs_np = self._local_directions.cpu().numpy() - hit_positions_np = data.hit_pos_w[env_idx].cpu().numpy() - distances_np = data.distances[env_idx].cpu().numpy() - normals_np = data.normals_w[env_idx].cpu().numpy() - - meansize = visualizer.meansize - ray_width = 0.1 * meansize - sphere_radius = self.cfg.viz.hit_sphere_radius * meansize - normal_length = self.cfg.viz.normal_length * meansize - normal_width = 0.1 * meansize - - for i in range(self._num_rays): - origin = frame_pos + rot_mat @ local_offsets_np[i] - hit = distances_np[i] >= 0 - - if hit: - end = hit_positions_np[i] - color = self.cfg.viz.hit_color - else: - direction = rot_mat @ local_dirs_np[i] - end = origin + direction * min(0.5, self.cfg.max_distance * 0.05) - color = self.cfg.viz.miss_color - - if self.cfg.viz.show_rays: - visualizer.add_arrow( - start=origin, - end=end, - color=color, - width=ray_width, - label=f"{self.cfg.name}_ray_{i}", - ) - - if hit: - visualizer.add_sphere( - center=end, - radius=sphere_radius, - color=self.cfg.viz.hit_sphere_color, - label=f"{self.cfg.name}_hit_{i}", - ) - if self.cfg.viz.show_normals: - normal_end = end + normals_np[i] * normal_length - visualizer.add_arrow( - start=end, - end=normal_end, - color=self.cfg.viz.normal_color, - width=normal_width, - label=f"{self.cfg.name}_normal_{i}", - )