Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 70 additions & 52 deletions examples/jaxsim_as_multibody_dynamics_library.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@
"\n",
"# Print the default state.\n",
"W_H_B, s = data.generalized_position\n",
"ν = data.generalized_velocity\n",
"ν = data.generalized_velocity()\n",
"\n",
"print(f\"W_H_B: shape={W_H_B.shape}\\n{W_H_B}\\n\")\n",
"print(f\"s: shape={s.shape}\\n{s}\\n\")\n",
Expand Down Expand Up @@ -304,7 +304,9 @@
")\n",
"\n",
"print(f\"link_forces: shape={references.link_forces(model=model, data=data).shape}\")\n",
"print(f\"joint_force_references: shape={references.joint_force_references(model=model).shape}\")"
"print(\n",
" f\"joint_force_references: shape={references.joint_force_references(model=model).shape}\"\n",
")"
]
},
{
Expand Down Expand Up @@ -382,9 +384,15 @@
"# @title Link 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)\n",
"LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed)\n",
"W_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial)\n",
"L_v_WL = js.link.velocity(\n",
" model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body\n",
")\n",
"LW_v_WL = js.link.velocity(\n",
" model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed\n",
")\n",
"W_v_WL = js.link.velocity(\n",
" model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\n",
")\n",
"\n",
"print(f\"Body-fixed velocity L_v_WL={L_v_WL}\")\n",
"print(f\"Mixed velocity: LW_v_WL={LW_v_WL}\")\n",
Expand All @@ -395,17 +403,21 @@
"# the velocity representation of ν, and an output velocity representation that\n",
"# corresponds to the velocity representation of the desired 6D velocity.\n",
"\n",
"# You can use the following context manager to easily switch between representations.\n",
"with data.switch_velocity_representation(VelRepr.Body):\n",
"\n",
" # Body-fixed generalized velocity.\n",
" B_ν = data.generalized_velocity\n",
"# You can set the output velocity representation of quantities depending\n",
"# on the velocity representation stored in data by passing an argument\n",
"# to the function, as shown below.\n",
"# Body-fixed generalized velocity.\n",
"B_ν = data.generalized_velocity(output_representation=VelRepr.Body)\n",
"\n",
" # Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
" # returning an inertial-fixed link velocity.\n",
" W_J_WL_B = js.link.jacobian(\n",
" model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\n",
" )\n",
"# Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
"# returning an inertial-fixed link velocity.\n",
"W_J_WL_B = js.link.jacobian(\n",
" model=model,\n",
" data=data,\n",
" link_index=link_index,\n",
" input_representation=VelRepr.Body,\n",
" output_vel_repr=VelRepr.Inertial,\n",
")\n",
"\n",
"# Now the following relation should hold.\n",
"assert jnp.allclose(W_v_WL, W_J_WL_B @ B_ν)"
Expand Down Expand Up @@ -455,9 +467,15 @@
"# @title Frame 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)\n",
"FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)\n",
"W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)\n",
"F_v_WF = js.frame.velocity(\n",
" model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body\n",
")\n",
"FW_v_WF = js.frame.velocity(\n",
" model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed\n",
")\n",
"W_v_WF = js.frame.velocity(\n",
" model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n",
")\n",
"\n",
"print(f\"Body-fixed velocity F_v_WF={F_v_WF}\")\n",
"print(f\"Mixed velocity: FW_v_WF={FW_v_WF}\")\n",
Expand All @@ -468,17 +486,21 @@
"# the velocity representation of ν, and an output velocity representation that\n",
"# corresponds to the velocity representation of the desired 6D velocity.\n",
"\n",
"# You can use the following context manager to easily switch between representations.\n",
"with data.switch_velocity_representation(VelRepr.Body):\n",
"# You can pass the output representation to getters of quantities depending\n",
"# on the velocity representation stored in data by passing an argument.\n",
"\n",
" # Body-fixed generalized velocity.\n",
" B_ν = data.generalized_velocity\n",
"# Body-fixed generalized velocity.\n",
"B_ν = data.generalized_velocity(VelRepr.Body)\n",
"\n",
" # Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
" # returning an inertial-fixed link velocity.\n",
" W_J_WF_B = js.frame.jacobian(\n",
" model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n",
" )\n",
"# Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
"# returning an inertial-fixed link velocity.\n",
"W_J_WF_B = js.frame.jacobian(\n",
" model=model,\n",
" data=data,\n",
" frame_index=frame_index,\n",
" input_representation=VelRepr.Body,\n",
" output_vel_repr=VelRepr.Inertial,\n",
")\n",
"\n",
"# Now the following relation should hold.\n",
"assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)"
Expand Down Expand Up @@ -662,7 +684,7 @@
" joint_accelerations=s̈,\n",
" # To check that f_B works, let's remove the force applied\n",
" # to the base link from the link forces.\n",
" link_forces=f_L.at[0].set(jnp.zeros(6))\n",
" link_forces=f_L.at[0].set(jnp.zeros(6)),\n",
")\n",
"\n",
"print(f\"f_B: shape={f_B.shape}\")\n",
Expand Down Expand Up @@ -825,30 +847,27 @@
},
"outputs": [],
"source": [
"with (\n",
" data.switch_velocity_representation(VelRepr.Mixed),\n",
" references.switch_velocity_representation(VelRepr.Mixed),\n",
"):\n",
"\n",
" # Compute the mixed generalized velocity.\n",
" BW_ν = data.generalized_velocity\n",
"\n",
" # Compute the mixed generalized acceleration.\n",
" BW_ν̇ = jnp.hstack(\n",
" js.model.forward_dynamics(\n",
" model=model,\n",
" data=data,\n",
" link_forces=references.link_forces(model=model, data=data),\n",
" joint_forces=references.joint_force_references(model=model),\n",
" )\n",
"# Compute the mixed generalized velocity.\n",
"BW_ν = data.generalized_velocity()\n",
"\n",
"# Compute the mixed generalized acceleration.\n",
"BW_ν̇ = jnp.hstack(\n",
" js.model.forward_dynamics(\n",
" model=model,\n",
" data=data,\n",
" link_forces=references.link_forces(\n",
" model=model, data=data, output_representation=VelRepr.Mixed\n",
" ),\n",
" joint_forces=references.joint_force_references(model=model),\n",
" )\n",
")\n",
"\n",
" # Compute the mass matrix in mixed representation.\n",
" BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"# Compute the mass matrix in mixed representation.\n",
"BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"\n",
" # Compute the contact Jacobian and its derivative.\n",
" Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\n",
" J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n",
"# Compute the contact Jacobian and its derivative.\n",
"Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\n",
"J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n",
"\n",
"# Compute the Delassus matrix.\n",
"Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]\n",
Expand All @@ -860,9 +879,8 @@
"print(f\"W_H_C: shape={W_H_C.shape}\")\n",
"\n",
"# Compute the linear velocity of the collidable points.\n",
"with data.switch_velocity_representation(VelRepr.Mixed):\n",
" W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\n",
" print(f\"W_ṗ_B: shape={W_ṗ_B.shape}\")\n",
"W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\n",
"print(f\"W_ṗ_B: shape={W_ṗ_B.shape}\")\n",
"\n",
"# Compute the linear acceleration of the collidable points.\n",
"W_p̈_C = 0\n",
Expand Down
Loading