diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py
index 0ba06cbce..0bb142fed 100644
--- a/mujoco_warp/_src/derivative.py
+++ b/mujoco_warp/_src/derivative.py
@@ -15,11 +15,13 @@
import warp as wp
+from mujoco_warp._src import math
from mujoco_warp._src.types import BiasType
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import DisableBit
from mujoco_warp._src.types import DynType
from mujoco_warp._src.types import GainType
+from mujoco_warp._src.types import JointType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import TileSet
from mujoco_warp._src.types import vec10f
@@ -240,13 +242,14 @@ def _qderiv_tendon_damping(
@event_scope
-def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)):
+def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float), flg_rne: bool = True):
"""Analytical derivative of smooth forces w.r.t. velocities.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
out: qM - dt * qDeriv (derivatives of smooth forces w.r.t velocities).
+ flg_rne: Whether to include RNE derivatives.
"""
qMi = m.qM_fullm_i
qMj = m.qM_fullm_j
@@ -321,4 +324,303 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)):
outputs=[out],
)
- # TODO(team): rne derivative
+ if flg_rne:
+ rne_vel(m, d, out)
+
+
+@wp.kernel
+def _derivative_com_vel_root(Dcvel_out: wp.array3d(dtype=wp.spatial_vector)):
+ worldid, elementid, k = wp.tid()
+ Dcvel_out[worldid, 0, k][elementid] = 0.0
+
+
+@wp.kernel
+def _derivative_com_vel_level(
+ # Model:
+ nv: int,
+ body_parentid: wp.array(dtype=int),
+ body_jntnum: wp.array(dtype=int),
+ body_jntadr: wp.array(dtype=int),
+ body_dofadr: wp.array(dtype=int),
+ jnt_type: wp.array(dtype=int),
+ # Data in:
+ qvel_in: wp.array2d(dtype=float),
+ cdof_in: wp.array2d(dtype=wp.spatial_vector),
+ # In:
+ body_tree_: wp.array(dtype=int),
+ # Data out:
+ # Out:
+ Dcvel_out: wp.array3d(dtype=wp.spatial_vector),
+ Dcdof_dot_out: wp.array3d(dtype=wp.spatial_vector),
+):
+ worldid, nodeid, k = wp.tid()
+ bodyid = body_tree_[nodeid]
+ dofid = body_dofadr[bodyid]
+ jntid = body_jntadr[bodyid]
+ jntnum = body_jntnum[bodyid]
+ pid = body_parentid[bodyid]
+
+ # Initialize from parent
+ cvel_k = Dcvel_out[worldid, pid, k]
+
+ if jntnum == 0:
+ Dcvel_out[worldid, bodyid, k] = cvel_k
+ return
+
+ qvel = qvel_in[worldid]
+ cdof = cdof_in[worldid]
+
+ for j in range(jntid, jntid + jntnum):
+ jnttype = jnt_type[j]
+
+ if jnttype == JointType.FREE:
+ # cvel += cdof * qvel
+ if k >= dofid and k < dofid + 3:
+ cvel_k += cdof[k]
+ elif k >= dofid + 3 and k < dofid + 6:
+ cvel_k += cdof[k]
+
+ if k < nv:
+ Dcdof_dot_out[worldid, dofid + 3, k] = math.motion_cross(cvel_k, cdof[dofid + 3])
+ Dcdof_dot_out[worldid, dofid + 4, k] = math.motion_cross(cvel_k, cdof[dofid + 4])
+ Dcdof_dot_out[worldid, dofid + 5, k] = math.motion_cross(cvel_k, cdof[dofid + 5])
+
+ dofid += 6
+
+ elif jnttype == JointType.BALL:
+ if k < nv:
+ Dcdof_dot_out[worldid, dofid + 0, k] = math.motion_cross(cvel_k, cdof[dofid + 0])
+ Dcdof_dot_out[worldid, dofid + 1, k] = math.motion_cross(cvel_k, cdof[dofid + 1])
+ Dcdof_dot_out[worldid, dofid + 2, k] = math.motion_cross(cvel_k, cdof[dofid + 2])
+
+ if k >= dofid and k < dofid + 3:
+ cvel_k += cdof[k]
+
+ dofid += 3
+ else:
+ if k < nv:
+ Dcdof_dot_out[worldid, dofid, k] = math.motion_cross(cvel_k, cdof[dofid])
+
+ if k == dofid:
+ cvel_k += cdof[dofid]
+
+ dofid += 1
+
+ Dcvel_out[worldid, bodyid, k] = cvel_k
+
+
+@wp.func
+def _mul_inert_vec(inert: vec10f, vec: wp.spatial_vector) -> wp.spatial_vector:
+ mass = inert[0]
+ h = wp.vec3(inert[1], inert[2], inert[3])
+ # I_3x3 from symmetric values (xx, yy, zz, xy, xz, yz)
+ # row 0: xx, xy, xz
+ # row 1: xy, yy, yz
+ # row 2: xz, yz, zz
+ I = wp.mat33(inert[4], inert[7], inert[8], inert[7], inert[5], inert[9], inert[8], inert[9], inert[6])
+
+ ang = wp.spatial_top(vec)
+ lin = wp.spatial_bottom(vec)
+
+ res_ang = I * ang + wp.cross(h, lin)
+ res_lin = mass * lin - wp.cross(h, ang)
+
+ return wp.spatial_vector(res_ang, res_lin)
+
+
+@wp.kernel
+def _derivative_rne_forward_level(
+ # Model:
+ nv: int,
+ body_parentid: wp.array(dtype=int),
+ body_dofnum: wp.array(dtype=int),
+ body_dofadr: wp.array(dtype=int),
+ # Data in:
+ qvel_in: wp.array2d(dtype=float),
+ cinert_in: wp.array2d(dtype=vec10f),
+ cvel_in: wp.array2d(dtype=wp.spatial_vector),
+ cdof_dot_in: wp.array2d(dtype=wp.spatial_vector),
+ # In:
+ body_tree_: wp.array(dtype=int),
+ Dcvel_in: wp.array3d(dtype=wp.spatial_vector),
+ Dcdof_dot_in: wp.array3d(dtype=wp.spatial_vector),
+ # Out:
+ Dcacc_out: wp.array3d(dtype=wp.spatial_vector),
+ Dcfrcbody_out: wp.array3d(dtype=wp.spatial_vector),
+):
+ worldid, nodeid, k = wp.tid()
+ bodyid = body_tree_[nodeid]
+ dofid = body_dofadr[bodyid]
+ dofnum = body_dofnum[bodyid]
+ pid = body_parentid[bodyid]
+
+ dcacc = Dcacc_out[worldid, pid, k]
+
+ qvel = qvel_in[worldid]
+
+ for j in range(dofid, dofid + dofnum):
+ # Term 1: cdof_dot * d(qvel)/dk
+ if j == k:
+ dcacc += cdof_dot_in[worldid, j]
+
+ # Term 2: Dcdofdot * qvel
+ dcdofdot = Dcdof_dot_in[worldid, j, k]
+ dcacc += dcdofdot * qvel[j]
+
+ Dcacc_out[worldid, bodyid, k] = dcacc
+
+ # Dcfrcbody calculation
+ cinert = cinert_in[worldid, bodyid]
+ cvel = cvel_in[worldid, bodyid]
+ dcvel = Dcvel_in[worldid, bodyid, k]
+
+ # term1 = cinert * dcacc
+ term1 = _mul_inert_vec(cinert, dcacc)
+
+ cinert_cvel = _mul_inert_vec(cinert, cvel)
+ cinert_dcvel = _mul_inert_vec(cinert, dcvel)
+
+ term2 = math.motion_cross_force(dcvel, cinert_cvel) + math.motion_cross_force(cvel, cinert_dcvel)
+
+ Dcfrcbody_out[worldid, bodyid, k] = term1 + term2
+
+
+@wp.kernel
+def _derivative_rne_backward_level(
+ # Model:
+ body_parentid: wp.array(dtype=int),
+ # In:
+ body_tree_: wp.array(dtype=int),
+ # Out:
+ Dcfrcbody_out: wp.array3d(dtype=wp.spatial_vector),
+):
+ worldid, nodeid, k = wp.tid()
+ bodyid = body_tree_[nodeid]
+ pid = body_parentid[bodyid]
+
+ if pid == 0 and bodyid == 0:
+ return # World body has no parent to add to
+
+ val = Dcfrcbody_out[worldid, bodyid, k]
+ wp.atomic_add(Dcfrcbody_out[worldid, pid], k, val)
+
+
+@wp.kernel
+def _derivative_rne_update_sparse(
+ # Model:
+ dof_bodyid: wp.array(dtype=int),
+ # Data in:
+ cdof_in: wp.array2d(dtype=wp.spatial_vector),
+ # In:
+ timestep: wp.array(dtype=float),
+ qMi: wp.array(dtype=int),
+ qMj: wp.array(dtype=int),
+ Dcfrcbody_in: wp.array3d(dtype=wp.spatial_vector),
+ # Out:
+ qDeriv_out: wp.array3d(dtype=float),
+):
+ worldid, elemid = wp.tid()
+ dt = timestep[worldid % timestep.shape[0]]
+
+ i = qMi[elemid]
+ j = qMj[elemid]
+
+ # qDeriv[i, j] -= cdof[i] * Dcfrcbody[body(i), j]
+
+ body_i = dof_bodyid[i]
+ dcfrc = Dcfrcbody_in[worldid, body_i, j]
+ term = wp.dot(cdof_in[worldid, i], dcfrc)
+
+ wp.atomic_add(qDeriv_out[worldid, 0], elemid, -dt * term)
+
+
+@wp.kernel
+def _derivative_rne_update_dense(
+ # Model:
+ dof_bodyid: wp.array(dtype=int),
+ # Data in:
+ cdof_in: wp.array2d(dtype=wp.spatial_vector),
+ # In:
+ timestep: wp.array(dtype=float),
+ Dcfrcbody_in: wp.array3d(dtype=wp.spatial_vector),
+ # Out:
+ qDeriv_out: wp.array3d(dtype=float),
+):
+ worldid, i, j = wp.tid()
+ dt = timestep[worldid % timestep.shape[0]]
+
+ body_i = dof_bodyid[i]
+ dcfrc = Dcfrcbody_in[worldid, body_i, j]
+ term = wp.dot(cdof_in[worldid, i], dcfrc)
+
+ qDeriv_out[worldid, i, j] -= dt * term
+
+
+@event_scope
+def rne_vel(m: Model, d: Data, out: wp.array2d(dtype=float)): # out is qDeriv-like
+ # Temporary dense allocations
+ Dcvel = wp.zeros((d.nworld, m.nbody, m.nv), dtype=wp.spatial_vector)
+ Dcdof_dot = wp.zeros((d.nworld, m.nv, m.nv), dtype=wp.spatial_vector)
+ Dcacc = wp.zeros((d.nworld, m.nbody, m.nv), dtype=wp.spatial_vector)
+ Dcfrcbody = wp.zeros((d.nworld, m.nbody, m.nv), dtype=wp.spatial_vector)
+
+ # Compute Dcvel and Dcdofdot
+ wp.launch(
+ _derivative_com_vel_root,
+ dim=(d.nworld, 1, m.nv),
+ inputs=[Dcvel],
+ outputs=[],
+ )
+
+ for body_tree in m.body_tree:
+ wp.launch(
+ _derivative_com_vel_level,
+ dim=(d.nworld, body_tree.size, m.nv),
+ inputs=[m.nv, m.body_parentid, m.body_jntnum, m.body_jntadr, m.body_dofadr, m.jnt_type, d.qvel, d.cdof, body_tree],
+ outputs=[Dcvel, Dcdof_dot],
+ )
+
+ # Forward pass (Dcacc, Dcfrcbody)
+ for body_tree in m.body_tree:
+ wp.launch(
+ _derivative_rne_forward_level,
+ dim=(d.nworld, body_tree.size, m.nv),
+ inputs=[
+ m.nv,
+ m.body_parentid,
+ m.body_dofnum,
+ m.body_dofadr,
+ d.qvel,
+ d.cinert,
+ d.cvel,
+ d.cdof_dot,
+ body_tree,
+ Dcvel,
+ Dcdof_dot,
+ ],
+ outputs=[Dcacc, Dcfrcbody],
+ )
+
+ # Backward pass (Accumulate Dcfrcbody)
+ for body_tree in reversed(m.body_tree):
+ wp.launch(
+ _derivative_rne_backward_level,
+ dim=(d.nworld, body_tree.size, m.nv),
+ inputs=[m.body_parentid, body_tree],
+ outputs=[Dcfrcbody], # In/Out
+ )
+
+ if m.is_sparse:
+ wp.launch(
+ _derivative_rne_update_sparse,
+ dim=(d.nworld, m.qM_fullm_i.size),
+ inputs=[m.dof_bodyid, d.cdof, m.opt.timestep, m.qM_fullm_i, m.qM_fullm_j, Dcfrcbody],
+ outputs=[out],
+ )
+ else:
+ wp.launch(
+ _derivative_rne_update_dense,
+ dim=(d.nworld, m.nv, m.nv),
+ inputs=[m.dof_bodyid, d.cdof, m.opt.timestep, Dcfrcbody],
+ outputs=[out],
+ )
diff --git a/mujoco_warp/_src/derivative_test.py b/mujoco_warp/_src/derivative_test.py
index da4d4b3bc..105b32f38 100644
--- a/mujoco_warp/_src/derivative_test.py
+++ b/mujoco_warp/_src/derivative_test.py
@@ -42,7 +42,7 @@ def test_smooth_vel(self, jacobian):
mjm, mjd, m, d = test_data.fixture(
xml="""