diff --git a/pyproject.toml b/pyproject.toml
index baf55931f..0a08261f1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -200,6 +200,7 @@ ignore = [
channels = ["conda-forge"]
platforms = ["linux-64", "linux-aarch64", "osx-arm64", "osx-64"]
requires-pixi = ">=0.39.0"
+preview = ["pixi-build"]
[tool.pixi.environments]
# We resolve only two groups: cpu and gpu.
diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py
index 14392e82a..2c2ca3e81 100644
--- a/src/jaxsim/api/kin_dyn_parameters.py
+++ b/src/jaxsim/api/kin_dyn_parameters.py
@@ -916,23 +916,29 @@ class LinkParametrizableShape:
Box: ClassVar[int] = 0
Cylinder: ClassVar[int] = 1
Sphere: ClassVar[int] = 2
+ Mesh: ClassVar[int] = 3
-@jax_dataclasses.pytree_dataclass
+@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class HwLinkMetadata(JaxsimDataclass):
"""
Class storing the hardware parameters of a link.
Attributes:
link_shape: The shape of the link.
- 0 = box, 1 = cylinder, 2 = sphere, -1 = unsupported.
- geometry: The dimensions of the link.
- box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0].
+ 0 = box, 1 = cylinder, 2 = sphere, 3 = mesh, -1 = unsupported.
+ geometry: Shape parameters used by HW parametrization.
+ box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0],
+ mesh: cumulative anisotropic scale factors [sx,sy,sz] (initialized to [1,1,1]).
density: The density of the link.
L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.
L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame.
L_H_pre_mask: The mask indicating the link's child joint indices.
L_H_pre: The homogeneous transforms for child joints.
+ mesh_vertices: The original centered mesh vertices (Nx3) for mesh shapes, None otherwise.
+ mesh_faces: The mesh triangle faces (Mx3 integer indices) for mesh shapes, None otherwise.
+ mesh_offset: The original mesh centroid offset (3D vector) for mesh shapes, None otherwise.
+ mesh_uri: The path to the mesh file for reference, None otherwise.
"""
link_shape: jtp.Vector
@@ -942,6 +948,10 @@ class HwLinkMetadata(JaxsimDataclass):
L_H_vis: jtp.Matrix
L_H_pre_mask: jtp.Vector
L_H_pre: jtp.Matrix
+ mesh_vertices: Static[tuple[HashedNumpyArray | None, ...] | None]
+ mesh_faces: Static[tuple[HashedNumpyArray | None, ...] | None]
+ mesh_offset: Static[tuple[HashedNumpyArray | None, ...] | None]
+ mesh_uri: Static[tuple[str | None, ...] | None]
@classmethod
def empty(cls) -> HwLinkMetadata:
@@ -954,7 +964,83 @@ def empty(cls) -> HwLinkMetadata:
L_H_vis=jnp.array([], dtype=float),
L_H_pre_mask=jnp.array([], dtype=bool),
L_H_pre=jnp.array([], dtype=float),
+ mesh_vertices=None,
+ mesh_faces=None,
+ mesh_offset=None,
+ mesh_uri=None,
+ )
+
+ @staticmethod
+ def compute_mesh_inertia(
+ vertices: jtp.Matrix, faces: jtp.Matrix, density: jtp.Float
+ ) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
+ """
+ Compute mass, center of mass, and inertia tensor from mesh geometry.
+
+ Uses the divergence theorem to compute volumetric properties by integrating
+ over tetrahedra formed between the mesh surface and the origin.
+
+ Args:
+ vertices: Mesh vertices (Nx3) in the link frame, should be centered.
+ faces: Triangle face indices (Mx3), integer indices into vertices array.
+ density: Material density.
+
+ Returns:
+ A tuple containing the computed mass, the CoM position and the 3x3
+ inertia tensor at the CoM.
+ """
+
+ # Extract triangles from vertices using face indices
+ triangles = vertices[faces.astype(int)]
+ A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]
+
+ # Compute signed volume of tetrahedra relative to origin
+ # vol = 1/6 * (A . (B x C))
+ tetrahedron_volumes = jnp.sum(A * jnp.cross(B, C), axis=1) / 6.0
+
+ total_signed_volume = jnp.sum(tetrahedron_volumes)
+
+ # Normalize the global winding sign so positive density yields non-negative mass.
+ orientation_sign = jnp.where(total_signed_volume < 0, -1.0, 1.0)
+ tetrahedron_volumes = tetrahedron_volumes * orientation_sign
+ total_volume = jnp.sum(tetrahedron_volumes)
+
+ eps = jnp.asarray(1e-12, dtype=total_volume.dtype)
+ is_valid_volume = jnp.abs(total_volume) > eps
+ safe_total_volume = jnp.where(is_valid_volume, total_volume, 1.0)
+ mass = jnp.where(is_valid_volume, total_volume * density, 0.0)
+
+ # Compute center of mass
+ tet_coms = (A + B + C) / 4.0
+ com_position = jnp.where(
+ is_valid_volume,
+ jnp.sum(tet_coms * tetrahedron_volumes[:, None], axis=0)
+ / safe_total_volume,
+ jnp.zeros(3, dtype=vertices.dtype),
+ )
+
+ # Compute inertia tensor with covariance approach
+ def compute_tetrahedron_covariance(a, b, c, vol):
+ s = a + b + c
+ return (vol / 20.0) * (
+ jnp.outer(a, a) + jnp.outer(b, b) + jnp.outer(c, c) + jnp.outer(s, s)
+ )
+
+ covariance_matrices = jax.vmap(compute_tetrahedron_covariance)(
+ A, B, C, tetrahedron_volumes
)
+ Σ_origin = jnp.sum(covariance_matrices, axis=0)
+
+ # Shift to CoM using parallel axis theorem
+ Σ_com = Σ_origin * density - mass * jnp.outer(com_position, com_position)
+
+ # Convert covariance to inertia tensor
+ I_com = jnp.trace(Σ_com) * jnp.eye(3, dtype=vertices.dtype) - Σ_com
+ I_com = jnp.where(
+ is_valid_volume, I_com, jnp.zeros((3, 3), dtype=vertices.dtype)
+ )
+
+ return mass, com_position, I_com
@staticmethod
def compute_mass_and_inertia(
@@ -1015,16 +1101,90 @@ def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:
return mass, inertia
- def compute_mass_inertia(shape_idx, dims, density):
- return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density)
+ def compute_mass_inertia_primitive(shape_idx, dims, density):
+ def unsupported_case(_):
+ return (
+ jnp.asarray(0.0, dtype=density.dtype),
+ jnp.zeros((3, 3), dtype=density.dtype),
+ )
- mass, inertia = jax.vmap(compute_mass_inertia)(
- hw_link_metadata.link_shape,
- hw_link_metadata.geometry,
- hw_link_metadata.density,
+ def supported_case(idx):
+ return jax.lax.switch(idx, (box, cylinder, sphere), dims, density)
+
+ return jax.lax.cond(
+ shape_idx < 0, unsupported_case, supported_case, shape_idx
+ )
+
+ # For models with mesh data, we need to handle Static heterogeneous mesh arrays
+ has_mesh_data = (
+ hw_link_metadata.mesh_vertices is not None
+ and hw_link_metadata.mesh_faces is not None
)
- return mass, inertia
+ if has_mesh_data:
+ mesh_verts_tuple = hw_link_metadata.mesh_vertices
+ mesh_faces_tuple = hw_link_metadata.mesh_faces
+ n_links = len(mesh_verts_tuple)
+
+ # Build per-link compute functions that capture mesh data in closures
+ # This loop runs once at trace time to build the computation graph
+ compute_fns = []
+ for i in range(n_links):
+ if mesh_verts_tuple[i] is not None:
+ # Capture this link's mesh data in the closure
+ verts_data = jnp.array(mesh_verts_tuple[i].get())
+ faces_data = jnp.array(mesh_faces_tuple[i].get())
+
+ def make_fn(verts, faces):
+ def link_fn(shape, dims, density):
+ def mesh_branch():
+ scaled_vertices = verts * dims
+ mass_m, _, inertia_m = (
+ HwLinkMetadata.compute_mesh_inertia(
+ scaled_vertices, faces, density
+ )
+ )
+ return mass_m, inertia_m
+
+ def primitive_branch():
+ return compute_mass_inertia_primitive(
+ shape, dims, density
+ )
+
+ return jax.lax.cond(
+ shape == LinkParametrizableShape.Mesh,
+ mesh_branch,
+ primitive_branch,
+ )
+
+ return link_fn
+
+ compute_fns.append(make_fn(verts_data, faces_data))
+ else:
+ # No mesh data for this link
+ def link_fn(shape, dims, density):
+ return compute_mass_inertia_primitive(shape, dims, density)
+
+ compute_fns.append(link_fn)
+
+ def compute_single_link(link_idx, shape, dims, density):
+ return jax.lax.switch(link_idx, compute_fns, shape, dims, density)
+
+ masses, inertias = jax.vmap(compute_single_link)(
+ jnp.arange(n_links),
+ hw_link_metadata.link_shape,
+ hw_link_metadata.geometry,
+ hw_link_metadata.density,
+ )
+ else:
+ # No mesh data - pure primitives with simple vmap
+ masses, inertias = jax.vmap(compute_mass_inertia_primitive)(
+ hw_link_metadata.link_shape,
+ hw_link_metadata.geometry,
+ hw_link_metadata.density,
+ )
+
+ return masses, inertias
@staticmethod
def _convert_scaling_to_3d_vector(
@@ -1034,7 +1194,7 @@ def _convert_scaling_to_3d_vector(
Convert scaling factors for specific shape dimensions into a 3D scaling vector.
Args:
- link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder).
+ link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder, mesh).
scaling_factors: The scaling factors for the shape dimensions.
Returns:
@@ -1045,17 +1205,20 @@ def _convert_scaling_to_3d_vector(
- Box: [lx, ly, lz]
- Cylinder: [r, r, l]
- Sphere: [r, r, r]
+ - Mesh: [sx, sy, sz]
"""
# Index mapping for each shape type (link_shapes x 3 dims)
# Box: [lx, ly, lz] -> [0, 1, 2]
# Cylinder: [r, r, l] -> [0, 0, 1]
# Sphere: [r, r, r] -> [0, 0, 0]
+ # Mesh: [sx, sy, sz] -> [0, 1, 2]
shape_indices = jnp.array(
[
[0, 1, 2], # Box
[0, 0, 1], # Cylinder
[0, 0, 0], # Sphere
+ [0, 1, 2], # Mesh
]
)
@@ -1117,9 +1280,19 @@ def box(parent_idx, L_p_C):
]
)
+ def mesh(parent_idx, L_p_C):
+ sx, sy, sz = scaling_factors.dims[parent_idx]
+ return jnp.hstack(
+ [
+ L_p_C[0] * sx,
+ L_p_C[1] * sy,
+ L_p_C[2] * sz,
+ ]
+ )
+
new_positions = jax.vmap(
lambda shape_idx, parent_idx, L_p_C: jax.lax.switch(
- shape_idx, (box, cylinder, sphere), parent_idx, L_p_C
+ shape_idx, (box, cylinder, sphere, mesh), parent_idx, L_p_C
)
)(
parent_link_shapes,
diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py
index 2b88ed212..f40ebd2d5 100644
--- a/src/jaxsim/api/model.py
+++ b/src/jaxsim/api/model.py
@@ -30,6 +30,7 @@
from jaxsim.parsers.descriptions import ModelDescription
from jaxsim.parsers.descriptions.joint import JointDescription
from jaxsim.parsers.descriptions.link import LinkDescription
+from jaxsim.parsers.rod.utils import prepare_mesh_for_parametrization
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
from .common import VelRepr
@@ -385,6 +386,10 @@ def compute_hw_link_metadata(
L_H_vises = []
L_H_pre_masks = []
L_H_pre = []
+ mesh_vertices = []
+ mesh_faces = []
+ mesh_offsets = []
+ mesh_uris = []
# Process each link, only parametrizing those in parametrized_links if provided
for link_description in ordered_links:
@@ -399,6 +404,10 @@ def compute_hw_link_metadata(
L_H_vises.append(jnp.eye(4))
L_H_pre_masks.append([0] * self.number_of_joints())
L_H_pre.append([jnp.eye(4)] * self.number_of_joints())
+ mesh_vertices.append(None)
+ mesh_faces.append(None)
+ mesh_offsets.append(None)
+ mesh_uris.append(None)
continue
rod_link = rod_links_dict.get(link_name)
@@ -432,7 +441,8 @@ def compute_hw_link_metadata(
v
for v in rod_link.visuals()
if isinstance(
- v.geometry.geometry(), (rod.Box, rod.Sphere, rod.Cylinder)
+ v.geometry.geometry(),
+ (rod.Box, rod.Sphere, rod.Cylinder, rod.Mesh),
)
),
None,
@@ -444,21 +454,75 @@ def compute_hw_link_metadata(
geometry = (
supported_visual.geometry.geometry() if supported_visual else None
)
+
if isinstance(geometry, rod.Box):
lx, ly, lz = geometry.size
density = mass / (lx * ly * lz)
geom = [lx, ly, lz]
shape = LinkParametrizableShape.Box
+ mesh_vertices.append(None)
+ mesh_faces.append(None)
+ mesh_offsets.append(None)
+ mesh_uris.append(None)
elif isinstance(geometry, rod.Sphere):
r = geometry.radius
density = mass / (4 / 3 * jnp.pi * r**3)
geom = [r, 0, 0]
shape = LinkParametrizableShape.Sphere
+ mesh_vertices.append(None)
+ mesh_faces.append(None)
+ mesh_offsets.append(None)
+ mesh_uris.append(None)
elif isinstance(geometry, rod.Cylinder):
r, l = geometry.radius, geometry.length
density = mass / (jnp.pi * r**2 * l)
geom = [r, l, 0]
shape = LinkParametrizableShape.Cylinder
+ mesh_vertices.append(None)
+ mesh_faces.append(None)
+ mesh_offsets.append(None)
+ mesh_uris.append(None)
+ elif isinstance(geometry, rod.Mesh):
+ # Load and prepare mesh for parametric scaling
+ try:
+
+ mesh_data = prepare_mesh_for_parametrization(
+ mesh_uri=geometry.uri,
+ scale=geometry.scale,
+ )
+
+ density = (
+ mass / mesh_data["volume"] if mesh_data["volume"] > 0 else 0.0
+ )
+
+ # For meshes, store cumulative scale factors (initially 1.0) in geometry
+ # instead of bounding box extents. This allows proper multiplicative scaling.
+ geom = [1.0, 1.0, 1.0]
+ shape = LinkParametrizableShape.Mesh
+
+ # Store mesh data
+ mesh_vertices.append(mesh_data["vertices"])
+ mesh_faces.append(mesh_data["faces"])
+ mesh_offsets.append(mesh_data["offset"])
+ mesh_uris.append(mesh_data["uri"])
+
+ logging.info(
+ f"Loaded mesh for link '{link_name}': "
+ f"{len(mesh_data['vertices'])} vertices, "
+ f"{len(mesh_data['faces'])} faces, "
+ )
+ except Exception as e:
+ logging.warning(
+ f"Failed to load mesh for link '{link_name}': {e}. "
+ f"Marking as unsupported."
+ )
+ density = 0.0
+ geom = [0, 0, 0]
+ shape = LinkParametrizableShape.Unsupported
+ mesh_vertices.append(None)
+ mesh_faces.append(None)
+ mesh_offsets.append(None)
+ mesh_uris.append(None)
else:
logging.debug(
f"Skipping link '{link_name}' for hardware parametrization due to unsupported geometry."
@@ -466,6 +530,10 @@ def compute_hw_link_metadata(
density = 0.0
geom = [0, 0, 0]
shape = LinkParametrizableShape.Unsupported
+ mesh_vertices.append(None)
+ mesh_faces.append(None)
+ mesh_offsets.append(None)
+ mesh_uris.append(None)
inertial_pose = (
rod_link.inertial.pose.transform() if rod_link else jnp.eye(4)
@@ -501,6 +569,13 @@ def compute_hw_link_metadata(
return HwLinkMetadata.empty()
# Stack collected data into JAX arrays
+ # Handle L_H_pre specially: ensure shape (n_links, n_joints, 4, 4) even when n_joints=0
+ L_H_pre_array = jnp.array(L_H_pre, dtype=float)
+ if self.number_of_joints() == 0:
+ # Reshape from (n_links, 0) to (n_links, 0, 4, 4)
+ n_links = len(L_H_pre)
+ L_H_pre_array = L_H_pre_array.reshape(n_links, 0, 4, 4)
+
return HwLinkMetadata(
link_shape=jnp.array(shapes, dtype=int),
geometry=jnp.array(geoms, dtype=float),
@@ -508,7 +583,34 @@ def compute_hw_link_metadata(
L_H_G=jnp.array(L_H_Gs, dtype=float),
L_H_vis=jnp.array(L_H_vises, dtype=float),
L_H_pre_mask=jnp.array(L_H_pre_masks, dtype=bool),
- L_H_pre=jnp.array(L_H_pre, dtype=float),
+ L_H_pre=L_H_pre_array,
+ mesh_vertices=(
+ tuple(
+ wrappers.HashedNumpyArray(array=v) if v is not None else None
+ for v in mesh_vertices
+ )
+ if any(v is not None for v in mesh_vertices)
+ else None
+ ),
+ mesh_faces=(
+ tuple(
+ wrappers.HashedNumpyArray(array=f) if f is not None else None
+ for f in mesh_faces
+ )
+ if any(f is not None for f in mesh_faces)
+ else None
+ ),
+ mesh_offset=(
+ tuple(
+ wrappers.HashedNumpyArray(array=o) if o is not None else None
+ for o in mesh_offsets
+ )
+ if any(o is not None for o in mesh_offsets)
+ else None
+ ),
+ mesh_uri=(
+ tuple(mesh_uris) if any(u is not None for u in mesh_uris) else None
+ ),
)
def export_updated_model(self) -> str:
@@ -548,6 +650,85 @@ def export_updated_model(self) -> str:
# Iterate over the hardware metadata to update the ROD model
hw_metadata = self.kin_dyn_parameters.hw_link_metadata
+ reduced_link_names = set(self.link_names())
+ reduced_joint_names = set(self.joint_names())
+ unit_scale = np.ones(3, dtype=float)
+ link_scale_factors: dict[str, np.ndarray] = {}
+
+ def collect_link_elements(link) -> list:
+ elements_to_update_raw = (link.visual, link.collision)
+ elements_to_update = []
+ for entry in elements_to_update_raw:
+ if entry is None:
+ continue
+ if isinstance(entry, (list, tuple)):
+ elements_to_update.extend(e for e in entry if e is not None)
+ else:
+ elements_to_update.append(entry)
+ return elements_to_update
+
+ def scale_pose_translation(element, scale_vector):
+ if getattr(element, "pose", None) is None:
+ return
+ transform = np.array(element.pose.transform(), dtype=float)
+ transform[0:3, 3] = scale_vector * transform[0:3, 3]
+ element.pose = rod.Pose.from_transform(
+ transform=transform,
+ relative_to=element.pose.relative_to,
+ )
+
+ def scale_link_elements(
+ elements_to_update: list,
+ scale_vector: np.ndarray,
+ *,
+ mesh_pose: rod.Pose | None = None,
+ mesh_shape_link: bool = False,
+ ) -> None:
+ for element in elements_to_update:
+ if (
+ element is None
+ or not hasattr(element, "geometry")
+ or element.geometry is None
+ ):
+ continue
+
+ geometry = element.geometry
+ if getattr(geometry, "box", None) is not None:
+ current_size = np.array(geometry.box.size, dtype=float)
+ geometry.box.size = tuple(
+ float(v) for v in (current_size * scale_vector).tolist()
+ )
+ scale_pose_translation(element, scale_vector)
+ elif getattr(geometry, "sphere", None) is not None:
+ geometry.sphere.radius = float(
+ float(geometry.sphere.radius) * float(scale_vector[0])
+ )
+ scale_pose_translation(element, scale_vector)
+ elif getattr(geometry, "cylinder", None) is not None:
+ geometry.cylinder.radius = float(
+ float(geometry.cylinder.radius) * float(scale_vector[0])
+ )
+ geometry.cylinder.length = float(
+ float(geometry.cylinder.length) * float(scale_vector[2])
+ )
+ scale_pose_translation(element, scale_vector)
+ elif getattr(geometry, "mesh", None) is not None:
+ base_scale = (
+ np.array(geometry.mesh.scale, dtype=float)
+ if geometry.mesh.scale is not None
+ else unit_scale
+ )
+ geometry.mesh.scale = tuple(
+ float(v) for v in (base_scale * scale_vector).tolist()
+ )
+
+ # Mesh-parametrized reduced links use metadata to preserve
+ # the main visual placement in the exported URDF.
+ if mesh_shape_link and mesh_pose is not None:
+ element.pose = mesh_pose
+ else:
+ scale_pose_translation(element, scale_vector)
+
for link_index, link_name in enumerate(self.link_names()):
if link_name not in links_dict:
continue
@@ -578,41 +759,73 @@ def export_updated_model(self) -> str:
inertia_tensor=inertia_tensor, validate=True
)
- # Update visuals and collisions
- dims = hw_metadata.geometry[link_index]
+ dims = np.array(hw_metadata.geometry[link_index], dtype=float)
+ elements_to_update = collect_link_elements(links_dict[link_name])
+
+ def find_reference_geometry(attr: str, elements: list = elements_to_update):
+ for element in elements:
+ if (
+ element is None
+ or not hasattr(element, "geometry")
+ or element.geometry is None
+ ):
+ continue
+ geometry = getattr(element.geometry, attr, None)
+ if geometry is not None:
+ return geometry
+ return None
+
+ if shape == LinkParametrizableShape.Mesh:
+ scale_vector = dims
+ elif shape == LinkParametrizableShape.Box:
+ ref_box = find_reference_geometry("box")
+ if ref_box is None:
+ scale_vector = unit_scale
+ else:
+ base_size = np.array(ref_box.size, dtype=float)
+ scale_vector = np.divide(
+ dims,
+ base_size,
+ out=np.ones(3, dtype=float),
+ where=np.abs(base_size) > 1e-12,
+ )
+ elif shape == LinkParametrizableShape.Sphere:
+ ref_sphere = find_reference_geometry("sphere")
+ base_radius = (
+ float(ref_sphere.radius) if ref_sphere is not None else 1.0
+ )
+ s = float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0
+ scale_vector = np.array([s, s, s], dtype=float)
+ elif shape == LinkParametrizableShape.Cylinder:
+ ref_cylinder = find_reference_geometry("cylinder")
+ base_radius = (
+ float(ref_cylinder.radius) if ref_cylinder is not None else 1.0
+ )
+ base_length = (
+ float(ref_cylinder.length) if ref_cylinder is not None else 1.0
+ )
+ s_radius = (
+ float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0
+ )
+ s_length = (
+ float(dims[1]) / base_length if abs(base_length) > 1e-12 else 1.0
+ )
+ scale_vector = np.array([s_radius, s_radius, s_length], dtype=float)
+ else:
+ scale_vector = unit_scale
- elements_to_update = (
- links_dict[link_name].visual,
- links_dict[link_name].collision,
- )
+ link_scale_factors[link_name] = np.array(scale_vector, dtype=float)
element_pose = rod.Pose.from_transform(
transform=np.array(hw_metadata.L_H_vis[link_index]),
relative_to=link_name,
)
-
- for element in elements_to_update:
- if element is None:
- continue
-
- # Update geometry
- if shape == LinkParametrizableShape.Box:
- element.geometry.box.size = dims.tolist()
- elif shape == LinkParametrizableShape.Sphere:
- element.geometry.sphere.radius = float(dims[0])
- elif shape == LinkParametrizableShape.Cylinder:
- element.geometry.cylinder.radius = float(dims[0])
- element.geometry.cylinder.length = float(dims[1])
- else:
- # This branch should be unreachable. Unsupported shapes should be
- # filtered out above.
- raise RuntimeError(
- f"Unexpected shape {shape} for link '{link_name}'. "
- "This should never be hit."
- )
-
- # Update pose
- element.pose = element_pose
+ scale_link_elements(
+ elements_to_update=elements_to_update,
+ scale_vector=scale_vector,
+ mesh_pose=element_pose,
+ mesh_shape_link=(shape == LinkParametrizableShape.Mesh),
+ )
# Update joint poses
for joint_index in range(self.number_of_joints()):
@@ -628,6 +841,51 @@ def export_updated_model(self) -> str:
relative_to=link_name,
)
+ # Propagate link scaling to descendants connected through fixed joints.
+ # These links are typically reduced away in the JaxSim model (e.g. feet
+ # attached to ankles) but still exist in the exported URDF tree.
+ updated = True
+ while updated:
+ updated = False
+ for joint in joints_dict.values():
+ if joint.type != "fixed":
+ continue
+ parent_scale = link_scale_factors.get(joint.parent, None)
+ if parent_scale is None or joint.child in link_scale_factors:
+ continue
+ link_scale_factors[joint.child] = np.array(parent_scale, dtype=float)
+ updated = True
+
+ # Scale fixed-joint offsets that are not part of the reduced joint set.
+ for joint_name, joint in joints_dict.items():
+ if joint.type != "fixed" or joint_name in reduced_joint_names:
+ continue
+ parent_scale = link_scale_factors.get(joint.parent, unit_scale)
+ if np.allclose(parent_scale, unit_scale):
+ continue
+ if joint.pose is None:
+ continue
+ transform = np.array(joint.pose.transform(), dtype=float)
+ transform[0:3, 3] = parent_scale * transform[0:3, 3]
+ joint.pose = rod.Pose.from_transform(
+ transform=transform,
+ relative_to=joint.pose.relative_to,
+ )
+
+ # Apply inherited scaling to non-reduced links (typically descendants
+ # connected via fixed joints).
+ for link_name, scale_vector in link_scale_factors.items():
+ if link_name in reduced_link_names:
+ continue
+ if np.allclose(scale_vector, unit_scale):
+ continue
+ if link_name not in links_dict:
+ continue
+ scale_link_elements(
+ elements_to_update=collect_link_elements(links_dict[link_name]),
+ scale_vector=scale_vector,
+ )
+
# Restore continuous joint types for joints with infinite limits
# to ensure valid URDF export (continuous joints should not have limits).
# Continuous joints are internally represented as revolute with infinite
@@ -2490,21 +2748,127 @@ def update_hw_parameters(
has_joints = model.number_of_joints() > 0
- supported_case = lambda hw_metadata, scaling_factors: HwLinkMetadata.apply_scaling(
- hw_metadata=hw_metadata, scaling_factors=scaling_factors, has_joints=has_joints
+ def apply_scaling_single_link(
+ link_shape,
+ geometry,
+ density,
+ L_H_G,
+ L_H_vis,
+ L_H_pre,
+ L_H_pre_mask,
+ scaling_dims,
+ scaling_density,
+ ):
+ """Apply scaling to a single link's numerical data."""
+
+ def scale_supported(_):
+ shape_indices_map = jnp.array([[0, 1, 2], [0, 0, 1], [0, 0, 0], [0, 1, 2]])
+ per_link_indices = shape_indices_map[link_shape]
+ scale_vector = scaling_dims[per_link_indices]
+
+ # Update kinematics
+ G_H_L = jaxsim.math.Transform.inverse(L_H_G)
+ G_H_vis = G_H_L @ L_H_vis
+ G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])
+ L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3])
+ L_H̅_vis = L_H̅_G @ G_H̅_vis
+
+ # Update shape parameters
+ updated_geom = geometry * scaling_dims
+ updated_dens = density * scaling_density
+
+ return updated_geom, updated_dens, L_H̅_G, L_H̅_vis, scale_vector
+
+ def scale_unsupported(_):
+ return (
+ geometry,
+ density,
+ L_H_G,
+ L_H_vis,
+ jnp.ones_like(scaling_dims),
+ )
+
+ return jax.lax.cond(
+ link_shape == LinkParametrizableShape.Unsupported,
+ scale_unsupported,
+ scale_supported,
+ operand=None,
+ )
+
+ # Vmap over all links for basic scaling
+ (
+ updated_geometry,
+ updated_density,
+ updated_L_H_G,
+ updated_L_H_vis,
+ scale_vectors,
+ ) = jax.vmap(apply_scaling_single_link)(
+ hw_link_metadata.link_shape,
+ hw_link_metadata.geometry,
+ hw_link_metadata.density,
+ hw_link_metadata.L_H_G,
+ hw_link_metadata.L_H_vis,
+ hw_link_metadata.L_H_pre,
+ hw_link_metadata.L_H_pre_mask,
+ scaling_factors.dims,
+ scaling_factors.density,
)
- unsupported_case = lambda hw_metadata, scaling_factors: hw_metadata
-
- # Apply scaling to hw_link_metadata using vmap
- updated_hw_link_metadata = jax.vmap(
- lambda hw_metadata, multipliers: jax.lax.cond(
- hw_metadata.link_shape == LinkParametrizableShape.Unsupported,
- unsupported_case,
- supported_case,
- hw_metadata,
- multipliers,
+
+ # Handle joint transforms separately, only if model has joints
+ def transform_all_joints(operands):
+ """Transform all joint poses across all links."""
+ original_L_H_G, updated_L_H_G, scale_vectors, L_H_pre, L_H_pre_mask = operands
+
+ # Vectorized transformation: (n_links, n_joints, 4, 4)
+ # Express joint transforms in the original CoM frames.
+ # Using the already-scaled L_H_G here introduces a second implicit
+ # scaling term and distorts kinematic chain proportions.
+ G_H_L_all = jax.vmap(jaxsim.math.Transform.inverse)(
+ original_L_H_G
+ ) # (n_links, 4, 4)
+
+ # Use batch matrix multiply with broadcasting
+ # G_H_L_all: (n_links, 4, 4) -> (n_links, 1, 4, 4)
+ # L_H_pre: (n_links, n_joints, 4, 4)
+ # Result: (n_links, n_joints, 4, 4)
+ G_H_pre = G_H_L_all[:, None, :, :] @ L_H_pre
+
+ # Scale translation components
+ G_H̅_pre = G_H_pre.at[:, :, :3, 3].set(
+ jnp.where(
+ L_H_pre_mask[:, :, None],
+ scale_vectors[:, None, :] * G_H_pre[:, :, :3, 3],
+ G_H_pre[:, :, :3, 3],
+ )
)
- )(hw_link_metadata, scaling_factors)
+
+ # Transform back to link frames
+ # updated_L_H_G: (n_links, 4, 4) -> (n_links, 1, 4, 4)
+ # G_H̅_pre: (n_links, n_joints, 4, 4)
+ # Result: (n_links, n_joints, 4, 4)
+ return updated_L_H_G[:, None, :, :] @ G_H̅_pre
+
+ updated_L_H_pre = jax.lax.cond(
+ has_joints,
+ transform_all_joints,
+ lambda operands: operands[3], # Return L_H_pre unchanged
+ operand=(
+ hw_link_metadata.L_H_G,
+ updated_L_H_G,
+ scale_vectors,
+ hw_link_metadata.L_H_pre,
+ hw_link_metadata.L_H_pre_mask,
+ ),
+ )
+
+ # Create updated HwLinkMetadata
+ updated_hw_link_metadata = hw_link_metadata.replace(
+ geometry=updated_geometry,
+ density=updated_density,
+ L_H_G=updated_L_H_G,
+ L_H_vis=updated_L_H_vis,
+ L_H_pre=updated_L_H_pre,
+ )
# Compute mass and inertia once and unpack the results
m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia(
diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py
index cf0bcb107..68d4b924b 100644
--- a/src/jaxsim/math/__init__.py
+++ b/src/jaxsim/math/__init__.py
@@ -9,6 +9,5 @@
from .joint_model import JointModel, supported_joint_motion # isort:skip
-
# Define the default standard gravity constant.
STANDARD_GRAVITY = 9.81
diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py
index a295b7fab..025d4b23e 100644
--- a/src/jaxsim/parsers/rod/utils.py
+++ b/src/jaxsim/parsers/rod/utils.py
@@ -278,3 +278,62 @@ def create_mesh_collision(
]
return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)
+
+
+def prepare_mesh_for_parametrization(
+ mesh_uri: str, scale: tuple[float, float, float] = (1.0, 1.0, 1.0)
+) -> dict:
+ """
+ Load and prepare a mesh for parametric scaling with exact inertia computation.
+
+ This function loads a mesh, ensures it's watertight (crucial for volume/inertia
+ calculation), centers it, and returns the data needed for parametric scaling.
+
+ Args:
+ mesh_uri: URI/path to the mesh file.
+ scale: Initial scale factors to apply (from SDF/URDF).
+
+ Returns:
+ A dictionary containing:
+ - 'vertices': Centered mesh vertices as numpy array (Nx3)
+ - 'faces': Triangle faces as numpy array (Mx3 integer indices)
+ - 'offset': Original mesh centroid offset as numpy array (3,)
+ - 'uri': The mesh URI for reference
+ - 'is_watertight': Boolean indicating if mesh is watertight
+ - 'volume': The volume of the mesh (after scaling)
+ """
+
+ # Load mesh
+ file = pathlib.Path(resolve_local_uri(uri=mesh_uri))
+ file_type = file.suffix.replace(".", "")
+ mesh = trimesh.load_mesh(file, file_type=file_type)
+
+ if mesh.is_empty:
+ raise RuntimeError(f"Failed to process '{file}' with trimesh")
+
+ # Apply initial scale from SDF/URDF
+ mesh.apply_scale(scale)
+
+ # Check and fix watertightness
+ is_watertight = mesh.is_watertight
+ if not is_watertight:
+ logging.warning(
+ f"Mesh {mesh_uri} is not watertight. Computing convex hull for valid inertia."
+ )
+ mesh = mesh.convex_hull
+ is_watertight = True
+
+ # Store original centroid as offset
+ offset = mesh.centroid.copy()
+
+ # Center the mesh
+ mesh.vertices -= offset
+
+ return {
+ "vertices": np.array(mesh.vertices, dtype=np.float64),
+ "faces": np.array(mesh.faces, dtype=np.int32),
+ "offset": np.array(offset, dtype=np.float64),
+ "uri": mesh_uri,
+ "is_watertight": is_watertight,
+ "volume": mesh.volume,
+ }
diff --git a/tests/assets/cube.stl b/tests/assets/cube.stl
new file mode 100644
index 000000000..4030ac997
--- /dev/null
+++ b/tests/assets/cube.stl
@@ -0,0 +1,86 @@
+solid model
+facet normal 0.0 0.0 -1.0
+outer loop
+vertex 20.0 0.0 0.0
+vertex 0.0 -20.0 0.0
+vertex 0.0 0.0 0.0
+endloop
+endfacet
+facet normal 0.0 0.0 -1.0
+outer loop
+vertex 0.0 -20.0 0.0
+vertex 20.0 0.0 0.0
+vertex 20.0 -20.0 0.0
+endloop
+endfacet
+facet normal -0.0 -1.0 -0.0
+outer loop
+vertex 20.0 -20.0 20.0
+vertex 0.0 -20.0 0.0
+vertex 20.0 -20.0 0.0
+endloop
+endfacet
+facet normal -0.0 -1.0 -0.0
+outer loop
+vertex 0.0 -20.0 0.0
+vertex 20.0 -20.0 20.0
+vertex 0.0 -20.0 20.0
+endloop
+endfacet
+facet normal 1.0 0.0 0.0
+outer loop
+vertex 20.0 0.0 0.0
+vertex 20.0 -20.0 20.0
+vertex 20.0 -20.0 0.0
+endloop
+endfacet
+facet normal 1.0 0.0 0.0
+outer loop
+vertex 20.0 -20.0 20.0
+vertex 20.0 0.0 0.0
+vertex 20.0 0.0 20.0
+endloop
+endfacet
+facet normal -0.0 -0.0 1.0
+outer loop
+vertex 20.0 -20.0 20.0
+vertex 0.0 0.0 20.0
+vertex 0.0 -20.0 20.0
+endloop
+endfacet
+facet normal -0.0 -0.0 1.0
+outer loop
+vertex 0.0 0.0 20.0
+vertex 20.0 -20.0 20.0
+vertex 20.0 0.0 20.0
+endloop
+endfacet
+facet normal -1.0 0.0 0.0
+outer loop
+vertex 0.0 0.0 20.0
+vertex 0.0 -20.0 0.0
+vertex 0.0 -20.0 20.0
+endloop
+endfacet
+facet normal -1.0 0.0 0.0
+outer loop
+vertex 0.0 -20.0 0.0
+vertex 0.0 0.0 20.0
+vertex 0.0 0.0 0.0
+endloop
+endfacet
+facet normal -0.0 1.0 0.0
+outer loop
+vertex 0.0 0.0 20.0
+vertex 20.0 0.0 0.0
+vertex 0.0 0.0 0.0
+endloop
+endfacet
+facet normal -0.0 1.0 0.0
+outer loop
+vertex 20.0 0.0 0.0
+vertex 0.0 0.0 20.0
+vertex 20.0 0.0 20.0
+endloop
+endfacet
+endsolid model
diff --git a/tests/assets/mixed_shapes_robot.urdf b/tests/assets/mixed_shapes_robot.urdf
new file mode 100644
index 000000000..a4bb34e11
--- /dev/null
+++ b/tests/assets/mixed_shapes_robot.urdf
@@ -0,0 +1,112 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/assets/test_cube.urdf b/tests/assets/test_cube.urdf
new file mode 100644
index 000000000..1fe648cd6
--- /dev/null
+++ b/tests/assets/test_cube.urdf
@@ -0,0 +1,37 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py
index 1fffcde51..98601f91c 100644
--- a/tests/test_api_model_hw_parametrization.py
+++ b/tests/test_api_model_hw_parametrization.py
@@ -3,11 +3,16 @@
import jax
import jax.numpy as jnp
+import numpy as np
import pytest
import rod
import jaxsim.api as js
-from jaxsim.api.kin_dyn_parameters import HwLinkMetadata, ScalingFactors
+from jaxsim.api.kin_dyn_parameters import (
+ HwLinkMetadata,
+ LinkParametrizableShape,
+ ScalingFactors,
+)
from jaxsim.rbda.contacts import SoftContactsParams
from .utils import assert_allclose
@@ -851,6 +856,119 @@ def test_export_model_with_missing_collision(
assert exported_model.name == model.name(), "Exported model name should match"
# Verify we can build a JaxSim model from the exported URDF
+ _ = js.model.JaxSimModel.build_from_model_description(
+ model_description=exported_urdf, is_urdf=True
+ )
+
+
+def test_export_mesh_scaling_preserves_nonzero_visual_and_joint_origins(
+ tmp_path: pathlib.Path,
+):
+ """
+ Regression test for mesh export:
+ non-identity scaling must preserve non-zero visual/joint origins in the URDF.
+ """
+
+ mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl"
+ if not mesh_file.exists():
+ pytest.skip(f"Test mesh file not found: {mesh_file}")
+
+ urdf_path = tmp_path / "mesh_origin_regression.urdf"
+ urdf_path.write_text(
+ f"""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+""",
+ encoding="utf-8",
+ )
+
+ model = js.model.JaxSimModel.build_from_model_description(
+ model_description=urdf_path,
+ is_urdf=True,
+ parametrized_links=("mesh_link",),
+ )
+
+ mesh_link_idx = js.link.name_to_idx(model=model, link_name="mesh_link")
+ dims = jnp.ones((model.number_of_links(), 3))
+ dims = dims.at[mesh_link_idx].set(jnp.array([1.7, 0.8, 1.3]))
+ scaling = ScalingFactors(dims=dims, density=jnp.ones(model.number_of_links()))
+
+ updated_model = js.model.update_hw_parameters(model=model, scaling_factors=scaling)
+ exported_urdf = updated_model.export_updated_model()
+ root = ET.fromstring(exported_urdf)
+
+ visual_origin = root.find(".//link[@name='mesh_link']/visual/origin")
+ assert visual_origin is not None, "Mesh visual origin must exist in exported URDF"
+ visual_xyz = np.array([float(v) for v in visual_origin.get("xyz").split()])
+
+ expected_visual_xyz = np.array(
+ updated_model.kin_dyn_parameters.hw_link_metadata.L_H_vis[mesh_link_idx][:3, 3]
+ )
+ assert_allclose(
+ visual_xyz,
+ expected_visual_xyz,
+ atol=1e-8,
+ err_msg="Exported mesh visual origin does not match updated metadata",
+ )
+ assert not np.allclose(visual_xyz, np.zeros(3), atol=1e-12)
+
+ joint_origin = root.find(".//joint[@name='base_to_mesh']/origin")
+ assert joint_origin is not None, "Joint origin must exist in exported URDF"
+ joint_xyz = np.array([float(v) for v in joint_origin.get("xyz").split()])
+
+ joint_idx = js.joint.name_to_idx(model=updated_model, joint_name="base_to_mesh")
+ expected_joint_xyz = np.array(
+ updated_model.kin_dyn_parameters.joint_model.λ_H_pre[joint_idx + 1][:3, 3]
+ )
+ assert_allclose(
+ joint_xyz,
+ expected_joint_xyz,
+ atol=1e-8,
+ err_msg="Exported joint origin does not match updated joint transform",
+ )
+ assert not np.allclose(joint_xyz, np.zeros(3), atol=1e-12)
+
reimported_jaxsim_model = js.model.JaxSimModel.build_from_model_description(
model_description=exported_urdf, is_urdf=True
)
@@ -860,3 +978,127 @@ def test_export_model_with_missing_collision(
assert (
reimported_jaxsim_model.number_of_links() == model.number_of_links()
), "Reimported model should have same number of links"
+
+
+# =============================================================================
+# Mesh Scaling Tests
+# =============================================================================
+
+
+def test_mesh_shape_enum():
+ """Test that the Mesh shape type is available in the enum."""
+ assert hasattr(LinkParametrizableShape, "Mesh")
+ assert LinkParametrizableShape.Mesh == 3
+
+
+def test_mixed_shapes_metadata():
+ """Test loading and metadata verification for mixed primitive and mesh shapes."""
+ test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf"
+
+ if not test_urdf.exists():
+ pytest.skip(f"Test URDF not found: {test_urdf}")
+
+ mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl"
+ if not mesh_file.exists():
+ pytest.skip(f"Test mesh not found: {mesh_file}")
+
+ # Load model with all link types parametrized
+ model = js.model.JaxSimModel.build_from_model_description(
+ model_description=test_urdf,
+ is_urdf=True,
+ parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"),
+ )
+
+ assert model.name() == "mixed_shapes_robot"
+ assert model.number_of_links() == 4
+
+ hw_meta = model.kin_dyn_parameters.hw_link_metadata
+
+ # Verify all 4 links are parametrized with correct shape types
+ assert len(hw_meta.link_shape) == 4
+ assert hw_meta.link_shape[0] == LinkParametrizableShape.Box
+ assert hw_meta.link_shape[1] == LinkParametrizableShape.Cylinder
+ assert hw_meta.link_shape[2] == LinkParametrizableShape.Mesh
+ assert hw_meta.link_shape[3] == LinkParametrizableShape.Sphere
+
+ # Verify mesh data exists only for mesh link
+ assert hw_meta.mesh_vertices is not None
+ assert hw_meta.mesh_vertices[0] is None # box
+ assert hw_meta.mesh_vertices[1] is None # cylinder
+ assert hw_meta.mesh_vertices[2] is not None # mesh
+ assert hw_meta.mesh_vertices[3] is None # sphere
+ assert hw_meta.mesh_faces is not None
+ assert hw_meta.mesh_faces[2] is not None # mesh link has faces
+
+
+def test_mixed_shapes_scaling():
+ """Test uniform and non-uniform scaling with mixed primitive and mesh shapes."""
+ test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf"
+
+ if not test_urdf.exists():
+ pytest.skip(f"Test URDF not found: {test_urdf}")
+
+ mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl"
+ if not mesh_file.exists():
+ pytest.skip(f"Test mesh not found: {mesh_file}")
+
+ model = js.model.JaxSimModel.build_from_model_description(
+ model_description=test_urdf,
+ is_urdf=True,
+ parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"),
+ )
+
+ hw_meta = model.kin_dyn_parameters.hw_link_metadata
+ if len(hw_meta.link_shape) == 0:
+ pytest.skip("Hardware parametrization not supported")
+
+ # Get original masses
+ masses_orig = {}
+ for i in range(model.number_of_links()):
+ link_name = js.link.idx_to_name(model=model, link_index=i)
+ masses_orig[link_name] = float(model.kin_dyn_parameters.link_parameters.mass[i])
+
+ # Test uniform scaling (2x), so all links should scaled by 8x
+ uniform_scaling = ScalingFactors(
+ dims=jnp.ones((4, 3)) * 2.0,
+ density=jnp.ones(4),
+ )
+ scaled_uniform = js.model.update_hw_parameters(model, uniform_scaling)
+
+ for i in range(scaled_uniform.number_of_links()):
+ link_name = js.link.idx_to_name(model=scaled_uniform, link_index=i)
+ mass_scaled = float(scaled_uniform.kin_dyn_parameters.link_parameters.mass[i])
+ ratio = mass_scaled / masses_orig[link_name]
+ assert jnp.allclose(
+ ratio, 8.0, rtol=0.1
+ ), f"Uniform scaling: {link_name} expected 8x, got {ratio:.2f}x"
+
+ # Test different scaling factors per link
+ different_scaling = ScalingFactors(
+ dims=jnp.array(
+ [
+ [2.0, 2.0, 2.0], # box: 8x
+ [3.0, 3.0, 3.0], # cylinder: 27x
+ [1.5, 1.5, 1.5], # mesh: 3.375x
+ [2.5, 2.5, 2.5], # sphere: 15.625x
+ ]
+ ),
+ density=jnp.ones(4),
+ )
+ scaled_different = js.model.update_hw_parameters(model, different_scaling)
+
+ expected_ratios = {
+ "box_link": 8.0,
+ "cylinder_link": 27.0,
+ "mesh_link": 3.375,
+ "sphere_link": 15.625,
+ }
+
+ for i in range(scaled_different.number_of_links()):
+ link_name = js.link.idx_to_name(model=scaled_different, link_index=i)
+ mass_scaled = float(scaled_different.kin_dyn_parameters.link_parameters.mass[i])
+ ratio = mass_scaled / masses_orig[link_name]
+ expected = expected_ratios[link_name]
+ assert jnp.allclose(
+ ratio, expected, rtol=0.1
+ ), f"Different scaling: {link_name} expected {expected}x, got {ratio:.2f}x"