Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ ignore = [
channels = ["conda-forge"]
platforms = ["linux-64", "linux-aarch64", "osx-arm64", "osx-64"]
requires-pixi = ">=0.39.0"
preview = ["pixi-build"]

[tool.pixi.environments]
# We resolve only two groups: cpu and gpu.
Expand Down
199 changes: 186 additions & 13 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,23 +916,29 @@ class LinkParametrizableShape:
Box: ClassVar[int] = 0
Cylinder: ClassVar[int] = 1
Sphere: ClassVar[int] = 2
Mesh: ClassVar[int] = 3


@jax_dataclasses.pytree_dataclass
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class HwLinkMetadata(JaxsimDataclass):
"""
Class storing the hardware parameters of a link.

Attributes:
link_shape: The shape of the link.
0 = box, 1 = cylinder, 2 = sphere, -1 = unsupported.
geometry: The dimensions of the link.
box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0].
0 = box, 1 = cylinder, 2 = sphere, 3 = mesh, -1 = unsupported.
geometry: Shape parameters used by HW parametrization.
box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0],
mesh: cumulative anisotropic scale factors [sx,sy,sz] (initialized to [1,1,1]).
density: The density of the link.
L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.
L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame.
L_H_pre_mask: The mask indicating the link's child joint indices.
L_H_pre: The homogeneous transforms for child joints.
mesh_vertices: The original centered mesh vertices (Nx3) for mesh shapes, None otherwise.
mesh_faces: The mesh triangle faces (Mx3 integer indices) for mesh shapes, None otherwise.
mesh_offset: The original mesh centroid offset (3D vector) for mesh shapes, None otherwise.
mesh_uri: The path to the mesh file for reference, None otherwise.
"""

link_shape: jtp.Vector
Expand All @@ -942,6 +948,10 @@ class HwLinkMetadata(JaxsimDataclass):
L_H_vis: jtp.Matrix
L_H_pre_mask: jtp.Vector
L_H_pre: jtp.Matrix
mesh_vertices: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_faces: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_offset: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_uri: Static[tuple[str | None, ...] | None]

@classmethod
def empty(cls) -> HwLinkMetadata:
Expand All @@ -954,7 +964,83 @@ def empty(cls) -> HwLinkMetadata:
L_H_vis=jnp.array([], dtype=float),
L_H_pre_mask=jnp.array([], dtype=bool),
L_H_pre=jnp.array([], dtype=float),
mesh_vertices=None,
mesh_faces=None,
mesh_offset=None,
mesh_uri=None,
)

@staticmethod
def compute_mesh_inertia(
vertices: jtp.Matrix, faces: jtp.Matrix, density: jtp.Float
) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
"""
Compute mass, center of mass, and inertia tensor from mesh geometry.

Uses the divergence theorem to compute volumetric properties by integrating
over tetrahedra formed between the mesh surface and the origin.

Args:
vertices: Mesh vertices (Nx3) in the link frame, should be centered.
faces: Triangle face indices (Mx3), integer indices into vertices array.
density: Material density.

Returns:
A tuple containing the computed mass, the CoM position and the 3x3
inertia tensor at the CoM.
"""

# Extract triangles from vertices using face indices
triangles = vertices[faces.astype(int)]
A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]

# Compute signed volume of tetrahedra relative to origin
# vol = 1/6 * (A . (B x C))
tetrahedron_volumes = jnp.sum(A * jnp.cross(B, C), axis=1) / 6.0

total_signed_volume = jnp.sum(tetrahedron_volumes)

# Normalize the global winding sign so positive density yields non-negative mass.
orientation_sign = jnp.where(total_signed_volume < 0, -1.0, 1.0)
tetrahedron_volumes = tetrahedron_volumes * orientation_sign
total_volume = jnp.sum(tetrahedron_volumes)

eps = jnp.asarray(1e-12, dtype=total_volume.dtype)
is_valid_volume = jnp.abs(total_volume) > eps
safe_total_volume = jnp.where(is_valid_volume, total_volume, 1.0)
mass = jnp.where(is_valid_volume, total_volume * density, 0.0)

# Compute center of mass
tet_coms = (A + B + C) / 4.0
com_position = jnp.where(
is_valid_volume,
jnp.sum(tet_coms * tetrahedron_volumes[:, None], axis=0)
/ safe_total_volume,
jnp.zeros(3, dtype=vertices.dtype),
)

# Compute inertia tensor with covariance approach
def compute_tetrahedron_covariance(a, b, c, vol):
s = a + b + c
return (vol / 20.0) * (
jnp.outer(a, a) + jnp.outer(b, b) + jnp.outer(c, c) + jnp.outer(s, s)
)

covariance_matrices = jax.vmap(compute_tetrahedron_covariance)(
A, B, C, tetrahedron_volumes
)
Σ_origin = jnp.sum(covariance_matrices, axis=0)

# Shift to CoM using parallel axis theorem
Σ_com = Σ_origin * density - mass * jnp.outer(com_position, com_position)

