diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 52e502a44..3a53d7de4 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -46,6 +46,12 @@ from mujoco_warp._src.forward import rungekutta4 as rungekutta4 from mujoco_warp._src.forward import step1 as step1 from mujoco_warp._src.forward import step2 as step2 +from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS +from mujoco_warp._src.grad import diff_forward as diff_forward +from mujoco_warp._src.grad import diff_step as diff_step +from mujoco_warp._src.grad import disable_grad as disable_grad +from mujoco_warp._src.grad import enable_grad as enable_grad +from mujoco_warp._src.grad import make_diff_data as make_diff_data from mujoco_warp._src.inverse import inverse as inverse from mujoco_warp._src.io import create_render_context as create_render_context from mujoco_warp._src.io import get_data_into as get_data_into diff --git a/mujoco_warp/_src/adjoint.py b/mujoco_warp/_src/adjoint.py new file mode 100644 index 000000000..231152066 --- /dev/null +++ b/mujoco_warp/_src/adjoint.py @@ -0,0 +1,107 @@ +"""custom adjoint definitions for MuJoCo Warp autodifferentiation. + +This module centralizes all ``@wp.func_grad`` registrations. It must be +imported before any tape recording so that custom gradients are registered +with Warp's AD system. + +Import this module via ``grad.py`` dont import it directly +""" + +import warp as wp + +from mujoco_warp._src import math + + +@wp.func_grad(math.quat_integrate) +def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat): + """Custom adjoint avoiding gradient singularity at |v|=0.""" + EPS = float(1e-10) + norm_v = wp.length(v) + norm_v_sq = norm_v * norm_v + half_angle = dt * norm_v * 0.5 + + # sinc-safe rotation quaternion construction + if norm_v > EPS: + s_over_nv = wp.sin(half_angle) / norm_v # sin(dt|v|/2) / |v| + c = wp.cos(half_angle) + # d(s_over_nv)/dv_j = ds_coeff * v_j + ds_coeff = (c * dt * 0.5 - s_over_nv) / norm_v_sq + else: + s_over_nv = dt * 0.5 + c = 1.0 + # Taylor limit: (c*dt/2 - s_over_nv) / |v|^2 -> -dt^3/24 + ds_coeff = -dt * dt * dt / 24.0 + + q_rot = wp.quat( + c, + s_over_nv * v[0], + s_over_nv * v[1], + s_over_nv * v[2], + ) + + # recompute forward intermediates + q_len = wp.length(q) + q_inv_len = 1.0 / wp.max(q_len, EPS) + q_n = wp.quat( + q[0] * q_inv_len, + q[1] * q_inv_len, + q[2] * q_inv_len, + q[3] * q_inv_len, + ) + + q_res = math.mul_quat(q_n, q_rot) + res_len = wp.length(q_res) + res_inv = 1.0 / wp.max(res_len, EPS) + + # result = normalize(q_res) + # adj_q_res_k = adj_ret_k / |q_res| - q_res_k * dot(adj_ret, q_res) / |q_res|^3 + dot_ar = adj_ret[0] * q_res[0] + adj_ret[1] * q_res[1] + adj_ret[2] * q_res[2] + adj_ret[3] * q_res[3] + res_inv3 = res_inv * res_inv * res_inv + adj_qr = wp.quat( + adj_ret[0] * res_inv - q_res[0] * dot_ar * res_inv3, + adj_ret[1] * res_inv - q_res[1] * dot_ar * res_inv3, + adj_ret[2] * res_inv - q_res[2] * dot_ar * res_inv3, + adj_ret[3] * res_inv - q_res[3] * dot_ar * res_inv3, + ) + + # q_res = mul_quat(q_n, q_rot) + # adj_q_n = mul_quat(adj_qr, conj(q_rot)) + # adj_q_rot = mul_quat(conj(q_n), adj_qr) + q_rot_conj = wp.quat(q_rot[0], -q_rot[1], -q_rot[2], -q_rot[3]) + adj_qn = math.mul_quat(adj_qr, q_rot_conj) + + q_n_conj = wp.quat(q_n[0], -q_n[1], -q_n[2], -q_n[3]) + adj_q_rot = math.mul_quat(q_n_conj, adj_qr) + + # q_rot = (c, s_over_nv * v) + # d(c)/dv_j = -s_over_nv * dt/2 * v_j + # d(s_over_nv * v_i)/dv_j = ds_coeff * v_j * v_i + s_over_nv * delta_ij + sv_dot = adj_q_rot[1] * v[0] + adj_q_rot[2] * v[1] + adj_q_rot[3] * v[2] + common = -s_over_nv * dt * 0.5 * adj_q_rot[0] + ds_coeff * sv_dot + adj_v_val = wp.vec3( + common * v[0] + s_over_nv * adj_q_rot[1], + common * v[1] + s_over_nv * adj_q_rot[2], + common * v[2] + s_over_nv * adj_q_rot[3], + ) + + # adj_dt from q_rot dependency on dt + # d(c)/d(dt) = -sin(half_angle) * norm_v / 2 + # d(s_over_nv * v_i)/dt = (c / 2) * v_i + adj_dt_val = adj_q_rot[0] * (-wp.sin(half_angle) * norm_v * 0.5) + adj_dt_val += sv_dot * c * 0.5 + + # q_n = normalize(q) + # adj_q_k = adj_qn_k / |q| - q_k * dot(adj_qn, q) / |q|^3 + dot_aqn = adj_qn[0] * q[0] + adj_qn[1] * q[1] + adj_qn[2] * q[2] + adj_qn[3] * q[3] + q_inv_len3 = q_inv_len * q_inv_len * q_inv_len + adj_q_val = wp.quat( + adj_qn[0] * q_inv_len - q[0] * dot_aqn * q_inv_len3, + adj_qn[1] * q_inv_len - q[1] * dot_aqn * q_inv_len3, + adj_qn[2] * q_inv_len - q[2] * dot_aqn * q_inv_len3, + adj_qn[3] * q_inv_len - q[3] * dot_aqn * q_inv_len3, + ) + + # accumulate adjoints + wp.adjoint[q] += adj_q_val + wp.adjoint[v] += adj_v_val + wp.adjoint[dt] += adj_dt_val diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py index baf099a4c..5deb1efd2 100644 --- a/mujoco_warp/_src/derivative.py +++ b/mujoco_warp/_src/derivative.py @@ -25,7 +25,7 @@ from mujoco_warp._src.types import vec10f from mujoco_warp._src.warp_util import event_scope -wp.set_module_options({"enable_backward": False}) +wp.set_module_options({"enable_backward": True}) @wp.kernel diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 80c5f8f49..ee083ee77 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -45,7 +45,7 @@ from mujoco_warp._src.warp_util import cache_kernel from mujoco_warp._src.warp_util import event_scope -wp.set_module_options({"enable_backward": False}) +wp.set_module_options({"enable_backward": True}) @wp.kernel @@ -214,6 +214,12 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) """Advance state and time given activation derivatives and acceleration.""" # TODO(team): can we assume static timesteps? + # Clone arrays used as both input and output so that Warp's tape retains the + # original values for correct reverse-mode AD. + act_in = wp.clone(d.act) + qvel_prev = wp.clone(d.qvel) + qpos_prev = wp.clone(d.qpos) + # advance activations wp.launch( _next_activation, @@ -226,7 +232,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) m.actuator_actlimited, m.actuator_dynprm, m.actuator_actrange, - d.act, + act_in, d.act_dot, 1.0, True, @@ -237,7 +243,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) wp.launch( _next_velocity, dim=(d.nworld, m.nv), - inputs=[m.opt.timestep, d.qvel, qacc, 1.0], + inputs=[m.opt.timestep, qvel_prev, qacc, 1.0], outputs=[d.qvel], ) @@ -247,7 +253,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) wp.launch( _next_position, dim=(d.nworld, m.njnt), - inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, qvel_in, 1.0], + inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, qpos_prev, qvel_in, 1.0], outputs=[d.qpos], ) @@ -774,9 +780,9 @@ def _tendon_actuator_force_clamp( actfrcrange = tendon_actfrcrange[worldid % tendon_actfrcrange.shape[0], tenid] if ten_actfrc < actfrcrange[0]: - actuator_force_out[worldid, actid] *= actfrcrange[0] / ten_actfrc + actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[0] / ten_actfrc) elif ten_actfrc > actfrcrange[1]: - actuator_force_out[worldid, actid] *= actfrcrange[1] / ten_actfrc + actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[1] / ten_actfrc) @wp.kernel @@ -911,6 +917,8 @@ def fwd_actuation(m: Model, d: Data): ], outputs=[d.qfrc_actuator], ) + # clone to break input/output aliasing for correct AD + qfrc_actuator_in = wp.clone(d.qfrc_actuator) wp.launch( _qfrc_actuator_gravcomp_limits, dim=(d.nworld, m.nv), @@ -921,7 +929,7 @@ def fwd_actuation(m: Model, d: Data): m.jnt_actfrcrange, m.dof_jntid, d.qfrc_gravcomp, - d.qfrc_actuator, + qfrc_actuator_in, ], outputs=[d.qfrc_actuator], ) diff --git a/mujoco_warp/_src/grad.py b/mujoco_warp/_src/grad.py new file mode 100644 index 000000000..7eb6fe24d --- /dev/null +++ b/mujoco_warp/_src/grad.py @@ -0,0 +1,150 @@ +"""Autodifferentiation coordination for MuJoCo Warp. + +This module provides utilities for enabling Warp's tape-based reverse-mode +automatic differentiation through the MuJoCo Warp physics pipeline. + +Usage:: + + import mujoco_warp as mjw + + d = mjw.make_diff_data(mjm) # Data with gradient tracking + tape = wp.Tape() + with tape: + mjw.step(m, d) + wp.launch(loss_kernel, dim=1, inputs=[d.xpos, target, loss]) + tape.backward(loss=loss) + grad_ctrl = d.ctrl.grad +""" + +from typing import Callable, Optional, Sequence + +import warp as wp + +from mujoco_warp._src import adjoint as _adjoint # noqa: F401 (register custom adjoints) +from mujoco_warp._src import io +from mujoco_warp._src.forward import forward +from mujoco_warp._src.forward import step +from mujoco_warp._src.types import Data +from mujoco_warp._src.types import Model + +SMOOTH_GRAD_FIELDS: tuple = ( + # primary state, user-controlled inputs + "qpos", + "qvel", + "ctrl", + "act", + "mocap_pos", + "mocap_quat", + "xfrc_applied", + "qfrc_applied", + # position-dependent outputs + "xpos", + "xquat", + "xmat", + "xipos", + "ximat", + "xanchor", + "xaxis", + "geom_xpos", + "geom_xmat", + "site_xpos", + "site_xmat", + "subtree_com", + "cinert", + "crb", + "cdof", + # Velocity-dependent outputs + "cdof_dot", + "cvel", + "subtree_linvel", + "subtree_angmom", + "actuator_velocity", + "ten_velocity", + # body-level intermediate quantities + "cacc", + "cfrc_int", + "cfrc_ext", + # force/acceleration outputs + "qfrc_bias", + "qfrc_spring", + "qfrc_damper", + "qfrc_gravcomp", + "qfrc_fluid", + "qfrc_passive", + "qfrc_actuator", + "qfrc_smooth", + "qacc", + "qacc_smooth", + "actuator_force", + "act_dot", + # inertia matrix + "qM", + "qLD", + "qLDiagInv", + # Tendon + "ten_J", + "ten_length", + # actuator + "actuator_length", + "actuator_moment", + # sensor + "sensordata", +) + + +def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None: + """Enables gradient tracking on Data arrays.""" + if fields is None: + fields = SMOOTH_GRAD_FIELDS + for name in fields: + arr = getattr(d, name, None) + if arr is not None and isinstance(arr, wp.array): + arr.requires_grad = True + + +def disable_grad(d: Data) -> None: + """Disables gradient tracking on all Data arrays.""" + for name in SMOOTH_GRAD_FIELDS: + arr = getattr(d, name, None) + if arr is not None and isinstance(arr, wp.array): + arr.requires_grad = False + + +def make_diff_data( + mjm, + nworld: int = 1, + grad_fields: Optional[Sequence[str]] = None, + **kwargs, +) -> Data: + """Creates a Data object with gradient tracking enabled.""" + d = io.make_data(mjm, nworld=nworld, **kwargs) + enable_grad(d, fields=grad_fields) + return d + + +def diff_step( + m: Model, + d: Data, + loss_fn: Callable[[Model, Data], wp.array], +) -> wp.Tape: + """Runs a differentiable physics step.""" + tape = wp.Tape() + with tape: + step(m, d) + loss = loss_fn(m, d) + tape.backward(loss=loss) + return tape + + +def diff_forward( + m: Model, + d: Data, + loss_fn: Callable[[Model, Data], wp.array], +) -> wp.Tape: + """Runs differentiable forward dynamics (no integration).""" + tape = wp.Tape() + with tape: + forward(m, d) + loss = loss_fn(m, d) + tape.backward(loss=loss) + return tape diff --git a/mujoco_warp/_src/grad_test.py b/mujoco_warp/_src/grad_test.py new file mode 100644 index 000000000..277284940 --- /dev/null +++ b/mujoco_warp/_src/grad_test.py @@ -0,0 +1,566 @@ +"""Tests for autodifferentiation gradients.""" + +# When run as a script, Python adds this file's directory (_src/) to sys.path, +# which causes types.py to shadow the stdlib 'types' module. Replace it with +# the project root so that 'import mujoco_warp' still works. +import os as _os +import sys as _sys + +_src_dir = _os.path.dirname(_os.path.abspath(__file__)) +_project_root = _os.path.dirname(_os.path.dirname(_src_dir)) +if _src_dir in _sys.path: + _sys.path[_sys.path.index(_src_dir)] = _project_root + +import mujoco +import numpy as np +import warp as wp +from absl.testing import absltest +from absl.testing import parameterized + +import mujoco_warp as mjw +from mujoco_warp import test_data +from mujoco_warp._src import math +from mujoco_warp._src.grad import enable_grad + +# tolerance for AD vs finite-difference comparison +_FD_TOL = 1e-3 + +# sparse jacobian to avoid tile kernels (which require cuSolverDx) +_SIMPLE_HINGE_XML = """ + + + + + + + + + + + + + + + + + + + + +""" + +_SIMPLE_SLIDE_XML = """ + + + + + + + + + + + + + + + +""" + +# 3-link chain with mixed joint axes for non-trivial Coriolis gradient. +# planar 2-link same-axis models have mathematically zero d(qfrc_bias)/d(qvel). +_3LINK_HINGE_XML = """ + + + + + + + + + + + + + + + + + + + + + + + + + +""" + +_SIMPLE_FREE_XML = """ + + + + + + + + + + + + +""" + + +def _fd_gradient(fn, x_np, eps=1e-3): + """Central-difference gradient of scalar fn w.r.t. x_np.""" + grad = np.zeros_like(x_np) + for i in range(x_np.size): + x_plus = x_np.copy() + x_minus = x_np.copy() + x_plus.flat[i] += eps + x_minus.flat[i] -= eps + grad.flat[i] = (fn(x_plus) - fn(x_minus)) / (2.0 * eps) + return grad + + +@wp.kernel +def _sum_xpos_kernel( + # Data in: + xpos_in: wp.array2d(dtype=wp.vec3), + # In: + loss: wp.array(dtype=float), +): + worldid, bodyid = wp.tid() + v = xpos_in[worldid, bodyid] + wp.atomic_add(loss, 0, v[0] + v[1] + v[2]) + + +@wp.kernel +def _sum_qacc_kernel( + # Data in: + qacc_in: wp.array2d(dtype=float), + # In: + loss: wp.array(dtype=float), +): + worldid, dofid = wp.tid() + wp.atomic_add(loss, 0, qacc_in[worldid, dofid]) + + +class GradSmoothTest(parameterized.TestCase): + @parameterized.parameters( + ("hinge", _SIMPLE_HINGE_XML), + ("slide", _SIMPLE_SLIDE_XML), + ) + def test_kinematics_grad(self, name, xml): + """dL/dqpos through kinematics(): loss = sum(xpos).""" + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + # AD gradient + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.kinematics(m, d) + mjw.com_pos(m, d) + wp.launch( + _sum_xpos_kernel, + dim=(d.nworld, m.nbody), + inputs=[d.xpos, loss], + ) + tape.backward(loss=loss) + ad_grad = d.qpos.grad.numpy()[0, : mjm.nq].copy() + tape.zero() + + # Finite-difference gradient + def eval_loss(qpos_np): + d_fd = mjw.make_data(mjm) + d_fd.qpos = wp.array(qpos_np.reshape(1, -1), dtype=float) + mjw.kinematics(m, d_fd) + mjw.com_pos(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch( + _sum_xpos_kernel, + dim=(d_fd.nworld, m.nbody), + inputs=[d_fd.xpos, l], + ) + return l.numpy()[0] + + qpos_np = d.qpos.numpy()[0, : mjm.nq] + fd_grad = _fd_gradient(eval_loss, qpos_np) + + np.testing.assert_allclose( + ad_grad, + fd_grad, + atol=_FD_TOL, + rtol=_FD_TOL, + err_msg=f"kinematics grad mismatch ({name})", + ) + + @parameterized.parameters( + ("3link_hinge", _3LINK_HINGE_XML), + ("slide", _SIMPLE_SLIDE_XML), + ) + def test_fwd_velocity_grad(self, name, xml): + """dL/dqvel through fwd_velocity().""" + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.kinematics(m, d) + mjw.com_pos(m, d) + mjw.crb(m, d) + mjw.factor_m(m, d) + mjw.transmission(m, d) + mjw.fwd_velocity(m, d) + wp.launch( + _sum_qacc_kernel, + dim=(d.nworld, m.nv), + inputs=[d.qfrc_bias, loss], + ) + tape.backward(loss=loss) + ad_grad = d.qvel.grad.numpy()[0, : mjm.nv].copy() + tape.zero() + + def eval_loss(qvel_np): + d_fd = mjw.make_data(mjm) + # Copy qpos from original + wp.copy(d_fd.qpos, d.qpos) + d_fd.qvel = wp.array(qvel_np.reshape(1, -1), dtype=float) + mjw.kinematics(m, d_fd) + mjw.com_pos(m, d_fd) + mjw.crb(m, d_fd) + mjw.factor_m(m, d_fd) + mjw.transmission(m, d_fd) + mjw.fwd_velocity(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch( + _sum_qacc_kernel, + dim=(d_fd.nworld, m.nv), + inputs=[d_fd.qfrc_bias, l], + ) + return l.numpy()[0] + + qvel_np = d.qvel.numpy()[0, : mjm.nv] + fd_grad = _fd_gradient(eval_loss, qvel_np) + + np.testing.assert_allclose( + ad_grad, + fd_grad, + atol=_FD_TOL, + rtol=_FD_TOL, + err_msg=f"fwd_velocity grad mismatch ({name})", + ) + + @parameterized.parameters( + ("hinge", _SIMPLE_HINGE_XML), + ) + def test_fwd_actuation_grad(self, name, xml): + """dL/dctrl through fwd_actuation().""" + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.kinematics(m, d) + mjw.com_pos(m, d) + mjw.crb(m, d) + mjw.factor_m(m, d) + mjw.transmission(m, d) + mjw.fwd_velocity(m, d) + mjw.fwd_actuation(m, d) + wp.launch( + _sum_qacc_kernel, + dim=(d.nworld, m.nv), + inputs=[d.qfrc_actuator, loss], + ) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + def eval_loss(ctrl_np): + d_fd = mjw.make_data(mjm) + wp.copy(d_fd.qpos, d.qpos) + wp.copy(d_fd.qvel, d.qvel) + d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) + mjw.kinematics(m, d_fd) + mjw.com_pos(m, d_fd) + mjw.crb(m, d_fd) + mjw.factor_m(m, d_fd) + mjw.transmission(m, d_fd) + mjw.fwd_velocity(m, d_fd) + mjw.fwd_actuation(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch( + _sum_qacc_kernel, + dim=(d_fd.nworld, m.nv), + inputs=[d_fd.qfrc_actuator, l], + ) + return l.numpy()[0] + + ctrl_np = d.ctrl.numpy()[0, : mjm.nu] + fd_grad = _fd_gradient(eval_loss, ctrl_np) + + np.testing.assert_allclose( + ad_grad, + fd_grad, + atol=_FD_TOL, + rtol=_FD_TOL, + err_msg=f"fwd_actuation grad mismatch ({name})", + ) + + @absltest.skipIf( + wp.get_device().is_cuda and wp.get_device().arch < 70, + "tile kernels (cuSolverDx) require sm_70+", + ) + def test_euler_step_grad(self): + """Full Euler step gradient: dL/dctrl through step().""" + xml = _SIMPLE_HINGE_XML + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.step(m, d) + wp.launch( + _sum_xpos_kernel, + dim=(d.nworld, m.nbody), + inputs=[d.xpos, loss], + ) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + def eval_loss(ctrl_np): + _, _, _, d_fd = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d_fd) + d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) + mjw.step(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch( + _sum_xpos_kernel, + dim=(d_fd.nworld, m.nbody), + inputs=[d_fd.xpos, l], + ) + return l.numpy()[0] + + ctrl_np = mjd.ctrl.copy() + fd_grad = _fd_gradient(eval_loss, ctrl_np) + + np.testing.assert_allclose( + ad_grad, + fd_grad, + atol=_FD_TOL, + rtol=_FD_TOL, + err_msg="euler step grad mismatch", + ) + + +@wp.kernel +def _quat_integrate_kernel( + # In: + q_in: wp.array(dtype=wp.quat), + v_in: wp.array(dtype=wp.vec3), + dt_in: wp.array(dtype=float), + # Out: + q_out: wp.array(dtype=wp.quat), +): + i = wp.tid() + q_out[i] = math.quat_integrate(q_in[i], v_in[i], dt_in[i]) + + +@wp.kernel +def _quat_loss_kernel( + # In: + q: wp.array(dtype=wp.quat), + loss: wp.array(dtype=float), +): + i = wp.tid() + v = q[i] + wp.atomic_add(loss, 0, v[0] + v[1] + v[2] + v[3]) + + +class GradQuaternionTest(parameterized.TestCase): + def test_quat_integrate_nonzero_vel(self): + """quat_integrate gradient at non-zero angular velocity.""" + q_np = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + v_np = np.array([0.1, 0.2, 0.3], dtype=np.float32) + dt_np = np.array([0.01], dtype=np.float32) + + q_arr = wp.array([wp.quat(*q_np)], dtype=wp.quat, requires_grad=True) + v_arr = wp.array([wp.vec3(*v_np)], dtype=wp.vec3, requires_grad=True) + dt_arr = wp.array(dt_np, dtype=float, requires_grad=True) + q_out = wp.zeros(1, dtype=wp.quat, requires_grad=True) + loss = wp.zeros(1, dtype=float, requires_grad=True) + + tape = wp.Tape() + with tape: + wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_arr, v_arr, dt_arr, q_out]) + wp.launch(_quat_loss_kernel, dim=1, inputs=[q_out, loss]) + tape.backward(loss=loss) + + ad_grad_v = v_arr.grad.numpy()[0].copy() + tape.zero() + + # Finite-difference + def eval_loss_v(v_test): + q_a = wp.array([wp.quat(*q_np)], dtype=wp.quat) + v_a = wp.array([wp.vec3(*v_test)], dtype=wp.vec3) + dt_a = wp.array(dt_np, dtype=float) + qo = wp.zeros(1, dtype=wp.quat) + l = wp.zeros(1, dtype=float) + wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_a, v_a, dt_a, qo]) + wp.launch(_quat_loss_kernel, dim=1, inputs=[qo, l]) + return l.numpy()[0] + + fd_grad_v = _fd_gradient(eval_loss_v, v_np) + + np.testing.assert_allclose( + ad_grad_v, + fd_grad_v, + atol=5e-3, + rtol=5e-2, + err_msg="quat_integrate grad w.r.t. v (nonzero)", + ) + + def test_quat_integrate_zero_vel(self): + """quat_integrate gradient at zero angular velocity (singularity test).""" + q_np = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + v_np = np.array([0.0, 0.0, 0.0], dtype=np.float32) + dt_np = np.array([0.01], dtype=np.float32) + + q_arr = wp.array([wp.quat(*q_np)], dtype=wp.quat, requires_grad=True) + v_arr = wp.array([wp.vec3(*v_np)], dtype=wp.vec3, requires_grad=True) + dt_arr = wp.array(dt_np, dtype=float, requires_grad=True) + q_out = wp.zeros(1, dtype=wp.quat, requires_grad=True) + loss = wp.zeros(1, dtype=float, requires_grad=True) + + tape = wp.Tape() + with tape: + wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_arr, v_arr, dt_arr, q_out]) + wp.launch(_quat_loss_kernel, dim=1, inputs=[q_out, loss]) + tape.backward(loss=loss) + + ad_grad_v = v_arr.grad.numpy()[0].copy() + tape.zero() + + # Should not be NaN or Inf + self.assertTrue(np.all(np.isfinite(ad_grad_v)), f"quat_integrate grad contains NaN/Inf at zero velocity: {ad_grad_v}") + + # Finite-difference + def eval_loss_v(v_test): + q_a = wp.array([wp.quat(*q_np)], dtype=wp.quat) + v_a = wp.array([wp.vec3(*v_test)], dtype=wp.vec3) + dt_a = wp.array(dt_np, dtype=float) + qo = wp.zeros(1, dtype=wp.quat) + l = wp.zeros(1, dtype=float) + wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_a, v_a, dt_a, qo]) + wp.launch(_quat_loss_kernel, dim=1, inputs=[qo, l]) + return l.numpy()[0] + + fd_grad_v = _fd_gradient(eval_loss_v, v_np) + + np.testing.assert_allclose( + ad_grad_v, + fd_grad_v, + atol=5e-3, + rtol=5e-2, + err_msg="quat_integrate grad w.r.t. v (zero vel)", + ) + + def test_quat_integrate_grad_q(self): + """quat_integrate gradient w.r.t. input quaternion q.""" + q_np = np.array([0.9239, 0.3827, 0.0, 0.0], dtype=np.float32) # ~45 deg rotation + v_np = np.array([0.1, 0.2, 0.3], dtype=np.float32) + dt_np = np.array([0.01], dtype=np.float32) + + q_arr = wp.array([wp.quat(*q_np)], dtype=wp.quat, requires_grad=True) + v_arr = wp.array([wp.vec3(*v_np)], dtype=wp.vec3, requires_grad=True) + dt_arr = wp.array(dt_np, dtype=float, requires_grad=True) + q_out = wp.zeros(1, dtype=wp.quat, requires_grad=True) + loss = wp.zeros(1, dtype=float, requires_grad=True) + + tape = wp.Tape() + with tape: + wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_arr, v_arr, dt_arr, q_out]) + wp.launch(_quat_loss_kernel, dim=1, inputs=[q_out, loss]) + tape.backward(loss=loss) + + ad_grad_q = q_arr.grad.numpy()[0].copy() + tape.zero() + + def eval_loss_q(q_test): + q_a = wp.array([wp.quat(*q_test)], dtype=wp.quat) + v_a = wp.array([wp.vec3(*v_np)], dtype=wp.vec3) + dt_a = wp.array(dt_np, dtype=float) + qo = wp.zeros(1, dtype=wp.quat) + l = wp.zeros(1, dtype=float) + wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_a, v_a, dt_a, qo]) + wp.launch(_quat_loss_kernel, dim=1, inputs=[qo, l]) + return l.numpy()[0] + + fd_grad_q = _fd_gradient(eval_loss_q, q_np) + + np.testing.assert_allclose( + ad_grad_q, + fd_grad_q, + atol=5e-2, + rtol=5e-2, + err_msg="quat_integrate grad w.r.t. q", + ) + + +class GradUtilTest(absltest.TestCase): + def test_enable_disable_grad(self): + """enable_grad / disable_grad toggle requires_grad on Data fields.""" + mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML) + d = mjw.make_data(mjm) + + # Initially, requires_grad should be False + self.assertFalse(d.qpos.requires_grad) + + mjw.enable_grad(d) + self.assertTrue(d.qpos.requires_grad) + self.assertTrue(d.qvel.requires_grad) + self.assertTrue(d.ctrl.requires_grad) + + mjw.disable_grad(d) + self.assertFalse(d.qpos.requires_grad) + + def test_make_diff_data(self): + """make_diff_data returns Data with gradient tracking enabled.""" + mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML) + d = mjw.make_diff_data(mjm) + + self.assertTrue(d.qpos.requires_grad) + self.assertTrue(d.qvel.requires_grad) + self.assertTrue(d.ctrl.requires_grad) + self.assertTrue(d.xpos.requires_grad) + self.assertTrue(d.qacc.requires_grad) + + def test_make_diff_data_custom_fields(self): + """make_diff_data with a custom field list.""" + mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML) + d = mjw.make_diff_data(mjm, grad_fields=["qpos", "xpos"]) + + self.assertTrue(d.qpos.requires_grad) + self.assertTrue(d.xpos.requires_grad) + self.assertFalse(d.qvel.requires_grad) + self.assertFalse(d.ctrl.requires_grad) + + +if __name__ == "__main__": + absltest.main() diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index 059ba6307..f8f534af1 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -25,7 +25,7 @@ from mujoco_warp._src.types import Model from mujoco_warp._src.warp_util import event_scope -wp.set_module_options({"enable_backward": False}) +wp.set_module_options({"enable_backward": True}) @wp.func @@ -442,8 +442,8 @@ def _fluid_force( lfrc_torque -= drag_ang_coef * l_ang lfrc_force += magnus_force + kutta_force - drag_lin_coef * l_lin - lfrc_torque *= coef - lfrc_force *= coef + lfrc_torque = lfrc_torque * coef + lfrc_force = lfrc_force * coef # map force/torque from local to world frame: lfrc -> bfrc torque_global += geom_rot @ lfrc_torque diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 7354e16c6..0254ec21f 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -38,7 +38,53 @@ from mujoco_warp._src.warp_util import cache_kernel from mujoco_warp._src.warp_util import event_scope -wp.set_module_options({"enable_backward": False}) +wp.set_module_options({"enable_backward": True}) + + +# kernel_analyzer: off +@wp.func +def _process_joint( + xpos: wp.vec3, + xquat: wp.quat, + jntadr: int, + jnt_pos_id: int, + worldid: int, + qpos0: wp.array2d(dtype=float), + jnt_type: wp.array(dtype=int), + jnt_qposadr: wp.array(dtype=int), + jnt_pos: wp.array2d(dtype=wp.vec3), + jnt_axis: wp.array2d(dtype=wp.vec3), + qpos: wp.array(dtype=float), + xanchor_out: wp.array2d(dtype=wp.vec3), + xaxis_out: wp.array2d(dtype=wp.vec3), +): + """Process a single joint and return updated xpos, xquat.""" + qadr = jnt_qposadr[jntadr] + jnt_type_ = jnt_type[jntadr] + jnt_axis_ = jnt_axis[worldid % jnt_axis.shape[0], jntadr] + xanchor = math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) + xpos + xaxis = math.rot_vec_quat(jnt_axis_, xquat) + + if jnt_type_ == JointType.BALL: + qloc = wp.quat(qpos[qadr + 0], qpos[qadr + 1], qpos[qadr + 2], qpos[qadr + 3]) + qloc = wp.normalize(qloc) + xquat = math.mul_quat(xquat, qloc) + xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) + elif jnt_type_ == JointType.SLIDE: + xpos = xpos + xaxis * (qpos[qadr] - qpos0[worldid % qpos0.shape[0], qadr]) + elif jnt_type_ == JointType.HINGE: + qpos0_ = qpos0[worldid % qpos0.shape[0], qadr] + qloc_ = math.axis_angle_to_quat(jnt_axis_, qpos[qadr] - qpos0_) + xquat = math.mul_quat(xquat, qloc_) + xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) + + xanchor_out[worldid, jntadr] = xanchor + xaxis_out[worldid, jntadr] = xaxis + + return xpos, xquat + + +# kernel_analyzer: on @wp.kernel @@ -112,31 +158,60 @@ def _kinematics_branch( xpos = math.rot_vec_quat(xpos, xquat_out[worldid, pid]) + xpos_out[worldid, pid] xquat = math.mul_quat(xquat_out[worldid, pid], xquat) - for _ in range(jntnum): - qadr = jnt_qposadr[jntadr] - jnt_type_ = jnt_type[jntadr] - jnt_axis_ = jnt_axis[worldid % jnt_axis.shape[0], jntadr] - xanchor = math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) + xpos - xaxis = math.rot_vec_quat(jnt_axis_, xquat) - - if jnt_type_ == JointType.BALL: - qloc = wp.quat(qpos[qadr + 0], qpos[qadr + 1], qpos[qadr + 2], qpos[qadr + 3]) - qloc = wp.normalize(qloc) - xquat = math.mul_quat(xquat, qloc) - # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) - elif jnt_type_ == JointType.SLIDE: - xpos += xaxis * (qpos[qadr] - qpos0[worldid % qpos0.shape[0], qadr]) - elif jnt_type_ == JointType.HINGE: - qpos0_ = qpos0[worldid % qpos0.shape[0], qadr] - qloc_ = math.axis_angle_to_quat(jnt_axis_, qpos[qadr] - qpos0_) - xquat = math.mul_quat(xquat, qloc_) - # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) - - xanchor_out[worldid, jntadr] = xanchor - xaxis_out[worldid, jntadr] = xaxis - jntadr += 1 + # Unrolled joint processing — avoids nested dynamic-range loop which + # produces incorrect gradients in Warp's AD. + if jntnum >= 1: + xpos, xquat = _process_joint( + xpos, xquat, jntadr, jnt_pos_id, worldid, qpos0, jnt_type, jnt_qposadr, jnt_pos, jnt_axis, qpos, xanchor_out, xaxis_out + ) + if jntnum >= 2: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 1, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 3: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 2, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 4: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 3, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) xquat = wp.normalize(xquat) xpos_out[worldid, bodyid] = xpos @@ -1127,6 +1202,28 @@ def _rne_cacc_world(m: Model, d: Data): wp.launch(_cacc_world, dim=[d.nworld], inputs=[m.opt.gravity], outputs=[d.cacc]) +# kernel_analyzer: off +@wp.func +def _process_dof_cacc( + local_cacc: wp.spatial_vector, + dofadr: int, + worldid: int, + qvel_in: wp.array2d(dtype=float), + qacc_in: wp.array2d(dtype=float), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + cdof_dot_in: wp.array2d(dtype=wp.spatial_vector), + flg_acc: bool, +): + """Accumulate one DOF contribution to body acceleration.""" + local_cacc += cdof_dot_in[worldid, dofadr] * qvel_in[worldid, dofadr] + if flg_acc: + local_cacc += cdof_in[worldid, dofadr] * qacc_in[worldid, dofadr] + return local_cacc + + +# kernel_analyzer: on + + @wp.kernel def _cacc_branch( # Model: @@ -1157,10 +1254,22 @@ def _cacc_branch( bodyid = body_branches[i] dofnum = body_dofnum[bodyid] dofadr = body_dofadr[bodyid] - for j in range(dofnum): - local_cacc += cdof_dot_in[worldid, dofadr + j] * qvel_in[worldid, dofadr + j] - if flg_acc: - local_cacc += cdof_in[worldid, dofadr + j] * qacc_in[worldid, dofadr + j] + + # unrolled dof processing — avoids nested dynamic-range loop which + # produces incorrect gradients in warp's AD + if dofnum >= 1: + local_cacc = _process_dof_cacc(local_cacc, dofadr, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc) + if dofnum >= 2: + local_cacc = _process_dof_cacc(local_cacc, dofadr + 1, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc) + if dofnum >= 3: + local_cacc = _process_dof_cacc(local_cacc, dofadr + 2, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc) + if dofnum >= 4: + local_cacc = _process_dof_cacc(local_cacc, dofadr + 3, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc) + if dofnum >= 5: + local_cacc = _process_dof_cacc(local_cacc, dofadr + 4, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc) + if dofnum >= 6: + local_cacc = _process_dof_cacc(local_cacc, dofadr + 5, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc) + cacc_out[worldid, bodyid] = local_cacc @@ -1216,28 +1325,46 @@ def _rne_cfrc(m: Model, d: Data, flg_cfrc_ext: bool = False): @wp.kernel -def _cfrc_backward( +def _cfrc_backward_level( # Model: body_parentid: wp.array(dtype=int), # Data in: cfrc_int_in: wp.array2d(dtype=wp.spatial_vector), # In: body_tree_: wp.array(dtype=int), + nbody_tree: int, # Data out: cfrc_int_out: wp.array2d(dtype=wp.spatial_vector), ): - worldid, nodeid = wp.tid() - bodyid = body_tree_[nodeid] - pid = body_parentid[bodyid] - if bodyid != 0: - wp.atomic_add(cfrc_int_out[worldid], pid, cfrc_int_in[worldid, bodyid]) + # copy input and accumulate child forces to parents in a single kernel + # to avoid warp AD output-gradient zeroing when separate copy + atomic_add + # target the same array + worldid, bodyid = wp.tid() + val = cfrc_int_in[worldid, bodyid] + for k in range(nbody_tree): + child = body_tree_[k] + if body_parentid[child] == bodyid and child != 0: + val = val + cfrc_int_in[worldid, child] + cfrc_int_out[worldid, bodyid] = val def _rne_cfrc_backward(m: Model, d: Data): + # accumulate child forces to parents using separate arrays at each level + # to avoid warp AD output-gradient zeroing issue (an array that is the + # output of multiple wp.launch / wp.copy calls loses gradient from all + # but the last operation during backward) + current = d.cfrc_int for body_tree in reversed(m.body_tree): + next_cfrc = wp.zeros_like(current) + next_cfrc.requires_grad = True wp.launch( - _cfrc_backward, dim=[d.nworld, body_tree.size], inputs=[m.body_parentid, d.cfrc_int, body_tree], outputs=[d.cfrc_int] + _cfrc_backward_level, + dim=[d.nworld, m.nbody], + inputs=[m.body_parentid, current, body_tree, body_tree.shape[0]], + outputs=[next_cfrc], ) + current = next_cfrc + return current @wp.kernel @@ -1270,8 +1397,10 @@ def rne(m: Model, d: Data, flg_acc: bool = False): _rne_cacc_world(m, d) _rne_cacc_forward(m, d, flg_acc=flg_acc) _rne_cfrc(m, d) - _rne_cfrc_backward(m, d) - wp.launch(_qfrc_bias, dim=[d.nworld, m.nv], inputs=[m.dof_bodyid, d.cdof, d.cfrc_int], outputs=[d.qfrc_bias]) + cfrc_total = _rne_cfrc_backward(m, d) + wp.launch(_qfrc_bias, dim=[d.nworld, m.nv], inputs=[m.dof_bodyid, d.cdof, cfrc_total], outputs=[d.qfrc_bias]) + # update d.cfrc_int with accumulated forces for downstream consumers + d.cfrc_int = cfrc_total @wp.kernel @@ -1578,7 +1707,7 @@ def rne_postconstraint(m: Model, d: Data): _rne_cfrc(m, d, flg_cfrc_ext=True) # backward pass over bodies: accumulate cfrc_int from children - _rne_cfrc_backward(m, d) + d.cfrc_int = _rne_cfrc_backward(m, d) @wp.func @@ -1748,7 +1877,7 @@ def _tendon_dot( dot = wp.dot(dpnt, dvel) dvel += dpnt * (-dot) if norm > MJ_MINVAL: - dvel /= norm + dvel = dvel / norm else: dvel = wp.vec3(0.0) @@ -1938,6 +2067,57 @@ def _comvel_root(cvel_out: wp.array2d(dtype=wp.spatial_vector)): cvel_out[worldid, 0][elementid] = 0.0 +# kernel_analyzer: off +@wp.func +def _process_joint_vel( + cvel: wp.spatial_vector, + dofid: int, + jntadr: int, + worldid: int, + jnt_type: wp.array(dtype=int), + qvel: wp.array(dtype=float), + cdof: wp.array(dtype=wp.spatial_vector), + cdof_dot_out: wp.array2d(dtype=wp.spatial_vector), +): + """Process a single joint for velocity propagation, return updated cvel and dofid.""" + jnttype = jnt_type[jntadr] + + if jnttype == JointType.FREE: + cvel += cdof[dofid + 0] * qvel[dofid + 0] + cvel += cdof[dofid + 1] * qvel[dofid + 1] + cvel += cdof[dofid + 2] * qvel[dofid + 2] + + cdof_dot_out[worldid, dofid + 3] = math.motion_cross(cvel, cdof[dofid + 3]) + cdof_dot_out[worldid, dofid + 4] = math.motion_cross(cvel, cdof[dofid + 4]) + cdof_dot_out[worldid, dofid + 5] = math.motion_cross(cvel, cdof[dofid + 5]) + + cvel += cdof[dofid + 3] * qvel[dofid + 3] + cvel += cdof[dofid + 4] * qvel[dofid + 4] + cvel += cdof[dofid + 5] * qvel[dofid + 5] + + dofid += 6 + elif jnttype == JointType.BALL: + cdof_dot_out[worldid, dofid + 0] = math.motion_cross(cvel, cdof[dofid + 0]) + cdof_dot_out[worldid, dofid + 1] = math.motion_cross(cvel, cdof[dofid + 1]) + cdof_dot_out[worldid, dofid + 2] = math.motion_cross(cvel, cdof[dofid + 2]) + + cvel += cdof[dofid + 0] * qvel[dofid + 0] + cvel += cdof[dofid + 1] * qvel[dofid + 1] + cvel += cdof[dofid + 2] * qvel[dofid + 2] + + dofid += 3 + else: + cdof_dot_out[worldid, dofid] = math.motion_cross(cvel, cdof[dofid]) + cvel += cdof[dofid] * qvel[dofid] + + dofid += 1 + + return cvel, dofid + + +# kernel_analyzer: on + + @wp.kernel def _comvel_branch( # Model: @@ -1975,38 +2155,16 @@ def _comvel_branch( cvel_out[worldid, bodyid] = cvel continue - for j in range(jntid, jntid + jntnum): - jnttype = jnt_type[j] - - if jnttype == JointType.FREE: - cvel += cdof[dofid + 0] * qvel[dofid + 0] - cvel += cdof[dofid + 1] * qvel[dofid + 1] - cvel += cdof[dofid + 2] * qvel[dofid + 2] - - cdof_dot_out[worldid, dofid + 3] = math.motion_cross(cvel, cdof[dofid + 3]) - cdof_dot_out[worldid, dofid + 4] = math.motion_cross(cvel, cdof[dofid + 4]) - cdof_dot_out[worldid, dofid + 5] = math.motion_cross(cvel, cdof[dofid + 5]) - - cvel += cdof[dofid + 3] * qvel[dofid + 3] - cvel += cdof[dofid + 4] * qvel[dofid + 4] - cvel += cdof[dofid + 5] * qvel[dofid + 5] - - dofid += 6 - elif jnttype == JointType.BALL: - cdof_dot_out[worldid, dofid + 0] = math.motion_cross(cvel, cdof[dofid + 0]) - cdof_dot_out[worldid, dofid + 1] = math.motion_cross(cvel, cdof[dofid + 1]) - cdof_dot_out[worldid, dofid + 2] = math.motion_cross(cvel, cdof[dofid + 2]) - - cvel += cdof[dofid + 0] * qvel[dofid + 0] - cvel += cdof[dofid + 1] * qvel[dofid + 1] - cvel += cdof[dofid + 2] * qvel[dofid + 2] - - dofid += 3 - else: - cdof_dot_out[worldid, dofid] = math.motion_cross(cvel, cdof[dofid]) - cvel += cdof[dofid] * qvel[dofid] - - dofid += 1 + # unrolled joint processing — avoids nested dynamic-range loop which + # produces incorrect gradients in warp's AD + if jntnum >= 1: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 2: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 1, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 3: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 2, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 4: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 3, worldid, jnt_type, qvel, cdof, cdof_dot_out) cvel_out[worldid, bodyid] = cvel @@ -2598,7 +2756,7 @@ def _transmission_body_moment_scale( if ncon > 0: actid = actuator_trntype_body_adr[trnbodyid] rowadr = moment_rowadr_in[worldid, actid] - actuator_moment_out[worldid, rowadr + dofid] /= -float(ncon) + actuator_moment_out[worldid, rowadr + dofid] = actuator_moment_out[worldid, rowadr + dofid] / -float(ncon) @event_scope @@ -2715,7 +2873,7 @@ def _solve_LD_sparse_qLDiag_mul( out: wp.array2d(dtype=float), ): worldid, dofid = wp.tid() - out[worldid, dofid] *= D[worldid, dofid] + out[worldid, dofid] = out[worldid, dofid] * D[worldid, dofid] @wp.kernel @@ -2933,9 +3091,8 @@ def _subtree_vel_forward( subtree_linvel_out[worldid, bodyid] = body_mass[body_mass_id, bodyid] * lin dv = wp.transpose(ximat) @ ang - dv[0] *= body_inertia[body_inertia_id, bodyid][0] - dv[1] *= body_inertia[body_inertia_id, bodyid][1] - dv[2] *= body_inertia[body_inertia_id, bodyid][2] + inertia = body_inertia[body_inertia_id, bodyid] + dv = wp.vec3(dv[0] * inertia[0], dv[1] * inertia[1], dv[2] * inertia[2]) subtree_angmom_out[worldid, bodyid] = ximat @ dv subtree_bodyvel_out[worldid, bodyid] = wp.spatial_vector(ang, lin) @@ -2957,7 +3114,9 @@ def _linear_momentum( if bodyid: pid = body_parentid[bodyid] wp.atomic_add(subtree_linvel_out[worldid], pid, subtree_linvel_in[worldid, bodyid]) - subtree_linvel_out[worldid, bodyid] /= wp.max(MJ_MINVAL, body_subtreemass[worldid % body_subtreemass.shape[0], bodyid]) + subtree_linvel_out[worldid, bodyid] = subtree_linvel_out[worldid, bodyid] / wp.max( + MJ_MINVAL, body_subtreemass[worldid % body_subtreemass.shape[0], bodyid] + ) @wp.kernel @@ -3008,7 +3167,7 @@ def _angular_momentum( # momentum wrt parent dx = com - com_parent dv = linvel - linvel_parent - dv *= subtreemass + dv = dv * subtreemass dL = wp.cross(dx, dv) wp.atomic_add(subtree_angmom_out[worldid], pid, dL)