diff --git a/examples/jaxsim_as_multibody_dynamics_library.ipynb b/examples/jaxsim_as_multibody_dynamics_library.ipynb index 4c73f469b..54d21fb01 100644 --- a/examples/jaxsim_as_multibody_dynamics_library.ipynb +++ b/examples/jaxsim_as_multibody_dynamics_library.ipynb @@ -261,7 +261,7 @@ "\n", "# Print the default state.\n", "W_H_B, s = data.generalized_position\n", - "ν = data.generalized_velocity\n", + "ν = data.generalized_velocity()\n", "\n", "print(f\"W_H_B: shape={W_H_B.shape}\\n{W_H_B}\\n\")\n", "print(f\"s: shape={s.shape}\\n{s}\\n\")\n", @@ -304,7 +304,9 @@ ")\n", "\n", "print(f\"link_forces: shape={references.link_forces(model=model, data=data).shape}\")\n", - "print(f\"joint_force_references: shape={references.joint_force_references(model=model).shape}\")" + "print(\n", + " f\"joint_force_references: shape={references.joint_force_references(model=model).shape}\"\n", + ")" ] }, { @@ -382,9 +384,15 @@ "# @title Link 6D Velocity\n", "\n", "# JaxSim allows to select the so-called representation of the frame velocity.\n", - "L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)\n", - "LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed)\n", - "W_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial)\n", + "L_v_WL = js.link.velocity(\n", + " model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body\n", + ")\n", + "LW_v_WL = js.link.velocity(\n", + " model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed\n", + ")\n", + "W_v_WL = js.link.velocity(\n", + " model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\n", + ")\n", "\n", "print(f\"Body-fixed velocity L_v_WL={L_v_WL}\")\n", "print(f\"Mixed velocity: LW_v_WL={LW_v_WL}\")\n", @@ -395,17 +403,21 @@ "# the velocity representation of ν, and an output velocity representation that\n", "# corresponds to the velocity representation of the desired 6D velocity.\n", "\n", - "# You can use the following context manager to easily switch between representations.\n", - "with data.switch_velocity_representation(VelRepr.Body):\n", - "\n", - " # Body-fixed generalized velocity.\n", - " B_ν = data.generalized_velocity\n", + "# You can set the output velocity representation of quantities depending\n", + "# on the velocity representation stored in data by passing an argument\n", + "# to the function, as shown below.\n", + "# Body-fixed generalized velocity.\n", + "B_ν = data.generalized_velocity(output_representation=VelRepr.Body)\n", "\n", - " # Free-floating Jacobian accepting a body-fixed generalized velocity and\n", - " # returning an inertial-fixed link velocity.\n", - " W_J_WL_B = js.link.jacobian(\n", - " model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\n", - " )\n", + "# Free-floating Jacobian accepting a body-fixed generalized velocity and\n", + "# returning an inertial-fixed link velocity.\n", + "W_J_WL_B = js.link.jacobian(\n", + " model=model,\n", + " data=data,\n", + " link_index=link_index,\n", + " input_representation=VelRepr.Body,\n", + " output_vel_repr=VelRepr.Inertial,\n", + ")\n", "\n", "# Now the following relation should hold.\n", "assert jnp.allclose(W_v_WL, W_J_WL_B @ B_ν)" @@ -455,9 +467,15 @@ "# @title Frame 6D Velocity\n", "\n", "# JaxSim allows to select the so-called representation of the frame velocity.\n", - "F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)\n", - "FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)\n", - "W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)\n", + "F_v_WF = js.frame.velocity(\n", + " model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body\n", + ")\n", + "FW_v_WF = js.frame.velocity(\n", + " model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed\n", + ")\n", + "W_v_WF = js.frame.velocity(\n", + " model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n", + ")\n", "\n", "print(f\"Body-fixed velocity F_v_WF={F_v_WF}\")\n", "print(f\"Mixed velocity: FW_v_WF={FW_v_WF}\")\n", @@ -468,17 +486,21 @@ "# the velocity representation of ν, and an output velocity representation that\n", "# corresponds to the velocity representation of the desired 6D velocity.\n", "\n", - "# You can use the following context manager to easily switch between representations.\n", - "with data.switch_velocity_representation(VelRepr.Body):\n", + "# You can pass the output representation to getters of quantities depending\n", + "# on the velocity representation stored in data by passing an argument.\n", "\n", - " # Body-fixed generalized velocity.\n", - " B_ν = data.generalized_velocity\n", + "# Body-fixed generalized velocity.\n", + "B_ν = data.generalized_velocity(VelRepr.Body)\n", "\n", - " # Free-floating Jacobian accepting a body-fixed generalized velocity and\n", - " # returning an inertial-fixed link velocity.\n", - " W_J_WF_B = js.frame.jacobian(\n", - " model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n", - " )\n", + "# Free-floating Jacobian accepting a body-fixed generalized velocity and\n", + "# returning an inertial-fixed link velocity.\n", + "W_J_WF_B = js.frame.jacobian(\n", + " model=model,\n", + " data=data,\n", + " frame_index=frame_index,\n", + " input_representation=VelRepr.Body,\n", + " output_vel_repr=VelRepr.Inertial,\n", + ")\n", "\n", "# Now the following relation should hold.\n", "assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)" @@ -662,7 +684,7 @@ " joint_accelerations=s̈,\n", " # To check that f_B works, let's remove the force applied\n", " # to the base link from the link forces.\n", - " link_forces=f_L.at[0].set(jnp.zeros(6))\n", + " link_forces=f_L.at[0].set(jnp.zeros(6)),\n", ")\n", "\n", "print(f\"f_B: shape={f_B.shape}\")\n", @@ -825,30 +847,27 @@ }, "outputs": [], "source": [ - "with (\n", - " data.switch_velocity_representation(VelRepr.Mixed),\n", - " references.switch_velocity_representation(VelRepr.Mixed),\n", - "):\n", - "\n", - " # Compute the mixed generalized velocity.\n", - " BW_ν = data.generalized_velocity\n", - "\n", - " # Compute the mixed generalized acceleration.\n", - " BW_ν̇ = jnp.hstack(\n", - " js.model.forward_dynamics(\n", - " model=model,\n", - " data=data,\n", - " link_forces=references.link_forces(model=model, data=data),\n", - " joint_forces=references.joint_force_references(model=model),\n", - " )\n", + "# Compute the mixed generalized velocity.\n", + "BW_ν = data.generalized_velocity()\n", + "\n", + "# Compute the mixed generalized acceleration.\n", + "BW_ν̇ = jnp.hstack(\n", + " js.model.forward_dynamics(\n", + " model=model,\n", + " data=data,\n", + " link_forces=references.link_forces(\n", + " model=model, data=data, output_representation=VelRepr.Mixed\n", + " ),\n", + " joint_forces=references.joint_force_references(model=model),\n", " )\n", + ")\n", "\n", - " # Compute the mass matrix in mixed representation.\n", - " BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\n", + "# Compute the mass matrix in mixed representation.\n", + "BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\n", "\n", - " # Compute the contact Jacobian and its derivative.\n", - " Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\n", - " J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n", + "# Compute the contact Jacobian and its derivative.\n", + "Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\n", + "J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n", "\n", "# Compute the Delassus matrix.\n", "Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]\n", @@ -860,9 +879,8 @@ "print(f\"W_H_C: shape={W_H_C.shape}\")\n", "\n", "# Compute the linear velocity of the collidable points.\n", - "with data.switch_velocity_representation(VelRepr.Mixed):\n", - " W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\n", - " print(f\"W_ṗ_B: shape={W_ṗ_B.shape}\")\n", + "W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\n", + "print(f\"W_ṗ_B: shape={W_ṗ_B.shape}\")\n", "\n", "# Compute the linear acceleration of the collidable points.\n", "W_p̈_C = 0\n", diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index ee85d078b..56489dc3b 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -1,3 +1,5 @@ +import functools + import jax import jax.numpy as jnp @@ -76,10 +78,13 @@ def com_linear_velocity( return G_vl_WG -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def centroidal_momentum( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Vector: r""" Compute the centroidal momentum of the model. @@ -87,6 +92,9 @@ def centroidal_momentum( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The representation of the output centroidal momentum. If None, the active + velocity representation of the data is used. Returns: The centroidal momentum of the model. @@ -98,16 +106,27 @@ def centroidal_momentum( and :math:`C = B` if the active velocity representation is body-fixed. """ - ν = data.generalized_velocity - G_J = centroidal_momentum_jacobian(model=model, data=data) + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + + ν = data.generalized_velocity(output_representation) + G_J = centroidal_momentum_jacobian( + model=model, data=data, output_representation=output_representation + ) return G_J @ ν -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def centroidal_momentum_jacobian( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the Jacobian of the centroidal momentum of the model. @@ -115,19 +134,28 @@ def centroidal_momentum_jacobian( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The representation of the output Jacobian. If None, the active + velocity representation of the data is used. Returns: The Jacobian of the centroidal momentum of the model. Note: The frame corresponding to the output representation of this Jacobian is either - :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed, - or :math:`G[B]`, if the active velocity representation is body-fixed. + :math:`G[W]`, if the selected velocity representation is inertial-fixed or mixed, + or :math:`G[B]`, if the selected velocity representation is body-fixed. Note: This Jacobian is also known in the literature as Centroidal Momentum Matrix. """ + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) + # Compute the Jacobian of the total momentum with body-fixed output representation. # We convert the output representation either to G[W] or G[B] below. B_Jh = js.model.total_momentum_jacobian( @@ -139,13 +167,13 @@ def centroidal_momentum_jacobian( W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: + match output_representation: case VelRepr.Inertial | VelRepr.Mixed: W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 case VelRepr.Body: W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) # Compute the transform for 6D forces. G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T @@ -153,10 +181,13 @@ def centroidal_momentum_jacobian( return G_Xf_B @ B_Jh -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def locked_centroidal_spatial_inertia( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ): """ Compute the locked centroidal spatial inertia of the model. @@ -164,24 +195,34 @@ def locked_centroidal_spatial_inertia( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The representation of the output spatial inertia. If None, the active + velocity representation of the data is used. Returns: The locked centroidal spatial inertia of the model. """ - with data.switch_velocity_representation(VelRepr.Body): - B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data) + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) + + B_Mbb_B = js.model.locked_spatial_inertia( + model=model, data=data, output_representation=VelRepr.Body + ) W_H_B = data._base_transform W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: + match output_representation: case VelRepr.Inertial | VelRepr.Mixed: W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 case VelRepr.Body: W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G @@ -191,10 +232,13 @@ def locked_centroidal_spatial_inertia( return G_Xf_B @ B_Mbb_B @ B_Xv_G -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def average_centroidal_velocity( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Vector: r""" Compute the average centroidal velocity of the model. @@ -202,6 +246,9 @@ def average_centroidal_velocity( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The representation of the output average centroidal velocity. If None, the active + velocity representation of the data is used. Returns: The average centroidal velocity of the model. @@ -213,16 +260,27 @@ def average_centroidal_velocity( and :math:`[C] = [B]` if the active velocity representation is body-fixed. """ - ν = data.generalized_velocity - G_J = average_centroidal_velocity_jacobian(model=model, data=data) + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + + ν = data.generalized_velocity(output_representation) + G_J = average_centroidal_velocity_jacobian( + model=model, data=data, output_representation=output_representation + ) return G_J @ ν -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def average_centroidal_velocity_jacobian( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the Jacobian of the average centroidal velocity of the model. @@ -230,6 +288,9 @@ def average_centroidal_velocity_jacobian( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The representation of the output Jacobian. If None, the active + velocity representation of the data is used. Returns: The Jacobian of the average centroidal velocity of the model. @@ -240,16 +301,29 @@ def average_centroidal_velocity_jacobian( or :math:`G[B]`, if the active velocity representation is body-fixed. """ - G_J = centroidal_momentum_jacobian(model=model, data=data) - G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data) + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + + G_J = centroidal_momentum_jacobian( + model=model, data=data, output_representation=output_representation + ) + G_Mbb = locked_centroidal_spatial_inertia( + model=model, data=data, output_representation=output_representation + ) return jnp.linalg.inv(G_Mbb) @ G_J -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def bias_acceleration( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Vector: r""" Compute the bias linear acceleration of the center of mass. @@ -257,6 +331,9 @@ def bias_acceleration( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The representation of the output bias acceleration. If None, the active + velocity representation of the data is used. Returns: The bias linear acceleration of the center of mass in the active representation. @@ -268,12 +345,20 @@ def bias_acceleration( and :math:`[C] = [B]` if the active velocity representation is body-fixed. """ + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) + # Compute the pose of all links with forward kinematics. W_H_L = data._link_transforms # Compute the bias acceleration of all links by zeroing the generalized velocity # in the active representation. - v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data) + v̇_bias_WL = js.model.link_bias_accelerations( + model=model, data=data, output_representation=output_representation + ) def other_representation_to_body( C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector @@ -291,7 +376,7 @@ def other_representation_to_body( # We need here to get the body-fixed bias acceleration of the links. # Since it's computed in the active representation, we need to convert it to body. - match data.velocity_representation: + match output_representation: case VelRepr.Body: L_a_bias_WL = v̇_bias_WL @@ -305,7 +390,10 @@ def other_representation_to_body( L_v_LC = L_v_LW = jax.vmap( # noqa: F841 lambda i: -js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body + model=model, + data=data, + link_index=i, + output_vel_repr=VelRepr.Body, ) )(jnp.arange(model.number_of_links())) @@ -324,7 +412,10 @@ def other_representation_to_body( C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841 lambda i: js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed + model=model, + data=data, + link_index=i, + output_vel_repr=VelRepr.Mixed, ) .at[3:6] .set(jnp.zeros(3)) @@ -354,7 +445,7 @@ def other_representation_to_body( )(jnp.arange(model.number_of_links())) case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) # Compute the bias of the 6D momentum derivative. def bias_momentum_derivative_term( @@ -392,7 +483,7 @@ def bias_momentum_derivative_term( # Compute the position of the CoM. W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: + match output_representation: # G := G[W] = (W_p_CoM, [W]) case VelRepr.Inertial | VelRepr.Mixed: @@ -418,4 +509,4 @@ def bias_momentum_derivative_term( return GB_v̇l_com_bias case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 6c49ba6ee..50eb46518 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -69,8 +69,23 @@ def switch_velocity_representation( Yields: The same object with the new velocity representation. + + Warning: + This context manager is deprecated. Use explicit `output_representation` + parameters on methods like `base_velocity()` and `generalized_velocity()` + instead. """ + import warnings + + warnings.warn( + "switch_velocity_representation() context manager is deprecated. " + "Use explicit input_representation or output_representation parameters instead, e.g., " + "data.base_velocity(output_representation=VelRepr.Inertial)", + DeprecationWarning, + stacklevel=2, + ) + original_representation = self.velocity_representation try: diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index b56c75fab..6ecd2b4d0 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -255,12 +255,13 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) @js.common.named_scope def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Array: r""" @@ -269,6 +270,8 @@ def jacobian( Args: model: The model to consider. data: The data of the considered model. + input_representation: + The input velocity representation of data. output_vel_repr: The output velocity representation of the free-floating jacobian. @@ -283,8 +286,14 @@ def jacobian( rigidly attached to. """ + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) # Get the indices of the enabled collidable points. @@ -298,7 +307,10 @@ def jacobian( # Compute the Jacobians of all links. W_J_WL = js.model.generalized_free_floating_jacobian( - model=model, data=data, output_vel_repr=VelRepr.Inertial + model=model, + data=data, + input_representation=input_representation, + output_vel_repr=VelRepr.Inertial, ) # Compute the contact Jacobian. @@ -348,12 +360,13 @@ def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: return O_J_WC -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) @js.common.named_scope def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" @@ -362,6 +375,8 @@ def jacobian_derivative( Args: model: The model to consider. data: The data of the considered model. + input_representation: + The input velocity representation of data. output_vel_repr: The output velocity representation of the free-floating jacobian derivative. @@ -373,8 +388,14 @@ def jacobian_derivative( velocity representation. """ + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) indices_of_enabled_collidable_points = ( @@ -412,7 +433,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # Compute the operator to change the representation of ν, and its # time derivative. - match data.velocity_representation: + match input_representation: case VelRepr.Inertial: W_H_W = jnp.eye(4) W_X_W = Adjoint.from_transform(transform=W_H_W) @@ -424,18 +445,18 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: case VelRepr.Body: W_H_B = data._base_transform W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity + B_v_WB = data.base_velocity(input_representation) B_vx_WB = Cross.vx(B_v_WB) - W_Ẋ_B = W_X_B @ B_vx_WB + W_Ẋ_B = W_X_B @ B_vx_WB T = compute_T(model=model, X=W_X_B) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) case VelRepr.Mixed: W_H_B = data._base_transform W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity + BW_v_WB = data.base_velocity(input_representation) BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_vx_W_BW = Cross.vx(BW_v_W_BW) W_Ẋ_BW = W_X_BW @ BW_vx_W_BW @@ -444,23 +465,24 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) case _: - raise ValueError(data.velocity_representation) + raise ValueError(input_representation) # ===================================================== # Compute quantities to adjust the output representation # ===================================================== - with data.switch_velocity_representation(VelRepr.Inertial): - # Compute the Jacobian of the parent link in inertial representation. - W_J_WL_W = js.model.generalized_free_floating_jacobian( - model=model, - data=data, - ) - # Compute the Jacobian derivative of the parent link in inertial representation. - W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( - model=model, - data=data, - ) + # Compute the Jacobian of the parent link in inertial representation. + W_J_WL_W = js.model.generalized_free_floating_jacobian( + model=model, + data=data, + input_representation=VelRepr.Inertial, + ) + # Compute the Jacobian derivative of the parent link in inertial representation. + W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( + model=model, + data=data, + input_representation=VelRepr.Inertial, + ) def compute_O_J̇_WC_I( L_p_C: jtp.Vector, diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 113620f89..4149acc02 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -285,15 +285,24 @@ def base_orientation(self) -> jtp.Matrix: W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return W_Q_B - @property - def base_velocity(self) -> jtp.Vector: + def base_velocity(self, output_representation: VelRepr | None = None) -> jtp.Vector: """ Get the base 6D velocity. + Args: + output_representation: The desired output representation. + If None, uses self.velocity_representation. + Returns: - The base 6D velocity in the active representation. + The base 6D velocity in the specified representation. """ + output_repr = ( + output_representation + if output_representation is not None + else self.velocity_representation + ) + W_v_WB = jnp.concatenate( [self._base_linear_velocity, self._base_angular_velocity], axis=-1 ) @@ -303,7 +312,7 @@ def base_velocity(self) -> jtp.Vector: return ( JaxSimModelData.inertial_to_other_representation( array=W_v_WB, - other_representation=self.velocity_representation, + other_representation=output_repr, transform=W_H_B, is_force=False, ) @@ -323,19 +332,29 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]: return self._base_transform, self.joint_positions - @property - def generalized_velocity(self) -> jtp.Vector: + def generalized_velocity( + self, output_representation: VelRepr | None = None + ) -> jtp.Vector: r""" Get the generalized velocity. :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}` + Args: + output_representation: The desired output representation. + If None, uses self.velocity_representation. + Returns: - The generalized velocity in the active representation. + The generalized velocity in the specified representation. """ return ( - jnp.hstack([self.base_velocity, self.joint_velocities]) + jnp.hstack( + [ + self.base_velocity(output_representation=output_representation), + self.joint_velocities, + ] + ) .squeeze() .astype(float) ) @@ -413,6 +432,7 @@ def replace( base_angular_velocity: jtp.Vector | None = None, base_position: jtp.Vector | None = None, *, + input_representation: VelRepr | None = None, contact_state: dict[str, jtp.Array] | None = None, validate: bool = False, ) -> Self: @@ -420,6 +440,12 @@ def replace( Replace the attributes of the `JaxSimModelData` object. """ + input_representation = ( + input_representation + if input_representation is not None + else self.velocity_representation + ) + if joint_positions is None: joint_positions = self.joint_positions if joint_velocities is None: @@ -467,7 +493,7 @@ def replace( W_v_WB = JaxSimModelData.other_representation_to_inertial( array=jnp.hstack([base_linear_velocity, base_angular_velocity]), - other_representation=self.velocity_representation, + other_representation=input_representation, transform=base_transform, is_force=False, ).astype(float) diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 60ca74086..1e4bc7ba5 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -229,19 +229,20 @@ def velocity( ) # Get the generalized velocity in the input velocity representation. - I_ν = data.generalized_velocity + I_ν = data.generalized_velocity() # Compute the frame velocity in the output velocity representation. return O_J_WF_I @ I_ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) @js.common.named_scope def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" @@ -251,6 +252,8 @@ def jacobian( model: The model to consider. data: The data of the considered model. frame_index: The index of the frame. + input_representation: + The input velocity representation of the free-floating jacobian. output_vel_repr: The output velocity representation of the free-floating jacobian. @@ -271,8 +274,14 @@ def jacobian( idx=frame_index, ) + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) # Get the index of the parent link. @@ -280,7 +289,11 @@ def jacobian( # Compute the Jacobian of the parent link using body-fixed output representation. L_J_WL = js.link.jacobian( - model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body + model=model, + data=data, + link_index=L, + input_representation=input_representation, + output_vel_repr=VelRepr.Body, ) # Adjust the output representation. @@ -315,13 +328,14 @@ def jacobian( return O_J_WL_I -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) @js.common.named_scope def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" @@ -331,6 +345,8 @@ def jacobian_derivative( model: The model to consider. data: The data of the considered model. frame_index: The index of the frame. + input_representation: + The input velocity representation of data. output_vel_repr: The output velocity representation of the free-floating jacobian derivative. @@ -351,29 +367,36 @@ def jacobian_derivative( idx=frame_index, ) + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) # Get the index of the parent link. L = idx_of_parent_link(model=model, frame_index=frame_index) - with data.switch_velocity_representation(VelRepr.Inertial): - # Compute the Jacobian of the parent link in inertial representation. - W_J_WL_W = js.link.jacobian( - model=model, - data=data, - link_index=L, - output_vel_repr=VelRepr.Inertial, - ) + # Compute the Jacobian of the parent link in inertial representation. + W_J_WL_W = js.link.jacobian( + model=model, + data=data, + link_index=L, + input_representation=VelRepr.Inertial, + output_vel_repr=VelRepr.Inertial, + ) - # Compute the Jacobian derivative of the parent link in inertial representation. - W_J̇_WL_W = js.link.jacobian_derivative( - model=model, - data=data, - link_index=L, - output_vel_repr=VelRepr.Inertial, - ) + # Compute the Jacobian derivative of the parent link in inertial representation. + W_J̇_WL_W = js.link.jacobian_derivative( + model=model, + data=data, + link_index=L, + input_representation=VelRepr.Inertial, + output_vel_repr=VelRepr.Inertial, + ) # ===================================================== # Compute quantities to adjust the input representation @@ -391,7 +414,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # Compute the operator to change the representation of ν, and its # time derivative. - match data.velocity_representation: + match input_representation: case VelRepr.Inertial: W_H_W = jnp.eye(4) W_X_W = Adjoint.from_transform(transform=W_H_W) @@ -403,7 +426,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: case VelRepr.Body: W_H_B = data._base_transform W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity + B_v_WB = data.base_velocity(input_representation) B_vx_WB = Cross.vx(B_v_WB) W_Ẋ_B = W_X_B @ B_vx_WB @@ -414,7 +437,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: W_H_B = data._base_transform W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity + BW_v_WB = data.base_velocity(input_representation) BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_vx_W_BW = Cross.vx(BW_v_W_BW) W_Ẋ_BW = W_X_BW @ BW_vx_W_BW @@ -423,7 +446,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) case _: - raise ValueError(data.velocity_representation) + raise ValueError(input_representation) # ===================================================== # Compute quantities to adjust the output representation @@ -437,8 +460,7 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: case VelRepr.Body: W_H_F = transform(model=model, data=data, frame_index=frame_index) O_X_W = F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True) - with data.switch_velocity_representation(VelRepr.Inertial): - W_nu = data.generalized_velocity + W_nu = data.generalized_velocity(VelRepr.Inertial) W_v_WF = W_J_WL_W @ W_nu W_vx_WF = Cross.vx(W_v_WF) O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841 @@ -448,14 +470,13 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3)) FW_H_W = Transform.inverse(W_H_FW) O_X_W = FW_X_W = Adjoint.from_transform(transform=FW_H_W) - with data.switch_velocity_representation(VelRepr.Mixed): - FW_J_WF_FW = jacobian( - model=model, - data=data, - frame_index=frame_index, - output_vel_repr=VelRepr.Mixed, - ) - FW_v_WF = FW_J_WF_FW @ data.generalized_velocity + FW_J_WF_FW = jacobian( + model=model, + data=data, + frame_index=frame_index, + output_vel_repr=VelRepr.Mixed, + ) + FW_v_WF = FW_J_WF_FW @ data.generalized_velocity(VelRepr.Mixed) W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3]) W_vx_W_FW = Cross.vx(W_v_W_FW) O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841 diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 5e18ee8d7..8b592d787 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -19,56 +19,56 @@ def semi_implicit_euler_integration( ) -> JaxSimModelData: """Integrate the system state using the semi-implicit Euler method.""" - with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - - # Compute the system acceleration - W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration( - model=model, - data=data, - link_forces=link_forces, - joint_torques=joint_torques, - ) + # Compute the system acceleration + W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration( + model=model, + data=data, + link_forces=link_forces, + joint_torques=joint_torques, + output_representation=jaxsim.VelRepr.Inertial, + ) - dt = model.time_step + dt = model.time_step - # Compute the new generalized velocity. - new_generalized_acceleration = jnp.hstack([W_v̇_WB, s̈]) - new_generalized_velocity = ( - data.generalized_velocity + dt * new_generalized_acceleration - ) + # Compute the new generalized velocity. + new_generalized_acceleration = jnp.hstack([W_v̇_WB, s̈]) + new_generalized_velocity = ( + data.generalized_velocity(jaxsim.VelRepr.Inertial) + + dt * new_generalized_acceleration + ) - # Extract the new base and joint velocities. - W_v_B = new_generalized_velocity[0:6] - ṡ = new_generalized_velocity[6:] + # Extract the new base and joint velocities. + W_v_B = new_generalized_velocity[0:6] + ṡ = new_generalized_velocity[6:] - # Compute the new base position and orientation. - W_ω_WB = new_generalized_velocity[3:6] + # Compute the new base position and orientation. + W_ω_WB = new_generalized_velocity[3:6] - # To obtain the derivative of the base position, we need to subtract - # the skew-symmetric matrix of the base angular velocity times the base position. - # See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9 - W_ṗ_B = new_generalized_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position + # To obtain the derivative of the base position, we need to subtract + # the skew-symmetric matrix of the base angular velocity times the base position. + # See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9 + W_ṗ_B = new_generalized_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position - W_Q̇_B = jaxsim.math.Quaternion.derivative( - quaternion=data.base_orientation, - omega=W_ω_WB, - omega_in_body_fixed=False, - ).squeeze() + W_Q̇_B = jaxsim.math.Quaternion.derivative( + quaternion=data.base_orientation, + omega=W_ω_WB, + omega_in_body_fixed=False, + ).squeeze() - W_p_B = data.base_position + dt * W_ṗ_B - W_Q_B = data.base_orientation + dt * W_Q̇_B + W_p_B = data.base_position + dt * W_ṗ_B + W_Q_B = data.base_orientation + dt * W_Q̇_B - base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1) + base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1) - W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm) + W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm) - s = data.joint_positions + dt * ṡ + s = data.joint_positions + dt * ṡ - integrated_contact_state = jax.tree.map( - lambda x, x_dot: x + dt * x_dot, - data.contact_state, - contact_state_derivative, - ) + integrated_contact_state = jax.tree.map( + lambda x, x_dot: x + dt * x_dot, + data.contact_state, + contact_state_derivative, + ) # TODO: Avoid double replace, e.g. by computing cached value here data = dataclasses.replace( @@ -100,16 +100,17 @@ def rk4_integration( def f(x) -> dict[str, jtp.Matrix]: - with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - - data_ti = data.replace(model=model, **x) + data_ti = data.replace( + model=model, input_representation=jaxsim.VelRepr.Inertial, **x + ) - return js.ode.system_dynamics( - model=model, - data=data_ti, - link_forces=link_forces, - joint_torques=joint_torques, - ) + return js.ode.system_dynamics( + model=model, + data=data_ti, + link_forces=link_forces, + joint_torques=joint_torques, + output_representation=jaxsim.VelRepr.Inertial, + ) base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1) base_quaternion = data._base_quaternion / jnp.where( @@ -191,21 +192,23 @@ def rk4fast_integration( def f(x) -> dict[str, jtp.Matrix]: - with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - - data_ti = data.replace(model=model, **x) + data_ti = data.replace( + model=model, input_representation=jaxsim.VelRepr.Inertial, **x + ) - W_v̇_WB, s̈ = js.model.forward_dynamics_aba( - model=model, - data=data_ti, - joint_forces=joint_torques, - link_forces=W_f_L_total, - ) + W_v̇_WB, s̈ = js.model.forward_dynamics_aba( + model=model, + data=data_ti, + joint_forces=joint_torques, + link_forces=W_f_L_total, + output_representation=jaxsim.VelRepr.Inertial, + ) - W_ṗ_B, W_Q̇_B, ṡ = js.ode.system_position_dynamics( - data=data, - baumgarte_quaternion_regularization=1.0, - ) + W_ṗ_B, W_Q̇_B, ṡ = js.ode.system_position_dynamics( + data=data, + baumgarte_quaternion_regularization=1.0, + output_representation=jaxsim.VelRepr.Inertial, + ) return dict( base_position=W_ṗ_B, diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 3e389ecce..6d656ba56 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -234,12 +234,13 @@ def com_in_inertial_frame(): ) -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" @@ -249,6 +250,8 @@ def jacobian( model: The model to consider. data: The data of the considered model. link_index: The index of the link. + input_representation: + The input velocity representation of the free-floating jacobian. output_vel_repr: The output velocity representation of the free-floating jacobian. @@ -268,8 +271,14 @@ def jacobian( idx=link_index, ) + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) # Compute the doubly-left free-floating full jacobian. @@ -283,7 +292,7 @@ def jacobian( B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B # Adjust the input representation such that `J_WL_I @ I_ν`. - match data.velocity_representation: + match input_representation: case VelRepr.Inertial: W_H_B = data._base_transform B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) @@ -303,7 +312,7 @@ def jacobian( ) case _: - raise ValueError(data.velocity_representation) + raise ValueError(input_representation) B_H_L = B_H_Li[link_index] @@ -378,18 +387,19 @@ def velocity( ) # Get the generalized velocity in the input velocity representation. - I_ν = data.generalized_velocity + I_ν = data.generalized_velocity() # Compute the link velocity in the output velocity representation. return O_J_WL_I @ I_ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" @@ -399,6 +409,8 @@ def jacobian_derivative( model: The model to consider. data: The data of the considered model. link_index: The index of the link. + input_representation: + The input velocity representation of the free-floating jacobian. output_vel_repr: The output velocity representation of the free-floating jacobian derivative. @@ -418,12 +430,21 @@ def jacobian_derivative( idx=link_index, ) + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative( - model=model, data=data, output_vel_repr=output_vel_repr + model=model, + data=data, + input_representation=input_representation, + output_vel_repr=output_vel_repr, )[link_index] return O_J̇_WL_I diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 2b88ed212..bf5929460 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -926,11 +926,12 @@ def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array: # ============================== -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) def generalized_free_floating_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: """ @@ -939,6 +940,8 @@ def generalized_free_floating_jacobian( Args: model: The model to consider. data: The data of the considered model. + input_representation: + The input velocity representation of the free-floating jacobians. output_vel_repr: The output velocity representation of the free-floating jacobians. @@ -951,9 +954,14 @@ def generalized_free_floating_jacobian( flattened 6D forces of the links, are useful to compute the `J.T @ f` product of the multi-body EoM. """ + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) # Compute the doubly-left free-floating full jacobian. @@ -966,7 +974,7 @@ def generalized_free_floating_jacobian( # Update the input velocity representation such that v_WL = J_WL_I @ I_ν # ====================================================================== - match data.velocity_representation: + match input_representation: case VelRepr.Inertial: W_H_B = data._base_transform B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) @@ -1049,11 +1057,12 @@ def generalized_free_floating_jacobian( return O_J_WL_I -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) def generalized_free_floating_jacobian_derivative( model: JaxSimModel, data: js.data.JaxSimModelData, *, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: """ @@ -1062,6 +1071,8 @@ def generalized_free_floating_jacobian_derivative( Args: model: The model to consider. data: The data of the considered model. + input_representation: + The input velocity representation of the free-floating jacobian derivatives. output_vel_repr: The output velocity representation of the free-floating jacobian derivatives. @@ -1070,8 +1081,14 @@ def generalized_free_floating_jacobian_derivative( jacobian derivatives of the links. The first axis is the link index. """ + input_representation = ( + input_representation + if input_representation is not None + else data.velocity_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) # Compute the derivative of the doubly-left free-floating full jacobian. @@ -1111,11 +1128,11 @@ def generalized_free_floating_jacobian_derivative( In = jnp.eye(model.dofs()) On = jnp.zeros(shape=(model.dofs(), model.dofs())) - match data.velocity_representation: + match input_representation: case VelRepr.Inertial: B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) - W_v_WB = data.base_velocity + W_v_WB = data.base_velocity(input_representation) B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) # Compute the operator to change the representation of ν, and its @@ -1139,7 +1156,7 @@ def generalized_free_floating_jacobian_derivative( BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) - BW_v_WB = data.base_velocity + BW_v_WB = data.base_velocity(input_representation) BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_v_BW_B = BW_v_WB - BW_v_W_BW @@ -1151,7 +1168,7 @@ def generalized_free_floating_jacobian_derivative( Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) case _: - raise ValueError(data.velocity_representation) + raise ValueError(input_representation) # ====================================================== # Compute quantities to adjust the output representation @@ -1161,8 +1178,7 @@ def generalized_free_floating_jacobian_derivative( case VelRepr.Inertial: O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity + B_v_WB = data.base_velocity(VelRepr.Body) O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841 @@ -1173,11 +1189,10 @@ def generalized_free_floating_jacobian_derivative( B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity - L_v_WL = jnp.einsum( - "b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity - ) + B_v_WB = data.base_velocity(VelRepr.Body) + L_v_WL = jnp.einsum( + "b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity(VelRepr.Body) + ) O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841 jnp.einsum("bij,bj->bi", B_X_L, L_v_WL) - B_v_WB @@ -1192,21 +1207,19 @@ def generalized_free_floating_jacobian_derivative( B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity - - with data.switch_velocity_representation(VelRepr.Mixed): - BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - LW_v_WL = jnp.einsum( - "bij,bj->bi", - LW_X_B, - B_J_WL_B - @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - @ data.generalized_velocity, - ) + B_v_WB = data.base_velocity(VelRepr.Body) + + BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + LW_v_WL = jnp.einsum( + "bij,bj->bi", + LW_X_B, + B_J_WL_B + @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) + @ data.generalized_velocity(VelRepr.Mixed), + ) - LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6])) + LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6])) LW_v_LW_L = LW_v_WL - LW_v_W_LW LW_v_B_LW = LW_v_WL - jnp.einsum("bij,j->bi", LW_X_B, B_v_WB) - LW_v_LW_L @@ -1270,7 +1283,7 @@ def forward_dynamics( ) -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def forward_dynamics_aba( model: JaxSimModel, @@ -1278,6 +1291,7 @@ def forward_dynamics_aba( *, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, + output_representation: VelRepr | None = None, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the forward dynamics of the model with the ABA algorithm. @@ -1290,6 +1304,8 @@ def forward_dynamics_aba( link_forces: The link 6D forces to consider as a matrix of shape `(nL, 6)`. The frame in which they are expressed must be `data.velocity_representation`. + output_representation: + The desired output velocity representation of the free-floating acceleration. Returns: A tuple containing the 6D acceleration in the active representation of the @@ -1297,6 +1313,12 @@ def forward_dynamics_aba( considered joint forces and external forces. """ + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) + # ============ # Prepare data # ============ @@ -1325,12 +1347,11 @@ def forward_dynamics_aba( ) # Extract the state in inertial-fixed representation. - with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position - W_v_WB = data.base_velocity - W_Q_B = data.base_orientation - s = data.joint_positions - ṡ = data.joint_velocities + W_p_B = data.base_position + W_v_WB = data.base_velocity(VelRepr.Inertial) + W_Q_B = data.base_orientation + s = data.joint_positions + ṡ = data.joint_velocities # Extract the inputs in inertial-fixed representation. W_f_L = references._link_forces @@ -1370,7 +1391,7 @@ def to_active( C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB) - match data.velocity_representation: + match output_representation: case VelRepr.Inertial: # In this case C=W W_H_C = W_H_W = jnp.eye(4) # noqa: F841 @@ -1385,11 +1406,11 @@ def to_active( # In this case C=B[W] W_H_B = data._base_transform W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity[0:3] + W_ṗ_B = data.base_velocity(output_representation)[0:3] W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) # We need to convert the derivative of the base velocity to the active # representation. In Mixed representation, this conversion is not a plain @@ -1557,10 +1578,13 @@ def _transform_M_block(M_body: jtp.Matrix, X: jtp.Matrix) -> jtp.Matrix: return jnp.concatenate([top, bottom], axis=0) -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def free_floating_mass_matrix( - model: JaxSimModel, data: js.data.JaxSimModelData + model: JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the free-floating mass matrix of the model with the CRBA algorithm. @@ -1568,17 +1592,25 @@ def free_floating_mass_matrix( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the mass matrix. Returns: The free-floating mass matrix of the model. """ + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + M_body = jaxsim.rbda.crba( model=model, joint_positions=data.joint_positions, ) - match data.velocity_representation: + match output_representation: case VelRepr.Body: return M_body @@ -1593,13 +1625,16 @@ def free_floating_mass_matrix( return _transform_M_block(M_body, B_X_BW) case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def free_floating_mass_matrix_inverse( - model: JaxSimModel, data: js.data.JaxSimModelData + model: JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the inverse of the free-floating mass matrix of the model @@ -1608,10 +1643,18 @@ def free_floating_mass_matrix_inverse( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the inverse mass matrix. Returns: The inverse of the free-floating mass matrix of the model. """ + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + M_inv_body = jaxsim.rbda.mass_inverse( model=model, base_position=data.base_position, @@ -1619,7 +1662,7 @@ def free_floating_mass_matrix_inverse( joint_positions=data.joint_positions, ) - match data.velocity_representation: + match output_representation: case VelRepr.Body: return M_inv_body case VelRepr.Inertial: @@ -1632,13 +1675,16 @@ def free_floating_mass_matrix_inverse( return _transform_M_block(M_inv_body, BW_X_B.T) case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def free_floating_coriolis_matrix( - model: JaxSimModel, data: js.data.JaxSimModelData + model: JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the free-floating Coriolis matrix of the model. @@ -1646,6 +1692,8 @@ def free_floating_coriolis_matrix( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the Coriolis matrix. Returns: The free-floating Coriolis matrix of the model. @@ -1656,17 +1704,26 @@ def free_floating_coriolis_matrix( the Coriolis matrix may be much slower than other quantities. """ + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + # We perform all the calculation in body-fixed. # The Coriolis matrix computed in this representation is converted later # to the active representation stored in data. - with data.switch_velocity_representation(VelRepr.Body): - B_ν = data.generalized_velocity + B_ν = data.generalized_velocity(VelRepr.Body) - # Doubly-left free-floating Jacobian. - L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data) + # Doubly-left free-floating Jacobian. + L_J_WL_B = generalized_free_floating_jacobian( + model=model, data=data, input_representation=VelRepr.Body + ) - # Doubly-left free-floating Jacobian derivative. - L_J̇_WL_B = generalized_free_floating_jacobian_derivative(model=model, data=data) + # Doubly-left free-floating Jacobian derivative. + L_J̇_WL_B = generalized_free_floating_jacobian_derivative( + model=model, data=data, input_representation=VelRepr.Body + ) L_M_L = link_spatial_inertia_matrices(model=model) @@ -1698,7 +1755,7 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: # Adjust the representation of the Coriolis matrix. # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. - match data.velocity_representation: + match output_representation: case VelRepr.Body: return C_B @@ -1708,14 +1765,14 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WB = data.base_velocity - B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) + W_v_WB = data.base_velocity(VelRepr.Inertial) + B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) - with data.switch_velocity_representation(VelRepr.Body): - M = free_floating_mass_matrix(model=model, data=data) + M = free_floating_mass_matrix( + model=model, data=data, output_representation=VelRepr.Body + ) C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W) @@ -1727,27 +1784,27 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) - with data.switch_velocity_representation(VelRepr.Mixed): - BW_v_WB = data.base_velocity - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + BW_v_WB = data.base_velocity(VelRepr.Mixed) + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_v_BW_B = BW_v_WB - BW_v_W_BW B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n))) - with data.switch_velocity_representation(VelRepr.Body): - M = free_floating_mass_matrix(model=model, data=data) + M = free_floating_mass_matrix( + model=model, data=data, output_representation=VelRepr.Body + ) C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW) return C case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def inverse_dynamics( model: JaxSimModel, @@ -1756,6 +1813,7 @@ def inverse_dynamics( joint_accelerations: jtp.VectorLike | None = None, base_acceleration: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, + output_representation: VelRepr | None = None, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics with the RNEA algorithm. @@ -1769,7 +1827,9 @@ def inverse_dynamics( The base acceleration to consider as a vector of shape `(6,)`. link_forces: The link 6D forces to consider as a matrix of shape `(nL, 6)`. - The frame in which they are expressed must be `data.velocity_representation`. + The frame in which they are expressed must be `output_representation`. + output_representation: + The desired output velocity representation of the base force. Returns: A tuple containing the 6D force in the active representation applied to the @@ -1777,6 +1837,12 @@ def inverse_dynamics( to obtain the considered joint accelerations. """ + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) + # ============ # Prepare data # ============ @@ -1816,24 +1882,23 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # In Inertial and Body representations, the cross product is always zero. return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB) - match data.velocity_representation: + match output_representation: case VelRepr.Inertial: W_H_C = W_H_W = jnp.eye(4) # noqa: F841 W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 case VelRepr.Body: W_H_C = W_H_B = data._base_transform - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity + W_v_WC = W_v_WB = data.base_velocity(VelRepr.Inertial) case VelRepr.Mixed: W_H_B = data._base_transform W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity[0:3] + W_ṗ_B = data.base_velocity(output_representation)[0:3] W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) # We need to convert the derivative of the base acceleration to the Inertial # representation. In Mixed representation, this conversion is not a plain @@ -1841,7 +1906,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): W_v̇_WB = to_inertial( C_v̇_WB=v̇_WB, W_H_C=W_H_C, - C_v_WB=data.base_velocity, + C_v_WB=data.base_velocity(output_representation), W_v_WC=W_v_WC, ) @@ -1850,16 +1915,15 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): model=model, data=data, link_forces=f_L, - velocity_representation=data.velocity_representation, + velocity_representation=output_representation, ) # Extract the state in inertial-fixed representation. - with data.switch_velocity_representation(VelRepr.Inertial): - W_p_B = data.base_position - W_v_WB = data.base_velocity - W_Q_B = data.base_quaternion - s = data.joint_positions - ṡ = data.joint_velocities + W_p_B = data.base_position + W_v_WB = data.base_velocity(VelRepr.Inertial) + W_Q_B = data.base_quaternion + s = data.joint_positions + ṡ = data.joint_velocities # Extract the inputs in inertial-fixed representation. W_f_L = references._link_forces @@ -1890,7 +1954,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # Express W_f_B in the active representation. f_B = js.data.JaxSimModelData.inertial_to_other_representation( array=W_f_B, - other_representation=data.velocity_representation, + other_representation=output_representation, transform=data._base_transform, is_force=True, ).squeeze() @@ -1898,10 +1962,13 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): return f_B.astype(float), τ.astype(float) -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def free_floating_gravity_forces( - model: JaxSimModel, data: js.data.JaxSimModelData + model: JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Vector: r""" Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model. @@ -1909,15 +1976,23 @@ def free_floating_gravity_forces( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the gravity forces. Returns: The free-floating gravity forces of the model. """ + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + # Build a new state with zeroed velocities. data_rnea = js.data.JaxSimModelData.build( model=model, - velocity_representation=data.velocity_representation, + velocity_representation=output_representation, base_position=data.base_position, base_quaternion=data.base_quaternion, joint_positions=data.joint_positions, @@ -1931,14 +2006,18 @@ def free_floating_gravity_forces( joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())), base_acceleration=jnp.zeros(6), link_forces=jnp.zeros(shape=(model.number_of_links(), 6)), + output_representation=output_representation, ) ).astype(float) -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def free_floating_bias_forces( - model: JaxSimModel, data: js.data.JaxSimModelData + model: JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Vector: r""" Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})` @@ -1947,21 +2026,29 @@ def free_floating_bias_forces( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the bias forces. Returns: The free-floating bias forces of the model. """ + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + # Set the generalized position and generalized velocity. base_linear_velocity, base_angular_velocity = None, None if model.floating_base(): - base_velocity = data.base_velocity + base_velocity = data.base_velocity(output_representation) base_linear_velocity = base_velocity[:3] base_angular_velocity = base_velocity[3:] data_rnea = js.data.JaxSimModelData.build( model=model, - velocity_representation=data.velocity_representation, + velocity_representation=output_representation, base_position=data.base_position, base_quaternion=data.base_quaternion, joint_positions=data.joint_positions, @@ -1978,6 +2065,7 @@ def free_floating_bias_forces( joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())), base_acceleration=jnp.zeros(6), link_forces=jnp.zeros(shape=(model.number_of_links(), 6)), + output_representation=output_representation, ) ).astype(float) @@ -1987,10 +2075,13 @@ def free_floating_bias_forces( # ========================== -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def locked_spatial_inertia( - model: JaxSimModel, data: js.data.JaxSimModelData + model: JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the locked 6D inertia matrix of the model. @@ -1998,39 +2089,61 @@ def locked_spatial_inertia( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the locked inertia matrix. Returns: The locked 6D inertia matrix of the model. """ - return total_momentum_jacobian(model=model, data=data)[:, 0:6] + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + return total_momentum_jacobian( + model=model, + data=data, + input_representation=output_representation, + output_vel_repr=output_representation, + )[:, 0:6] -@jax.jit + +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope -def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector: +def total_momentum( + model: JaxSimModel, + data: js.data.JaxSimModelData, + output_representation: VelRepr | None = None, +) -> jtp.Vector: """ Compute the total momentum of the model. Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the total momentum. Returns: The total momentum of the model in the active velocity representation. """ - ν = data.generalized_velocity - Jh = total_momentum_jacobian(model=model, data=data) + ν = data.generalized_velocity(output_representation) + Jh = total_momentum_jacobian( + model=model, data=data, output_vel_repr=output_representation + ) return Jh @ ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["input_representation", "output_vel_repr"]) def total_momentum_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, + input_representation: VelRepr | None = None, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: """ @@ -2039,23 +2152,33 @@ def total_momentum_jacobian( Args: model: The model to consider. data: The data of the considered model. + input_representation: The input velocity representation of the data. output_vel_repr: The output velocity representation of the jacobian. Returns: The jacobian of the total momentum of the model in the active representation. """ + input_representation = ( + data.velocity_representation + if input_representation is None + else input_representation + ) + output_vel_repr = ( - output_vel_repr if output_vel_repr is not None else data.velocity_representation + output_vel_repr if output_vel_repr is not None else input_representation ) - if output_vel_repr is data.velocity_representation: - return free_floating_mass_matrix(model=model, data=data)[0:6] + if output_vel_repr is input_representation: + return free_floating_mass_matrix( + model=model, data=data, output_representation=output_vel_repr + )[0:6] - with data.switch_velocity_representation(VelRepr.Body): - B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6] + B_Jh_B = free_floating_mass_matrix( + model=model, data=data, output_representation=VelRepr.Body + )[0:6] - match data.velocity_representation: + match input_representation: case VelRepr.Body: B_Jh = B_Jh_B @@ -2069,7 +2192,7 @@ def total_momentum_jacobian( B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) case _: - raise ValueError(data.velocity_representation) + raise ValueError(input_representation) match output_vel_repr: case VelRepr.Body: @@ -2095,21 +2218,30 @@ def total_momentum_jacobian( @jax.jit @js.common.named_scope -def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector: +def average_velocity( + model: JaxSimModel, + data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, +) -> jtp.Vector: """ Compute the average velocity of the model. Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the average velocity. Returns: The average velocity of the model computed in the base frame and expressed in the active representation. """ - ν = data.generalized_velocity - J = average_velocity_jacobian(model=model, data=data) + ν = data.generalized_velocity(output_representation) + J = average_velocity_jacobian( + model=model, data=data, output_vel_repr=output_representation + ) return J @ ν @@ -2178,11 +2310,13 @@ def average_velocity_jacobian( # ======================== -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def link_bias_accelerations( model: JaxSimModel, data: js.data.JaxSimModelData, + *, + output_representation: VelRepr | None = None, ) -> jtp.Vector: r""" Compute the bias accelerations of the links of the model. @@ -2190,6 +2324,8 @@ def link_bias_accelerations( Args: model: The model to consider. data: The data of the considered model. + output_representation: + The output velocity representation of the bias accelerations. Returns: The bias accelerations of the links of the model. @@ -2200,6 +2336,12 @@ def link_bias_accelerations( It is often called :math:`\dot{J} \boldsymbol{\nu}`. """ + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + # ================================================ # Compute the body-fixed zero base 6D acceleration # ================================================ @@ -2227,33 +2369,28 @@ def other_representation_to_inertial( # because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration # W_a_WB, and intrinsic accelerations can be expressed in different frames through # a simple C_X_W 6D transform. - match data.velocity_representation: + match output_representation: case VelRepr.Inertial: W_H_C = W_H_W = jnp.eye(4) # noqa: F841 W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - with data.switch_velocity_representation(VelRepr.Inertial): - C_v_WB = W_v_WB = data.base_velocity + C_v_WB = W_v_WB = data.base_velocity(VelRepr.Inertial) case VelRepr.Body: W_H_C = W_H_B - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity # noqa: F841 - with data.switch_velocity_representation(VelRepr.Body): - C_v_WB = B_v_WB = data.base_velocity + W_v_WC = W_v_WB = data.base_velocity(VelRepr.Inertial) # noqa: F841 + C_v_WB = B_v_WB = data.base_velocity(VelRepr.Body) case VelRepr.Mixed: W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_H_C = W_H_BW - with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity[0:3] - BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) - W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW) - W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841 - with data.switch_velocity_representation(VelRepr.Mixed): - C_v_WB = BW_v_WB = data.base_velocity # noqa: F841 + W_ṗ_B = data.base_velocity(VelRepr.Mixed)[0:3] + BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW) + W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841 + C_v_WB = BW_v_WB = data.base_velocity(VelRepr.Mixed) # noqa: F841 case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) # Convert a zero 6D acceleration from the active representation to inertial-fixed. W_v̇_WB = other_representation_to_inertial( @@ -2285,9 +2422,8 @@ def other_representation_to_inertial( L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6)) # Store the base velocity. - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity - L_v_WL = L_v_WL.at[0].set(B_v_WB) + B_v_WB = data.base_velocity(VelRepr.Body) + L_v_WL = L_v_WL.at[0].set(B_v_WB) # Get the joint velocities. ṡ = data.joint_velocities @@ -2359,7 +2495,7 @@ def body_to_other_representation( C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L) return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL) - match data.velocity_representation: + match output_representation: case VelRepr.Body: C_H_L = L_H_L = jnp.stack( # noqa: F841 [jnp.eye(4)] * model.number_of_links() @@ -2381,7 +2517,7 @@ def body_to_other_representation( )(L_v_WL) case _: - raise ValueError(data.velocity_representation) + raise ValueError(output_representation) # Convert from body-fixed to the active representation. O_v̇_WL = jax.vmap(body_to_other_representation)( @@ -2430,9 +2566,10 @@ def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Flo The kinetic energy of the model. """ - with data.switch_velocity_representation(velocity_representation=VelRepr.Body): - B_ν = data.generalized_velocity - M_B = free_floating_mass_matrix(model=model, data=data) + B_ν = data.generalized_velocity(VelRepr.Body) + M_B = free_floating_mass_matrix( + model=model, data=data, output_representation=VelRepr.Body + ) K = 0.5 * B_ν.T @ M_B @ B_ν return K.squeeze().astype(float) @@ -2602,7 +2739,7 @@ def update_λ_H_pre(joint_index): # ========== -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def step( model: JaxSimModel, @@ -2610,6 +2747,7 @@ def step( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, + output_representation: VelRepr | None = None, ) -> js.data.JaxSimModelData: """ Perform a simulation step. @@ -2621,6 +2759,9 @@ def step( link_forces: The 6D forces to apply to the links expressed in same representation of data. joint_force_references: The joint force references to consider. + output_representation: + The output velocity representation of the data object after the step. + Defaults to the same representation of data. Returns: The new data of the model after the simulation step. @@ -2631,6 +2772,12 @@ def step( particularly useful for automatically differentiated logic. """ + output_representation = ( + data.velocity_representation + if output_representation is None + else output_representation + ) + # TODO: some contact models here may want to perform a dynamic filtering of # the enabled collidable points diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 377d00b6b..26f10898d 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -1,3 +1,5 @@ +import functools + import jax import jax.numpy as jnp @@ -13,12 +15,14 @@ # ================================== +@functools.partial(jax.jit, static_argnames=["output_representation"]) def system_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.MatrixLike | None = None, joint_torques: jtp.VectorLike | None = None, + output_representation: VelRepr | None = None, ) -> tuple[jtp.Vector, jtp.Vector, dict[str, jtp.PyTree]]: """ Compute the system acceleration in the active representation. @@ -30,12 +34,21 @@ def system_acceleration( The 6D forces to apply to the links expressed in the same velocity representation of data. joint_torques: The joint torques applied to the joints. + output_representation: + The desired output velocity representation. If None, the current + velocity representation of `data` is used. Returns: A tuple containing the base 6D acceleration in the active representation, the joint accelerations, and the contact state. """ + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) + # ==================== # Validate input data # ==================== @@ -110,7 +123,7 @@ def system_acceleration( references = js.references.JaxSimModelReferences.build( model=model, data=data, - velocity_representation=data.velocity_representation, + velocity_representation=output_representation, link_forces=W_f_L_total, ) @@ -126,16 +139,19 @@ def system_acceleration( data=data, joint_forces=joint_torques, link_forces=references.link_forces(model=model, data=data), + output_representation=output_representation, ) return v̇_WB, s̈, contact_state -@jax.jit +@functools.partial(jax.jit, static_argnames=["output_representation"]) @js.common.named_scope def system_position_dynamics( data: js.data.JaxSimModelData, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, + *, + output_representation: VelRepr | None = None, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: r""" Compute the dynamics of the system position. @@ -144,6 +160,9 @@ def system_position_dynamics( data: The data of the considered model. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient for adjusting the quaternion norm. + output_representation: + The desired output velocity representation. If None, the current + velocity representation of `data` is used. Returns: A tuple containing the derivative of the base position, the derivative of the @@ -156,10 +175,19 @@ def system_position_dynamics( Where :math:`S(\cdot)` is the skew-symmetric matrix operator. """ + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) + ṡ = data.joint_velocities W_Q_B = data.base_orientation - W_ω_WB = data.base_velocity[3:6] - W_ṗ_B = data.base_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position + W_ω_WB = data.base_velocity(output_representation)[3:6] + W_ṗ_B = ( + data.base_velocity(output_representation)[0:3] + + Skew.wedge(W_ω_WB) @ data.base_position + ) W_Q̇_B = Quaternion.derivative( quaternion=W_Q_B, @@ -179,6 +207,7 @@ def system_dynamics( *, link_forces: jtp.Vector | None = None, joint_torques: jtp.Vector | None = None, + output_representation: VelRepr | None = None, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> dict[str, jtp.Vector]: """ @@ -191,6 +220,9 @@ def system_dynamics( The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. joint_torques: The joint torques acting on the joints. + output_representation: + The desired output velocity representation. If None, the current + velocity representation of `data` is used. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient used to adjust the norm of the quaternion (only used in integrators not operating on the SO(3) manifold). @@ -201,18 +233,25 @@ def system_dynamics( joint velocities. """ - with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): - W_v̇_WB, s̈, contact_state_derivative = system_acceleration( - model=model, - data=data, - joint_torques=joint_torques, - link_forces=link_forces, - ) + output_representation = ( + output_representation + if output_representation is not None + else data.velocity_representation + ) - W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( - data=data, - baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, - ) + W_v̇_WB, s̈, contact_state_derivative = system_acceleration( + model=model, + data=data, + joint_torques=joint_torques, + link_forces=link_forces, + output_representation=VelRepr.Inertial, + ) + + W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( + data=data, + baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, + output_representation=VelRepr.Inertial, + ) return dict( base_position=W_ṗ_B, diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index e1fc0256f..a8d4a83a9 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -164,20 +164,23 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool: # Extract quantities # ================== @js.common.named_scope - @functools.partial(jax.jit, static_argnames=["link_names"]) + @functools.partial(jax.jit, static_argnames=["link_names", "output_representation"]) def link_forces( self, model: js.model.JaxSimModel | None = None, data: js.data.JaxSimModelData | None = None, link_names: tuple[str, ...] | None = None, + output_representation: VelRepr | None = None, ) -> jtp.Matrix: """ - Return the link forces expressed in the frame of the active representation. + Return the link forces expressed in the specified representation. Args: model: The model to consider. data: The data of the considered model. link_names: The names of the links corresponding to the forces. + output_representation: + The desired output representation. If None, uses self.velocity_representation. Returns: If no model and no link names are provided, the link forces as a @@ -198,10 +201,17 @@ def link_forces( W_f_L = self._link_forces - # Return all link forces in inertial-fixed representation using the implicit + # Determine the output representation + output_repr = ( + output_representation + if output_representation is not None + else self.velocity_representation + ) + + # Return all link forces in the desired representation using the implicit # serialization. if model is None: - if self.velocity_representation is not VelRepr.Inertial: + if output_repr is not VelRepr.Inertial: msg = "Missing model to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) @@ -218,7 +228,7 @@ def link_forces( ) # In inertial-fixed representation, we already have the link forces. - if self.velocity_representation is VelRepr.Inertial: + if output_repr is VelRepr.Inertial: return W_f_L[link_idxs, :] if data is None: @@ -228,14 +238,14 @@ def link_forces( if not_tracing(self._link_forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") - # Helper function to convert a single 6D force to the active representation + # Helper function to convert a single 6D force to the desired representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: return jax.vmap( lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( array=W_f_L, - other_representation=self.velocity_representation, + other_representation=output_repr, transform=W_H_L, is_force=True, ) @@ -347,7 +357,9 @@ def replace(forces: jtp.Vector) -> JaxSimModelReferences: return replace(forces=self._joint_force_references.at[joint_idxs].set(forces)) @js.common.named_scope - @functools.partial(jax.jit, static_argnames=["link_names", "additive"]) + @functools.partial( + jax.jit, static_argnames=["link_names", "additive", "input_representation"] + ) def apply_link_forces( self, forces: jtp.MatrixLike, @@ -355,32 +367,42 @@ def apply_link_forces( data: js.data.JaxSimModelData | None = None, link_names: tuple[str, ...] | str | None = None, additive: bool = False, + input_representation: VelRepr | None = None, ) -> Self: """ Apply the link forces. Args: - forces: The link 6D forces in the active representation. + forces: The link 6D forces in the specified input representation. model: The model to consider, only needed if a link serialization different from the implicit one is used. data: - The data of the considered model, only needed if the velocity + The data of the considered model, only needed if the input representation is not inertial-fixed. link_names: The names of the links corresponding to the forces. additive: Whether to add the forces to the existing ones instead of replacing them. + input_representation: + The representation of the input forces. If None, uses self.velocity_representation. Returns: A new `JaxSimModelReferences` object with the given link forces. Note: - The link forces must be expressed in the active representation. - Then, we always convert and store forces in inertial-fixed representation. + The link forces must be expressed in the specified input representation. + They are always converted and stored internally in inertial-fixed representation. """ f_L = jnp.atleast_2d(forces).astype(float) + # Determine the input representation + input_repr = ( + input_representation + if input_representation is not None + else self.velocity_representation + ) + # Helper function to replace the link forces. def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: return self.replace( @@ -391,7 +413,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # In this case, we allow only to set the inertial 6D forces to all links # using the implicit link serialization. if model is None: - if self.velocity_representation is not VelRepr.Inertial: + if input_repr is not VelRepr.Inertial: msg = "Missing model to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) @@ -421,7 +443,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: ) # If inertial-fixed representation, we can directly store the link forces. - if self.velocity_representation is VelRepr.Inertial: + if input_repr is VelRepr.Inertial: W_f_L = f_L return replace( forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L) @@ -438,10 +460,10 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # Convert a single 6D force to the inertial representation # considering as body the link (i.e. L_f_L and LW_f_L). - # The f_L input is either L_f_L or LW_f_L, depending on the representation. + # The f_L input is either L_f_L or LW_f_L, depending on the input representation. W_f_L = JaxSimModelReferences.other_representation_to_inertial( array=f_L, - other_representation=self.velocity_representation, + other_representation=input_repr, transform=W_H_L[link_idxs] if model.number_of_links() > 1 else W_H_L, is_force=True, ) @@ -497,7 +519,7 @@ def apply_frame_forces( ] exceptions.raise_value_error_if( - condition=~data.valid(model=model), + condition=not data.valid(model=model), msg="The provided data is not valid for the model", ) W_H_Fi = jax.vmap( @@ -530,17 +552,12 @@ def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix: mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links()) W_f_L = mask.T @ W_f_F - with self.switch_velocity_representation( - velocity_representation=VelRepr.Inertial - ): - references = self.apply_link_forces( - model=model, - data=data, - forces=W_f_L, - additive=additive, - ) - - with references.switch_velocity_representation( - velocity_representation=self.velocity_representation - ): - return references + # Apply the forces, specifying that they are already in Inertial representation. + # The output will maintain the current velocity representation. + return self.apply_link_forces( + model=model, + data=data, + forces=W_f_L, + additive=additive, + input_representation=VelRepr.Inertial, + ) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index d35f64d85..d4273103b 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -352,36 +352,41 @@ def compute_contact_forces( # collidable points. W_H_C = js.contact.transforms(model=model, data=data) - with ( - data.switch_velocity_representation(VelRepr.Mixed), - references.switch_velocity_representation(VelRepr.Mixed), - ): - BW_ν = data.generalized_velocity - - BW_ν̇_free = jnp.hstack( - js.model.forward_dynamics_aba( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - joint_forces=references.joint_force_references(model=model), - ) + BW_ν = data.generalized_velocity(VelRepr.Mixed) + + BW_ν̇_free = jnp.hstack( + js.model.forward_dynamics_aba( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_forces=references.joint_force_references(model=model), + output_representation=VelRepr.Mixed, ) + ) - M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) + M_inv = js.model.free_floating_mass_matrix_inverse( + model=model, data=data, output_representation=VelRepr.Mixed + ) - # Compute the linear part of the Jacobian of the collidable points - Jl_WC = jnp.vstack( - jax.vmap(lambda J, δ: J * (δ > 0))( - js.contact.jacobian(model=model, data=data)[:, :3, :], δ - ) + # Compute the linear part of the Jacobian of the collidable points + Jl_WC = jnp.vstack( + jax.vmap(lambda J, δ: J * (δ > 0))( + js.contact.jacobian( + model=model, data=data, input_representation=VelRepr.Mixed + )[:, :3, :], + δ, ) + ) - # Compute the linear part of the Jacobian derivative of the collidable points - J̇l_WC = jnp.vstack( - jax.vmap(lambda J̇, δ: J̇ * (δ > 0))( - js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ - ), - ) + # Compute the linear part of the Jacobian derivative of the collidable points + J̇l_WC = jnp.vstack( + jax.vmap(lambda J̇, δ: J̇ * (δ > 0))( + js.contact.jacobian_derivative( + model=model, data=data, input_representation=VelRepr.Mixed + )[:, :3], + δ, + ), + ) # Compute the Delassus matrix directly using J and J̇. G_contacts = Jl_WC @ M_inv @ Jl_WC.T diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 43d89b32e..a0ecfe893 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -287,27 +287,29 @@ def compute_contact_forces( W_H_C = js.contact.transforms(model=model, data=data) - with ( - references.switch_velocity_representation(VelRepr.Mixed), - data.switch_velocity_representation(VelRepr.Mixed), - ): - # Compute kin-dyn quantities used in the contact model. - BW_ν = data.generalized_velocity - - M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) - - J_WC = js.contact.jacobian(model=model, data=data) - J̇_WC = js.contact.jacobian_derivative(model=model, data=data) - - # Compute the generalized free acceleration. - BW_ν̇_free = jnp.hstack( - js.model.forward_dynamics_aba( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - joint_forces=references.joint_force_references(model=model), - ) + # Compute kin-dyn quantities used in the contact model. + BW_ν = data.generalized_velocity(VelRepr.Mixed) + + M_inv = js.model.free_floating_mass_matrix_inverse( + model=model, data=data, output_representation=VelRepr.Mixed + ) + + J_WC = js.contact.jacobian( + model=model, data=data, output_vel_repr=VelRepr.Mixed + ) + J̇_WC = js.contact.jacobian_derivative( + model=model, data=data, output_vel_repr=VelRepr.Mixed + ) + + # Compute the generalized free acceleration. + BW_ν̇_free = jnp.hstack( + js.model.forward_dynamics_aba( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_forces=references.joint_force_references(model=model), ) + ) # Compute the free linear acceleration of the collidable points. # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. @@ -409,28 +411,29 @@ def update_velocity_after_impact( in_axes=(0, 0, None), )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) - with data.switch_velocity_representation(VelRepr.Mixed): - J_WC = js.contact.jacobian(model, data)[ - indices_of_enabled_collidable_points - ] - M = js.model.free_floating_mass_matrix(model, data) - BW_ν_pre_impact = data.generalized_velocity - - # Compute the impact velocity. - # It may be discontinuous in case new contacts are made. - BW_ν_post_impact = RigidContacts.compute_impact_velocity( - generalized_velocity=BW_ν_pre_impact, - inactive_collidable_points=(δ <= 0), - M=M, - J_WC=J_WC, - ) + J_WC = js.contact.jacobian(model, data, input_representation=VelRepr.Mixed)[ + indices_of_enabled_collidable_points + ] + M = js.model.free_floating_mass_matrix( + model, data, output_representation=VelRepr.Mixed + ) + BW_ν_pre_impact = data.generalized_velocity(VelRepr.Mixed) - BW_ν_post_impact_inertial = data.other_representation_to_inertial( - array=BW_ν_post_impact[0:6], - other_representation=VelRepr.Mixed, - transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)), - is_force=False, - ) + # Compute the impact velocity. + # It may be discontinuous in case new contacts are made. + BW_ν_post_impact = RigidContacts.compute_impact_velocity( + generalized_velocity=BW_ν_pre_impact, + inactive_collidable_points=(δ <= 0), + M=M, + J_WC=J_WC, + ) + + BW_ν_post_impact_inertial = data.other_representation_to_inertial( + array=BW_ν_post_impact[0:6], + other_representation=VelRepr.Mixed, + transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)), + is_force=False, + ) # Reset the generalized velocity. data = dataclasses.replace( diff --git a/src/jaxsim/rbda/kinematic_constraints.py b/src/jaxsim/rbda/kinematic_constraints.py index 9f0a17d38..9fe7abf16 100644 --- a/src/jaxsim/rbda/kinematic_constraints.py +++ b/src/jaxsim/rbda/kinematic_constraints.py @@ -77,14 +77,16 @@ def _compute_constraint_jacobians_batched( matrices. """ - with data.switch_velocity_representation(VelRepr.Body): - # Doubly-left free-floating Jacobian. - L_J_WL_B = js.model.generalized_free_floating_jacobian( - model=model, data=data, output_vel_repr=VelRepr.Body - ) + # Doubly-left free-floating Jacobian. + L_J_WL_B = js.model.generalized_free_floating_jacobian( + model=model, + data=data, + input_representation=VelRepr.Body, + output_vel_repr=VelRepr.Body, + ) - # Link transforms - W_H_L = data._link_transforms + # Link transforms + W_H_L = data._link_transforms def compute_frame_jacobian_mixed(L_J_WL, W_H_L, W_H_F, parent_link_index): """Compute the jacobian of a frame in mixed representation.""" @@ -237,70 +239,69 @@ def compute_constraint_wrenches( velocity_representation=VelRepr.Inertial, ) - with ( - data.switch_velocity_representation(VelRepr.Mixed), - references.switch_velocity_representation(VelRepr.Mixed), - ): - BW_ν = data.generalized_velocity - - # Compute free acceleration without constraints - BW_ν̇_free = jnp.hstack( - js.model.forward_dynamics_aba( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - joint_forces=references.joint_force_references(model=model), - ) - ) - - # Compute mass matrix - M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) + BW_ν = data.generalized_velocity(VelRepr.Mixed) - W_H_constr_pairs = _compute_constraint_transforms_batched( + # Compute free acceleration without constraints + BW_ν̇_free = jnp.hstack( + js.model.forward_dynamics_aba( model=model, data=data, - constraints=kin_constraints, + link_forces=references.link_forces(model=model, data=data), + joint_forces=references.joint_force_references(model=model), + output_representation=VelRepr.Mixed, ) + ) - # Compute constraint jacobians - J_constr = _compute_constraint_jacobians_batched( - model=model, - data=data, - constraints=kin_constraints, - W_H_constraint_pairs=W_H_constr_pairs, - ) + # Compute mass matrix + M_inv = js.model.free_floating_mass_matrix_inverse( + model=model, data=data, output_representation=VelRepr.Mixed + ) - # Compute Baumgarte stabilization term - constr_baumgarte_term = jnp.ravel( - jax.vmap( - _compute_constraint_baumgarte_term, - in_axes=(0, None, 0, 0), - )( - J_constr, - BW_ν, - W_H_constr_pairs, - kin_constraints, - ), - ) + W_H_constr_pairs = _compute_constraint_transforms_batched( + model=model, + data=data, + constraints=kin_constraints, + ) + + # Compute constraint jacobians + J_constr = _compute_constraint_jacobians_batched( + model=model, + data=data, + constraints=kin_constraints, + W_H_constraint_pairs=W_H_constr_pairs, + ) + + # Compute Baumgarte stabilization term + constr_baumgarte_term = jnp.ravel( + jax.vmap( + _compute_constraint_baumgarte_term, + in_axes=(0, None, 0, 0), + )( + J_constr, + BW_ν, + W_H_constr_pairs, + kin_constraints, + ), + ) - # Stack constraint jacobians - J_constr = jnp.vstack(J_constr) + # Stack constraint jacobians + J_constr = jnp.vstack(J_constr) - # Compute Delassus matrix for constraints - G_constraints = J_constr @ M_inv @ J_constr.T + # Compute Delassus matrix for constraints + G_constraints = J_constr @ M_inv @ J_constr.T - # Compute constraint acceleration - # TODO: add J̇_constr with efficient computation - CW_al_free_constr = J_constr @ BW_ν̇_free + # Compute constraint acceleration + # TODO: add J̇_constr with efficient computation + CW_al_free_constr = J_constr @ BW_ν̇_free - # Setup constraint optimization problem - constraint_regularization = regularization * jnp.ones(n_kin_constraints) - R = jnp.diag(constraint_regularization) - A = G_constraints + R - b = CW_al_free_constr + constr_baumgarte_term + # Setup constraint optimization problem + constraint_regularization = regularization * jnp.ones(n_kin_constraints) + R = jnp.diag(constraint_regularization) + A = G_constraints + R + b = CW_al_free_constr + constr_baumgarte_term - # Solve for constraint forces - kin_constr_wrench_mixed = jnp.linalg.solve(A, -b).reshape(-1, 6) + # Solve for constraint forces + kin_constr_wrench_mixed = jnp.linalg.solve(A, -b).reshape(-1, 6) def transform_wrenches_to_inertial(wrench, transform_pair): """ diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index e4fe0bbdf..e970a5291 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -62,7 +62,7 @@ def test_contact_kinematics( W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) # Compute the velocity of the collidable point using the contact Jacobian. - ν = data.generalized_velocity + ν = data.generalized_velocity() CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] @@ -93,7 +93,7 @@ def test_collidable_point_jacobians( W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) # Compute the generalized velocity and the free-floating Jacobian of the frame C. - ν = data.generalized_velocity + ν = data.generalized_velocity() CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) # Compute the velocity of the collidable points using the Jacobians. @@ -169,8 +169,8 @@ def test_contact_jacobian_derivative( base_position=data.base_position, base_quaternion=data.base_orientation, joint_positions=data.joint_positions, - base_linear_velocity=data.base_velocity[0:3], - base_angular_velocity=data.base_velocity[3:6], + base_linear_velocity=data.base_velocity()[0:3], + base_angular_velocity=data.base_velocity()[3:6], joint_velocities=data.joint_velocities, velocity_representation=velocity_representation, ) diff --git a/tests/test_api_data.py b/tests/test_api_data.py index 541544962..1fb61af44 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -1,11 +1,8 @@ import jax -import jax.numpy as jnp -import pytest from numpy.testing import assert_raises import jaxsim.api as js from jaxsim import VelRepr -from jaxsim.utils import Mutability from . import utils from .utils import assert_allclose @@ -21,46 +18,6 @@ def test_data_valid( assert data.valid(model=model) -def test_data_switch_velocity_representation( - jaxsim_models_types: js.model.JaxSimModel, - prng_key: jax.Array, -): - - model = jaxsim_models_types - - _, subkey = jax.random.split(prng_key, num=2) - data = js.data.random_model_data( - model=model, key=subkey, velocity_representation=VelRepr.Inertial - ) - - # ===== - # Tests - # ===== - - new_base_linear_velocity = jnp.array([1.0, -2.0, 3.0]) - old_base_linear_velocity = data._base_linear_velocity - - # The following should not change the original `data` object since it raises. - with pytest.raises(RuntimeError): - with data.switch_velocity_representation( - velocity_representation=VelRepr.Inertial - ): - with data.mutable_context(mutability=Mutability.MUTABLE): - data._base_linear_velocity = new_base_linear_velocity - raise RuntimeError("This is raised on purpose inside this context") - - assert_allclose(data._base_linear_velocity, old_base_linear_velocity) - - # The following instead should result to an updated `data` object. - with ( - data.switch_velocity_representation(velocity_representation=VelRepr.Inertial), - data.mutable_context(mutability=Mutability.MUTABLE), - ): - data._base_linear_velocity = new_base_linear_velocity - - assert_allclose(data._base_linear_velocity, new_base_linear_velocity) - - def test_data_change_velocity_representation( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, @@ -81,42 +38,38 @@ def test_data_change_velocity_representation( model=model, data=data ) - with data.switch_velocity_representation(VelRepr.Mixed): - kin_dyn_mixed = utils.build_kindyncomputations_from_jaxsim_model( - model=model, data=data - ) + kin_dyn_mixed = utils.build_kindyncomputations_from_jaxsim_model( + model=model, data=data, vel_repr=VelRepr.Mixed + ) - with data.switch_velocity_representation(VelRepr.Body): - kin_dyn_body = utils.build_kindyncomputations_from_jaxsim_model( - model=model, data=data - ) + kin_dyn_body = utils.build_kindyncomputations_from_jaxsim_model( + model=model, data=data, vel_repr=VelRepr.Body + ) - assert_allclose(data.base_velocity, kin_dyn_inertial.base_velocity()) + assert_allclose(data.base_velocity(), kin_dyn_inertial.base_velocity()) if not model.floating_base(): return - with data.switch_velocity_representation(VelRepr.Mixed): - assert_allclose(data.base_velocity, kin_dyn_mixed.base_velocity()) - assert_raises( - AssertionError, - assert_allclose, - data.base_velocity[0:3], - data._base_linear_velocity, - ) - assert_allclose(data.base_velocity[3:6], data._base_angular_velocity) - - with data.switch_velocity_representation(VelRepr.Body): - assert_allclose(data.base_velocity, kin_dyn_body.base_velocity()) - assert_raises( - AssertionError, - assert_allclose, - data.base_velocity[0:3], - data._base_linear_velocity, - ) - assert_raises( - AssertionError, - assert_allclose, - data.base_velocity[3:6], - data._base_angular_velocity, - ) + assert_allclose(data.base_velocity(VelRepr.Mixed), kin_dyn_mixed.base_velocity()) + assert_raises( + AssertionError, + assert_allclose, + data.base_velocity(VelRepr.Mixed)[0:3], + data._base_linear_velocity, + ) + assert_allclose(data.base_velocity(VelRepr.Mixed)[3:6], data._base_angular_velocity) + + assert_allclose(data.base_velocity(VelRepr.Body), kin_dyn_body.base_velocity()) + assert_raises( + AssertionError, + assert_allclose, + data.base_velocity(VelRepr.Body)[0:3], + data._base_linear_velocity, + ) + assert_raises( + AssertionError, + assert_allclose, + data.base_velocity(VelRepr.Body)[3:6], + data._base_angular_velocity, + ) diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index 30215ee2e..852a26420 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -214,7 +214,7 @@ def test_frame_jacobian_derivative( # =============== # Get the generalized velocity. - I_ν = data.generalized_velocity + I_ν = data.generalized_velocity() # Compute J̇. O_J̇_WF_I = jax.vmap( @@ -257,11 +257,9 @@ def compute_q(data: js.data.JaxSimModelData) -> jax.Array: return q def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: - with data.switch_velocity_representation(VelRepr.Body): - B_ω_WB = data.base_velocity[3:6] + B_ω_WB = data.base_velocity(VelRepr.Body)[3:6] - with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity[0:3] + W_ṗ_B = data.base_velocity(VelRepr.Mixed)[0:3] W_Q̇_B = Quaternion.derivative( quaternion=data.base_orientation, diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 32546ef02..c1156e3e0 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -184,10 +184,9 @@ def test_link_jacobians( {data.velocity_representation} ): - with data.switch_velocity_representation(other_repr): - kin_dyn_other_repr = utils.build_kindyncomputations_from_jaxsim_model( - model=model, data=data - ) + kin_dyn_other_repr = utils.build_kindyncomputations_from_jaxsim_model( + model=model, data=data, vel_repr=other_repr + ) for link_name, link_idx in zip( model.link_names(), @@ -248,17 +247,17 @@ def test_link_bias_acceleration( W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) - with data.switch_velocity_representation(VelRepr.Body): - - W_X_L = jax.vmap( - lambda W_H_L: jaxsim.math.Adjoint.from_transform(transform=W_H_L) - )(W_H_L) + W_X_L = jax.vmap( + lambda W_H_L: jaxsim.math.Adjoint.from_transform(transform=W_H_L) + )(W_H_L) - L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) + L_a_bias_WL = js.model.link_bias_accelerations( + model=model, data=data, output_representation=VelRepr.Body + ) - W_a_bias_WL_converted = jax.vmap( - lambda W_X_L, L_a_bias_WL: W_X_L @ L_a_bias_WL - )(W_X_L, L_a_bias_WL) + W_a_bias_WL_converted = jax.vmap( + lambda W_X_L, L_a_bias_WL: W_X_L @ L_a_bias_WL + )(W_X_L, L_a_bias_WL) assert_allclose(W_a_bias_WL, W_a_bias_WL_converted) @@ -269,19 +268,19 @@ def test_link_bias_acceleration( L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) - with data.switch_velocity_representation(VelRepr.Inertial): - - L_X_W = jax.vmap( - lambda W_H_L: jaxsim.math.Adjoint.from_transform( - transform=W_H_L, inverse=True - ) - )(W_H_L) + L_X_W = jax.vmap( + lambda W_H_L: jaxsim.math.Adjoint.from_transform( + transform=W_H_L, inverse=True + ) + )(W_H_L) - W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) + W_a_bias_WL = js.model.link_bias_accelerations( + model=model, data=data, output_representation=VelRepr.Inertial + ) - L_a_bias_WL_converted = jax.vmap( - lambda L_X_W, W_a_bias_WL: L_X_W @ W_a_bias_WL - )(L_X_W, W_a_bias_WL) + L_a_bias_WL_converted = jax.vmap( + lambda L_X_W, W_a_bias_WL: L_X_W @ W_a_bias_WL + )(L_X_W, W_a_bias_WL) assert_allclose(L_a_bias_WL, L_a_bias_WL_converted) @@ -306,7 +305,7 @@ def test_link_jacobian_derivative( # ===== # Get the generalized velocity. - I_ν = data.generalized_velocity + I_ν = data.generalized_velocity() # Compute J̇. O_J̇_WL_I = jax.vmap( @@ -351,11 +350,9 @@ def compute_q(data: js.data.JaxSimModelData) -> jax.Array: def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: - with data.switch_velocity_representation(VelRepr.Body): - B_ω_WB = data.base_velocity[3:6] + B_ω_WB = data.base_velocity(VelRepr.Body)[3:6] - with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity[0:3] + W_ṗ_B = data.base_velocity(VelRepr.Mixed)[0:3] W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, diff --git a/tests/test_api_model.py b/tests/test_api_model.py index aced9cf80..ca3af1e5e 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -112,8 +112,8 @@ def test_model_creation_and_reduction( base_position=data_full.base_position, base_quaternion=data_full.base_orientation, joint_positions=data_full.joint_positions[joint_idxs], - base_linear_velocity=data_full.base_velocity[0:3], - base_angular_velocity=data_full.base_velocity[3:6], + base_linear_velocity=data_full.base_velocity()[0:3], + base_angular_velocity=data_full.base_velocity()[3:6], joint_velocities=data_full.joint_velocities[joint_idxs], velocity_representation=data_full.velocity_representation, ) @@ -370,37 +370,35 @@ def test_model_jacobian( # Get the J.T @ f product in inertial-fixed input/output representation. # We use doubly right-trivialized jacobian with inertial-fixed 6D forces. - with ( - references.switch_velocity_representation(VelRepr.Inertial), - data.switch_velocity_representation(VelRepr.Inertial), - ): - - f = references.link_forces(model=model, data=data) - assert_allclose(f, references._link_forces) + f = references.link_forces( + model=model, data=data, output_representation=VelRepr.Inertial + ) + assert_allclose(f, references._link_forces) - J = js.model.generalized_free_floating_jacobian(model=model, data=data) - JTf_inertial = jnp.einsum("l6g,l6->g", J, f) + J = js.model.generalized_free_floating_jacobian( + model=model, data=data, input_representation=VelRepr.Inertial + ) + JTf_inertial = jnp.einsum("l6g,l6->g", J, f) for vel_repr in (VelRepr.Body, VelRepr.Mixed): - with references.switch_velocity_representation(vel_repr): - - # Get the jacobian having an inertial-fixed input representation (so that - # it computes the same quantity computed above) and an output representation - # compatible with the frame in which the external forces are expressed. - with data.switch_velocity_representation(VelRepr.Inertial): - - J = js.model.generalized_free_floating_jacobian( - model=model, data=data, output_vel_repr=vel_repr - ) - - # Get the forces in the tested representation and compute the product - # O_J_WL_W.T @ O_f, producing a generalized acceleration in W. - # The resulting acceleration can be tested again the one computed before. - with data.switch_velocity_representation(vel_repr): + # Get the jacobian having an inertial-fixed input representation (so that + # it computes the same quantity computed above) and an output representation + # compatible with the frame in which the external forces are expressed. + J = js.model.generalized_free_floating_jacobian( + model=model, + data=data, + input_representation=VelRepr.Inertial, + output_vel_repr=vel_repr, + ) - f = references.link_forces(model=model, data=data) - JTf_other = jnp.einsum("l6g,l6->g", J, f) - assert_allclose(JTf_inertial, JTf_other, err_msg=vel_repr.name) + # Get the forces in the tested representation and compute the product + # O_J_WL_W.T @ O_f, producing a generalized acceleration in W. + # The resulting acceleration can be tested again the one computed before. + f = references.link_forces( + model=model, data=data, output_representation=vel_repr + ) + JTf_other = jnp.einsum("l6g,l6->g", J, f) + assert_allclose(JTf_inertial, JTf_other, err_msg=vel_repr.name) def test_coriolis_matrix( @@ -420,7 +418,7 @@ def test_coriolis_matrix( # Tests # ===== - I_ν = data.generalized_velocity + I_ν = data.generalized_velocity() C = js.model.free_floating_coriolis_matrix(model=model, data=data) h = js.model.free_floating_bias_forces(model=model, data=data) @@ -457,11 +455,8 @@ def compute_q(data: js.data.JaxSimModelData) -> jax.Array: def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: - with data.switch_velocity_representation(VelRepr.Body): - B_ω_WB = data.base_velocity[3:6] - - with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity[0:3] + B_ω_WB = data.base_velocity(VelRepr.Body)[3:6] + W_ṗ_B = data.base_velocity(VelRepr.Mixed)[0:3] W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 05e949e91..ff0e970fd 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -81,7 +81,7 @@ def test_ad_aba( W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions - W_v_WB = data.base_velocity + W_v_WB = data.base_velocity() ṡ = data.joint_velocities # Inputs. @@ -135,7 +135,7 @@ def test_ad_rnea( W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions - W_v_WB = data.base_velocity + W_v_WB = data.base_velocity() ṡ = data.joint_velocities # Inputs. @@ -359,7 +359,7 @@ def test_ad_integration( W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions - W_v_WB = data.base_velocity + W_v_WB = data.base_velocity() ṡ = data.joint_velocities # Inputs. @@ -405,7 +405,7 @@ def step( xf_W_p_B = data_xf.base_position xf_W_Q_B = data_xf.base_orientation xf_s = data_xf.joint_positions - xf_W_v_WB = data_xf.base_velocity + xf_W_v_WB = data_xf.base_velocity() xf_ṡ = data_xf.joint_velocities return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 8b30a4f27..8b8bbecbf 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -55,14 +55,14 @@ def test_box_with_external_forces( ) # Apply a link forces to the base link. - with references.switch_velocity_representation(VelRepr.Body): - references = references.apply_link_forces( - forces=jnp.atleast_2d(L_f), - link_names=model.link_names()[0:1], - model=model, - data=data0, - additive=False, - ) + references = references.apply_link_forces( + forces=jnp.atleast_2d(L_f), + link_names=model.link_names()[0:1], + model=model, + data=data0, + additive=False, + input_representation=VelRepr.Body, + ) # Initialize the simulation horizon. tf = 0.5 @@ -115,30 +115,28 @@ def test_box_with_zero_gravity( velocity_representation=velocity_representation, ) - # Apply a link forces to the base link. - with references.switch_velocity_representation(jaxsim.VelRepr.Mixed): - - # Generate a random linear force. - # We enforce them to be the same for all velocity representations so that - # we can compare their outcomes. - LW_f = 10.0 * ( - jax.random.uniform(jax.random.key(0), shape=(model.number_of_links(), 6)) - .at[:, 3:] - .set(jnp.zeros(3)) - ) + # Generate a random linear force. + # We enforce them to be the same for all velocity representations so that + # we can compare their outcomes. + LW_f = 10.0 * ( + jax.random.uniform(jax.random.key(0), shape=(model.number_of_links(), 6)) + .at[:, 3:] + .set(jnp.zeros(3)) + ) - # Note that the context manager does not switch back the newly created - # `references` (that is not the yielded object) to the original representation. - # In the simulation loop below, we need to make sure that we switch both `data` - # and `references` to the same representation before extracting the information - # passed to the step function. - references = references.apply_link_forces( - forces=jnp.atleast_2d(LW_f), - link_names=model.link_names(), - model=model, - data=data0, - additive=False, - ) + # Note that the context manager does not switch back the newly created + # `references` (that is not the yielded object) to the original representation. + # In the simulation loop below, we need to make sure that we switch both `data` + # and `references` to the same representation before extracting the information + # passed to the step function. + references = references.apply_link_forces( + forces=jnp.atleast_2d(LW_f), + link_names=model.link_names(), + model=model, + data=data0, + additive=False, + input_representation=VelRepr.Mixed, + ) tf = 0.01 T = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int) @@ -148,15 +146,14 @@ def test_box_with_zero_gravity( # ... and step the simulation. for _ in T: - with ( - data.switch_velocity_representation(velocity_representation), - references.switch_velocity_representation(velocity_representation), - ): - data = js.model.step( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - ) + data = js.model.step( + model=model, + data=data, + link_forces=references.link_forces( + model=model, data=data, output_representation=velocity_representation + ), + output_representation=velocity_representation, + ) # Check that the box moved as expected. assert_allclose( diff --git a/tests/utils.py b/tests/utils.py index 51acf090f..10204d356 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -31,6 +31,8 @@ def build_kindyncomputations_from_jaxsim_model( data: js.data.JaxSimModelData, considered_joints: list[str] | None = None, removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, + *, + vel_repr: VelRepr | None = None, ) -> KinDynComputations: """ Build a `KinDynComputations` from `JaxSimModel` and `JaxSimModelData`. @@ -42,6 +44,8 @@ def build_kindyncomputations_from_jaxsim_model( The list of joint names to consider in the `KinDynComputations`. removed_joint_positions: A dictionary defining the positions of the removed joints (default is 0). + vel_repr: + The velocity representation to use in the `KinDynComputations`. Returns: The `KinDynComputations` built from the `JaxSimModel` and `JaxSimModelData`. @@ -51,6 +55,8 @@ def build_kindyncomputations_from_jaxsim_model( """ + vel_repr = vel_repr if vel_repr is not None else data.velocity_representation + if ( isinstance(model.built_from, pathlib.Path) and model.built_from.suffix != ".urdf" @@ -89,7 +95,7 @@ def build_kindyncomputations_from_jaxsim_model( kin_dyn = KinDynComputations.build( urdf=model.built_from, considered_joints=considered_joints, - vel_repr=data.velocity_representation, + vel_repr=vel_repr, gravity=np.array([0, 0, model.gravity]), removed_joint_positions=removed_joint_positions, ) @@ -120,13 +126,12 @@ def store_jaxsim_data_in_kindyncomputations( if kin_dyn.dofs() != data.joint_positions.size: raise ValueError(data) - with data.switch_velocity_representation(kin_dyn.vel_repr): - kin_dyn.set_robot_state( - joint_positions=np.array(data.joint_positions), - joint_velocities=np.array(data.joint_velocities), - base_transform=np.array(data._base_transform), - base_velocity=np.array(data.base_velocity), - ) + kin_dyn.set_robot_state( + joint_positions=np.array(data.joint_positions), + joint_velocities=np.array(data.joint_velocities), + base_transform=np.array(data._base_transform), + base_velocity=np.array(data.base_velocity(kin_dyn.vel_repr)), + ) return kin_dyn