diff --git a/pyproject.toml b/pyproject.toml index baf55931f..0a08261f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,6 +200,7 @@ ignore = [ channels = ["conda-forge"] platforms = ["linux-64", "linux-aarch64", "osx-arm64", "osx-64"] requires-pixi = ">=0.39.0" +preview = ["pixi-build"] [tool.pixi.environments] # We resolve only two groups: cpu and gpu. diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 14392e82a..2c2ca3e81 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -916,23 +916,29 @@ class LinkParametrizableShape: Box: ClassVar[int] = 0 Cylinder: ClassVar[int] = 1 Sphere: ClassVar[int] = 2 + Mesh: ClassVar[int] = 3 -@jax_dataclasses.pytree_dataclass +@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class HwLinkMetadata(JaxsimDataclass): """ Class storing the hardware parameters of a link. Attributes: link_shape: The shape of the link. - 0 = box, 1 = cylinder, 2 = sphere, -1 = unsupported. - geometry: The dimensions of the link. - box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0]. + 0 = box, 1 = cylinder, 2 = sphere, 3 = mesh, -1 = unsupported. + geometry: Shape parameters used by HW parametrization. + box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0], + mesh: cumulative anisotropic scale factors [sx,sy,sz] (initialized to [1,1,1]). density: The density of the link. L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G. L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame. L_H_pre_mask: The mask indicating the link's child joint indices. L_H_pre: The homogeneous transforms for child joints. + mesh_vertices: The original centered mesh vertices (Nx3) for mesh shapes, None otherwise. + mesh_faces: The mesh triangle faces (Mx3 integer indices) for mesh shapes, None otherwise. + mesh_offset: The original mesh centroid offset (3D vector) for mesh shapes, None otherwise. + mesh_uri: The path to the mesh file for reference, None otherwise. """ link_shape: jtp.Vector @@ -942,6 +948,10 @@ class HwLinkMetadata(JaxsimDataclass): L_H_vis: jtp.Matrix L_H_pre_mask: jtp.Vector L_H_pre: jtp.Matrix + mesh_vertices: Static[tuple[HashedNumpyArray | None, ...] | None] + mesh_faces: Static[tuple[HashedNumpyArray | None, ...] | None] + mesh_offset: Static[tuple[HashedNumpyArray | None, ...] | None] + mesh_uri: Static[tuple[str | None, ...] | None] @classmethod def empty(cls) -> HwLinkMetadata: @@ -954,7 +964,83 @@ def empty(cls) -> HwLinkMetadata: L_H_vis=jnp.array([], dtype=float), L_H_pre_mask=jnp.array([], dtype=bool), L_H_pre=jnp.array([], dtype=float), + mesh_vertices=None, + mesh_faces=None, + mesh_offset=None, + mesh_uri=None, + ) + + @staticmethod + def compute_mesh_inertia( + vertices: jtp.Matrix, faces: jtp.Matrix, density: jtp.Float + ) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]: + """ + Compute mass, center of mass, and inertia tensor from mesh geometry. + + Uses the divergence theorem to compute volumetric properties by integrating + over tetrahedra formed between the mesh surface and the origin. + + Args: + vertices: Mesh vertices (Nx3) in the link frame, should be centered. + faces: Triangle face indices (Mx3), integer indices into vertices array. + density: Material density. + + Returns: + A tuple containing the computed mass, the CoM position and the 3x3 + inertia tensor at the CoM. + """ + + # Extract triangles from vertices using face indices + triangles = vertices[faces.astype(int)] + A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2] + + # Compute signed volume of tetrahedra relative to origin + # vol = 1/6 * (A . (B x C)) + tetrahedron_volumes = jnp.sum(A * jnp.cross(B, C), axis=1) / 6.0 + + total_signed_volume = jnp.sum(tetrahedron_volumes) + + # Normalize the global winding sign so positive density yields non-negative mass. + orientation_sign = jnp.where(total_signed_volume < 0, -1.0, 1.0) + tetrahedron_volumes = tetrahedron_volumes * orientation_sign + total_volume = jnp.sum(tetrahedron_volumes) + + eps = jnp.asarray(1e-12, dtype=total_volume.dtype) + is_valid_volume = jnp.abs(total_volume) > eps + safe_total_volume = jnp.where(is_valid_volume, total_volume, 1.0) + mass = jnp.where(is_valid_volume, total_volume * density, 0.0) + + # Compute center of mass + tet_coms = (A + B + C) / 4.0 + com_position = jnp.where( + is_valid_volume, + jnp.sum(tet_coms * tetrahedron_volumes[:, None], axis=0) + / safe_total_volume, + jnp.zeros(3, dtype=vertices.dtype), + ) + + # Compute inertia tensor with covariance approach + def compute_tetrahedron_covariance(a, b, c, vol): + s = a + b + c + return (vol / 20.0) * ( + jnp.outer(a, a) + jnp.outer(b, b) + jnp.outer(c, c) + jnp.outer(s, s) + ) + + covariance_matrices = jax.vmap(compute_tetrahedron_covariance)( + A, B, C, tetrahedron_volumes ) + Σ_origin = jnp.sum(covariance_matrices, axis=0) + + # Shift to CoM using parallel axis theorem + Σ_com = Σ_origin * density - mass * jnp.outer(com_position, com_position) + + # Convert covariance to inertia tensor + I_com = jnp.trace(Σ_com) * jnp.eye(3, dtype=vertices.dtype) - Σ_com + I_com = jnp.where( + is_valid_volume, I_com, jnp.zeros((3, 3), dtype=vertices.dtype) + ) + + return mass, com_position, I_com @staticmethod def compute_mass_and_inertia( @@ -1015,16 +1101,90 @@ def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]: return mass, inertia - def compute_mass_inertia(shape_idx, dims, density): - return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density) + def compute_mass_inertia_primitive(shape_idx, dims, density): + def unsupported_case(_): + return ( + jnp.asarray(0.0, dtype=density.dtype), + jnp.zeros((3, 3), dtype=density.dtype), + ) - mass, inertia = jax.vmap(compute_mass_inertia)( - hw_link_metadata.link_shape, - hw_link_metadata.geometry, - hw_link_metadata.density, + def supported_case(idx): + return jax.lax.switch(idx, (box, cylinder, sphere), dims, density) + + return jax.lax.cond( + shape_idx < 0, unsupported_case, supported_case, shape_idx + ) + + # For models with mesh data, we need to handle Static heterogeneous mesh arrays + has_mesh_data = ( + hw_link_metadata.mesh_vertices is not None + and hw_link_metadata.mesh_faces is not None ) - return mass, inertia + if has_mesh_data: + mesh_verts_tuple = hw_link_metadata.mesh_vertices + mesh_faces_tuple = hw_link_metadata.mesh_faces + n_links = len(mesh_verts_tuple) + + # Build per-link compute functions that capture mesh data in closures + # This loop runs once at trace time to build the computation graph + compute_fns = [] + for i in range(n_links): + if mesh_verts_tuple[i] is not None: + # Capture this link's mesh data in the closure + verts_data = jnp.array(mesh_verts_tuple[i].get()) + faces_data = jnp.array(mesh_faces_tuple[i].get()) + + def make_fn(verts, faces): + def link_fn(shape, dims, density): + def mesh_branch(): + scaled_vertices = verts * dims + mass_m, _, inertia_m = ( + HwLinkMetadata.compute_mesh_inertia( + scaled_vertices, faces, density + ) + ) + return mass_m, inertia_m + + def primitive_branch(): + return compute_mass_inertia_primitive( + shape, dims, density + ) + + return jax.lax.cond( + shape == LinkParametrizableShape.Mesh, + mesh_branch, + primitive_branch, + ) + + return link_fn + + compute_fns.append(make_fn(verts_data, faces_data)) + else: + # No mesh data for this link + def link_fn(shape, dims, density): + return compute_mass_inertia_primitive(shape, dims, density) + + compute_fns.append(link_fn) + + def compute_single_link(link_idx, shape, dims, density): + return jax.lax.switch(link_idx, compute_fns, shape, dims, density) + + masses, inertias = jax.vmap(compute_single_link)( + jnp.arange(n_links), + hw_link_metadata.link_shape, + hw_link_metadata.geometry, + hw_link_metadata.density, + ) + else: + # No mesh data - pure primitives with simple vmap + masses, inertias = jax.vmap(compute_mass_inertia_primitive)( + hw_link_metadata.link_shape, + hw_link_metadata.geometry, + hw_link_metadata.density, + ) + + return masses, inertias @staticmethod def _convert_scaling_to_3d_vector( @@ -1034,7 +1194,7 @@ def _convert_scaling_to_3d_vector( Convert scaling factors for specific shape dimensions into a 3D scaling vector. Args: - link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder). + link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder, mesh). scaling_factors: The scaling factors for the shape dimensions. Returns: @@ -1045,17 +1205,20 @@ def _convert_scaling_to_3d_vector( - Box: [lx, ly, lz] - Cylinder: [r, r, l] - Sphere: [r, r, r] + - Mesh: [sx, sy, sz] """ # Index mapping for each shape type (link_shapes x 3 dims) # Box: [lx, ly, lz] -> [0, 1, 2] # Cylinder: [r, r, l] -> [0, 0, 1] # Sphere: [r, r, r] -> [0, 0, 0] + # Mesh: [sx, sy, sz] -> [0, 1, 2] shape_indices = jnp.array( [ [0, 1, 2], # Box [0, 0, 1], # Cylinder [0, 0, 0], # Sphere + [0, 1, 2], # Mesh ] ) @@ -1117,9 +1280,19 @@ def box(parent_idx, L_p_C): ] ) + def mesh(parent_idx, L_p_C): + sx, sy, sz = scaling_factors.dims[parent_idx] + return jnp.hstack( + [ + L_p_C[0] * sx, + L_p_C[1] * sy, + L_p_C[2] * sz, + ] + ) + new_positions = jax.vmap( lambda shape_idx, parent_idx, L_p_C: jax.lax.switch( - shape_idx, (box, cylinder, sphere), parent_idx, L_p_C + shape_idx, (box, cylinder, sphere, mesh), parent_idx, L_p_C ) )( parent_link_shapes, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 2b88ed212..f40ebd2d5 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -30,6 +30,7 @@ from jaxsim.parsers.descriptions import ModelDescription from jaxsim.parsers.descriptions.joint import JointDescription from jaxsim.parsers.descriptions.link import LinkDescription +from jaxsim.parsers.rod.utils import prepare_mesh_for_parametrization from jaxsim.utils import JaxsimDataclass, Mutability, wrappers from .common import VelRepr @@ -385,6 +386,10 @@ def compute_hw_link_metadata( L_H_vises = [] L_H_pre_masks = [] L_H_pre = [] + mesh_vertices = [] + mesh_faces = [] + mesh_offsets = [] + mesh_uris = [] # Process each link, only parametrizing those in parametrized_links if provided for link_description in ordered_links: @@ -399,6 +404,10 @@ def compute_hw_link_metadata( L_H_vises.append(jnp.eye(4)) L_H_pre_masks.append([0] * self.number_of_joints()) L_H_pre.append([jnp.eye(4)] * self.number_of_joints()) + mesh_vertices.append(None) + mesh_faces.append(None) + mesh_offsets.append(None) + mesh_uris.append(None) continue rod_link = rod_links_dict.get(link_name) @@ -432,7 +441,8 @@ def compute_hw_link_metadata( v for v in rod_link.visuals() if isinstance( - v.geometry.geometry(), (rod.Box, rod.Sphere, rod.Cylinder) + v.geometry.geometry(), + (rod.Box, rod.Sphere, rod.Cylinder, rod.Mesh), ) ), None, @@ -444,21 +454,75 @@ def compute_hw_link_metadata( geometry = ( supported_visual.geometry.geometry() if supported_visual else None ) + if isinstance(geometry, rod.Box): lx, ly, lz = geometry.size density = mass / (lx * ly * lz) geom = [lx, ly, lz] shape = LinkParametrizableShape.Box + mesh_vertices.append(None) + mesh_faces.append(None) + mesh_offsets.append(None) + mesh_uris.append(None) elif isinstance(geometry, rod.Sphere): r = geometry.radius density = mass / (4 / 3 * jnp.pi * r**3) geom = [r, 0, 0] shape = LinkParametrizableShape.Sphere + mesh_vertices.append(None) + mesh_faces.append(None) + mesh_offsets.append(None) + mesh_uris.append(None) elif isinstance(geometry, rod.Cylinder): r, l = geometry.radius, geometry.length density = mass / (jnp.pi * r**2 * l) geom = [r, l, 0] shape = LinkParametrizableShape.Cylinder + mesh_vertices.append(None) + mesh_faces.append(None) + mesh_offsets.append(None) + mesh_uris.append(None) + elif isinstance(geometry, rod.Mesh): + # Load and prepare mesh for parametric scaling + try: + + mesh_data = prepare_mesh_for_parametrization( + mesh_uri=geometry.uri, + scale=geometry.scale, + ) + + density = ( + mass / mesh_data["volume"] if mesh_data["volume"] > 0 else 0.0 + ) + + # For meshes, store cumulative scale factors (initially 1.0) in geometry + # instead of bounding box extents. This allows proper multiplicative scaling. + geom = [1.0, 1.0, 1.0] + shape = LinkParametrizableShape.Mesh + + # Store mesh data + mesh_vertices.append(mesh_data["vertices"]) + mesh_faces.append(mesh_data["faces"]) + mesh_offsets.append(mesh_data["offset"]) + mesh_uris.append(mesh_data["uri"]) + + logging.info( + f"Loaded mesh for link '{link_name}': " + f"{len(mesh_data['vertices'])} vertices, " + f"{len(mesh_data['faces'])} faces, " + ) + except Exception as e: + logging.warning( + f"Failed to load mesh for link '{link_name}': {e}. " + f"Marking as unsupported." + ) + density = 0.0 + geom = [0, 0, 0] + shape = LinkParametrizableShape.Unsupported + mesh_vertices.append(None) + mesh_faces.append(None) + mesh_offsets.append(None) + mesh_uris.append(None) else: logging.debug( f"Skipping link '{link_name}' for hardware parametrization due to unsupported geometry." @@ -466,6 +530,10 @@ def compute_hw_link_metadata( density = 0.0 geom = [0, 0, 0] shape = LinkParametrizableShape.Unsupported + mesh_vertices.append(None) + mesh_faces.append(None) + mesh_offsets.append(None) + mesh_uris.append(None) inertial_pose = ( rod_link.inertial.pose.transform() if rod_link else jnp.eye(4) @@ -501,6 +569,13 @@ def compute_hw_link_metadata( return HwLinkMetadata.empty() # Stack collected data into JAX arrays + # Handle L_H_pre specially: ensure shape (n_links, n_joints, 4, 4) even when n_joints=0 + L_H_pre_array = jnp.array(L_H_pre, dtype=float) + if self.number_of_joints() == 0: + # Reshape from (n_links, 0) to (n_links, 0, 4, 4) + n_links = len(L_H_pre) + L_H_pre_array = L_H_pre_array.reshape(n_links, 0, 4, 4) + return HwLinkMetadata( link_shape=jnp.array(shapes, dtype=int), geometry=jnp.array(geoms, dtype=float), @@ -508,7 +583,34 @@ def compute_hw_link_metadata( L_H_G=jnp.array(L_H_Gs, dtype=float), L_H_vis=jnp.array(L_H_vises, dtype=float), L_H_pre_mask=jnp.array(L_H_pre_masks, dtype=bool), - L_H_pre=jnp.array(L_H_pre, dtype=float), + L_H_pre=L_H_pre_array, + mesh_vertices=( + tuple( + wrappers.HashedNumpyArray(array=v) if v is not None else None + for v in mesh_vertices + ) + if any(v is not None for v in mesh_vertices) + else None + ), + mesh_faces=( + tuple( + wrappers.HashedNumpyArray(array=f) if f is not None else None + for f in mesh_faces + ) + if any(f is not None for f in mesh_faces) + else None + ), + mesh_offset=( + tuple( + wrappers.HashedNumpyArray(array=o) if o is not None else None + for o in mesh_offsets + ) + if any(o is not None for o in mesh_offsets) + else None + ), + mesh_uri=( + tuple(mesh_uris) if any(u is not None for u in mesh_uris) else None + ), ) def export_updated_model(self) -> str: @@ -548,6 +650,85 @@ def export_updated_model(self) -> str: # Iterate over the hardware metadata to update the ROD model hw_metadata = self.kin_dyn_parameters.hw_link_metadata + reduced_link_names = set(self.link_names()) + reduced_joint_names = set(self.joint_names()) + unit_scale = np.ones(3, dtype=float) + link_scale_factors: dict[str, np.ndarray] = {} + + def collect_link_elements(link) -> list: + elements_to_update_raw = (link.visual, link.collision) + elements_to_update = [] + for entry in elements_to_update_raw: + if entry is None: + continue + if isinstance(entry, (list, tuple)): + elements_to_update.extend(e for e in entry if e is not None) + else: + elements_to_update.append(entry) + return elements_to_update + + def scale_pose_translation(element, scale_vector): + if getattr(element, "pose", None) is None: + return + transform = np.array(element.pose.transform(), dtype=float) + transform[0:3, 3] = scale_vector * transform[0:3, 3] + element.pose = rod.Pose.from_transform( + transform=transform, + relative_to=element.pose.relative_to, + ) + + def scale_link_elements( + elements_to_update: list, + scale_vector: np.ndarray, + *, + mesh_pose: rod.Pose | None = None, + mesh_shape_link: bool = False, + ) -> None: + for element in elements_to_update: + if ( + element is None + or not hasattr(element, "geometry") + or element.geometry is None + ): + continue + + geometry = element.geometry + if getattr(geometry, "box", None) is not None: + current_size = np.array(geometry.box.size, dtype=float) + geometry.box.size = tuple( + float(v) for v in (current_size * scale_vector).tolist() + ) + scale_pose_translation(element, scale_vector) + elif getattr(geometry, "sphere", None) is not None: + geometry.sphere.radius = float( + float(geometry.sphere.radius) * float(scale_vector[0]) + ) + scale_pose_translation(element, scale_vector) + elif getattr(geometry, "cylinder", None) is not None: + geometry.cylinder.radius = float( + float(geometry.cylinder.radius) * float(scale_vector[0]) + ) + geometry.cylinder.length = float( + float(geometry.cylinder.length) * float(scale_vector[2]) + ) + scale_pose_translation(element, scale_vector) + elif getattr(geometry, "mesh", None) is not None: + base_scale = ( + np.array(geometry.mesh.scale, dtype=float) + if geometry.mesh.scale is not None + else unit_scale + ) + geometry.mesh.scale = tuple( + float(v) for v in (base_scale * scale_vector).tolist() + ) + + # Mesh-parametrized reduced links use metadata to preserve + # the main visual placement in the exported URDF. + if mesh_shape_link and mesh_pose is not None: + element.pose = mesh_pose + else: + scale_pose_translation(element, scale_vector) + for link_index, link_name in enumerate(self.link_names()): if link_name not in links_dict: continue @@ -578,41 +759,73 @@ def export_updated_model(self) -> str: inertia_tensor=inertia_tensor, validate=True ) - # Update visuals and collisions - dims = hw_metadata.geometry[link_index] + dims = np.array(hw_metadata.geometry[link_index], dtype=float) + elements_to_update = collect_link_elements(links_dict[link_name]) + + def find_reference_geometry(attr: str, elements: list = elements_to_update): + for element in elements: + if ( + element is None + or not hasattr(element, "geometry") + or element.geometry is None + ): + continue + geometry = getattr(element.geometry, attr, None) + if geometry is not None: + return geometry + return None + + if shape == LinkParametrizableShape.Mesh: + scale_vector = dims + elif shape == LinkParametrizableShape.Box: + ref_box = find_reference_geometry("box") + if ref_box is None: + scale_vector = unit_scale + else: + base_size = np.array(ref_box.size, dtype=float) + scale_vector = np.divide( + dims, + base_size, + out=np.ones(3, dtype=float), + where=np.abs(base_size) > 1e-12, + ) + elif shape == LinkParametrizableShape.Sphere: + ref_sphere = find_reference_geometry("sphere") + base_radius = ( + float(ref_sphere.radius) if ref_sphere is not None else 1.0 + ) + s = float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0 + scale_vector = np.array([s, s, s], dtype=float) + elif shape == LinkParametrizableShape.Cylinder: + ref_cylinder = find_reference_geometry("cylinder") + base_radius = ( + float(ref_cylinder.radius) if ref_cylinder is not None else 1.0 + ) + base_length = ( + float(ref_cylinder.length) if ref_cylinder is not None else 1.0 + ) + s_radius = ( + float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0 + ) + s_length = ( + float(dims[1]) / base_length if abs(base_length) > 1e-12 else 1.0 + ) + scale_vector = np.array([s_radius, s_radius, s_length], dtype=float) + else: + scale_vector = unit_scale - elements_to_update = ( - links_dict[link_name].visual, - links_dict[link_name].collision, - ) + link_scale_factors[link_name] = np.array(scale_vector, dtype=float) element_pose = rod.Pose.from_transform( transform=np.array(hw_metadata.L_H_vis[link_index]), relative_to=link_name, ) - - for element in elements_to_update: - if element is None: - continue - - # Update geometry - if shape == LinkParametrizableShape.Box: - element.geometry.box.size = dims.tolist() - elif shape == LinkParametrizableShape.Sphere: - element.geometry.sphere.radius = float(dims[0]) - elif shape == LinkParametrizableShape.Cylinder: - element.geometry.cylinder.radius = float(dims[0]) - element.geometry.cylinder.length = float(dims[1]) - else: - # This branch should be unreachable. Unsupported shapes should be - # filtered out above. - raise RuntimeError( - f"Unexpected shape {shape} for link '{link_name}'. " - "This should never be hit." - ) - - # Update pose - element.pose = element_pose + scale_link_elements( + elements_to_update=elements_to_update, + scale_vector=scale_vector, + mesh_pose=element_pose, + mesh_shape_link=(shape == LinkParametrizableShape.Mesh), + ) # Update joint poses for joint_index in range(self.number_of_joints()): @@ -628,6 +841,51 @@ def export_updated_model(self) -> str: relative_to=link_name, ) + # Propagate link scaling to descendants connected through fixed joints. + # These links are typically reduced away in the JaxSim model (e.g. feet + # attached to ankles) but still exist in the exported URDF tree. + updated = True + while updated: + updated = False + for joint in joints_dict.values(): + if joint.type != "fixed": + continue + parent_scale = link_scale_factors.get(joint.parent, None) + if parent_scale is None or joint.child in link_scale_factors: + continue + link_scale_factors[joint.child] = np.array(parent_scale, dtype=float) + updated = True + + # Scale fixed-joint offsets that are not part of the reduced joint set. + for joint_name, joint in joints_dict.items(): + if joint.type != "fixed" or joint_name in reduced_joint_names: + continue + parent_scale = link_scale_factors.get(joint.parent, unit_scale) + if np.allclose(parent_scale, unit_scale): + continue + if joint.pose is None: + continue + transform = np.array(joint.pose.transform(), dtype=float) + transform[0:3, 3] = parent_scale * transform[0:3, 3] + joint.pose = rod.Pose.from_transform( + transform=transform, + relative_to=joint.pose.relative_to, + ) + + # Apply inherited scaling to non-reduced links (typically descendants + # connected via fixed joints). + for link_name, scale_vector in link_scale_factors.items(): + if link_name in reduced_link_names: + continue + if np.allclose(scale_vector, unit_scale): + continue + if link_name not in links_dict: + continue + scale_link_elements( + elements_to_update=collect_link_elements(links_dict[link_name]), + scale_vector=scale_vector, + ) + # Restore continuous joint types for joints with infinite limits # to ensure valid URDF export (continuous joints should not have limits). # Continuous joints are internally represented as revolute with infinite @@ -2490,21 +2748,127 @@ def update_hw_parameters( has_joints = model.number_of_joints() > 0 - supported_case = lambda hw_metadata, scaling_factors: HwLinkMetadata.apply_scaling( - hw_metadata=hw_metadata, scaling_factors=scaling_factors, has_joints=has_joints + def apply_scaling_single_link( + link_shape, + geometry, + density, + L_H_G, + L_H_vis, + L_H_pre, + L_H_pre_mask, + scaling_dims, + scaling_density, + ): + """Apply scaling to a single link's numerical data.""" + + def scale_supported(_): + shape_indices_map = jnp.array([[0, 1, 2], [0, 0, 1], [0, 0, 0], [0, 1, 2]]) + per_link_indices = shape_indices_map[link_shape] + scale_vector = scaling_dims[per_link_indices] + + # Update kinematics + G_H_L = jaxsim.math.Transform.inverse(L_H_G) + G_H_vis = G_H_L @ L_H_vis + G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) + L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3]) + L_H̅_vis = L_H̅_G @ G_H̅_vis + + # Update shape parameters + updated_geom = geometry * scaling_dims + updated_dens = density * scaling_density + + return updated_geom, updated_dens, L_H̅_G, L_H̅_vis, scale_vector + + def scale_unsupported(_): + return ( + geometry, + density, + L_H_G, + L_H_vis, + jnp.ones_like(scaling_dims), + ) + + return jax.lax.cond( + link_shape == LinkParametrizableShape.Unsupported, + scale_unsupported, + scale_supported, + operand=None, + ) + + # Vmap over all links for basic scaling + ( + updated_geometry, + updated_density, + updated_L_H_G, + updated_L_H_vis, + scale_vectors, + ) = jax.vmap(apply_scaling_single_link)( + hw_link_metadata.link_shape, + hw_link_metadata.geometry, + hw_link_metadata.density, + hw_link_metadata.L_H_G, + hw_link_metadata.L_H_vis, + hw_link_metadata.L_H_pre, + hw_link_metadata.L_H_pre_mask, + scaling_factors.dims, + scaling_factors.density, ) - unsupported_case = lambda hw_metadata, scaling_factors: hw_metadata - - # Apply scaling to hw_link_metadata using vmap - updated_hw_link_metadata = jax.vmap( - lambda hw_metadata, multipliers: jax.lax.cond( - hw_metadata.link_shape == LinkParametrizableShape.Unsupported, - unsupported_case, - supported_case, - hw_metadata, - multipliers, + + # Handle joint transforms separately, only if model has joints + def transform_all_joints(operands): + """Transform all joint poses across all links.""" + original_L_H_G, updated_L_H_G, scale_vectors, L_H_pre, L_H_pre_mask = operands + + # Vectorized transformation: (n_links, n_joints, 4, 4) + # Express joint transforms in the original CoM frames. + # Using the already-scaled L_H_G here introduces a second implicit + # scaling term and distorts kinematic chain proportions. + G_H_L_all = jax.vmap(jaxsim.math.Transform.inverse)( + original_L_H_G + ) # (n_links, 4, 4) + + # Use batch matrix multiply with broadcasting + # G_H_L_all: (n_links, 4, 4) -> (n_links, 1, 4, 4) + # L_H_pre: (n_links, n_joints, 4, 4) + # Result: (n_links, n_joints, 4, 4) + G_H_pre = G_H_L_all[:, None, :, :] @ L_H_pre + + # Scale translation components + G_H̅_pre = G_H_pre.at[:, :, :3, 3].set( + jnp.where( + L_H_pre_mask[:, :, None], + scale_vectors[:, None, :] * G_H_pre[:, :, :3, 3], + G_H_pre[:, :, :3, 3], + ) ) - )(hw_link_metadata, scaling_factors) + + # Transform back to link frames + # updated_L_H_G: (n_links, 4, 4) -> (n_links, 1, 4, 4) + # G_H̅_pre: (n_links, n_joints, 4, 4) + # Result: (n_links, n_joints, 4, 4) + return updated_L_H_G[:, None, :, :] @ G_H̅_pre + + updated_L_H_pre = jax.lax.cond( + has_joints, + transform_all_joints, + lambda operands: operands[3], # Return L_H_pre unchanged + operand=( + hw_link_metadata.L_H_G, + updated_L_H_G, + scale_vectors, + hw_link_metadata.L_H_pre, + hw_link_metadata.L_H_pre_mask, + ), + ) + + # Create updated HwLinkMetadata + updated_hw_link_metadata = hw_link_metadata.replace( + geometry=updated_geometry, + density=updated_density, + L_H_G=updated_L_H_G, + L_H_vis=updated_L_H_vis, + L_H_pre=updated_L_H_pre, + ) # Compute mass and inertia once and unpack the results m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia( diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py index cf0bcb107..68d4b924b 100644 --- a/src/jaxsim/math/__init__.py +++ b/src/jaxsim/math/__init__.py @@ -9,6 +9,5 @@ from .joint_model import JointModel, supported_joint_motion # isort:skip - # Define the default standard gravity constant. STANDARD_GRAVITY = 9.81 diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index a295b7fab..025d4b23e 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -278,3 +278,62 @@ def create_mesh_collision( ] return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L) + + +def prepare_mesh_for_parametrization( + mesh_uri: str, scale: tuple[float, float, float] = (1.0, 1.0, 1.0) +) -> dict: + """ + Load and prepare a mesh for parametric scaling with exact inertia computation. + + This function loads a mesh, ensures it's watertight (crucial for volume/inertia + calculation), centers it, and returns the data needed for parametric scaling. + + Args: + mesh_uri: URI/path to the mesh file. + scale: Initial scale factors to apply (from SDF/URDF). + + Returns: + A dictionary containing: + - 'vertices': Centered mesh vertices as numpy array (Nx3) + - 'faces': Triangle faces as numpy array (Mx3 integer indices) + - 'offset': Original mesh centroid offset as numpy array (3,) + - 'uri': The mesh URI for reference + - 'is_watertight': Boolean indicating if mesh is watertight + - 'volume': The volume of the mesh (after scaling) + """ + + # Load mesh + file = pathlib.Path(resolve_local_uri(uri=mesh_uri)) + file_type = file.suffix.replace(".", "") + mesh = trimesh.load_mesh(file, file_type=file_type) + + if mesh.is_empty: + raise RuntimeError(f"Failed to process '{file}' with trimesh") + + # Apply initial scale from SDF/URDF + mesh.apply_scale(scale) + + # Check and fix watertightness + is_watertight = mesh.is_watertight + if not is_watertight: + logging.warning( + f"Mesh {mesh_uri} is not watertight. Computing convex hull for valid inertia." + ) + mesh = mesh.convex_hull + is_watertight = True + + # Store original centroid as offset + offset = mesh.centroid.copy() + + # Center the mesh + mesh.vertices -= offset + + return { + "vertices": np.array(mesh.vertices, dtype=np.float64), + "faces": np.array(mesh.faces, dtype=np.int32), + "offset": np.array(offset, dtype=np.float64), + "uri": mesh_uri, + "is_watertight": is_watertight, + "volume": mesh.volume, + } diff --git a/tests/assets/cube.stl b/tests/assets/cube.stl new file mode 100644 index 000000000..4030ac997 --- /dev/null +++ b/tests/assets/cube.stl @@ -0,0 +1,86 @@ +solid model +facet normal 0.0 0.0 -1.0 +outer loop +vertex 20.0 0.0 0.0 +vertex 0.0 -20.0 0.0 +vertex 0.0 0.0 0.0 +endloop +endfacet +facet normal 0.0 0.0 -1.0 +outer loop +vertex 0.0 -20.0 0.0 +vertex 20.0 0.0 0.0 +vertex 20.0 -20.0 0.0 +endloop +endfacet +facet normal -0.0 -1.0 -0.0 +outer loop +vertex 20.0 -20.0 20.0 +vertex 0.0 -20.0 0.0 +vertex 20.0 -20.0 0.0 +endloop +endfacet +facet normal -0.0 -1.0 -0.0 +outer loop +vertex 0.0 -20.0 0.0 +vertex 20.0 -20.0 20.0 +vertex 0.0 -20.0 20.0 +endloop +endfacet +facet normal 1.0 0.0 0.0 +outer loop +vertex 20.0 0.0 0.0 +vertex 20.0 -20.0 20.0 +vertex 20.0 -20.0 0.0 +endloop +endfacet +facet normal 1.0 0.0 0.0 +outer loop +vertex 20.0 -20.0 20.0 +vertex 20.0 0.0 0.0 +vertex 20.0 0.0 20.0 +endloop +endfacet +facet normal -0.0 -0.0 1.0 +outer loop +vertex 20.0 -20.0 20.0 +vertex 0.0 0.0 20.0 +vertex 0.0 -20.0 20.0 +endloop +endfacet +facet normal -0.0 -0.0 1.0 +outer loop +vertex 0.0 0.0 20.0 +vertex 20.0 -20.0 20.0 +vertex 20.0 0.0 20.0 +endloop +endfacet +facet normal -1.0 0.0 0.0 +outer loop +vertex 0.0 0.0 20.0 +vertex 0.0 -20.0 0.0 +vertex 0.0 -20.0 20.0 +endloop +endfacet +facet normal -1.0 0.0 0.0 +outer loop +vertex 0.0 -20.0 0.0 +vertex 0.0 0.0 20.0 +vertex 0.0 0.0 0.0 +endloop +endfacet +facet normal -0.0 1.0 0.0 +outer loop +vertex 0.0 0.0 20.0 +vertex 20.0 0.0 0.0 +vertex 0.0 0.0 0.0 +endloop +endfacet +facet normal -0.0 1.0 0.0 +outer loop +vertex 20.0 0.0 0.0 +vertex 0.0 0.0 20.0 +vertex 20.0 0.0 20.0 +endloop +endfacet +endsolid model diff --git a/tests/assets/mixed_shapes_robot.urdf b/tests/assets/mixed_shapes_robot.urdf new file mode 100644 index 000000000..a4bb34e11 --- /dev/null +++ b/tests/assets/mixed_shapes_robot.urdf @@ -0,0 +1,112 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/assets/test_cube.urdf b/tests/assets/test_cube.urdf new file mode 100644 index 000000000..1fe648cd6 --- /dev/null +++ b/tests/assets/test_cube.urdf @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 1fffcde51..98601f91c 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -3,11 +3,16 @@ import jax import jax.numpy as jnp +import numpy as np import pytest import rod import jaxsim.api as js -from jaxsim.api.kin_dyn_parameters import HwLinkMetadata, ScalingFactors +from jaxsim.api.kin_dyn_parameters import ( + HwLinkMetadata, + LinkParametrizableShape, + ScalingFactors, +) from jaxsim.rbda.contacts import SoftContactsParams from .utils import assert_allclose @@ -851,6 +856,119 @@ def test_export_model_with_missing_collision( assert exported_model.name == model.name(), "Exported model name should match" # Verify we can build a JaxSim model from the exported URDF + _ = js.model.JaxSimModel.build_from_model_description( + model_description=exported_urdf, is_urdf=True + ) + + +def test_export_mesh_scaling_preserves_nonzero_visual_and_joint_origins( + tmp_path: pathlib.Path, +): + """ + Regression test for mesh export: + non-identity scaling must preserve non-zero visual/joint origins in the URDF. + """ + + mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl" + if not mesh_file.exists(): + pytest.skip(f"Test mesh file not found: {mesh_file}") + + urdf_path = tmp_path / "mesh_origin_regression.urdf" + urdf_path.write_text( + f""" + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +""", + encoding="utf-8", + ) + + model = js.model.JaxSimModel.build_from_model_description( + model_description=urdf_path, + is_urdf=True, + parametrized_links=("mesh_link",), + ) + + mesh_link_idx = js.link.name_to_idx(model=model, link_name="mesh_link") + dims = jnp.ones((model.number_of_links(), 3)) + dims = dims.at[mesh_link_idx].set(jnp.array([1.7, 0.8, 1.3])) + scaling = ScalingFactors(dims=dims, density=jnp.ones(model.number_of_links())) + + updated_model = js.model.update_hw_parameters(model=model, scaling_factors=scaling) + exported_urdf = updated_model.export_updated_model() + root = ET.fromstring(exported_urdf) + + visual_origin = root.find(".//link[@name='mesh_link']/visual/origin") + assert visual_origin is not None, "Mesh visual origin must exist in exported URDF" + visual_xyz = np.array([float(v) for v in visual_origin.get("xyz").split()]) + + expected_visual_xyz = np.array( + updated_model.kin_dyn_parameters.hw_link_metadata.L_H_vis[mesh_link_idx][:3, 3] + ) + assert_allclose( + visual_xyz, + expected_visual_xyz, + atol=1e-8, + err_msg="Exported mesh visual origin does not match updated metadata", + ) + assert not np.allclose(visual_xyz, np.zeros(3), atol=1e-12) + + joint_origin = root.find(".//joint[@name='base_to_mesh']/origin") + assert joint_origin is not None, "Joint origin must exist in exported URDF" + joint_xyz = np.array([float(v) for v in joint_origin.get("xyz").split()]) + + joint_idx = js.joint.name_to_idx(model=updated_model, joint_name="base_to_mesh") + expected_joint_xyz = np.array( + updated_model.kin_dyn_parameters.joint_model.λ_H_pre[joint_idx + 1][:3, 3] + ) + assert_allclose( + joint_xyz, + expected_joint_xyz, + atol=1e-8, + err_msg="Exported joint origin does not match updated joint transform", + ) + assert not np.allclose(joint_xyz, np.zeros(3), atol=1e-12) + reimported_jaxsim_model = js.model.JaxSimModel.build_from_model_description( model_description=exported_urdf, is_urdf=True ) @@ -860,3 +978,127 @@ def test_export_model_with_missing_collision( assert ( reimported_jaxsim_model.number_of_links() == model.number_of_links() ), "Reimported model should have same number of links" + + +# ============================================================================= +# Mesh Scaling Tests +# ============================================================================= + + +def test_mesh_shape_enum(): + """Test that the Mesh shape type is available in the enum.""" + assert hasattr(LinkParametrizableShape, "Mesh") + assert LinkParametrizableShape.Mesh == 3 + + +def test_mixed_shapes_metadata(): + """Test loading and metadata verification for mixed primitive and mesh shapes.""" + test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf" + + if not test_urdf.exists(): + pytest.skip(f"Test URDF not found: {test_urdf}") + + mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl" + if not mesh_file.exists(): + pytest.skip(f"Test mesh not found: {mesh_file}") + + # Load model with all link types parametrized + model = js.model.JaxSimModel.build_from_model_description( + model_description=test_urdf, + is_urdf=True, + parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"), + ) + + assert model.name() == "mixed_shapes_robot" + assert model.number_of_links() == 4 + + hw_meta = model.kin_dyn_parameters.hw_link_metadata + + # Verify all 4 links are parametrized with correct shape types + assert len(hw_meta.link_shape) == 4 + assert hw_meta.link_shape[0] == LinkParametrizableShape.Box + assert hw_meta.link_shape[1] == LinkParametrizableShape.Cylinder + assert hw_meta.link_shape[2] == LinkParametrizableShape.Mesh + assert hw_meta.link_shape[3] == LinkParametrizableShape.Sphere + + # Verify mesh data exists only for mesh link + assert hw_meta.mesh_vertices is not None + assert hw_meta.mesh_vertices[0] is None # box + assert hw_meta.mesh_vertices[1] is None # cylinder + assert hw_meta.mesh_vertices[2] is not None # mesh + assert hw_meta.mesh_vertices[3] is None # sphere + assert hw_meta.mesh_faces is not None + assert hw_meta.mesh_faces[2] is not None # mesh link has faces + + +def test_mixed_shapes_scaling(): + """Test uniform and non-uniform scaling with mixed primitive and mesh shapes.""" + test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf" + + if not test_urdf.exists(): + pytest.skip(f"Test URDF not found: {test_urdf}") + + mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl" + if not mesh_file.exists(): + pytest.skip(f"Test mesh not found: {mesh_file}") + + model = js.model.JaxSimModel.build_from_model_description( + model_description=test_urdf, + is_urdf=True, + parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"), + ) + + hw_meta = model.kin_dyn_parameters.hw_link_metadata + if len(hw_meta.link_shape) == 0: + pytest.skip("Hardware parametrization not supported") + + # Get original masses + masses_orig = {} + for i in range(model.number_of_links()): + link_name = js.link.idx_to_name(model=model, link_index=i) + masses_orig[link_name] = float(model.kin_dyn_parameters.link_parameters.mass[i]) + + # Test uniform scaling (2x), so all links should scaled by 8x + uniform_scaling = ScalingFactors( + dims=jnp.ones((4, 3)) * 2.0, + density=jnp.ones(4), + ) + scaled_uniform = js.model.update_hw_parameters(model, uniform_scaling) + + for i in range(scaled_uniform.number_of_links()): + link_name = js.link.idx_to_name(model=scaled_uniform, link_index=i) + mass_scaled = float(scaled_uniform.kin_dyn_parameters.link_parameters.mass[i]) + ratio = mass_scaled / masses_orig[link_name] + assert jnp.allclose( + ratio, 8.0, rtol=0.1 + ), f"Uniform scaling: {link_name} expected 8x, got {ratio:.2f}x" + + # Test different scaling factors per link + different_scaling = ScalingFactors( + dims=jnp.array( + [ + [2.0, 2.0, 2.0], # box: 8x + [3.0, 3.0, 3.0], # cylinder: 27x + [1.5, 1.5, 1.5], # mesh: 3.375x + [2.5, 2.5, 2.5], # sphere: 15.625x + ] + ), + density=jnp.ones(4), + ) + scaled_different = js.model.update_hw_parameters(model, different_scaling) + + expected_ratios = { + "box_link": 8.0, + "cylinder_link": 27.0, + "mesh_link": 3.375, + "sphere_link": 15.625, + } + + for i in range(scaled_different.number_of_links()): + link_name = js.link.idx_to_name(model=scaled_different, link_index=i) + mass_scaled = float(scaled_different.kin_dyn_parameters.link_parameters.mass[i]) + ratio = mass_scaled / masses_orig[link_name] + expected = expected_ratios[link_name] + assert jnp.allclose( + ratio, expected, rtol=0.1 + ), f"Different scaling: {link_name} expected {expected}x, got {ratio:.2f}x"