# Convert covariance to inertia tensor
I_com = jnp.trace(Σ_com) * jnp.eye(3, dtype=vertices.dtype) - Σ_com
I_com = jnp.where(
is_valid_volume, I_com, jnp.zeros((3, 3), dtype=vertices.dtype)
)

return mass, com_position, I_com

@staticmethod
def compute_mass_and_inertia(
Expand Down Expand Up @@ -1015,16 +1101,90 @@ def sphere(dims, density) -> tuple[jtp.Float, jtp.Matrix]:

return mass, inertia

def compute_mass_inertia(shape_idx, dims, density):
return jax.lax.switch(shape_idx, (box, cylinder, sphere), dims, density)
def compute_mass_inertia_primitive(shape_idx, dims, density):
def unsupported_case(_):
return (
jnp.asarray(0.0, dtype=density.dtype),
jnp.zeros((3, 3), dtype=density.dtype),
)

mass, inertia = jax.vmap(compute_mass_inertia)(
hw_link_metadata.link_shape,
hw_link_metadata.geometry,
hw_link_metadata.density,
def supported_case(idx):
return jax.lax.switch(idx, (box, cylinder, sphere), dims, density)

return jax.lax.cond(
shape_idx < 0, unsupported_case, supported_case, shape_idx
)

# For models with mesh data, we need to handle Static heterogeneous mesh arrays
has_mesh_data = (
hw_link_metadata.mesh_vertices is not None
and hw_link_metadata.mesh_faces is not None
)

return mass, inertia
if has_mesh_data:
mesh_verts_tuple = hw_link_metadata.mesh_vertices
mesh_faces_tuple = hw_link_metadata.mesh_faces
n_links = len(mesh_verts_tuple)

# Build per-link compute functions that capture mesh data in closures
# This loop runs once at trace time to build the computation graph
compute_fns = []
for i in range(n_links):
if mesh_verts_tuple[i] is not None:
# Capture this link's mesh data in the closure
verts_data = jnp.array(mesh_verts_tuple[i].get())
faces_data = jnp.array(mesh_faces_tuple[i].get())

def make_fn(verts, faces):
def link_fn(shape, dims, density):
def mesh_branch():
scaled_vertices = verts * dims
mass_m, _, inertia_m = (
HwLinkMetadata.compute_mesh_inertia(
scaled_vertices, faces, density
)
)
return mass_m, inertia_m

def primitive_branch():
return compute_mass_inertia_primitive(
shape, dims, density
)

return jax.lax.cond(
shape == LinkParametrizableShape.Mesh,
mesh_branch,
primitive_branch,
)

return link_fn

compute_fns.append(make_fn(verts_data, faces_data))
else:
# No mesh data for this link
def link_fn(shape, dims, density):
return compute_mass_inertia_primitive(shape, dims, density)

compute_fns.append(link_fn)

def compute_single_link(link_idx, shape, dims, density):
return jax.lax.switch(link_idx, compute_fns, shape, dims, density)

masses, inertias = jax.vmap(compute_single_link)(
jnp.arange(n_links),
hw_link_metadata.link_shape,
hw_link_metadata.geometry,
hw_link_metadata.density,
)
else:
# No mesh data - pure primitives with simple vmap
masses, inertias = jax.vmap(compute_mass_inertia_primitive)(
hw_link_metadata.link_shape,
hw_link_metadata.geometry,
hw_link_metadata.density,
)

return masses, inertias

@staticmethod
def _convert_scaling_to_3d_vector(
Expand All @@ -1034,7 +1194,7 @@ def _convert_scaling_to_3d_vector(
Convert scaling factors for specific shape dimensions into a 3D scaling vector.

Args:
link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder).
link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder, mesh).
scaling_factors: The scaling factors for the shape dimensions.

Returns:
Expand All @@ -1045,17 +1205,20 @@ def _convert_scaling_to_3d_vector(
- Box: [lx, ly, lz]
- Cylinder: [r, r, l]
- Sphere: [r, r, r]
- Mesh: [sx, sy, sz]
"""

# Index mapping for each shape type (link_shapes x 3 dims)
# Box: [lx, ly, lz] -> [0, 1, 2]
# Cylinder: [r, r, l] -> [0, 0, 1]
# Sphere: [r, r, r] -> [0, 0, 0]
# Mesh: [sx, sy, sz] -> [0, 1, 2]
shape_indices = jnp.array(
[
[0, 1, 2], # Box
[0, 0, 1], # Cylinder
[0, 0, 0], # Sphere
[0, 1, 2], # Mesh
]
)

Expand Down Expand Up @@ -1117,9 +1280,19 @@ def box(parent_idx, L_p_C):
]
)

def mesh(parent_idx, L_p_C):
sx, sy, sz = scaling_factors.dims[parent_idx]
return jnp.hstack(
[
L_p_C[0] * sx,
L_p_C[1] * sy,
L_p_C[2] * sz,
]
)

new_positions = jax.vmap(
lambda shape_idx, parent_idx, L_p_C: jax.lax.switch(
shape_idx, (box, cylinder, sphere), parent_idx, L_p_C
shape_idx, (box, cylinder, sphere, mesh), parent_idx, L_p_C
)
)(
parent_link_shapes,
Expand Down
Loading