Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 69 additions & 6 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,33 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp
return W_H_LL


def _transform_M_block(M_body: jtp.Matrix, X: jtp.Matrix) -> jtp.Matrix:
"""
Apply invTᵀ M_body invT with invT = diag(X, I_n), without forming invT.

Args:
M_body: (6+n, 6+n) mass matrix (inverse) in body representation.
X: (6, 6) adjoint (e.g. B_X_W or B_X_BW).

Returns:
M_repr: (6+n, 6+n) mass matrix (inverse) in the new representation.
"""

# invTᵀ M invT with invT = diag(X, I):
# Mbb' = Xᵀ Mbb X
# Mbj' = Xᵀ Mbj
# Mjb' = Mjb X
# Mjj' = Mjj
Mbb_t = X.T @ M_body[:6, :6] @ X
Mbj_t = X.T @ M_body[:6, 6:]
Mjb_t = M_body[6:, :6] @ X
Mjj_t = M_body[6:, 6:]

top = jnp.concatenate([Mbb_t, Mbj_t], axis=1)
bottom = jnp.concatenate([Mjb_t, Mjj_t], axis=1)
return jnp.concatenate([top, bottom], axis=0)


@jax.jit
@js.common.named_scope
def free_floating_mass_matrix(
Expand All @@ -1535,18 +1562,54 @@ def free_floating_mass_matrix(
return M_body

case VelRepr.Inertial:
B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)
invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
B_X_W = Adjoint.from_transform(transform=data.base_transform, inverse=True)

return invT.T @ M_body @ invT
return _transform_M_block(M_body, B_X_W)

case VelRepr.Mixed:
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
BW_H_B = data.base_transform.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))

return invT.T @ M_body @ invT
return _transform_M_block(M_body, B_X_BW)
case _:
raise ValueError(data.velocity_representation)


@jax.jit
@js.common.named_scope
def free_floating_mass_matrix_inverse(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the inverse of the free-floating mass matrix of the model
with the CRBA algorithm.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The inverse of the free-floating mass matrix of the model.
"""
M_inv_body = jaxsim.rbda.mass_inverse(
model=model,
base_position=data.base_position,
base_quaternion=data.base_orientation,
joint_positions=data.joint_positions,
)

match data.velocity_representation:
case VelRepr.Body:
return M_inv_body
case VelRepr.Inertial:
W_X_B = Adjoint.from_transform(transform=data.base_transform)

return _transform_M_block(M_inv_body, W_X_B.T)
case VelRepr.Mixed:
B_H_BW = data.base_transform.at[0:3, 3].set(jnp.zeros(3))
BW_X_B = Adjoint.from_transform(transform=B_H_BW)

return _transform_M_block(M_inv_body, BW_X_B.T)
case _:
raise ValueError(data.velocity_representation)

Expand Down
1 change: 1 addition & 0 deletions src/jaxsim/rbda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
jacobian_full_doubly_left,
)
from .kinematic_constraints import compute_constraint_wrenches
from .mass_inverse import mass_inverse
from .rnea import rnea
5 changes: 1 addition & 4 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def compute_contact_forces(
)
)

M = js.model.free_floating_mass_matrix(model=model, data=data)
M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)

# Compute the linear part of the Jacobian of the collidable points
Jl_WC = jnp.vstack(
Expand All @@ -383,9 +383,6 @@ def compute_contact_forces(
),
)

# Compute the Delassus matrix for contacts (mixed representation).
M_inv = jnp.linalg.pinv(M)

# Compute the Delassus matrix directly using J and J̇.
G_contacts = Jl_WC @ M_inv @ Jl_WC.T

Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def compute_contact_forces(
# Compute kin-dyn quantities used in the contact model.
BW_ν = data.generalized_velocity

M = js.model.free_floating_mass_matrix(model=model, data=data)
M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)

J_WC = js.contact.jacobian(model=model, data=data)
J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
Expand Down Expand Up @@ -329,7 +329,7 @@ def compute_contact_forces(
).flatten()

# Compute the Delassus matrix.
delassus_matrix = _delassus_matrix(M=M, J_WC=J_WC)
delassus_matrix = _delassus_matrix(M_inv=M_inv, J_WC=J_WC)

# Initialize regularization term of the Delassus matrix for
# better numerical conditioning.
Expand Down Expand Up @@ -460,14 +460,14 @@ def update_contact_state(

@staticmethod
def _delassus_matrix(
M: jtp.MatrixLike,
M_inv: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
) -> jtp.Matrix:

sl = jnp.s_[:, 0:3, :]
J_WC_lin = jnp.vstack(J_WC[sl])

delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
delassus_matrix = J_WC_lin @ M_inv @ J_WC_lin.T
return delassus_matrix


Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/rbda/kinematic_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def compute_constraint_wrenches(
)

# Compute mass matrix
M = js.model.free_floating_mass_matrix(model=model, data=data)
M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)

W_H_constr_pairs = _compute_constraint_transforms_batched(
model=model,
Expand Down Expand Up @@ -287,7 +287,7 @@ def compute_constraint_wrenches(
J_constr = jnp.vstack(J_constr)

# Compute Delassus matrix for constraints
G_constraints = J_constr @ jnp.linalg.solve(M, J_constr.T)
G_constraints = J_constr @ M_inv @ J_constr.T

# Compute constraint acceleration
# TODO: add J̇_constr with efficient computation
Expand Down
Loading