diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index be3f88d5f..404586e8d 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1509,6 +1509,33 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp return W_H_LL +def _transform_M_block(M_body: jtp.Matrix, X: jtp.Matrix) -> jtp.Matrix: + """ + Apply invTᵀ M_body invT with invT = diag(X, I_n), without forming invT. + + Args: + M_body: (6+n, 6+n) mass matrix (inverse) in body representation. + X: (6, 6) adjoint (e.g. B_X_W or B_X_BW). + + Returns: + M_repr: (6+n, 6+n) mass matrix (inverse) in the new representation. + """ + + # invTᵀ M invT with invT = diag(X, I): + # Mbb' = Xᵀ Mbb X + # Mbj' = Xᵀ Mbj + # Mjb' = Mjb X + # Mjj' = Mjj + Mbb_t = X.T @ M_body[:6, :6] @ X + Mbj_t = X.T @ M_body[:6, 6:] + Mjb_t = M_body[6:, :6] @ X + Mjj_t = M_body[6:, 6:] + + top = jnp.concatenate([Mbb_t, Mbj_t], axis=1) + bottom = jnp.concatenate([Mjb_t, Mjj_t], axis=1) + return jnp.concatenate([top, bottom], axis=0) + + @jax.jit @js.common.named_scope def free_floating_mass_matrix( @@ -1535,18 +1562,54 @@ def free_floating_mass_matrix( return M_body case VelRepr.Inertial: - B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True) - invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + B_X_W = Adjoint.from_transform(transform=data.base_transform, inverse=True) - return invT.T @ M_body @ invT + return _transform_M_block(M_body, B_X_W) case VelRepr.Mixed: - BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3)) + BW_H_B = data.base_transform.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - return invT.T @ M_body @ invT + return _transform_M_block(M_body, B_X_BW) + case _: + raise ValueError(data.velocity_representation) + + +@jax.jit +@js.common.named_scope +def free_floating_mass_matrix_inverse( + model: JaxSimModel, data: js.data.JaxSimModelData +) -> jtp.Matrix: + """ + Compute the inverse of the free-floating mass matrix of the model + with the CRBA algorithm. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + The inverse of the free-floating mass matrix of the model. + """ + M_inv_body = jaxsim.rbda.mass_inverse( + model=model, + base_position=data.base_position, + base_quaternion=data.base_orientation, + joint_positions=data.joint_positions, + ) + + match data.velocity_representation: + case VelRepr.Body: + return M_inv_body + case VelRepr.Inertial: + W_X_B = Adjoint.from_transform(transform=data.base_transform) + + return _transform_M_block(M_inv_body, W_X_B.T) + case VelRepr.Mixed: + B_H_BW = data.base_transform.at[0:3, 3].set(jnp.zeros(3)) + BW_X_B = Adjoint.from_transform(transform=B_H_BW) + return _transform_M_block(M_inv_body, BW_X_B.T) case _: raise ValueError(data.velocity_representation) diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 5e0af2a66..07ef33884 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -9,4 +9,5 @@ jacobian_full_doubly_left, ) from .kinematic_constraints import compute_constraint_wrenches +from .mass_inverse import mass_inverse from .rnea import rnea diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 0b08082ce..d35f64d85 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -367,7 +367,7 @@ def compute_contact_forces( ) ) - M = js.model.free_floating_mass_matrix(model=model, data=data) + M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) # Compute the linear part of the Jacobian of the collidable points Jl_WC = jnp.vstack( @@ -383,9 +383,6 @@ def compute_contact_forces( ), ) - # Compute the Delassus matrix for contacts (mixed representation). - M_inv = jnp.linalg.pinv(M) - # Compute the Delassus matrix directly using J and J̇. G_contacts = Jl_WC @ M_inv @ Jl_WC.T diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 568b6ec32..3d6cf5bea 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -294,7 +294,7 @@ def compute_contact_forces( # Compute kin-dyn quantities used in the contact model. BW_ν = data.generalized_velocity - M = js.model.free_floating_mass_matrix(model=model, data=data) + M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) J_WC = js.contact.jacobian(model=model, data=data) J̇_WC = js.contact.jacobian_derivative(model=model, data=data) @@ -329,7 +329,7 @@ def compute_contact_forces( ).flatten() # Compute the Delassus matrix. - delassus_matrix = _delassus_matrix(M=M, J_WC=J_WC) + delassus_matrix = _delassus_matrix(M_inv=M_inv, J_WC=J_WC) # Initialize regularization term of the Delassus matrix for # better numerical conditioning. @@ -460,14 +460,14 @@ def update_contact_state( @staticmethod def _delassus_matrix( - M: jtp.MatrixLike, + M_inv: jtp.MatrixLike, J_WC: jtp.MatrixLike, ) -> jtp.Matrix: sl = jnp.s_[:, 0:3, :] J_WC_lin = jnp.vstack(J_WC[sl]) - delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T + delassus_matrix = J_WC_lin @ M_inv @ J_WC_lin.T return delassus_matrix diff --git a/src/jaxsim/rbda/kinematic_constraints.py b/src/jaxsim/rbda/kinematic_constraints.py index 6cb6ee93b..9f0a17d38 100644 --- a/src/jaxsim/rbda/kinematic_constraints.py +++ b/src/jaxsim/rbda/kinematic_constraints.py @@ -254,7 +254,7 @@ def compute_constraint_wrenches( ) # Compute mass matrix - M = js.model.free_floating_mass_matrix(model=model, data=data) + M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) W_H_constr_pairs = _compute_constraint_transforms_batched( model=model, @@ -287,7 +287,7 @@ def compute_constraint_wrenches( J_constr = jnp.vstack(J_constr) # Compute Delassus matrix for constraints - G_constraints = J_constr @ jnp.linalg.solve(M, J_constr.T) + G_constraints = J_constr @ M_inv @ J_constr.T # Compute constraint acceleration # TODO: add J̇_constr with efficient computation diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py new file mode 100644 index 000000000..dcd31be76 --- /dev/null +++ b/src/jaxsim/rbda/mass_inverse.py @@ -0,0 +1,233 @@ +import jax +import jax.numpy as jnp +import jaxlie + +import jaxsim.api as js +import jaxsim.typing as jtp + +from . import utils + + +def mass_inverse( + model: js.model.JaxSimModel, + *, + base_position: jtp.VectorLike, + base_quaternion: jtp.VectorLike, + joint_positions: jtp.VectorLike, +) -> jtp.Matrix: + """ + Compute the inverse of the mass matrix using an ABA-like algorithm. + The implementation follows the approach described in https://laas.hal.science/hal-01790934v2. + + Args: + model: The model to consider. + base_position: The position of the base link. + base_quaternion: The orientation of the base link (w, x, y, z). + joint_positions: The positions of the joints. + + Returns: + The inverse of the mass matrix. + """ + + W_p_B, W_Q_B, s, _, _, _, _, _, _, _ = utils.process_inputs( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + ) + + # Get the 6D spatial inertia matrices of all links. + I_A = js.model.link_spatial_inertia_matrices(model=model) + + # Get the parent array λ(i). + # λ[0] ~ -1 (world) + # λ[i] = parent link index for link i. + λ = model.kin_dyn_parameters.parent_array + + # Compute the base transform. + W_H_B = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3(wxyz=W_Q_B), + translation=W_p_B, + ) + + # Compute the parent-to-child adjoints of the joints. + # These transforms define the relative kinematics of the entire model, including + # the base transform for both floating-base and fixed-base models. + i_X_λi = model.kin_dyn_parameters.joint_transforms( + joint_positions=s, + base_transform=W_H_B.as_matrix(), + ) + + # Extract the joint motion subspaces. + S = model.kin_dyn_parameters.motion_subspaces + + NB = model.number_of_links() + N = model.number_of_joints() + + # Total generalized velocities: 6 base + N. + nv = N + 6 + + # Allocate buffers. + F = jnp.zeros((NB, 6, nv), dtype=float) + P = jnp.zeros((NB, 6, nv), dtype=float) + U = jnp.zeros((NB, 6), dtype=float) + D = jnp.zeros((NB,), dtype=float) + + # Pre-allocate mass matrix inverse + M_inv = jnp.zeros((nv, nv), dtype=float) + + # Pre-compute indices. + idx_fwd = jnp.arange(1, NB) + idx_rev = jnp.arange(NB - 1, 0, -1) + + # ============= + # Backward Pass + # ============= + + BackwardPassCarry = tuple[ + jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix + ] + backward_pass_carry: BackwardPassCarry = (I_A, F, U, D, M_inv) + + def loop_backward_pass( + carry: BackwardPassCarry, i: jtp.Int + ) -> tuple[BackwardPassCarry, None]: + I_A, F, U, D, M_inv = carry + + Si = jnp.squeeze(S[i], axis=-1) + Fi = F[i] + Xi = i_X_λi[i] + parent = λ[i] + + Ui = I_A[i] @ Si + Di = jnp.dot(Si, Ui) + + U = U.at[i].set(Ui) + D = D.at[i].set(Di) + + # Row index in ν for joint i: 6 + (i - 1) + r = 6 + (i - 1) + + Minv_row = M_inv[r] + + # Diagonal element + Minv_row = Minv_row.at[r].add(1.0 / Di) + + # Off-diagonals: Minv[r,:] -= (1/Di) * Sᵢᵀ Fᵢ + sTFi = jnp.einsum("s,sn->n", Si, Fi) + Minv_row = Minv_row - sTFi / Di + + M_inv = M_inv.at[r].set(Minv_row) + + # Propagate to parent if any (parent >= 0) + def propagate(IA_F): + I_A_, F_ = IA_F + + Ui_col = Ui[:, None] + + # F_a_i = F_i + U_i * Minv[r,:] + Fa_i = Fi + Ui_col @ Minv_row[None, :] + + # F_parent += Xᵢᵀ F_a_i + F_parent_new = F_[parent] + Xi.T @ Fa_i + F_ = F_.at[parent].set(F_parent_new) + + # I_a_i = IAi - U_i D_i^{-1} U_iᵀ + Ia_i = I_A[i] - jnp.outer(Ui, Ui) / Di + + # I_A[parent] += Xᵢᵀ I_a_i Xᵢ + I_parent_new = I_A_[parent] + Xi.T @ Ia_i @ Xi + I_A_ = I_A_.at[parent].set(I_parent_new) + + return I_A_, F_ + + I_A, F = jax.lax.cond( + parent >= 0, + propagate, + lambda IA_F: IA_F, + (I_A, F), + ) + + return (I_A, F, U, D, M_inv), None + + (I_A, F, U, D, M_inv), _ = jax.lax.scan( + loop_backward_pass, backward_pass_carry, idx_rev + ) + + S0 = jnp.eye(6, dtype=float) + U0 = I_A[0] @ S0 + D0 = S0.T @ U0 + D0_inv = jnp.linalg.inv(D0) + + # Base rows 0..5 in ν + base_rows = slice(0, 6) + + # Diagonal base block + M_inv = M_inv.at[base_rows, base_rows].add(D0_inv) + + # Off-diagonal base contribution: M_inv[base,:] -= D0^{-T} F[0] + term0 = D0_inv.T @ F[0] + M_inv = M_inv.at[base_rows, :].add(-term0) + + # ============ + # Forward Pass + # ============ + + # Initialize P_0 = S0 * Minv[base,:] = I * Minv[base,:] + Minv_base = M_inv[base_rows, :] + P = P.at[0].set(Minv_base) + + ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix] + forward_pass_carry: ForwardPassCarry = (M_inv, P) + + def loop_forward_pass( + carry: ForwardPassCarry, i: jtp.Int + ) -> tuple[ForwardPassCarry, None]: + M_inv, P = carry + + Si = jnp.squeeze(S[i], axis=-1) + Ui = U[i] + Di = D[i] + Xi = i_X_λi[i] + parent = λ[i] + + P_parent = jax.lax.cond( + parent >= 0, + lambda P_: P_[parent], + lambda P_: jnp.zeros_like(P_[i]), + P, + ) + + # Row index in ν for joint i + r = 6 + (i - 1) + + # Row update: M_inv[r,:] -= D_i^{-1} U_iᵀ Xᵢ P_parent + def update_row(Minv_): + X_P = Xi @ P_parent + UiT_XP = jnp.einsum("s,sn->n", Ui, X_P) + Minv_row = Minv_[r, :] - UiT_XP / Di + return Minv_.at[r, :].set(Minv_row) + + M_inv = jax.lax.cond( + parent >= 0, + update_row, + lambda Minv_: Minv_, + M_inv, + ) + + Minv_row = M_inv[r, :] + + # P_i = S_i Minv[r,:] + Xᵢ P_parent + Pi = jnp.expand_dims(Si, 1) @ jnp.expand_dims(Minv_row, 0) + Pi = Pi + Xi @ P_parent + + P = P.at[i].set(Pi) + + return (M_inv, P), None + + (M_inv, P), _ = jax.lax.scan(loop_forward_pass, forward_pass_carry, idx_fwd) + + # Symmetrize numerically + M_inv = 0.5 * (M_inv + M_inv.T) + + return M_inv diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 65415fafb..6fef83cf7 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -326,6 +326,11 @@ def test_model_rbda( ) assert_allclose(Jν_idt, Jν_js) + # Mass matrix inverse via RBDA + M_inv_js = js.model.free_floating_mass_matrix_inverse(model=model, data=data) + M_inv_idt = jnp.linalg.inv(M_idt) + assert_allclose(M_inv_idt[sl, sl], M_inv_js[sl, sl]) + def test_model_jacobian( jaxsim_models_types: js.model.JaxSimModel,