diff --git a/src/mjlab/viewer/viser/conversions.py b/src/mjlab/viewer/viser/conversions.py index 961302ce0..e22e8cbd4 100644 --- a/src/mjlab/viewer/viser/conversions.py +++ b/src/mjlab/viewer/viser/conversions.py @@ -9,6 +9,15 @@ from mujoco import mj_id2name, mjtGeom, mjtObj from PIL import Image +# Default colors for geoms without materials. +_DEFAULT_COLLISION_COLOR = np.array([204, 102, 102, 128], dtype=np.uint8) +_DEFAULT_VISUAL_COLOR = np.array([31, 128, 230, 255], dtype=np.uint8) + + +def rgba_to_uint8(rgba: np.ndarray) -> np.ndarray: + """Convert RGBA from [0,1] range to [0,255] uint8.""" + return (rgba * 255).astype(np.uint8) + def mujoco_mesh_to_trimesh( mj_model: mujoco.MjModel, geom_idx: int, verbose: bool = False @@ -23,57 +32,38 @@ def mujoco_mesh_to_trimesh( Returns: A trimesh object with texture/material applied if available """ - - # Get the mesh ID for this geometry. mesh_id = mj_model.geom_dataid[geom_idx] - # Get mesh data ranges from MuJoCo. vert_start = int(mj_model.mesh_vertadr[mesh_id]) vert_count = int(mj_model.mesh_vertnum[mesh_id]) face_start = int(mj_model.mesh_faceadr[mesh_id]) face_count = int(mj_model.mesh_facenum[mesh_id]) - # Extract vertices and faces. - # mesh_vert shape: (total_verts_in_model, 3) - # We extract our mesh's vertices. - vertices = mj_model.mesh_vert[ - vert_start : vert_start + vert_count - ] # Shape: (vert_count, 3) + vertices = mj_model.mesh_vert[vert_start : vert_start + vert_count] assert vertices.shape == ( vert_count, 3, ), f"Expected vertices shape ({vert_count}, 3), got {vertices.shape}" - # mesh_face shape: (total_faces_in_model, 3) - # Each face has 3 vertex indices. - faces = mj_model.mesh_face[ - face_start : face_start + face_count - ] # Shape: (face_count, 3) + faces = mj_model.mesh_face[face_start : face_start + face_count] assert faces.shape == ( face_count, 3, ), f"Expected faces shape ({face_count}, 3), got {faces.shape}" - # Check if this mesh has texture coordinates. texcoord_adr = mj_model.mesh_texcoordadr[mesh_id] texcoord_num = mj_model.mesh_texcoordnum[mesh_id] if texcoord_num > 0: - # This mesh has UV coordinates. if verbose: print(f"Mesh has {texcoord_num} texture coordinates") - # Extract texture coordinates. - # mesh_texcoord is a 2D array with shape (nmeshtexcoord, 2). texcoords = mj_model.mesh_texcoord[texcoord_adr : texcoord_adr + texcoord_num] assert texcoords.shape == ( texcoord_num, 2, ), f"Expected texcoords shape ({texcoord_num}, 2), got {texcoords.shape}" - # Get per-face texture coordinate indices. - # For each face vertex, this tells us which texcoord to use. - # mesh_facetexcoord is a 2D array with shape (nmeshface, 3). face_texcoord_idx = mj_model.mesh_facetexcoord[face_start : face_start + face_count] assert face_texcoord_idx.shape == (face_count, 3), ( f"Expected face_texcoord_idx shape ({face_count}, 3), got {face_texcoord_idx.shape}" @@ -83,16 +73,15 @@ def mujoco_mesh_to_trimesh( # we need to duplicate vertices. Each face will get its own 3 vertices. # Duplicate vertices for each face reference. - # faces.flatten() gives us vertex indices in order: [v0_f0, v1_f0, v2_f0, v0_f1, v1_f1, v2_f1, ...] - new_vertices = vertices[faces.flatten()] # Shape: (face_count * 3, 3) + # faces.flatten() gives us vertex indices in order: + # [v0_f0, v1_f0, v2_f0, v0_f1, v1_f1, v2_f1, ...]. + new_vertices = vertices[faces.flatten()] assert new_vertices.shape == ( face_count * 3, 3, ), f"Expected new_vertices shape ({face_count * 3}, 3), got {new_vertices.shape}" - # Get UV coordinates for each duplicated vertex. - # face_texcoord_idx.flatten() gives us texcoord indices in the same order. - new_uvs = texcoords[face_texcoord_idx.flatten()] # Shape: (face_count * 3, 2) + new_uvs = texcoords[face_texcoord_idx.flatten()] assert new_uvs.shape == ( face_count * 3, 2, @@ -100,7 +89,7 @@ def mujoco_mesh_to_trimesh( # Create new faces - now just sequential since vertices are duplicated. # [[0, 1, 2], [3, 4, 5], [6, 7, 8], ...] - new_faces = np.arange(face_count * 3).reshape(-1, 3) # Shape: (face_count, 3) + new_faces = np.arange(face_count * 3).reshape(-1, 3) assert new_faces.shape == ( face_count, 3, @@ -109,14 +98,11 @@ def mujoco_mesh_to_trimesh( # Create the mesh (process=False to preserve all vertices). mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces, process=False) - # Now handle material and texture. matid = mj_model.geom_matid[geom_idx] if matid >= 0 and matid < mj_model.nmat: - # This geometry has a material. - rgba = mj_model.mat_rgba[matid] # Shape: (4,) - # mat_texid is 2D (nmat x mjNTEXROLE), get the RGB/RGBA texture. - # Try RGB first (index 1), then RGBA (index 8). + rgba = mj_model.mat_rgba[matid] + # mat_texid is 2D (nmat x mjNTEXROLE), try RGB first, then RGBA. texid = int(mj_model.mat_texid[matid, int(mujoco.mjtTextureRole.mjTEXROLE_RGB)]) if texid < 0: texid = int( @@ -124,44 +110,31 @@ def mujoco_mesh_to_trimesh( ) if texid >= 0 and texid < mj_model.ntex: - # This material has a texture. if verbose: print(f"Material has texture ID {texid}") - # Extract texture data. tex_width = mj_model.tex_width[texid] tex_height = mj_model.tex_height[texid] tex_nchannel = mj_model.tex_nchannel[texid] tex_adr = mj_model.tex_adr[texid] - - # Calculate texture data size. tex_size = tex_width * tex_height * tex_nchannel - - # Extract raw texture data. tex_data = mj_model.tex_data[tex_adr : tex_adr + tex_size] assert tex_data.shape == (tex_size,), ( f"Expected tex_data shape ({tex_size},), got {tex_data.shape}" ) - # Reshape texture data based on number of channels. - # Note: MuJoCo uses OpenGL convention (origin at bottom-left) - # but GLTF/GLB expects top-left origin, so we flip vertically. + # MuJoCo uses OpenGL convention (origin at bottom-left) but GLTF/GLB + # expects top-left origin, so we flip vertically. if tex_nchannel == 1: - # Grayscale. tex_array = tex_data.reshape(tex_height, tex_width) - # Flip vertically for GLTF convention. tex_array = np.flipud(tex_array) image = Image.fromarray(tex_array.astype(np.uint8), mode="L") elif tex_nchannel == 3: - # RGB. tex_array = tex_data.reshape(tex_height, tex_width, 3) - # Flip vertically for GLTF convention. tex_array = np.flipud(tex_array) image = Image.fromarray(tex_array.astype(np.uint8), mode="RGB") elif tex_nchannel == 4: - # RGBA. tex_array = tex_data.reshape(tex_height, tex_width, 4) - # Flip vertically for GLTF convention. tex_array = np.flipud(tex_array) image = Image.fromarray(tex_array.astype(np.uint8), mode="RGBA") else: @@ -170,7 +143,6 @@ def mujoco_mesh_to_trimesh( image = None if image is not None: - # Create material with texture. # Set PBR properties for proper rendering: # - metallicFactor=0.0: non-metallic (dielectric) material # - roughnessFactor=1.0: fully rough (diffuse) surface @@ -180,34 +152,24 @@ def mujoco_mesh_to_trimesh( metallicFactor=0.0, roughnessFactor=1.0, ) - - # Apply texture visual with UV coordinates. mesh.visual = trimesh.visual.TextureVisuals(uv=new_uvs, material=material) if verbose: print(f"Applied texture: {tex_width}x{tex_height}, {tex_nchannel} channels") else: - # Just use material color - convert from [0,1] to [0,255]. - rgba_255 = (rgba * 255).astype(np.uint8) mesh.visual = trimesh.visual.ColorVisuals( - vertex_colors=np.tile(rgba_255, (len(new_vertices), 1)) + vertex_colors=np.tile(rgba_to_uint8(rgba), (len(new_vertices), 1)) ) else: - # Material but no texture - use material color. if verbose: print(f"Material has no texture, using color: {rgba}") - rgba_255 = (rgba * 255).astype(np.uint8) mesh.visual = trimesh.visual.ColorVisuals( - vertex_colors=np.tile(rgba_255, (len(new_vertices), 1)) + vertex_colors=np.tile(rgba_to_uint8(rgba), (len(new_vertices), 1)) ) else: - # No material - use default color based on collision/visual. is_collision = ( mj_model.geom_contype[geom_idx] != 0 or mj_model.geom_conaffinity[geom_idx] != 0 ) - if is_collision: - color = np.array([204, 102, 102, 128], dtype=np.uint8) # Red-ish for collision. - else: - color = np.array([31, 128, 230, 255], dtype=np.uint8) # Blue-ish for visual. + color = _DEFAULT_COLLISION_COLOR if is_collision else _DEFAULT_VISUAL_COLOR mesh.visual = trimesh.visual.ColorVisuals( vertex_colors=np.tile(color, (len(new_vertices), 1)) @@ -218,43 +180,32 @@ def mujoco_mesh_to_trimesh( ) else: - # No texture coordinates - simpler case. if verbose: print("Mesh has no texture coordinates") - # Create mesh with original vertices and faces (process=False to avoid vertex removal). mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) - # Apply material color if available. matid = mj_model.geom_matid[geom_idx] if matid >= 0 and matid < mj_model.nmat: rgba = mj_model.mat_rgba[matid] - rgba_255 = (rgba * 255).astype(np.uint8) - # Use actual vertex count after mesh creation. mesh.visual = trimesh.visual.ColorVisuals( - vertex_colors=np.tile(rgba_255, (len(mesh.vertices), 1)) + vertex_colors=np.tile(rgba_to_uint8(rgba), (len(mesh.vertices), 1)) ) if verbose: print(f"Applied material color: {rgba}") else: - # Default color. is_collision = ( mj_model.geom_contype[geom_idx] != 0 or mj_model.geom_conaffinity[geom_idx] != 0 ) - if is_collision: - color = np.array([204, 102, 102, 128], dtype=np.uint8) # Red-ish for collision. - else: - color = np.array([31, 128, 230, 255], dtype=np.uint8) # Blue-ish for visual. + color = _DEFAULT_COLLISION_COLOR if is_collision else _DEFAULT_VISUAL_COLOR - # Use actual vertex count after mesh creation. mesh.visual = trimesh.visual.ColorVisuals( vertex_colors=np.tile(color, (len(mesh.vertices), 1)) ) if verbose: print(f"Using default {'collision' if is_collision else 'visual'} color") - # Final sanity checks. assert mesh.vertices.shape[1] == 3, ( f"Vertices should be Nx3, got {mesh.vertices.shape}" ) @@ -268,6 +219,103 @@ def mujoco_mesh_to_trimesh( return mesh +def _create_hfield_mesh(mj_model: mujoco.MjModel, geom_id: int) -> trimesh.Trimesh: + """Create heightfield mesh from MuJoCo hfield data.""" + hfield_id = mj_model.geom_dataid[geom_id] + nrow = mj_model.hfield_nrow[hfield_id] + ncol = mj_model.hfield_ncol[hfield_id] + sx, sy, sz, base = mj_model.hfield_size[hfield_id] + + offset = 0 + for k in range(hfield_id): + offset += mj_model.hfield_nrow[k] * mj_model.hfield_ncol[k] + hfield = mj_model.hfield_data[offset : offset + nrow * ncol].reshape(nrow, ncol) + + x = np.linspace(-sx, sx, ncol) + y = np.linspace(-sy, sy, nrow) + xx, yy = np.meshgrid(x, y) + zz = base + sz * hfield + + vertices = np.column_stack((xx.ravel(), yy.ravel(), zz.ravel())) + + faces = [] + for i in range(nrow - 1): + for j in range(ncol - 1): + i0 = i * ncol + j + i1 = i0 + 1 + i2 = i0 + ncol + i3 = i2 + 1 + faces.append([i0, i2, i1]) + faces.append([i1, i2, i3]) + faces = np.array(faces, dtype=np.int64) + return trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + + +# Dispatch table for primitive shape creation. +_SHAPE_CREATORS = { + mjtGeom.mjGEOM_SPHERE.value: lambda size: trimesh.creation.icosphere( + radius=size[0], subdivisions=2 + ), + mjtGeom.mjGEOM_BOX.value: lambda size: trimesh.creation.box(extents=2.0 * size), + mjtGeom.mjGEOM_CAPSULE.value: lambda size: trimesh.creation.capsule( + radius=size[0], height=2.0 * size[1] + ), + mjtGeom.mjGEOM_CYLINDER.value: lambda size: trimesh.creation.cylinder( + radius=size[0], height=2.0 * size[1] + ), + mjtGeom.mjGEOM_PLANE.value: lambda size: trimesh.creation.box((20, 20, 0.01)), +} + + +def _create_ellipsoid_mesh(size: np.ndarray) -> trimesh.Trimesh: + """Create ellipsoid mesh by scaling a unit sphere.""" + mesh = trimesh.creation.icosphere(subdivisions=3, radius=1.0) + mesh.apply_scale(size) + return mesh + + +def _create_shape_mesh( + shape_type: int, + size: np.ndarray, + rgba: np.ndarray, + mj_model: mujoco.MjModel | None = None, + geom_id: int | None = None, +) -> trimesh.Trimesh: + """Create a mesh for a primitive shape type. + + Args: + shape_type: MuJoCo geom type (mjtGeom enum value) + size: Shape size array (interpretation depends on shape_type) + rgba: RGBA color array (0-1 range) + mj_model: MuJoCo model (required for HFIELD type) + geom_id: Geom index (required for HFIELD type) + + Returns: + Trimesh representation of the shape + """ + alpha_mode = "BLEND" if rgba[3] < 1.0 else "OPAQUE" + material = trimesh.visual.material.PBRMaterial( # type: ignore + baseColorFactor=rgba, + metallicFactor=0.0, + roughnessFactor=1.0, + alphaMode=alpha_mode, + ) + + if shape_type in _SHAPE_CREATORS: + mesh = _SHAPE_CREATORS[shape_type](size) + elif shape_type == mjtGeom.mjGEOM_ELLIPSOID: + mesh = _create_ellipsoid_mesh(size) + elif shape_type == mjtGeom.mjGEOM_HFIELD: + if mj_model is None or geom_id is None: + raise ValueError("mj_model and geom_id required for HFIELD type") + mesh = _create_hfield_mesh(mj_model, geom_id) + else: + raise ValueError(f"Unsupported shape type: {shape_type}") + + mesh.visual = trimesh.visual.TextureVisuals(material=material) # type: ignore + return mesh + + def create_primitive_mesh(mj_model: mujoco.MjModel, geom_id: int) -> trimesh.Trimesh: """Create a mesh for primitive geom types (sphere, box, capsule, cylinder, plane). @@ -398,6 +446,7 @@ def create_primitive_mesh(mj_model: mujoco.MjModel, geom_id: int) -> trimesh.Tri return mesh + def merge_geoms(mj_model: mujoco.MjModel, geom_ids: list[int]) -> trimesh.Trimesh: """Merge multiple geoms into a single trimesh. @@ -519,3 +568,43 @@ def get_body_name(mj_model: mujoco.MjModel, body_id: int) -> str: if not body_name: body_name = f"body_{body_id}" return body_name + + +def merge_sites(mj_model: mujoco.MjModel, site_ids: list[int]) -> trimesh.Trimesh: + """Merge multiple sites into a single trimesh. + + Args: + mj_model: MuJoCo model containing site definitions. + site_ids: List of site indices to merge. + + Returns: + Single merged trimesh with all sites transformed to their local poses. + """ + supported_types = { + mjtGeom.mjGEOM_SPHERE, + mjtGeom.mjGEOM_BOX, + mjtGeom.mjGEOM_CAPSULE, + mjtGeom.mjGEOM_CYLINDER, + mjtGeom.mjGEOM_ELLIPSOID, + } + meshes = [] + for site_id in site_ids: + site_type = int(mj_model.site_type[site_id]) + if site_type not in supported_types: + site_type = int(mjtGeom.mjGEOM_SPHERE) + mesh = _create_shape_mesh( + shape_type=site_type, + size=mj_model.site_size[site_id], + rgba=mj_model.site_rgba[site_id].copy(), + ) + pos = mj_model.site_pos[site_id] + quat = mj_model.site_quat[site_id] + transform = np.eye(4) + transform[:3, :3] = vtf.SO3(quat).as_matrix() + transform[:3, 3] = pos + mesh.apply_transform(transform) + meshes.append(mesh) + + if len(meshes) == 1: + return meshes[0] + return trimesh.util.concatenate(meshes) diff --git a/src/mjlab/viewer/viser/reward_plotter.py b/src/mjlab/viewer/viser/reward_plotter.py index f157fe99d..09b058cd2 100644 --- a/src/mjlab/viewer/viser/reward_plotter.py +++ b/src/mjlab/viewer/viser/reward_plotter.py @@ -20,55 +20,39 @@ def __init__( """Initialize the reward plotter. Args: - server: The Viser server instance - term_names: List of reward term names to plot - history_length: Number of points to keep in history - max_terms: Maximum number of reward terms to plot + server: The Viser server instance. + term_names: List of reward term names to plot. + history_length: Number of points to keep in history. + max_terms: Maximum number of reward terms to plot. """ self._server = server self._history_length = history_length self._max_terms = max_terms - - # State self._term_names = term_names[: self._max_terms] self._histories: dict[str, deque[float]] = {} self._plot_handles: dict[str, viser.GuiUplotHandle] = {} - - # Pre-allocated x-axis array (reused for all plots) self._x_array = np.arange(-history_length + 1, 1, dtype=np.float64) self._folder_handle = None - # Add checkbox to enable/disable reward plots self._enabled_checkbox = self._server.gui.add_checkbox( "Enable reward plots", initial_value=False ) @self._enabled_checkbox.on_update def _(_) -> None: - # Show/hide plots based on checkbox state for handle in self._plot_handles.values(): handle.visible = self._enabled_checkbox.value - # Create individual plot for each reward term for name in self._term_names: - # Initialize history deque for this term self._histories[name] = deque(maxlen=self._history_length) - - # Create initial empty data x_data = np.array([], dtype=np.float64) y_data = np.array([], dtype=np.float64) - # Configure series for this single term series = [ - viser.uplot.Series(label="Steps"), # X-axis - viser.uplot.Series( - label=name, - stroke="#1f77b4", # Blue for all plots - width=2, - ), + viser.uplot.Series(label="Steps"), + viser.uplot.Series(label=name, stroke="#1f77b4", width=2), ] - # Create uPlot chart for this term with title plot_handle = self._server.gui.add_uplot( data=(x_data, y_data), series=tuple(series), @@ -78,9 +62,9 @@ def _(_) -> None: ), "y": viser.uplot.Scale(auto=True), }, - legend=viser.uplot.Legend(show=False), # No legend needed for single series - title=name, # Add title to the plot - aspect=2.0, # Wider aspect ratio for individual plots + legend=viser.uplot.Legend(show=False), + title=name, + aspect=2.0, visible=False, ) @@ -90,38 +74,26 @@ def update(self, reward_terms: list[tuple[str, np.ndarray]]) -> None: """Update the plots with new reward data. Args: - reward_terms: List of (name, value_array) tuples + reward_terms: List of (name, value_array) tuples. """ - # Early return if plots are disabled if not self._enabled_checkbox.value: return if not self._plot_handles or not self._term_names: return - # Update each term's plot individually for name, arr in reward_terms: if name not in self._histories or name not in self._plot_handles: continue value = float(arr[0]) if np.isfinite(value): - # Add to history deque (automatically pops oldest when full) self._histories[name].append(value) - - # Update this term's plot hist = self._histories[name] hist_len = len(hist) - if hist_len > 0: - # Use view of pre-allocated x-array x_data = self._x_array[-hist_len:] - - # Convert deque to numpy array efficiently - # np.fromiter is efficient for converting iterables y_data = np.fromiter(hist, dtype=np.float64, count=hist_len) - - # Update plot data self._plot_handles[name].data = (x_data, y_data) def clear_histories(self) -> None: @@ -129,7 +101,6 @@ def clear_histories(self) -> None: for history in self._histories.values(): history.clear() - # Reset plot data to empty for handle in self._plot_handles.values(): handle.data = (np.array([], dtype=np.float64), np.array([], dtype=np.float64)) @@ -137,7 +108,7 @@ def set_visible(self, visible: bool) -> None: """Set visibility of all plots. Args: - visible: Whether plots should be visible + visible: Whether plots should be visible. """ for handle in self._plot_handles.values(): handle.visible = visible diff --git a/src/mjlab/viewer/viser/scene.py b/src/mjlab/viewer/viser/scene.py index 1453beb26..7ed36d26f 100644 --- a/src/mjlab/viewer/viser/scene.py +++ b/src/mjlab/viewer/viser/scene.py @@ -12,7 +12,6 @@ import viser import viser.transforms as vtf from mujoco import mj_id2name, mjtGeom, mjtObj -from typing_extensions import override from mjlab.viewer.debug_visualizer import DebugVisualizer from mjlab.viewer.viser.conversions import ( @@ -20,24 +19,22 @@ get_body_name, is_fixed_body, merge_geoms, + merge_sites, mujoco_mesh_to_trimesh, + rgba_to_uint8, rotation_matrix_from_vectors, rotation_quat_from_vectors, ) -try: - import mujoco_warp as mjwarp -except ImportError: - mjwarp = None # type: ignore - - -# Viser visualization defaults. -_DEFAULT_FOV_DEGREES = 60 -_DEFAULT_FOV_MIN = 20 -_DEFAULT_FOV_MAX = 150 +_NUM_GEOM_GROUPS = 6 _DEFAULT_ENVIRONMENT_INTENSITY = 0.8 _DEFAULT_CONTACT_POINT_COLOR = (230, 153, 51) _DEFAULT_CONTACT_FORCE_COLOR = (255, 0, 0) +_DEFAULT_FOV_DEGREES = 60 +_DEFAULT_FOV_MIN = 20 +_DEFAULT_FOV_MAX = 150 +_ARROW_SHAFT_LENGTH_RATIO = 0.8 +_ARROW_HEAD_LENGTH_RATIO = 0.2 @dataclass @@ -80,28 +77,38 @@ class ViserMujocoScene(DebugVisualizer): like arrows, ghost meshes, and coordinate frames. """ - # Core. server: viser.ViserServer mj_model: mujoco.MjModel mj_data: mujoco.MjData num_envs: int - # Handles (created once). fixed_bodies_frame: viser.SceneNodeHandle = field(init=False) - mesh_handles_by_group: dict[tuple[int, int], viser.BatchedGlbHandle] = field( + # Key: (body_id, group_id, is_site). Unified storage for geom and site handles. + _batched_handles: dict[tuple[int, int, bool], viser.BatchedGlbHandle] = field( + default_factory=dict + ) + _fixed_handles: dict[tuple[int, int, bool], viser.GlbHandle] = field( default_factory=dict ) contact_point_handle: viser.BatchedMeshHandle | None = None contact_force_shaft_handle: viser.BatchedMeshHandle | None = None contact_force_head_handle: viser.BatchedMeshHandle | None = None - # Visualization settings (set directly or automatically updated by create_options_gui). - env_idx: int = 0 # Current environment index (DebugVisualizer protocol). + _label_handles: dict[str, viser.LabelHandle] = field(default_factory=dict, init=False) + _frame_handles: dict[str, viser.FrameHandle] = field(default_factory=dict, init=False) + + env_idx: int = 0 camera_tracking_enabled: bool = False show_only_selected: bool = False geom_groups_visible: list[bool] = field( default_factory=lambda: [True, True, True, False, False, False] ) + site_groups_visible: list[bool] = field( + default_factory=lambda: [True, True, True, False, False, False] + ) + label_targets: set[str] = field(default_factory=set) + frame_targets: set[str] = field(default_factory=set) + frame_scale: float = 1.0 show_contact_points: bool = False show_contact_forces: bool = False contact_point_color: tuple[int, int, int] = _DEFAULT_CONTACT_POINT_COLOR @@ -110,7 +117,6 @@ class ViserMujocoScene(DebugVisualizer): needs_update: bool = False _tracked_body_id: int | None = field(init=False, default=None) - # Cached visualization state for re-rendering when settings change. _last_body_xpos: np.ndarray | None = None _last_body_xmat: np.ndarray | None = None _last_mocap_pos: np.ndarray | None = None @@ -118,7 +124,6 @@ class ViserMujocoScene(DebugVisualizer): _last_env_idx: int = 0 _last_contacts: list[_Contact] | None = None - # Debug visualization (arrows, ghosts, frames). debug_visualization_enabled: bool = False _scene_offset: np.ndarray = field(default_factory=lambda: np.zeros(3), init=False) _queued_arrows: list[ @@ -156,14 +161,6 @@ def create( Visual geometry is created immediately. Collision geometry is created lazily when first needed. - - Args: - server: Viser server instance. - mj_model: MuJoCo model. - num_envs: Number of parallel environments. - - Returns: - ViserMujocoScene instance with scene populated. """ mj_data = mujoco.MjData(mj_model) @@ -174,22 +171,15 @@ def create( num_envs=num_envs, ) - # Initialize debug visualization data. scene._viz_data = mujoco.MjData(mj_model) - # Configure environment lighting. server.scene.configure_environment_map( environment_intensity=_DEFAULT_ENVIRONMENT_INTENSITY ) - # Create frame for fixed world geometry. scene.fixed_bodies_frame = server.scene.add_frame("/fixed_bodies", show_axes=False) - - # Add fixed geometry (planes, terrain, etc.). scene._add_fixed_geometry() - - # Create mesh handles per geom group. - scene._create_mesh_handles_by_group() + scene._create_batched_handles() # Find first non-fixed body for camera tracking. for body_id in range(mj_model.nbody): @@ -199,231 +189,8 @@ def create( return scene - def _is_collision_geom(self, geom_id: int) -> bool: - """Check if a geom is a collision geom.""" - return ( - self.mj_model.geom_contype[geom_id] != 0 - or self.mj_model.geom_conaffinity[geom_id] != 0 - ) - - def _sync_visibilities(self) -> None: - """Synchronize all handle visibilities based on current flags.""" - # Geom group meshes. - for (_body_id, group_id), handle in self.mesh_handles_by_group.items(): - handle.visible = group_id < 6 and self.geom_groups_visible[group_id] - - # Contact points. - if self.contact_point_handle is not None and not self.show_contact_points: - self.contact_point_handle.visible = False - - # Contact forces. - if not self.show_contact_forces: - if self.contact_force_shaft_handle is not None: - self.contact_force_shaft_handle.visible = False - if self.contact_force_head_handle is not None: - self.contact_force_head_handle.visible = False - - def create_visualization_gui( - self, - camera_distance: float = 3.0, - camera_azimuth: float = 45.0, - camera_elevation: float = 30.0, - show_debug_viz_control: bool = True, - ) -> None: - """Add standard GUI controls that automatically update this scene's settings. - - Args: - camera_distance: Default camera distance from tracked body when tracking is enabled. - camera_azimuth: Default camera azimuth angle in degrees. - camera_elevation: Default camera elevation angle in degrees. - show_debug_viz_control: Whether to show the debug visualization checkbox. - """ - with self.server.gui.add_folder("Visualization"): - slider_fov = self.server.gui.add_slider( - "FOV (°)", - min=_DEFAULT_FOV_MIN, - max=_DEFAULT_FOV_MAX, - step=1, - initial_value=_DEFAULT_FOV_DEGREES, - hint="Vertical FOV of viewer camera, in degrees.", - ) - - @slider_fov.on_update - def _(_) -> None: - for client in self.server.get_clients().values(): - client.camera.fov = np.radians(slider_fov.value) - - @self.server.on_client_connect - def _(client: viser.ClientHandle) -> None: - client.camera.fov = np.radians(slider_fov.value) - - # Environment selection (only if multiple environments). - with self.server.gui.add_folder("Environment"): - # Environment selection slider (if multiple envs). - if self.num_envs > 1: - env_slider = self.server.gui.add_slider( - "Select", - min=0, - max=self.num_envs - 1, - step=1, - initial_value=self.env_idx, - hint=f"Select environment (0-{self.num_envs - 1})", - ) - - @env_slider.on_update - def _(_) -> None: - self.env_idx = int(env_slider.value) - self._request_update() - - show_only_cb = self.server.gui.add_checkbox( - "Hide others", - initial_value=self.show_only_selected, - hint="Show only the selected environment.", - ) - - @show_only_cb.on_update - def _(_) -> None: - self.show_only_selected = show_only_cb.value - self._request_update() - - # Camera tracking controls. - cb_camera_tracking = self.server.gui.add_checkbox( - "Track camera", - initial_value=self.camera_tracking_enabled, - hint="Keep tracked body centered. Use Viser camera controls to adjust view.", - ) - - @cb_camera_tracking.on_update - def _(_) -> None: - self.camera_tracking_enabled = cb_camera_tracking.value - # Snap camera to default view when enabling tracking. - if self.camera_tracking_enabled: - # Convert to radians and calculate camera position. - azimuth_rad = np.deg2rad(camera_azimuth) - elevation_rad = np.deg2rad(camera_elevation) - - # Calculate forward vector from spherical coordinates. - forward = np.array( - [ - np.cos(elevation_rad) * np.cos(azimuth_rad), - np.cos(elevation_rad) * np.sin(azimuth_rad), - np.sin(elevation_rad), - ] - ) - - # Camera position is origin - forward * distance. - camera_pos = -forward * camera_distance - - # Snap all connected clients to this view. - for client in self.server.get_clients().values(): - client.camera.position = camera_pos - client.camera.look_at = np.zeros(3) - - self._request_update() - - # Debug visualization controls (only show if requested). - if show_debug_viz_control: - cb_debug_vis = self.server.gui.add_checkbox( - "Debug visualization", - initial_value=self.debug_visualization_enabled, - hint="Show debug arrows and ghost meshes.", - ) - - @cb_debug_vis.on_update - def _(_) -> None: - self.debug_visualization_enabled = cb_debug_vis.value - # Clear visualizer if hiding. - if not self.debug_visualization_enabled: - self.clear_debug_all() - self._request_update() - - # Contact visualization settings. - with self.server.gui.add_folder("Contacts"): - cb_contact_points = self.server.gui.add_checkbox( - "Points", - initial_value=False, - hint="Toggle contact point visualization.", - ) - contact_point_color = self.server.gui.add_rgb( - "Points Color", initial_value=self.contact_point_color - ) - cb_contact_forces = self.server.gui.add_checkbox( - "Forces", - initial_value=False, - hint="Toggle contact force visualization.", - ) - contact_force_color = self.server.gui.add_rgb( - "Forces Color", initial_value=self.contact_force_color - ) - meansize_input = self.server.gui.add_number( - "Scale", - step=self.mj_model.stat.meansize * 0.01, - initial_value=self.mj_model.stat.meansize, - ) - - @cb_contact_points.on_update - def _(_) -> None: - self.show_contact_points = cb_contact_points.value - self._sync_visibilities() - self._request_update() - - @contact_point_color.on_update - def _(_) -> None: - self.contact_point_color = contact_point_color.value - if self.contact_point_handle is not None: - self.contact_point_handle.remove() - self.contact_point_handle = None - self._request_update() - - @cb_contact_forces.on_update - def _(_) -> None: - self.show_contact_forces = cb_contact_forces.value - self._sync_visibilities() - self._request_update() - - @contact_force_color.on_update - def _(_) -> None: - self.contact_force_color = contact_force_color.value - if self.contact_force_shaft_handle is not None: - self.contact_force_shaft_handle.remove() - self.contact_force_shaft_handle = None - if self.contact_force_head_handle is not None: - self.contact_force_head_handle.remove() - self.contact_force_head_handle = None - self._request_update() - - @meansize_input.on_update - def _(_) -> None: - self.meansize_override = meansize_input.value - self._request_update() - - def create_geom_groups_gui(self, tabs) -> None: - """Add geom groups tab to the given tab group. - - Args: - tabs: The viser tab group to add the geom groups tab to. - """ - with tabs.add_tab("Geoms", icon=viser.Icon.EYE): - for i in range(6): - cb = self.server.gui.add_checkbox( - f"Group {i}", - initial_value=self.geom_groups_visible[i], - hint=f"Show/hide geoms in group {i}", - ) - - @cb.on_update - def _(event, group_idx=i) -> None: - self.geom_groups_visible[group_idx] = event.target.value - self._sync_visibilities() - self._request_update() - def update(self, wp_data, env_idx: int | None = None) -> None: - """Update scene from batched simulation data. - - Args: - wp_data: Batched Warp simulation data (mjwarp.Data). - env_idx: Environment index to visualize. If None, uses self.env_idx. - """ + """Update scene from batched simulation data.""" if env_idx is None: env_idx = self.env_idx @@ -437,19 +204,33 @@ def update(self, wp_data, env_idx: int | None = None) -> None: scene_offset = -tracked_pos contacts = None - if self.show_contact_points or self.show_contact_forces: + mj_data = None + if ( + self.show_contact_points + or self.show_contact_forces + or self.label_targets + or self.frame_targets + ): self.mj_data.qpos[:] = wp_data.qpos.numpy()[env_idx] self.mj_data.qvel[:] = wp_data.qvel.numpy()[env_idx] self.mj_data.mocap_pos[:] = mocap_pos[env_idx] self.mj_data.mocap_quat[:] = mocap_quat[env_idx] mujoco.mj_forward(self.mj_model, self.mj_data) - contacts = self._extract_contacts_from_mjdata(self.mj_data) + mj_data = self.mj_data + if self.show_contact_points or self.show_contact_forces: + contacts = self._extract_contacts_from_mjdata(self.mj_data) self._update_visualization( - body_xpos, body_xmat, mocap_pos, mocap_quat, env_idx, scene_offset, contacts + body_xpos, + body_xmat, + mocap_pos, + mocap_quat, + env_idx, + scene_offset, + contacts, + mj_data, ) - # Update scene offset for debug visualizations and sync arrows, spheres, cylinders if self.debug_visualization_enabled: self._scene_offset = scene_offset self._sync_arrows() @@ -457,11 +238,7 @@ def update(self, wp_data, env_idx: int | None = None) -> None: self._sync_cylinders() def update_from_mjdata(self, mj_data: mujoco.MjData) -> None: - """Update scene from single-environment MuJoCo data. - - Args: - mj_data: Single environment MuJoCo data. - """ + """Update scene from single-environment MuJoCo data.""" body_xpos = mj_data.xpos[None, ...] body_xmat = mj_data.xmat.reshape(-1, 3, 3)[None, ...] mocap_pos = mj_data.mocap_pos[None, ...] @@ -472,16 +249,19 @@ def update_from_mjdata(self, mj_data: mujoco.MjData) -> None: tracked_pos = mj_data.xpos[self._tracked_body_id, :].copy() scene_offset = -tracked_pos - # Always extract contacts for single-environment updates (used by nan_viz). - # This allows toggling contact visualization without needing to scrub timesteps. - # Not performance-critical since this isn't called in tight loops. contacts = self._extract_contacts_from_mjdata(mj_data) self._update_visualization( - body_xpos, body_xmat, mocap_pos, mocap_quat, env_idx, scene_offset, contacts + body_xpos, + body_xmat, + mocap_pos, + mocap_quat, + env_idx, + scene_offset, + contacts, + mj_data, ) - # Update scene offset for debug visualizations and sync arrows, spheres, cylinders if self.debug_visualization_enabled: self._scene_offset = scene_offset self._sync_arrows() @@ -497,30 +277,26 @@ def _update_visualization( env_idx: int, scene_offset: np.ndarray, contacts: list[_Contact] | None, + mj_data: mujoco.MjData | None, ) -> None: """Shared visualization update logic.""" - # Cache visualization state for re-rendering when settings change. self._last_body_xpos = body_xpos self._last_body_xmat = body_xmat self._last_mocap_pos = mocap_pos self._last_mocap_quat = mocap_quat self._last_env_idx = env_idx self._scene_offset = scene_offset - # Only update cached contacts if we have new contact data (don't overwrite with None) if contacts is not None: self._last_contacts = contacts self.fixed_bodies_frame.position = scene_offset with self.server.atomic(): body_xquat = vtf.SO3.from_matrix(body_xmat).wxyz - for (body_id, _group_id), handle in self.mesh_handles_by_group.items(): + for (body_id, _group_id, _is_site), handle in self._batched_handles.items(): if not handle.visible: continue - # Check if this is a mocap body. mocap_id = self.mj_model.body_mocapid[body_id] if mocap_id >= 0: - # Use mocap pos/quat for mocap bodies. - # Note: mocap_quat is already in wxyz format (MuJoCo convention). if self.show_only_selected and self.num_envs > 1: single_pos = mocap_pos[env_idx, mocap_id, :] + scene_offset single_quat = mocap_quat[env_idx, mocap_id, :] @@ -530,7 +306,6 @@ def _update_visualization( handle.batched_positions = mocap_pos[:, mocap_id, :] + scene_offset handle.batched_wxyzs = mocap_quat[:, mocap_id, :] else: - # Use xpos/xmat for regular bodies. if self.show_only_selected and self.num_envs > 1: single_pos = body_xpos[env_idx, body_id, :] + scene_offset single_quat = body_xquat[env_idx, body_id, :] @@ -539,44 +314,35 @@ def _update_visualization( else: handle.batched_positions = body_xpos[..., body_id, :] + scene_offset handle.batched_wxyzs = body_xquat[..., body_id, :] + if contacts is not None: self._update_contact_visualization(contacts, scene_offset) + self._update_annotations(body_xpos, body_xmat, env_idx, scene_offset, mj_data) + self.server.flush() def _request_update(self) -> None: - """Request a visualization update and trigger immediate re-render from cache. - - This is called when visualization settings change to provide immediate feedback. - For viewers with continuous update loops (viser_play), the loop will refresh soon. - For static viewers (nan_viz), this provides the only update mechanism. - """ + """Request a visualization update and trigger immediate re-render from cache.""" self.needs_update = True self.refresh_visualization() def refresh_visualization(self) -> None: - """Re-render the scene using cached visualization data. - - This is useful when visualization settings change (e.g., toggling contacts) - but the underlying simulation data hasn't changed. Clears the needs_update flag. - """ + """Re-render the scene using cached visualization data.""" if ( self._last_body_xpos is None or self._last_body_xmat is None or self._last_mocap_pos is None or self._last_mocap_quat is None ): - return # No cached data yet + return - # Use cached contacts (don't recompute - the data might be stale). - # The next regular update will refresh contacts if needed. contacts = ( self._last_contacts if (self.show_contact_points or self.show_contact_forces) else None ) - # Recalculate scene offset based on current camera tracking state. scene_offset = np.zeros(3) if self.camera_tracking_enabled and self._tracked_body_id is not None: tracked_pos = self._last_body_xpos[ @@ -584,7 +350,6 @@ def refresh_visualization(self) -> None: ].copy() scene_offset = -tracked_pos - # Re-render with cached data (_update_visualization has its own atomic block and flush) self._update_visualization( self._last_body_xpos, self._last_body_xmat, @@ -593,9 +358,36 @@ def refresh_visualization(self) -> None: self._last_env_idx, scene_offset, contacts, + None, ) self.needs_update = False + def _is_collision_geom(self, geom_id: int) -> bool: + """Check if a geom is a collision geom.""" + return ( + self.mj_model.geom_contype[geom_id] != 0 + or self.mj_model.geom_conaffinity[geom_id] != 0 + ) + + def _sync_visibilities(self) -> None: + """Synchronize all handle visibilities based on current flags.""" + for (_body_id, group_id, is_site), handle in self._batched_handles.items(): + visible_list = self.site_groups_visible if is_site else self.geom_groups_visible + handle.visible = group_id < _NUM_GEOM_GROUPS and visible_list[group_id] + + for (_body_id, group_id, is_site), handle in self._fixed_handles.items(): + visible_list = self.site_groups_visible if is_site else self.geom_groups_visible + handle.visible = group_id < _NUM_GEOM_GROUPS and visible_list[group_id] + + if self.contact_point_handle is not None and not self.show_contact_points: + self.contact_point_handle.visible = False + + if not self.show_contact_forces: + if self.contact_force_shaft_handle is not None: + self.contact_force_shaft_handle.visible = False + if self.contact_force_head_handle is not None: + self.contact_force_head_handle.visible = False + def _add_fixed_geometry(self) -> None: """Add fixed world geometry to the scene.""" body_geoms_visual: dict[int, list[int]] = {} @@ -606,17 +398,12 @@ def _add_fixed_geometry(self) -> None: target = body_geoms_collision if self._is_collision_geom(i) else body_geoms_visual target.setdefault(body_id, []).append(i) - # Process all bodies with geoms. all_bodies = set(body_geoms_visual.keys()) | set(body_geoms_collision.keys()) for body_id in all_bodies: - # Get body name. body_name = get_body_name(self.mj_model, body_id) - # Fixed world geometry. We'll assume this is shared between all environments. if is_fixed_body(self.mj_model, body_id): - # Create both visual and collision geoms for fixed bodies (terrain, floor, etc.) - # but show them all since they're static. all_geoms = [] if body_id in body_geoms_visual: all_geoms.extend(body_geoms_visual[body_id]) @@ -626,20 +413,13 @@ def _add_fixed_geometry(self) -> None: if not all_geoms: continue - # Iterate over geoms. nonplane_geom_ids: list[int] = [] for geom_id in all_geoms: geom_type = self.mj_model.geom_type[geom_id] - # Add plane geoms as infinite grids. if geom_type == mjtGeom.mjGEOM_PLANE: geom_name = mj_id2name(self.mj_model, mjtObj.mjOBJ_GEOM, geom_id) self.server.scene.add_grid( f"/fixed_bodies/{body_name}/{geom_name}", - # For infinite grids in viser 1.0.10, the width and height - # parameters determined the region of the grid that can - # receive shadows. We'll just make this really big for now. - # In a future release of Viser these two args should ideally be - # unnecessary. width=2000.0, height=2000.0, infinite_grid=True, @@ -651,7 +431,6 @@ def _add_fixed_geometry(self) -> None: else: nonplane_geom_ids.append(geom_id) - # Handle non-plane geoms. if len(nonplane_geom_ids) > 0: self.server.scene.add_mesh_trimesh( f"/fixed_bodies/{body_name}", @@ -663,52 +442,482 @@ def _add_fixed_geometry(self) -> None: visible=True, ) - def _create_mesh_handles_by_group(self) -> None: - """Create mesh handles for each geom group separately to allow independent toggling.""" - # Group geoms by (body_id, group_id). - body_group_geoms: dict[tuple[int, int], list[int]] = {} + body_group_sites: dict[tuple[int, int], list[int]] = {} + for i in range(self.mj_model.nsite): + body_id = self.mj_model.site_bodyid[i] + if is_fixed_body(self.mj_model, body_id): + site_group = self.mj_model.site_group[i] + key = (body_id, site_group) + body_group_sites.setdefault(key, []).append(i) + for (body_id, group_id), site_ids in body_group_sites.items(): + body_name = get_body_name(self.mj_model, body_id) + visible = group_id < _NUM_GEOM_GROUPS and self.site_groups_visible[group_id] + handle = self.server.scene.add_mesh_trimesh( + f"/fixed_bodies/{body_name}/sites_group{group_id}", + merge_sites(self.mj_model, site_ids), + cast_shadow=False, + receive_shadow=0.2, + position=self.mj_model.body(body_id).pos, + wxyz=self.mj_model.body(body_id).quat, + visible=visible, + ) + self._fixed_handles[(body_id, group_id, True)] = handle + + def _create_batched_handles(self) -> None: + """Create batched mesh handles for geoms and sites on non-fixed bodies.""" + # Collect geoms by (body_id, group_id). + body_group_geoms: dict[tuple[int, int], list[int]] = {} for i in range(self.mj_model.ngeom): body_id = self.mj_model.geom_bodyid[i] - - # Skip fixed world geometry. if is_fixed_body(self.mj_model, body_id): continue + key = (body_id, self.mj_model.geom_group[i]) + body_group_geoms.setdefault(key, []).append(i) - geom_group = self.mj_model.geom_group[i] - key = (body_id, geom_group) + # Collect sites by (body_id, group_id). + body_group_sites: dict[tuple[int, int], list[int]] = {} + for i in range(self.mj_model.nsite): + body_id = self.mj_model.site_bodyid[i] + if is_fixed_body(self.mj_model, body_id): + continue + key = (body_id, self.mj_model.site_group[i]) + body_group_sites.setdefault(key, []).append(i) - if key not in body_group_geoms: - body_group_geoms[key] = [] - body_group_geoms[key].append(i) + default_wxyz = np.array([[1.0, 0.0, 0.0, 0.0]] * self.num_envs) + default_pos = np.array([[0.0, 0.0, 0.0]] * self.num_envs) - # Create handles for each (body, group) combination. with self.server.atomic(): - for (body_id, group_id), geom_indices in body_group_geoms.items(): - # Get body name. + for (body_id, group_id), indices in body_group_geoms.items(): body_name = get_body_name(self.mj_model, body_id) - - # Merge geoms into a single mesh. - mesh = merge_geoms(self.mj_model, geom_indices) + mesh = merge_geoms(self.mj_model, indices) lod_ratio = 1000.0 / mesh.vertices.shape[0] + visible = group_id < _NUM_GEOM_GROUPS and self.geom_groups_visible[group_id] + handle = self.server.scene.add_batched_meshes_trimesh( + f"/bodies/{body_name}/geoms_g{group_id}", + mesh, + batched_wxyzs=default_wxyz, + batched_positions=default_pos, + lod=((2.0, lod_ratio),) if lod_ratio < 0.5 else "off", + visible=visible, + ) + self._batched_handles[(body_id, group_id, False)] = handle - # Check if this group should be visible. - visible = group_id < 6 and self.geom_groups_visible[group_id] - - # Create handle. + for (body_id, group_id), indices in body_group_sites.items(): + body_name = get_body_name(self.mj_model, body_id) + mesh = merge_sites(self.mj_model, indices) + lod_ratio = 1000.0 / mesh.vertices.shape[0] + visible = group_id < _NUM_GEOM_GROUPS and self.site_groups_visible[group_id] handle = self.server.scene.add_batched_meshes_trimesh( - f"/bodies/{body_name}/group{group_id}", + f"/bodies/{body_name}/sites_g{group_id}", mesh, - batched_wxyzs=np.array([1.0, 0.0, 0.0, 0.0])[None].repeat( - self.num_envs, axis=0 - ), - batched_positions=np.array([0.0, 0.0, 0.0])[None].repeat( - self.num_envs, axis=0 - ), + batched_wxyzs=default_wxyz, + batched_positions=default_pos, lod=((2.0, lod_ratio),) if lod_ratio < 0.5 else "off", visible=visible, ) - self.mesh_handles_by_group[(body_id, group_id)] = handle + self._batched_handles[(body_id, group_id, True)] = handle + + def create_env_selector_gui(self) -> None: + """Add environment selector at top level (always visible across all tabs).""" + if self.num_envs > 1: + env_slider = self.server.gui.add_slider( + "Environment", + min=0, + max=self.num_envs - 1, + step=1, + initial_value=self.env_idx, + hint=f"Select environment (0-{self.num_envs - 1})", + ) + + @env_slider.on_update + def _(_) -> None: + self.env_idx = int(env_slider.value) + self._request_update() + + def create_visualization_gui( + self, + camera_distance: float = 3.0, + camera_azimuth: float = 45.0, + camera_elevation: float = 30.0, + show_debug_viz_control: bool = True, + ) -> None: + """Add standard GUI controls that automatically update this scene's settings.""" + with self.server.gui.add_folder("Camera"): + slider_fov = self.server.gui.add_slider( + "FOV (°)", + min=_DEFAULT_FOV_MIN, + max=_DEFAULT_FOV_MAX, + step=1, + initial_value=_DEFAULT_FOV_DEGREES, + hint="Vertical FOV of viewer camera, in degrees.", + ) + + @slider_fov.on_update + def _(_) -> None: + for client in self.server.get_clients().values(): + client.camera.fov = np.radians(slider_fov.value) + + @self.server.on_client_connect + def _(client: viser.ClientHandle) -> None: + client.camera.fov = np.radians(slider_fov.value) + + cb_camera_tracking = self.server.gui.add_checkbox( + "Track body", + initial_value=self.camera_tracking_enabled, + hint="Keep tracked body centered. Use Viser camera controls to adjust view.", + ) + + if self.num_envs > 1: + show_only_cb = self.server.gui.add_checkbox( + "Hide other envs", + initial_value=self.show_only_selected, + hint="Show only the selected environment.", + ) + + @show_only_cb.on_update + def _(_) -> None: + self.show_only_selected = show_only_cb.value + self._request_update() + + @cb_camera_tracking.on_update + def _(_) -> None: + self.camera_tracking_enabled = cb_camera_tracking.value + if self.camera_tracking_enabled: + azimuth_rad = np.deg2rad(camera_azimuth) + elevation_rad = np.deg2rad(camera_elevation) + + forward = np.array( + [ + np.cos(elevation_rad) * np.cos(azimuth_rad), + np.cos(elevation_rad) * np.sin(azimuth_rad), + np.sin(elevation_rad), + ] + ) + + camera_pos = -forward * camera_distance + + for client in self.server.get_clients().values(): + client.camera.position = camera_pos + client.camera.look_at = np.zeros(3) + + self._request_update() + + if show_debug_viz_control: + cb_debug_vis = self.server.gui.add_checkbox( + "Debug visualization", + initial_value=self.debug_visualization_enabled, + hint="Show debug arrows and ghost meshes.", + ) + + @cb_debug_vis.on_update + def _(_) -> None: + self.debug_visualization_enabled = cb_debug_vis.value + if not self.debug_visualization_enabled: + self.clear_debug_all() + self._request_update() + + def create_groups_gui(self, tabs) -> None: + """Add groups tab combining geom and site visibility controls.""" + with tabs.add_tab("Groups", icon=viser.Icon.EYE): + self.server.gui.add_markdown("**Geoms**") + for i in range(_NUM_GEOM_GROUPS): + cb = self.server.gui.add_checkbox( + f"G{i}", + initial_value=self.geom_groups_visible[i], + hint=f"Show/hide geoms in group {i}", + ) + + @cb.on_update + def _(event, group_idx=i) -> None: + self.geom_groups_visible[group_idx] = event.target.value + self._sync_visibilities() + self._request_update() + + self.server.gui.add_markdown("**Sites**") + for i in range(_NUM_GEOM_GROUPS): + cb = self.server.gui.add_checkbox( + f"S{i}", + initial_value=self.site_groups_visible[i], + hint=f"Show/hide sites in group {i}", + ) + + @cb.on_update + def _(event, group_idx=i) -> None: + self.site_groups_visible[group_idx] = event.target.value + self._sync_visibilities() + if "sites" in self.label_targets or "sites" in self.frame_targets: + self._refresh_annotations() + self._request_update() + + def create_overlays_gui(self, tabs) -> None: + """Add overlays tab combining annotations and contacts.""" + with tabs.add_tab("Overlays", icon=viser.Icon.LAYERS_LINKED): + self.server.gui.add_markdown("**Labels**") + cb_site_labels = self.server.gui.add_checkbox( + "Sites", + initial_value="sites" in self.label_targets, + hint="Show text labels for visible sites", + ) + cb_body_labels = self.server.gui.add_checkbox( + "Bodies", + initial_value="bodies" in self.label_targets, + hint="Show text labels for bodies", + ) + + self.server.gui.add_markdown("**Frames**") + cb_site_frames = self.server.gui.add_checkbox( + "Sites", + initial_value="sites" in self.frame_targets, + hint="Show coordinate frames for visible sites", + ) + cb_body_frames = self.server.gui.add_checkbox( + "Bodies", + initial_value="bodies" in self.frame_targets, + hint="Show coordinate frames for bodies", + ) + frame_scale_input = self.server.gui.add_slider( + "Frame scale", + min=0.1, + max=5.0, + step=0.1, + initial_value=self.frame_scale, + hint="Scale multiplier for coordinate frames", + ) + + self.server.gui.add_markdown("**Contacts**") + cb_contact_points = self.server.gui.add_checkbox( + "Points", + initial_value=self.show_contact_points, + hint="Toggle contact point visualization.", + ) + cb_contact_forces = self.server.gui.add_checkbox( + "Forces", + initial_value=self.show_contact_forces, + hint="Toggle contact force visualization.", + ) + contact_scale_input = self.server.gui.add_number( + "Contact scale", + step=self.mj_model.stat.meansize * 0.01, + initial_value=self.meansize_override or self.mj_model.stat.meansize, + hint="Scale for contact visualization", + ) + + @cb_site_labels.on_update + def _(_) -> None: + if cb_site_labels.value: + self.label_targets.add("sites") + else: + self.label_targets.discard("sites") + self._refresh_annotations() + self._request_update() + + @cb_body_labels.on_update + def _(_) -> None: + if cb_body_labels.value: + self.label_targets.add("bodies") + else: + self.label_targets.discard("bodies") + self._refresh_annotations() + self._request_update() + + @cb_site_frames.on_update + def _(_) -> None: + if cb_site_frames.value: + self.frame_targets.add("sites") + else: + self.frame_targets.discard("sites") + self._refresh_annotations() + self._request_update() + + @cb_body_frames.on_update + def _(_) -> None: + if cb_body_frames.value: + self.frame_targets.add("bodies") + else: + self.frame_targets.discard("bodies") + self._refresh_annotations() + self._request_update() + + @frame_scale_input.on_update + def _(_) -> None: + self.frame_scale = frame_scale_input.value + if self.frame_targets: + self._refresh_annotations() + self._request_update() + + @cb_contact_points.on_update + def _(_) -> None: + self.show_contact_points = cb_contact_points.value + self._sync_visibilities() + self._request_update() + + @cb_contact_forces.on_update + def _(_) -> None: + self.show_contact_forces = cb_contact_forces.value + self._sync_visibilities() + self._request_update() + + @contact_scale_input.on_update + def _(_) -> None: + self.meansize_override = contact_scale_input.value + self._request_update() + + def _clear_annotations(self) -> None: + """Remove all annotation handles (labels and frames).""" + for handle in self._label_handles.values(): + handle.remove() + self._label_handles.clear() + + for handle in self._frame_handles.values(): + handle.remove() + self._frame_handles.clear() + + def _refresh_annotations(self) -> None: + """Recreate all annotations based on current label_targets and frame_targets.""" + self._clear_annotations() + + if "sites" in self.label_targets: + for site_id in range(self.mj_model.nsite): + site_group = self.mj_model.site_group[site_id] + if site_group >= _NUM_GEOM_GROUPS or not self.site_groups_visible[site_group]: + continue + site_name = mj_id2name(self.mj_model, mjtObj.mjOBJ_SITE, site_id) + if not site_name: + site_name = f"site_{site_id}" + label = self.server.scene.add_label( + f"/annotations/labels/site_{site_id}", + site_name, + wxyz=(1.0, 0.0, 0.0, 0.0), + position=(0.0, 0.0, 0.0), + ) + self._label_handles[f"site_{site_id}"] = label + + if "bodies" in self.label_targets: + for body_id in range(self.mj_model.nbody): + if is_fixed_body(self.mj_model, body_id): + continue + body_name = get_body_name(self.mj_model, body_id) + label = self.server.scene.add_label( + f"/annotations/labels/body_{body_id}", + body_name, + wxyz=(1.0, 0.0, 0.0, 0.0), + position=(0.0, 0.0, 0.0), + ) + self._label_handles[f"body_{body_id}"] = label + + if "sites" in self.frame_targets: + self._create_frame_handles("sites") + + if "bodies" in self.frame_targets: + self._create_frame_handles("bodies") + + def _create_frame_handles(self, target: str) -> None: + """Create coordinate frame visualization handles for the given target type.""" + meansize = self.mj_model.stat.meansize + frame_length = self.mj_model.vis.scale.framelength * meansize * self.frame_scale + frame_width = self.mj_model.vis.scale.framewidth * meansize * self.frame_scale + + if target == "sites": + for site_id in range(self.mj_model.nsite): + site_group = self.mj_model.site_group[site_id] + if site_group >= _NUM_GEOM_GROUPS or not self.site_groups_visible[site_group]: + continue + key = f"site_frame_{site_id}" + handle = self.server.scene.add_frame( + f"/annotations/frames/{key}", + axes_length=frame_length, + axes_radius=frame_width, + ) + self._frame_handles[key] = handle + else: + for body_id in range(self.mj_model.nbody): + if is_fixed_body(self.mj_model, body_id): + continue + key = f"body_frame_{body_id}" + handle = self.server.scene.add_frame( + f"/annotations/frames/{key}", + axes_length=frame_length, + axes_radius=frame_width, + ) + self._frame_handles[key] = handle + + def _update_annotations( + self, + body_xpos: np.ndarray, + body_xmat: np.ndarray, + env_idx: int, + scene_offset: np.ndarray, + mj_data: mujoco.MjData | None, + ) -> None: + """Update positions of all annotations for the selected environment.""" + if not self.label_targets and not self.frame_targets: + return + + if "sites" in self.label_targets and mj_data is not None: + for site_id in range(self.mj_model.nsite): + key = f"site_{site_id}" + if key not in self._label_handles: + continue + site_world_pos = mj_data.site(site_id).xpos + self._label_handles[key].position = site_world_pos + scene_offset + + if "bodies" in self.label_targets: + for body_id in range(self.mj_model.nbody): + key = f"body_{body_id}" + if key not in self._label_handles: + continue + body_pos = body_xpos[env_idx, body_id, :] + scene_offset + self._label_handles[key].position = body_pos + + if "sites" in self.frame_targets: + self._update_frame_positions( + "sites", body_xpos, body_xmat, env_idx, scene_offset, mj_data + ) + + if "bodies" in self.frame_targets: + self._update_frame_positions( + "bodies", body_xpos, body_xmat, env_idx, scene_offset, mj_data + ) + + def _update_frame_positions( + self, + target: str, + body_xpos: np.ndarray, + body_xmat: np.ndarray, + env_idx: int, + scene_offset: np.ndarray, + mj_data: mujoco.MjData | None, + ) -> None: + """Update frame handle positions and orientations.""" + body_xquat = vtf.SO3.from_matrix(body_xmat).wxyz + + if target == "sites" and mj_data is not None: + for site_id in range(self.mj_model.nsite): + key = f"site_frame_{site_id}" + if key not in self._frame_handles: + continue + + site_world_pos = mj_data.site(site_id).xpos + scene_offset + site_world_mat = mj_data.site(site_id).xmat.reshape(3, 3) + site_world_quat = vtf.SO3.from_matrix(site_world_mat).wxyz + + handle = self._frame_handles[key] + handle.position = site_world_pos + handle.wxyz = site_world_quat + + elif target == "bodies": + for body_id in range(self.mj_model.nbody): + key = f"body_frame_{body_id}" + if key not in self._frame_handles: + continue + + body_pos = body_xpos[env_idx, body_id, :] + scene_offset + body_quat = body_xquat[env_idx, body_id, :] + + handle = self._frame_handles[key] + handle.position = body_pos + handle.wxyz = body_quat def _extract_contacts_from_mjdata(self, mj_data: mujoco.MjData) -> list[_Contact]: """Extract contact data from given MuJoCo data.""" @@ -739,11 +948,9 @@ def _update_contact_visualization( if not contact.included: continue - # Transform force from contact frame to world frame. force_world = contact.frame.T @ contact.force force_mag = np.linalg.norm(force_world) - # Contact point visualization (cylinder). if self.show_contact_points: contact_points.append( _ContactPointVisual( @@ -761,7 +968,6 @@ def _update_contact_visualization( ) ) - # Contact force visualization (arrow shaft + head). if self.show_contact_forces and force_mag > 1e-6: force_dir = force_world / force_mag arrow_length = ( @@ -785,11 +991,15 @@ def _update_contact_visualization( ) ) - # Update or create contact point handle. if contact_points: - positions = np.array([p.position for p in contact_points], dtype=np.float32) - orientations = np.array([p.orientation for p in contact_points], dtype=np.float32) - scales = np.array([p.scale for p in contact_points], dtype=np.float32) + n_points = len(contact_points) + positions = np.empty((n_points, 3), dtype=np.float32) + orientations = np.empty((n_points, 4), dtype=np.float32) + scales = np.empty((n_points, 3), dtype=np.float32) + for i, p in enumerate(contact_points): + positions[i] = p.position + orientations[i] = p.orientation + scales[i] = p.scale if self.contact_point_handle is None: mesh = trimesh.creation.cylinder(radius=1.0, height=1.0) self.contact_point_handle = self.server.scene.add_batched_meshes_simple( @@ -812,22 +1022,21 @@ def _update_contact_visualization( elif self.contact_point_handle is not None: self.contact_point_handle.visible = False - # Update or create contact force handles (shaft and head separately). if contact_forces: - shaft_positions = np.array( - [f.shaft_position for f in contact_forces], dtype=np.float32 - ) - shaft_orientations = np.array( - [f.shaft_orientation for f in contact_forces], dtype=np.float32 - ) - shaft_scales = np.array([f.shaft_scale for f in contact_forces], dtype=np.float32) - head_positions = np.array( - [f.head_position for f in contact_forces], dtype=np.float32 - ) - head_orientations = np.array( - [f.head_orientation for f in contact_forces], dtype=np.float32 - ) - head_scales = np.array([f.head_scale for f in contact_forces], dtype=np.float32) + n_forces = len(contact_forces) + shaft_positions = np.empty((n_forces, 3), dtype=np.float32) + shaft_orientations = np.empty((n_forces, 4), dtype=np.float32) + shaft_scales = np.empty((n_forces, 3), dtype=np.float32) + head_positions = np.empty((n_forces, 3), dtype=np.float32) + head_orientations = np.empty((n_forces, 4), dtype=np.float32) + head_scales = np.empty((n_forces, 3), dtype=np.float32) + for i, f in enumerate(contact_forces): + shaft_positions[i] = f.shaft_position + shaft_orientations[i] = f.shaft_orientation + shaft_scales[i] = f.shaft_scale + head_positions[i] = f.head_position + head_orientations[i] = f.head_orientation + head_scales[i] = f.head_scale if self.contact_force_shaft_handle is None: shaft_mesh = trimesh.creation.cylinder(radius=0.4, height=1.0) shaft_mesh.apply_translation([0, 0, 0.5]) @@ -876,11 +1085,6 @@ def _update_contact_visualization( self.contact_force_head_handle.visible ) = False - # ============================================================================ - # DebugVisualizer Protocol Implementation - # ============================================================================ - - @override def add_arrow( self, start: np.ndarray | torch.Tensor, @@ -889,15 +1093,11 @@ def add_arrow( width: float = 0.015, label: str | None = None, ) -> None: - """Queue an arrow for batched rendering. - - Arrows are not rendered immediately but queued and rendered together - in the next update() call for efficiency. - """ + """Queue an arrow for batched rendering.""" if not self.debug_visualization_enabled: return - del label # Unused. + del label if isinstance(start, torch.Tensor): start = start.cpu().numpy() if isinstance(end, torch.Tensor): @@ -909,10 +1109,8 @@ def add_arrow( if length < 1e-6: return - # Queue the arrow for batched rendering (without scene offset - applied during sync) self._queued_arrows.append((start, end, color, width)) - @override def add_ghost_mesh( self, qpos: np.ndarray | torch.Tensor, @@ -920,32 +1118,20 @@ def add_ghost_mesh( alpha: float = 0.5, label: str | None = None, ) -> None: - """Add a ghost mesh by rendering the robot at a different pose. - - For Viser, we create meshes once and update their poses for efficiency. - - Args: - qpos: Joint positions for the ghost pose - model: MuJoCo model with pre-configured appearance (geom_rgba for colors) - alpha: Transparency override - label: Optional label for this ghost - """ + """Add a ghost mesh by rendering the robot at a different pose.""" if not self.debug_visualization_enabled: return if isinstance(qpos, torch.Tensor): qpos = qpos.cpu().numpy() - # Use model hash to support models with same structure but different colors model_hash = hash((model.ngeom, model.nbody, model.nq)) self._viz_data.qpos[:] = qpos mujoco.mj_forward(model, self._viz_data) - # Use current scene offset scene_offset = self._scene_offset - # Group geoms by body body_geoms: dict[int, list[int]] = {} for i in range(model.ngeom): body_id = model.geom_bodyid[i] @@ -960,18 +1146,15 @@ def add_ghost_mesh( body_geoms[body_id] = [] body_geoms[body_id].append(i) - # Update or create mesh for each body for body_id, geom_indices in body_geoms.items(): body_pos = self._viz_data.xpos[body_id] + scene_offset body_quat = self._mat_to_quat(self._viz_data.xmat[body_id].reshape(3, 3)) - # Check if we already have a handle for this body if body_id in self._ghost_handles: handle = self._ghost_handles[body_id] handle.wxyz = body_quat handle.position = body_pos else: - # Create mesh if not cached if model_hash not in self._ghost_meshes: self._ghost_meshes[model_hash] = {} @@ -1002,9 +1185,8 @@ def add_ghost_mesh( body_name = get_body_name(model, body_id) handle_name = f"/debug/env_{self.env_idx}/ghost/body_{body_name}" - # Extract color from geom (convert RGBA 0-1 to RGB 0-255) rgba = model.geom_rgba[geom_indices[0]].copy() - color_uint8 = (rgba[:3] * 255).astype(np.uint8) + color_uint8 = rgba_to_uint8(rgba[:3]) handle = self.server.scene.add_mesh_simple( handle_name, @@ -1019,7 +1201,6 @@ def add_ghost_mesh( ) self._ghost_handles[body_id] = handle - @override def add_frame( self, position: np.ndarray | torch.Tensor, @@ -1030,26 +1211,11 @@ def add_frame( alpha: float = 1.0, axis_colors: tuple[tuple[float, float, float], ...] | None = None, ) -> None: - """Add a coordinate frame visualization with RGB-colored axes. - - This implementation reuses add_arrow to draw the three axis arrows. - - Args: - position: Position of the frame origin (3D vector) - rotation_matrix: Rotation matrix (3x3) - scale: Scale/length of the axis arrows - label: Optional label for this frame. - axis_radius: Radius of the axis arrows. - alpha: Opacity for all axes (0=transparent, 1=opaque). Note: This implementation - does not support per-arrow transparency. All arrows in the scene will share - the same alpha value. - axis_colors: Optional tuple of 3 RGB colors for X, Y, Z axes. If None, uses - default RGB coloring (X=red, Y=green, Z=blue). - """ + """Add a coordinate frame visualization with RGB-colored axes.""" if not self.debug_visualization_enabled: return - del label # Unused. + del label if isinstance(position, torch.Tensor): position = position.cpu().numpy() @@ -1071,7 +1237,6 @@ def add_frame( width=axis_radius, ) - @override def add_sphere( self, center: np.ndarray | torch.Tensor, @@ -1100,7 +1265,6 @@ def add_sphere( # Queue the sphere for batched rendering self._queued_spheres.append((center.copy(), radius, color)) - @override def add_cylinder( self, start: np.ndarray | torch.Tensor, @@ -1133,7 +1297,6 @@ def add_cylinder( # Queue the cylinder for batched rendering self._queued_cylinders.append((start.copy(), end.copy(), radius, color)) - @override def clear(self) -> None: """Clear all debug visualizations. @@ -1146,13 +1309,9 @@ def clear(self) -> None: self._queued_cylinders.clear() def clear_debug_all(self) -> None: - """Clear all debug visualizations including ghosts. - - Called when switching to a different environment or disabling debug visualization. - """ + """Clear all debug visualizations including ghosts.""" self.clear() - # Remove arrow meshes if self._arrow_shaft_handle is not None: self._arrow_shaft_handle.remove() self._arrow_shaft_handle = None @@ -1178,15 +1337,7 @@ def clear_debug_all(self) -> None: def _create_geom_mesh_from_model( self, mj_model: mujoco.MjModel, geom_id: int ) -> trimesh.Trimesh | None: - """Create a trimesh from a MuJoCo geom using the specified model. - - Args: - mj_model: MuJoCo model containing geom definition - geom_id: Index of the geom to create mesh for - - Returns: - Trimesh representation of the geom, or None if unsupported type - """ + """Create a trimesh from a MuJoCo geom using the specified model.""" geom_type = mj_model.geom_type[geom_id] if geom_type == mjtGeom.mjGEOM_MESH: @@ -1195,16 +1346,11 @@ def _create_geom_mesh_from_model( return create_primitive_mesh(mj_model, geom_id) def _sync_arrows(self) -> None: - """Render all queued arrows using batched meshes. - - This should be called after all debug visualizations have been queued - for the current frame. - """ + """Render all queued arrows using batched meshes.""" if not self.debug_visualization_enabled: return if not self._queued_arrows: - # Remove arrow meshes if no arrows to render if self._arrow_shaft_handle is not None: self._arrow_shaft_handle.remove() self._arrow_shaft_handle = None @@ -1213,19 +1359,14 @@ def _sync_arrows(self) -> None: self._arrow_head_handle = None return - # Create arrow mesh components if needed (unit-sized base meshes) if self._arrow_shaft_mesh is None: - # Unit cylinder: radius=1.0, height=1.0 self._arrow_shaft_mesh = trimesh.creation.cylinder(radius=1.0, height=1.0) - self._arrow_shaft_mesh.apply_translation(np.array([0, 0, 0.5])) # Center at z=0.5 + self._arrow_shaft_mesh.apply_translation(np.array([0, 0, 0.5])) if self._arrow_head_mesh is None: - # Unit cone: radius=2.0, height=1.0 (base at z=0, tip at z=1.0 by default) head_width = 2.0 self._arrow_head_mesh = trimesh.creation.cone(radius=head_width, height=1.0) - # No translation needed - cone already has base at z=0 - # Prepare batched data num_arrows = len(self._queued_arrows) shaft_positions = np.zeros((num_arrows, 3), dtype=np.float32) shaft_wxyzs = np.zeros((num_arrows, 4), dtype=np.float32) @@ -1238,12 +1379,8 @@ def _sync_arrows(self) -> None: head_colors = np.zeros((num_arrows, 3), dtype=np.uint8) z_axis = np.array([0, 0, 1]) - shaft_length_ratio = 0.8 - head_length_ratio = 0.2 - # Apply scene offset to all arrows for i, (start, end, color, width) in enumerate(self._queued_arrows): - # Apply scene offset start_offset = start + self._scene_offset end_offset = end + self._scene_offset @@ -1253,25 +1390,19 @@ def _sync_arrows(self) -> None: rotation_quat = rotation_quat_from_vectors(z_axis, direction) - # Shaft: scale width in XY, length in Z - shaft_length = shaft_length_ratio * length + shaft_length = _ARROW_SHAFT_LENGTH_RATIO * length shaft_positions[i] = start_offset shaft_wxyzs[i] = rotation_quat - shaft_scales[i] = [width, width, shaft_length] # Per-axis scale - shaft_colors[i] = (np.array(color[:3]) * 255).astype(np.uint8) - - # Head: position at end of shaft - # The cone has its base at z=0, so after scaling by head_length, - # the base is still at z=0 in local coords - # We want the base at the end of the shaft (at shaft_length) - head_length = head_length_ratio * length + shaft_scales[i] = [width, width, shaft_length] + shaft_colors[i] = rgba_to_uint8(np.array(color[:3])) + + head_length = _ARROW_HEAD_LENGTH_RATIO * length head_position = start_offset + direction * shaft_length head_positions[i] = head_position head_wxyzs[i] = rotation_quat - head_scales[i] = [width, width, head_length] # Per-axis scale - head_colors[i] = (np.array(color[:3]) * 255).astype(np.uint8) + head_scales[i] = [width, width, head_length] + head_colors[i] = rgba_to_uint8(np.array(color[:3])) - # Check if we need to recreate handles (number of arrows changed) needs_recreation = ( self._arrow_shaft_handle is None or self._arrow_head_handle is None @@ -1279,13 +1410,11 @@ def _sync_arrows(self) -> None: ) if needs_recreation: - # Remove old handles if self._arrow_shaft_handle is not None: self._arrow_shaft_handle.remove() if self._arrow_head_handle is not None: self._arrow_head_handle.remove() - # Create new batched meshes self._arrow_shaft_handle = self.server.scene.add_batched_meshes_simple( f"/debug/env_{self.env_idx}/arrow_shafts", self._arrow_shaft_mesh.vertices, @@ -1310,7 +1439,6 @@ def _sync_arrows(self) -> None: receive_shadow=False, ) else: - # Update existing handles (guaranteed to exist by needs_recreation check) assert self._arrow_shaft_handle is not None assert self._arrow_head_handle is not None diff --git a/src/mjlab/viewer/viser/viewer.py b/src/mjlab/viewer/viser/viewer.py index 832d203b5..459dfb77a 100644 --- a/src/mjlab/viewer/viser/viewer.py +++ b/src/mjlab/viewer/viser/viewer.py @@ -40,7 +40,6 @@ def setup(self) -> None: self._counter = 0 self._needs_update = False - # Create ViserMujocoScene for all 3D visualization (with debug visualization enabled). self._scene = ViserMujocoScene.create( server=self._server, mj_model=sim.mj_model, @@ -48,22 +47,17 @@ def setup(self) -> None: ) self._scene.env_idx = self.cfg.env_idx - self._scene.debug_visualization_enabled = ( - True # Enable debug visualization by default - ) + self._scene.debug_visualization_enabled = True + + self._scene.create_env_selector_gui() - # Create tab group. tabs = self._server.gui.add_tab_group() - # Main tab with simulation controls and display settings. with tabs.add_tab("Controls", icon=viser.Icon.SETTINGS): - # Status display. with self._server.gui.add_folder("Info"): self._status_html = self._server.gui.add_html("") - # Simulation controls. with self._server.gui.add_folder("Simulation"): - # Play/Pause button. self._pause_button = self._server.gui.add_button( "Play" if self._is_paused else "Pause", icon=viser.Icon.PLAYER_PLAY if self._is_paused else viser.Icon.PLAYER_PAUSE, @@ -79,7 +73,6 @@ def _(_) -> None: self._update_status_display() self._needs_update = True - # Reset button. reset_button = self._server.gui.add_button("Reset Environment") @reset_button.on_click @@ -88,7 +81,6 @@ def _(_) -> None: self._update_status_display() self._needs_update = True - # Speed controls. speed_buttons = self._server.gui.add_button_group( "Speed", options=["Slower", "Faster"], @@ -102,7 +94,6 @@ def _(event) -> None: self.increase_speed() self._update_status_display() - # Add standard visualization options from ViserMujocoScene (Environment, Visualization, Contacts, Camera Tracking, Debug Visualization). self._scene.create_visualization_gui( camera_distance=self.cfg.distance, camera_azimuth=self.cfg.azimuth, @@ -111,10 +102,8 @@ def _(event) -> None: self._prev_env_idx = self._scene.env_idx - # Reward plots tab. if hasattr(self.env.unwrapped, "reward_manager"): with tabs.add_tab("Rewards", icon=viser.Icon.CHART_LINE): - # Get reward term names and create reward plotter. term_names = [ name for name, _ in self.env.unwrapped.reward_manager.get_active_iterable_terms( @@ -123,8 +112,8 @@ def _(event) -> None: ] self._reward_plotter = ViserRewardPlotter(self._server, term_names) - # Geom groups tab. - self._scene.create_geom_groups_gui(tabs) + self._scene.create_groups_gui(tabs) + self._scene.create_overlays_gui(tabs) @override def sync_env_to_viewer(self) -> None: @@ -138,7 +127,6 @@ def sync_env_to_viewer(self) -> None: self._prev_env_idx = self._scene.env_idx if self._reward_plotter: self._reward_plotter.clear_histories() - # Clear debug visualizations when switching environments if self._scene.debug_visualization_enabled: self._scene.clear_debug_all() @@ -150,11 +138,10 @@ def sync_env_to_viewer(self) -> None: ) self._reward_plotter.update(terms) - # Update debug visualizations if enabled if self._scene.debug_visualization_enabled and hasattr( self.env.unwrapped, "update_visualizers" ): - self._scene.clear() # Clear queued arrows from previous frame + self._scene.clear() self.env.unwrapped.update_visualizers(self._scene) if self._counter % 2 != 0: