diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py
index 52e502a44..3a53d7de4 100644
--- a/mujoco_warp/__init__.py
+++ b/mujoco_warp/__init__.py
@@ -46,6 +46,12 @@
from mujoco_warp._src.forward import rungekutta4 as rungekutta4
from mujoco_warp._src.forward import step1 as step1
from mujoco_warp._src.forward import step2 as step2
+from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS
+from mujoco_warp._src.grad import diff_forward as diff_forward
+from mujoco_warp._src.grad import diff_step as diff_step
+from mujoco_warp._src.grad import disable_grad as disable_grad
+from mujoco_warp._src.grad import enable_grad as enable_grad
+from mujoco_warp._src.grad import make_diff_data as make_diff_data
from mujoco_warp._src.inverse import inverse as inverse
from mujoco_warp._src.io import create_render_context as create_render_context
from mujoco_warp._src.io import get_data_into as get_data_into
diff --git a/mujoco_warp/_src/adjoint.py b/mujoco_warp/_src/adjoint.py
new file mode 100644
index 000000000..231152066
--- /dev/null
+++ b/mujoco_warp/_src/adjoint.py
@@ -0,0 +1,107 @@
+"""custom adjoint definitions for MuJoCo Warp autodifferentiation.
+
+This module centralizes all ``@wp.func_grad`` registrations. It must be
+imported before any tape recording so that custom gradients are registered
+with Warp's AD system.
+
+Import this module via ``grad.py`` dont import it directly
+"""
+
+import warp as wp
+
+from mujoco_warp._src import math
+
+
+@wp.func_grad(math.quat_integrate)
+def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat):
+ """Custom adjoint avoiding gradient singularity at |v|=0."""
+ EPS = float(1e-10)
+ norm_v = wp.length(v)
+ norm_v_sq = norm_v * norm_v
+ half_angle = dt * norm_v * 0.5
+
+ # sinc-safe rotation quaternion construction
+ if norm_v > EPS:
+ s_over_nv = wp.sin(half_angle) / norm_v # sin(dt|v|/2) / |v|
+ c = wp.cos(half_angle)
+ # d(s_over_nv)/dv_j = ds_coeff * v_j
+ ds_coeff = (c * dt * 0.5 - s_over_nv) / norm_v_sq
+ else:
+ s_over_nv = dt * 0.5
+ c = 1.0
+ # Taylor limit: (c*dt/2 - s_over_nv) / |v|^2 -> -dt^3/24
+ ds_coeff = -dt * dt * dt / 24.0
+
+ q_rot = wp.quat(
+ c,
+ s_over_nv * v[0],
+ s_over_nv * v[1],
+ s_over_nv * v[2],
+ )
+
+ # recompute forward intermediates
+ q_len = wp.length(q)
+ q_inv_len = 1.0 / wp.max(q_len, EPS)
+ q_n = wp.quat(
+ q[0] * q_inv_len,
+ q[1] * q_inv_len,
+ q[2] * q_inv_len,
+ q[3] * q_inv_len,
+ )
+
+ q_res = math.mul_quat(q_n, q_rot)
+ res_len = wp.length(q_res)
+ res_inv = 1.0 / wp.max(res_len, EPS)
+
+ # result = normalize(q_res)
+ # adj_q_res_k = adj_ret_k / |q_res| - q_res_k * dot(adj_ret, q_res) / |q_res|^3
+ dot_ar = adj_ret[0] * q_res[0] + adj_ret[1] * q_res[1] + adj_ret[2] * q_res[2] + adj_ret[3] * q_res[3]
+ res_inv3 = res_inv * res_inv * res_inv
+ adj_qr = wp.quat(
+ adj_ret[0] * res_inv - q_res[0] * dot_ar * res_inv3,
+ adj_ret[1] * res_inv - q_res[1] * dot_ar * res_inv3,
+ adj_ret[2] * res_inv - q_res[2] * dot_ar * res_inv3,
+ adj_ret[3] * res_inv - q_res[3] * dot_ar * res_inv3,
+ )
+
+ # q_res = mul_quat(q_n, q_rot)
+ # adj_q_n = mul_quat(adj_qr, conj(q_rot))
+ # adj_q_rot = mul_quat(conj(q_n), adj_qr)
+ q_rot_conj = wp.quat(q_rot[0], -q_rot[1], -q_rot[2], -q_rot[3])
+ adj_qn = math.mul_quat(adj_qr, q_rot_conj)
+
+ q_n_conj = wp.quat(q_n[0], -q_n[1], -q_n[2], -q_n[3])
+ adj_q_rot = math.mul_quat(q_n_conj, adj_qr)
+
+ # q_rot = (c, s_over_nv * v)
+ # d(c)/dv_j = -s_over_nv * dt/2 * v_j
+ # d(s_over_nv * v_i)/dv_j = ds_coeff * v_j * v_i + s_over_nv * delta_ij
+ sv_dot = adj_q_rot[1] * v[0] + adj_q_rot[2] * v[1] + adj_q_rot[3] * v[2]
+ common = -s_over_nv * dt * 0.5 * adj_q_rot[0] + ds_coeff * sv_dot
+ adj_v_val = wp.vec3(
+ common * v[0] + s_over_nv * adj_q_rot[1],
+ common * v[1] + s_over_nv * adj_q_rot[2],
+ common * v[2] + s_over_nv * adj_q_rot[3],
+ )
+
+ # adj_dt from q_rot dependency on dt
+ # d(c)/d(dt) = -sin(half_angle) * norm_v / 2
+ # d(s_over_nv * v_i)/dt = (c / 2) * v_i
+ adj_dt_val = adj_q_rot[0] * (-wp.sin(half_angle) * norm_v * 0.5)
+ adj_dt_val += sv_dot * c * 0.5
+
+ # q_n = normalize(q)
+ # adj_q_k = adj_qn_k / |q| - q_k * dot(adj_qn, q) / |q|^3
+ dot_aqn = adj_qn[0] * q[0] + adj_qn[1] * q[1] + adj_qn[2] * q[2] + adj_qn[3] * q[3]
+ q_inv_len3 = q_inv_len * q_inv_len * q_inv_len
+ adj_q_val = wp.quat(
+ adj_qn[0] * q_inv_len - q[0] * dot_aqn * q_inv_len3,
+ adj_qn[1] * q_inv_len - q[1] * dot_aqn * q_inv_len3,
+ adj_qn[2] * q_inv_len - q[2] * dot_aqn * q_inv_len3,
+ adj_qn[3] * q_inv_len - q[3] * dot_aqn * q_inv_len3,
+ )
+
+ # accumulate adjoints
+ wp.adjoint[q] += adj_q_val
+ wp.adjoint[v] += adj_v_val
+ wp.adjoint[dt] += adj_dt_val
diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py
index baf099a4c..5deb1efd2 100644
--- a/mujoco_warp/_src/derivative.py
+++ b/mujoco_warp/_src/derivative.py
@@ -25,7 +25,7 @@
from mujoco_warp._src.types import vec10f
from mujoco_warp._src.warp_util import event_scope
-wp.set_module_options({"enable_backward": False})
+wp.set_module_options({"enable_backward": True})
@wp.kernel
diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py
index 80c5f8f49..ee083ee77 100644
--- a/mujoco_warp/_src/forward.py
+++ b/mujoco_warp/_src/forward.py
@@ -45,7 +45,7 @@
from mujoco_warp._src.warp_util import cache_kernel
from mujoco_warp._src.warp_util import event_scope
-wp.set_module_options({"enable_backward": False})
+wp.set_module_options({"enable_backward": True})
@wp.kernel
@@ -214,6 +214,12 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
"""Advance state and time given activation derivatives and acceleration."""
# TODO(team): can we assume static timesteps?
+ # Clone arrays used as both input and output so that Warp's tape retains the
+ # original values for correct reverse-mode AD.
+ act_in = wp.clone(d.act)
+ qvel_prev = wp.clone(d.qvel)
+ qpos_prev = wp.clone(d.qpos)
+
# advance activations
wp.launch(
_next_activation,
@@ -226,7 +232,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
m.actuator_actlimited,
m.actuator_dynprm,
m.actuator_actrange,
- d.act,
+ act_in,
d.act_dot,
1.0,
True,
@@ -237,7 +243,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
wp.launch(
_next_velocity,
dim=(d.nworld, m.nv),
- inputs=[m.opt.timestep, d.qvel, qacc, 1.0],
+ inputs=[m.opt.timestep, qvel_prev, qacc, 1.0],
outputs=[d.qvel],
)
@@ -247,7 +253,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
wp.launch(
_next_position,
dim=(d.nworld, m.njnt),
- inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, qvel_in, 1.0],
+ inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, qpos_prev, qvel_in, 1.0],
outputs=[d.qpos],
)
@@ -774,9 +780,9 @@ def _tendon_actuator_force_clamp(
actfrcrange = tendon_actfrcrange[worldid % tendon_actfrcrange.shape[0], tenid]
if ten_actfrc < actfrcrange[0]:
- actuator_force_out[worldid, actid] *= actfrcrange[0] / ten_actfrc
+ actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[0] / ten_actfrc)
elif ten_actfrc > actfrcrange[1]:
- actuator_force_out[worldid, actid] *= actfrcrange[1] / ten_actfrc
+ actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[1] / ten_actfrc)
@wp.kernel
@@ -911,6 +917,8 @@ def fwd_actuation(m: Model, d: Data):
],
outputs=[d.qfrc_actuator],
)
+ # clone to break input/output aliasing for correct AD
+ qfrc_actuator_in = wp.clone(d.qfrc_actuator)
wp.launch(
_qfrc_actuator_gravcomp_limits,
dim=(d.nworld, m.nv),
@@ -921,7 +929,7 @@ def fwd_actuation(m: Model, d: Data):
m.jnt_actfrcrange,
m.dof_jntid,
d.qfrc_gravcomp,
- d.qfrc_actuator,
+ qfrc_actuator_in,
],
outputs=[d.qfrc_actuator],
)
diff --git a/mujoco_warp/_src/grad.py b/mujoco_warp/_src/grad.py
new file mode 100644
index 000000000..7eb6fe24d
--- /dev/null
+++ b/mujoco_warp/_src/grad.py
@@ -0,0 +1,150 @@
+"""Autodifferentiation coordination for MuJoCo Warp.
+
+This module provides utilities for enabling Warp's tape-based reverse-mode
+automatic differentiation through the MuJoCo Warp physics pipeline.
+
+Usage::
+
+ import mujoco_warp as mjw
+
+ d = mjw.make_diff_data(mjm) # Data with gradient tracking
+ tape = wp.Tape()
+ with tape:
+ mjw.step(m, d)
+ wp.launch(loss_kernel, dim=1, inputs=[d.xpos, target, loss])
+ tape.backward(loss=loss)
+ grad_ctrl = d.ctrl.grad
+"""
+
+from typing import Callable, Optional, Sequence
+
+import warp as wp
+
+from mujoco_warp._src import adjoint as _adjoint # noqa: F401 (register custom adjoints)
+from mujoco_warp._src import io
+from mujoco_warp._src.forward import forward
+from mujoco_warp._src.forward import step
+from mujoco_warp._src.types import Data
+from mujoco_warp._src.types import Model
+
+SMOOTH_GRAD_FIELDS: tuple = (
+ # primary state, user-controlled inputs
+ "qpos",
+ "qvel",
+ "ctrl",
+ "act",
+ "mocap_pos",
+ "mocap_quat",
+ "xfrc_applied",
+ "qfrc_applied",
+ # position-dependent outputs
+ "xpos",
+ "xquat",
+ "xmat",
+ "xipos",
+ "ximat",
+ "xanchor",
+ "xaxis",
+ "geom_xpos",
+ "geom_xmat",
+ "site_xpos",
+ "site_xmat",
+ "subtree_com",
+ "cinert",
+ "crb",
+ "cdof",
+ # Velocity-dependent outputs
+ "cdof_dot",
+ "cvel",
+ "subtree_linvel",
+ "subtree_angmom",
+ "actuator_velocity",
+ "ten_velocity",
+ # body-level intermediate quantities
+ "cacc",
+ "cfrc_int",
+ "cfrc_ext",
+ # force/acceleration outputs
+ "qfrc_bias",
+ "qfrc_spring",
+ "qfrc_damper",
+ "qfrc_gravcomp",
+ "qfrc_fluid",
+ "qfrc_passive",
+ "qfrc_actuator",
+ "qfrc_smooth",
+ "qacc",
+ "qacc_smooth",
+ "actuator_force",
+ "act_dot",
+ # inertia matrix
+ "qM",
+ "qLD",
+ "qLDiagInv",
+ # Tendon
+ "ten_J",
+ "ten_length",
+ # actuator
+ "actuator_length",
+ "actuator_moment",
+ # sensor
+ "sensordata",
+)
+
+
+def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None:
+ """Enables gradient tracking on Data arrays."""
+ if fields is None:
+ fields = SMOOTH_GRAD_FIELDS
+ for name in fields:
+ arr = getattr(d, name, None)
+ if arr is not None and isinstance(arr, wp.array):
+ arr.requires_grad = True
+
+
+def disable_grad(d: Data) -> None:
+ """Disables gradient tracking on all Data arrays."""
+ for name in SMOOTH_GRAD_FIELDS:
+ arr = getattr(d, name, None)
+ if arr is not None and isinstance(arr, wp.array):
+ arr.requires_grad = False
+
+
+def make_diff_data(
+ mjm,
+ nworld: int = 1,
+ grad_fields: Optional[Sequence[str]] = None,
+ **kwargs,
+) -> Data:
+ """Creates a Data object with gradient tracking enabled."""
+ d = io.make_data(mjm, nworld=nworld, **kwargs)
+ enable_grad(d, fields=grad_fields)
+ return d
+
+
+def diff_step(
+ m: Model,
+ d: Data,
+ loss_fn: Callable[[Model, Data], wp.array],
+) -> wp.Tape:
+ """Runs a differentiable physics step."""
+ tape = wp.Tape()
+ with tape:
+ step(m, d)
+ loss = loss_fn(m, d)
+ tape.backward(loss=loss)
+ return tape
+
+
+def diff_forward(
+ m: Model,
+ d: Data,
+ loss_fn: Callable[[Model, Data], wp.array],
+) -> wp.Tape:
+ """Runs differentiable forward dynamics (no integration)."""
+ tape = wp.Tape()
+ with tape:
+ forward(m, d)
+ loss = loss_fn(m, d)
+ tape.backward(loss=loss)
+ return tape
diff --git a/mujoco_warp/_src/grad_test.py b/mujoco_warp/_src/grad_test.py
new file mode 100644
index 000000000..277284940
--- /dev/null
+++ b/mujoco_warp/_src/grad_test.py
@@ -0,0 +1,566 @@
+"""Tests for autodifferentiation gradients."""
+
+# When run as a script, Python adds this file's directory (_src/) to sys.path,
+# which causes types.py to shadow the stdlib 'types' module. Replace it with
+# the project root so that 'import mujoco_warp' still works.
+import os as _os
+import sys as _sys
+
+_src_dir = _os.path.dirname(_os.path.abspath(__file__))
+_project_root = _os.path.dirname(_os.path.dirname(_src_dir))
+if _src_dir in _sys.path:
+ _sys.path[_sys.path.index(_src_dir)] = _project_root
+
+import mujoco
+import numpy as np
+import warp as wp
+from absl.testing import absltest
+from absl.testing import parameterized
+
+import mujoco_warp as mjw
+from mujoco_warp import test_data
+from mujoco_warp._src import math
+from mujoco_warp._src.grad import enable_grad
+
+# tolerance for AD vs finite-difference comparison
+_FD_TOL = 1e-3
+
+# sparse jacobian to avoid tile kernels (which require cuSolverDx)
+_SIMPLE_HINGE_XML = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+_SIMPLE_SLIDE_XML = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+# 3-link chain with mixed joint axes for non-trivial Coriolis gradient.
+# planar 2-link same-axis models have mathematically zero d(qfrc_bias)/d(qvel).
+_3LINK_HINGE_XML = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+_SIMPLE_FREE_XML = """
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+
+def _fd_gradient(fn, x_np, eps=1e-3):
+ """Central-difference gradient of scalar fn w.r.t. x_np."""
+ grad = np.zeros_like(x_np)
+ for i in range(x_np.size):
+ x_plus = x_np.copy()
+ x_minus = x_np.copy()
+ x_plus.flat[i] += eps
+ x_minus.flat[i] -= eps
+ grad.flat[i] = (fn(x_plus) - fn(x_minus)) / (2.0 * eps)
+ return grad
+
+
+@wp.kernel
+def _sum_xpos_kernel(
+ # Data in:
+ xpos_in: wp.array2d(dtype=wp.vec3),
+ # In:
+ loss: wp.array(dtype=float),
+):
+ worldid, bodyid = wp.tid()
+ v = xpos_in[worldid, bodyid]
+ wp.atomic_add(loss, 0, v[0] + v[1] + v[2])
+
+
+@wp.kernel
+def _sum_qacc_kernel(
+ # Data in:
+ qacc_in: wp.array2d(dtype=float),
+ # In:
+ loss: wp.array(dtype=float),
+):
+ worldid, dofid = wp.tid()
+ wp.atomic_add(loss, 0, qacc_in[worldid, dofid])
+
+
+class GradSmoothTest(parameterized.TestCase):
+ @parameterized.parameters(
+ ("hinge", _SIMPLE_HINGE_XML),
+ ("slide", _SIMPLE_SLIDE_XML),
+ )
+ def test_kinematics_grad(self, name, xml):
+ """dL/dqpos through kinematics(): loss = sum(xpos)."""
+ mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0)
+ enable_grad(d)
+
+ # AD gradient
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
+ tape = wp.Tape()
+ with tape:
+ mjw.kinematics(m, d)
+ mjw.com_pos(m, d)
+ wp.launch(
+ _sum_xpos_kernel,
+ dim=(d.nworld, m.nbody),
+ inputs=[d.xpos, loss],
+ )
+ tape.backward(loss=loss)
+ ad_grad = d.qpos.grad.numpy()[0, : mjm.nq].copy()
+ tape.zero()
+
+ # Finite-difference gradient
+ def eval_loss(qpos_np):
+ d_fd = mjw.make_data(mjm)
+ d_fd.qpos = wp.array(qpos_np.reshape(1, -1), dtype=float)
+ mjw.kinematics(m, d_fd)
+ mjw.com_pos(m, d_fd)
+ l = wp.zeros(1, dtype=float)
+ wp.launch(
+ _sum_xpos_kernel,
+ dim=(d_fd.nworld, m.nbody),
+ inputs=[d_fd.xpos, l],
+ )
+ return l.numpy()[0]
+
+ qpos_np = d.qpos.numpy()[0, : mjm.nq]
+ fd_grad = _fd_gradient(eval_loss, qpos_np)
+
+ np.testing.assert_allclose(
+ ad_grad,
+ fd_grad,
+ atol=_FD_TOL,
+ rtol=_FD_TOL,
+ err_msg=f"kinematics grad mismatch ({name})",
+ )
+
+ @parameterized.parameters(
+ ("3link_hinge", _3LINK_HINGE_XML),
+ ("slide", _SIMPLE_SLIDE_XML),
+ )
+ def test_fwd_velocity_grad(self, name, xml):
+ """dL/dqvel through fwd_velocity()."""
+ mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0)
+ enable_grad(d)
+
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
+ tape = wp.Tape()
+ with tape:
+ mjw.kinematics(m, d)
+ mjw.com_pos(m, d)
+ mjw.crb(m, d)
+ mjw.factor_m(m, d)
+ mjw.transmission(m, d)
+ mjw.fwd_velocity(m, d)
+ wp.launch(
+ _sum_qacc_kernel,
+ dim=(d.nworld, m.nv),
+ inputs=[d.qfrc_bias, loss],
+ )
+ tape.backward(loss=loss)
+ ad_grad = d.qvel.grad.numpy()[0, : mjm.nv].copy()
+ tape.zero()
+
+ def eval_loss(qvel_np):
+ d_fd = mjw.make_data(mjm)
+ # Copy qpos from original
+ wp.copy(d_fd.qpos, d.qpos)
+ d_fd.qvel = wp.array(qvel_np.reshape(1, -1), dtype=float)
+ mjw.kinematics(m, d_fd)
+ mjw.com_pos(m, d_fd)
+ mjw.crb(m, d_fd)
+ mjw.factor_m(m, d_fd)
+ mjw.transmission(m, d_fd)
+ mjw.fwd_velocity(m, d_fd)
+ l = wp.zeros(1, dtype=float)
+ wp.launch(
+ _sum_qacc_kernel,
+ dim=(d_fd.nworld, m.nv),
+ inputs=[d_fd.qfrc_bias, l],
+ )
+ return l.numpy()[0]
+
+ qvel_np = d.qvel.numpy()[0, : mjm.nv]
+ fd_grad = _fd_gradient(eval_loss, qvel_np)
+
+ np.testing.assert_allclose(
+ ad_grad,
+ fd_grad,
+ atol=_FD_TOL,
+ rtol=_FD_TOL,
+ err_msg=f"fwd_velocity grad mismatch ({name})",
+ )
+
+ @parameterized.parameters(
+ ("hinge", _SIMPLE_HINGE_XML),
+ )
+ def test_fwd_actuation_grad(self, name, xml):
+ """dL/dctrl through fwd_actuation()."""
+ mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0)
+ enable_grad(d)
+
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
+ tape = wp.Tape()
+ with tape:
+ mjw.kinematics(m, d)
+ mjw.com_pos(m, d)
+ mjw.crb(m, d)
+ mjw.factor_m(m, d)
+ mjw.transmission(m, d)
+ mjw.fwd_velocity(m, d)
+ mjw.fwd_actuation(m, d)
+ wp.launch(
+ _sum_qacc_kernel,
+ dim=(d.nworld, m.nv),
+ inputs=[d.qfrc_actuator, loss],
+ )
+ tape.backward(loss=loss)
+ ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy()
+ tape.zero()
+
+ def eval_loss(ctrl_np):
+ d_fd = mjw.make_data(mjm)
+ wp.copy(d_fd.qpos, d.qpos)
+ wp.copy(d_fd.qvel, d.qvel)
+ d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float)
+ mjw.kinematics(m, d_fd)
+ mjw.com_pos(m, d_fd)
+ mjw.crb(m, d_fd)
+ mjw.factor_m(m, d_fd)
+ mjw.transmission(m, d_fd)
+ mjw.fwd_velocity(m, d_fd)
+ mjw.fwd_actuation(m, d_fd)
+ l = wp.zeros(1, dtype=float)
+ wp.launch(
+ _sum_qacc_kernel,
+ dim=(d_fd.nworld, m.nv),
+ inputs=[d_fd.qfrc_actuator, l],
+ )
+ return l.numpy()[0]
+
+ ctrl_np = d.ctrl.numpy()[0, : mjm.nu]
+ fd_grad = _fd_gradient(eval_loss, ctrl_np)
+
+ np.testing.assert_allclose(
+ ad_grad,
+ fd_grad,
+ atol=_FD_TOL,
+ rtol=_FD_TOL,
+ err_msg=f"fwd_actuation grad mismatch ({name})",
+ )
+
+ @absltest.skipIf(
+ wp.get_device().is_cuda and wp.get_device().arch < 70,
+ "tile kernels (cuSolverDx) require sm_70+",
+ )
+ def test_euler_step_grad(self):
+ """Full Euler step gradient: dL/dctrl through step()."""
+ xml = _SIMPLE_HINGE_XML
+ mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0)
+ enable_grad(d)
+
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
+ tape = wp.Tape()
+ with tape:
+ mjw.step(m, d)
+ wp.launch(
+ _sum_xpos_kernel,
+ dim=(d.nworld, m.nbody),
+ inputs=[d.xpos, loss],
+ )
+ tape.backward(loss=loss)
+ ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy()
+ tape.zero()
+
+ def eval_loss(ctrl_np):
+ _, _, _, d_fd = test_data.fixture(xml=xml, keyframe=0)
+ enable_grad(d_fd)
+ d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float)
+ mjw.step(m, d_fd)
+ l = wp.zeros(1, dtype=float)
+ wp.launch(
+ _sum_xpos_kernel,
+ dim=(d_fd.nworld, m.nbody),
+ inputs=[d_fd.xpos, l],
+ )
+ return l.numpy()[0]
+
+ ctrl_np = mjd.ctrl.copy()
+ fd_grad = _fd_gradient(eval_loss, ctrl_np)
+
+ np.testing.assert_allclose(
+ ad_grad,
+ fd_grad,
+ atol=_FD_TOL,
+ rtol=_FD_TOL,
+ err_msg="euler step grad mismatch",
+ )
+
+
+@wp.kernel
+def _quat_integrate_kernel(
+ # In:
+ q_in: wp.array(dtype=wp.quat),
+ v_in: wp.array(dtype=wp.vec3),
+ dt_in: wp.array(dtype=float),
+ # Out:
+ q_out: wp.array(dtype=wp.quat),
+):
+ i = wp.tid()
+ q_out[i] = math.quat_integrate(q_in[i], v_in[i], dt_in[i])
+
+
+@wp.kernel
+def _quat_loss_kernel(
+ # In:
+ q: wp.array(dtype=wp.quat),
+ loss: wp.array(dtype=float),
+):
+ i = wp.tid()
+ v = q[i]
+ wp.atomic_add(loss, 0, v[0] + v[1] + v[2] + v[3])
+
+
+class GradQuaternionTest(parameterized.TestCase):
+ def test_quat_integrate_nonzero_vel(self):
+ """quat_integrate gradient at non-zero angular velocity."""
+ q_np = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
+ v_np = np.array([0.1, 0.2, 0.3], dtype=np.float32)
+ dt_np = np.array([0.01], dtype=np.float32)
+
+ q_arr = wp.array([wp.quat(*q_np)], dtype=wp.quat, requires_grad=True)
+ v_arr = wp.array([wp.vec3(*v_np)], dtype=wp.vec3, requires_grad=True)
+ dt_arr = wp.array(dt_np, dtype=float, requires_grad=True)
+ q_out = wp.zeros(1, dtype=wp.quat, requires_grad=True)
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
+
+ tape = wp.Tape()
+ with tape:
+ wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_arr, v_arr, dt_arr, q_out])
+ wp.launch(_quat_loss_kernel, dim=1, inputs=[q_out, loss])
+ tape.backward(loss=loss)
+
+ ad_grad_v = v_arr.grad.numpy()[0].copy()
+ tape.zero()
+
+ # Finite-difference
+ def eval_loss_v(v_test):
+ q_a = wp.array([wp.quat(*q_np)], dtype=wp.quat)
+ v_a = wp.array([wp.vec3(*v_test)], dtype=wp.vec3)
+ dt_a = wp.array(dt_np, dtype=float)
+ qo = wp.zeros(1, dtype=wp.quat)
+ l = wp.zeros(1, dtype=float)
+ wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_a, v_a, dt_a, qo])
+ wp.launch(_quat_loss_kernel, dim=1, inputs=[qo, l])
+ return l.numpy()[0]
+
+ fd_grad_v = _fd_gradient(eval_loss_v, v_np)
+
+ np.testing.assert_allclose(
+ ad_grad_v,
+ fd_grad_v,
+ atol=5e-3,
+ rtol=5e-2,
+ err_msg="quat_integrate grad w.r.t. v (nonzero)",
+ )
+
+ def test_quat_integrate_zero_vel(self):
+ """quat_integrate gradient at zero angular velocity (singularity test)."""
+ q_np = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
+ v_np = np.array([0.0, 0.0, 0.0], dtype=np.float32)
+ dt_np = np.array([0.01], dtype=np.float32)
+
+ q_arr = wp.array([wp.quat(*q_np)], dtype=wp.quat, requires_grad=True)
+ v_arr = wp.array([wp.vec3(*v_np)], dtype=wp.vec3, requires_grad=True)
+ dt_arr = wp.array(dt_np, dtype=float, requires_grad=True)
+ q_out = wp.zeros(1, dtype=wp.quat, requires_grad=True)
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
+
+ tape = wp.Tape()
+ with tape:
+ wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_arr, v_arr, dt_arr, q_out])
+ wp.launch(_quat_loss_kernel, dim=1, inputs=[q_out, loss])
+ tape.backward(loss=loss)
+
+ ad_grad_v = v_arr.grad.numpy()[0].copy()
+ tape.zero()
+
+ # Should not be NaN or Inf
+ self.assertTrue(np.all(np.isfinite(ad_grad_v)), f"quat_integrate grad contains NaN/Inf at zero velocity: {ad_grad_v}")
+
+ # Finite-difference
+ def eval_loss_v(v_test):
+ q_a = wp.array([wp.quat(*q_np)], dtype=wp.quat)
+ v_a = wp.array([wp.vec3(*v_test)], dtype=wp.vec3)
+ dt_a = wp.array(dt_np, dtype=float)
+ qo = wp.zeros(1, dtype=wp.quat)
+ l = wp.zeros(1, dtype=float)
+ wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_a, v_a, dt_a, qo])
+ wp.launch(_quat_loss_kernel, dim=1, inputs=[qo, l])
+ return l.numpy()[0]
+
+ fd_grad_v = _fd_gradient(eval_loss_v, v_np)
+
+ np.testing.assert_allclose(
+ ad_grad_v,
+ fd_grad_v,
+ atol=5e-3,
+ rtol=5e-2,
+ err_msg="quat_integrate grad w.r.t. v (zero vel)",
+ )
+
+ def test_quat_integrate_grad_q(self):
+ """quat_integrate gradient w.r.t. input quaternion q."""
+ q_np = np.array([0.9239, 0.3827, 0.0, 0.0], dtype=np.float32) # ~45 deg rotation
+ v_np = np.array([0.1, 0.2, 0.3], dtype=np.float32)
+ dt_np = np.array([0.01], dtype=np.float32)
+
+ q_arr = wp.array([wp.quat(*q_np)], dtype=wp.quat, requires_grad=True)
+ v_arr = wp.array([wp.vec3(*v_np)], dtype=wp.vec3, requires_grad=True)
+ dt_arr = wp.array(dt_np, dtype=float, requires_grad=True)
+ q_out = wp.zeros(1, dtype=wp.quat, requires_grad=True)
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
+
+ tape = wp.Tape()
+ with tape:
+ wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_arr, v_arr, dt_arr, q_out])
+ wp.launch(_quat_loss_kernel, dim=1, inputs=[q_out, loss])
+ tape.backward(loss=loss)
+
+ ad_grad_q = q_arr.grad.numpy()[0].copy()
+ tape.zero()
+
+ def eval_loss_q(q_test):
+ q_a = wp.array([wp.quat(*q_test)], dtype=wp.quat)
+ v_a = wp.array([wp.vec3(*v_np)], dtype=wp.vec3)
+ dt_a = wp.array(dt_np, dtype=float)
+ qo = wp.zeros(1, dtype=wp.quat)
+ l = wp.zeros(1, dtype=float)
+ wp.launch(_quat_integrate_kernel, dim=1, inputs=[q_a, v_a, dt_a, qo])
+ wp.launch(_quat_loss_kernel, dim=1, inputs=[qo, l])
+ return l.numpy()[0]
+
+ fd_grad_q = _fd_gradient(eval_loss_q, q_np)
+
+ np.testing.assert_allclose(
+ ad_grad_q,
+ fd_grad_q,
+ atol=5e-2,
+ rtol=5e-2,
+ err_msg="quat_integrate grad w.r.t. q",
+ )
+
+
+class GradUtilTest(absltest.TestCase):
+ def test_enable_disable_grad(self):
+ """enable_grad / disable_grad toggle requires_grad on Data fields."""
+ mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML)
+ d = mjw.make_data(mjm)
+
+ # Initially, requires_grad should be False
+ self.assertFalse(d.qpos.requires_grad)
+
+ mjw.enable_grad(d)
+ self.assertTrue(d.qpos.requires_grad)
+ self.assertTrue(d.qvel.requires_grad)
+ self.assertTrue(d.ctrl.requires_grad)
+
+ mjw.disable_grad(d)
+ self.assertFalse(d.qpos.requires_grad)
+
+ def test_make_diff_data(self):
+ """make_diff_data returns Data with gradient tracking enabled."""
+ mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML)
+ d = mjw.make_diff_data(mjm)
+
+ self.assertTrue(d.qpos.requires_grad)
+ self.assertTrue(d.qvel.requires_grad)
+ self.assertTrue(d.ctrl.requires_grad)
+ self.assertTrue(d.xpos.requires_grad)
+ self.assertTrue(d.qacc.requires_grad)
+
+ def test_make_diff_data_custom_fields(self):
+ """make_diff_data with a custom field list."""
+ mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML)
+ d = mjw.make_diff_data(mjm, grad_fields=["qpos", "xpos"])
+
+ self.assertTrue(d.qpos.requires_grad)
+ self.assertTrue(d.xpos.requires_grad)
+ self.assertFalse(d.qvel.requires_grad)
+ self.assertFalse(d.ctrl.requires_grad)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py
index 059ba6307..f8f534af1 100644
--- a/mujoco_warp/_src/passive.py
+++ b/mujoco_warp/_src/passive.py
@@ -25,7 +25,7 @@
from mujoco_warp._src.types import Model
from mujoco_warp._src.warp_util import event_scope
-wp.set_module_options({"enable_backward": False})
+wp.set_module_options({"enable_backward": True})
@wp.func
@@ -442,8 +442,8 @@ def _fluid_force(
lfrc_torque -= drag_ang_coef * l_ang
lfrc_force += magnus_force + kutta_force - drag_lin_coef * l_lin
- lfrc_torque *= coef
- lfrc_force *= coef
+ lfrc_torque = lfrc_torque * coef
+ lfrc_force = lfrc_force * coef
# map force/torque from local to world frame: lfrc -> bfrc
torque_global += geom_rot @ lfrc_torque
diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py
index 7354e16c6..0254ec21f 100644
--- a/mujoco_warp/_src/smooth.py
+++ b/mujoco_warp/_src/smooth.py
@@ -38,7 +38,53 @@
from mujoco_warp._src.warp_util import cache_kernel
from mujoco_warp._src.warp_util import event_scope
-wp.set_module_options({"enable_backward": False})
+wp.set_module_options({"enable_backward": True})
+
+
+# kernel_analyzer: off
+@wp.func
+def _process_joint(
+ xpos: wp.vec3,
+ xquat: wp.quat,
+ jntadr: int,
+ jnt_pos_id: int,
+ worldid: int,
+ qpos0: wp.array2d(dtype=float),
+ jnt_type: wp.array(dtype=int),
+ jnt_qposadr: wp.array(dtype=int),
+ jnt_pos: wp.array2d(dtype=wp.vec3),
+ jnt_axis: wp.array2d(dtype=wp.vec3),
+ qpos: wp.array(dtype=float),
+ xanchor_out: wp.array2d(dtype=wp.vec3),
+ xaxis_out: wp.array2d(dtype=wp.vec3),
+):
+ """Process a single joint and return updated xpos, xquat."""
+ qadr = jnt_qposadr[jntadr]
+ jnt_type_ = jnt_type[jntadr]
+ jnt_axis_ = jnt_axis[worldid % jnt_axis.shape[0], jntadr]
+ xanchor = math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) + xpos
+ xaxis = math.rot_vec_quat(jnt_axis_, xquat)
+
+ if jnt_type_ == JointType.BALL:
+ qloc = wp.quat(qpos[qadr + 0], qpos[qadr + 1], qpos[qadr + 2], qpos[qadr + 3])
+ qloc = wp.normalize(qloc)
+ xquat = math.mul_quat(xquat, qloc)
+ xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat)
+ elif jnt_type_ == JointType.SLIDE:
+ xpos = xpos + xaxis * (qpos[qadr] - qpos0[worldid % qpos0.shape[0], qadr])
+ elif jnt_type_ == JointType.HINGE:
+ qpos0_ = qpos0[worldid % qpos0.shape[0], qadr]
+ qloc_ = math.axis_angle_to_quat(jnt_axis_, qpos[qadr] - qpos0_)
+ xquat = math.mul_quat(xquat, qloc_)
+ xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat)
+
+ xanchor_out[worldid, jntadr] = xanchor
+ xaxis_out[worldid, jntadr] = xaxis
+
+ return xpos, xquat
+
+
+# kernel_analyzer: on
@wp.kernel
@@ -112,31 +158,60 @@ def _kinematics_branch(
xpos = math.rot_vec_quat(xpos, xquat_out[worldid, pid]) + xpos_out[worldid, pid]
xquat = math.mul_quat(xquat_out[worldid, pid], xquat)
- for _ in range(jntnum):
- qadr = jnt_qposadr[jntadr]
- jnt_type_ = jnt_type[jntadr]
- jnt_axis_ = jnt_axis[worldid % jnt_axis.shape[0], jntadr]
- xanchor = math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat) + xpos
- xaxis = math.rot_vec_quat(jnt_axis_, xquat)
-
- if jnt_type_ == JointType.BALL:
- qloc = wp.quat(qpos[qadr + 0], qpos[qadr + 1], qpos[qadr + 2], qpos[qadr + 3])
- qloc = wp.normalize(qloc)
- xquat = math.mul_quat(xquat, qloc)
- # correct for off-center rotation
- xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat)
- elif jnt_type_ == JointType.SLIDE:
- xpos += xaxis * (qpos[qadr] - qpos0[worldid % qpos0.shape[0], qadr])
- elif jnt_type_ == JointType.HINGE:
- qpos0_ = qpos0[worldid % qpos0.shape[0], qadr]
- qloc_ = math.axis_angle_to_quat(jnt_axis_, qpos[qadr] - qpos0_)
- xquat = math.mul_quat(xquat, qloc_)
- # correct for off-center rotation
- xpos = xanchor - math.rot_vec_quat(jnt_pos[jnt_pos_id, jntadr], xquat)
-
- xanchor_out[worldid, jntadr] = xanchor
- xaxis_out[worldid, jntadr] = xaxis
- jntadr += 1
+ # Unrolled joint processing — avoids nested dynamic-range loop which
+ # produces incorrect gradients in Warp's AD.
+ if jntnum >= 1:
+ xpos, xquat = _process_joint(
+ xpos, xquat, jntadr, jnt_pos_id, worldid, qpos0, jnt_type, jnt_qposadr, jnt_pos, jnt_axis, qpos, xanchor_out, xaxis_out
+ )
+ if jntnum >= 2:
+ xpos, xquat = _process_joint(
+ xpos,
+ xquat,
+ jntadr + 1,
+ jnt_pos_id,
+ worldid,
+ qpos0,
+ jnt_type,
+ jnt_qposadr,
+ jnt_pos,
+ jnt_axis,
+ qpos,
+ xanchor_out,
+ xaxis_out,
+ )
+ if jntnum >= 3:
+ xpos, xquat = _process_joint(
+ xpos,
+ xquat,
+ jntadr + 2,
+ jnt_pos_id,
+ worldid,
+ qpos0,
+ jnt_type,
+ jnt_qposadr,
+ jnt_pos,
+ jnt_axis,
+ qpos,
+ xanchor_out,
+ xaxis_out,
+ )
+ if jntnum >= 4:
+ xpos, xquat = _process_joint(
+ xpos,
+ xquat,
+ jntadr + 3,
+ jnt_pos_id,
+ worldid,
+ qpos0,
+ jnt_type,
+ jnt_qposadr,
+ jnt_pos,
+ jnt_axis,
+ qpos,
+ xanchor_out,
+ xaxis_out,
+ )
xquat = wp.normalize(xquat)
xpos_out[worldid, bodyid] = xpos
@@ -1127,6 +1202,28 @@ def _rne_cacc_world(m: Model, d: Data):
wp.launch(_cacc_world, dim=[d.nworld], inputs=[m.opt.gravity], outputs=[d.cacc])
+# kernel_analyzer: off
+@wp.func
+def _process_dof_cacc(
+ local_cacc: wp.spatial_vector,
+ dofadr: int,
+ worldid: int,
+ qvel_in: wp.array2d(dtype=float),
+ qacc_in: wp.array2d(dtype=float),
+ cdof_in: wp.array2d(dtype=wp.spatial_vector),
+ cdof_dot_in: wp.array2d(dtype=wp.spatial_vector),
+ flg_acc: bool,
+):
+ """Accumulate one DOF contribution to body acceleration."""
+ local_cacc += cdof_dot_in[worldid, dofadr] * qvel_in[worldid, dofadr]
+ if flg_acc:
+ local_cacc += cdof_in[worldid, dofadr] * qacc_in[worldid, dofadr]
+ return local_cacc
+
+
+# kernel_analyzer: on
+
+
@wp.kernel
def _cacc_branch(
# Model:
@@ -1157,10 +1254,22 @@ def _cacc_branch(
bodyid = body_branches[i]
dofnum = body_dofnum[bodyid]
dofadr = body_dofadr[bodyid]
- for j in range(dofnum):
- local_cacc += cdof_dot_in[worldid, dofadr + j] * qvel_in[worldid, dofadr + j]
- if flg_acc:
- local_cacc += cdof_in[worldid, dofadr + j] * qacc_in[worldid, dofadr + j]
+
+ # unrolled dof processing — avoids nested dynamic-range loop which
+ # produces incorrect gradients in warp's AD
+ if dofnum >= 1:
+ local_cacc = _process_dof_cacc(local_cacc, dofadr, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc)
+ if dofnum >= 2:
+ local_cacc = _process_dof_cacc(local_cacc, dofadr + 1, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc)
+ if dofnum >= 3:
+ local_cacc = _process_dof_cacc(local_cacc, dofadr + 2, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc)
+ if dofnum >= 4:
+ local_cacc = _process_dof_cacc(local_cacc, dofadr + 3, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc)
+ if dofnum >= 5:
+ local_cacc = _process_dof_cacc(local_cacc, dofadr + 4, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc)
+ if dofnum >= 6:
+ local_cacc = _process_dof_cacc(local_cacc, dofadr + 5, worldid, qvel_in, qacc_in, cdof_in, cdof_dot_in, flg_acc)
+
cacc_out[worldid, bodyid] = local_cacc
@@ -1216,28 +1325,46 @@ def _rne_cfrc(m: Model, d: Data, flg_cfrc_ext: bool = False):
@wp.kernel
-def _cfrc_backward(
+def _cfrc_backward_level(
# Model:
body_parentid: wp.array(dtype=int),
# Data in:
cfrc_int_in: wp.array2d(dtype=wp.spatial_vector),
# In:
body_tree_: wp.array(dtype=int),
+ nbody_tree: int,
# Data out:
cfrc_int_out: wp.array2d(dtype=wp.spatial_vector),
):
- worldid, nodeid = wp.tid()
- bodyid = body_tree_[nodeid]
- pid = body_parentid[bodyid]
- if bodyid != 0:
- wp.atomic_add(cfrc_int_out[worldid], pid, cfrc_int_in[worldid, bodyid])
+ # copy input and accumulate child forces to parents in a single kernel
+ # to avoid warp AD output-gradient zeroing when separate copy + atomic_add
+ # target the same array
+ worldid, bodyid = wp.tid()
+ val = cfrc_int_in[worldid, bodyid]
+ for k in range(nbody_tree):
+ child = body_tree_[k]
+ if body_parentid[child] == bodyid and child != 0:
+ val = val + cfrc_int_in[worldid, child]
+ cfrc_int_out[worldid, bodyid] = val
def _rne_cfrc_backward(m: Model, d: Data):
+ # accumulate child forces to parents using separate arrays at each level
+ # to avoid warp AD output-gradient zeroing issue (an array that is the
+ # output of multiple wp.launch / wp.copy calls loses gradient from all
+ # but the last operation during backward)
+ current = d.cfrc_int
for body_tree in reversed(m.body_tree):
+ next_cfrc = wp.zeros_like(current)
+ next_cfrc.requires_grad = True
wp.launch(
- _cfrc_backward, dim=[d.nworld, body_tree.size], inputs=[m.body_parentid, d.cfrc_int, body_tree], outputs=[d.cfrc_int]
+ _cfrc_backward_level,
+ dim=[d.nworld, m.nbody],
+ inputs=[m.body_parentid, current, body_tree, body_tree.shape[0]],
+ outputs=[next_cfrc],
)
+ current = next_cfrc
+ return current
@wp.kernel
@@ -1270,8 +1397,10 @@ def rne(m: Model, d: Data, flg_acc: bool = False):
_rne_cacc_world(m, d)
_rne_cacc_forward(m, d, flg_acc=flg_acc)
_rne_cfrc(m, d)
- _rne_cfrc_backward(m, d)
- wp.launch(_qfrc_bias, dim=[d.nworld, m.nv], inputs=[m.dof_bodyid, d.cdof, d.cfrc_int], outputs=[d.qfrc_bias])
+ cfrc_total = _rne_cfrc_backward(m, d)
+ wp.launch(_qfrc_bias, dim=[d.nworld, m.nv], inputs=[m.dof_bodyid, d.cdof, cfrc_total], outputs=[d.qfrc_bias])
+ # update d.cfrc_int with accumulated forces for downstream consumers
+ d.cfrc_int = cfrc_total
@wp.kernel
@@ -1578,7 +1707,7 @@ def rne_postconstraint(m: Model, d: Data):
_rne_cfrc(m, d, flg_cfrc_ext=True)
# backward pass over bodies: accumulate cfrc_int from children
- _rne_cfrc_backward(m, d)
+ d.cfrc_int = _rne_cfrc_backward(m, d)
@wp.func
@@ -1748,7 +1877,7 @@ def _tendon_dot(
dot = wp.dot(dpnt, dvel)
dvel += dpnt * (-dot)
if norm > MJ_MINVAL:
- dvel /= norm
+ dvel = dvel / norm
else:
dvel = wp.vec3(0.0)
@@ -1938,6 +2067,57 @@ def _comvel_root(cvel_out: wp.array2d(dtype=wp.spatial_vector)):
cvel_out[worldid, 0][elementid] = 0.0
+# kernel_analyzer: off
+@wp.func
+def _process_joint_vel(
+ cvel: wp.spatial_vector,
+ dofid: int,
+ jntadr: int,
+ worldid: int,
+ jnt_type: wp.array(dtype=int),
+ qvel: wp.array(dtype=float),
+ cdof: wp.array(dtype=wp.spatial_vector),
+ cdof_dot_out: wp.array2d(dtype=wp.spatial_vector),
+):
+ """Process a single joint for velocity propagation, return updated cvel and dofid."""
+ jnttype = jnt_type[jntadr]
+
+ if jnttype == JointType.FREE:
+ cvel += cdof[dofid + 0] * qvel[dofid + 0]
+ cvel += cdof[dofid + 1] * qvel[dofid + 1]
+ cvel += cdof[dofid + 2] * qvel[dofid + 2]
+
+ cdof_dot_out[worldid, dofid + 3] = math.motion_cross(cvel, cdof[dofid + 3])
+ cdof_dot_out[worldid, dofid + 4] = math.motion_cross(cvel, cdof[dofid + 4])
+ cdof_dot_out[worldid, dofid + 5] = math.motion_cross(cvel, cdof[dofid + 5])
+
+ cvel += cdof[dofid + 3] * qvel[dofid + 3]
+ cvel += cdof[dofid + 4] * qvel[dofid + 4]
+ cvel += cdof[dofid + 5] * qvel[dofid + 5]
+
+ dofid += 6
+ elif jnttype == JointType.BALL:
+ cdof_dot_out[worldid, dofid + 0] = math.motion_cross(cvel, cdof[dofid + 0])
+ cdof_dot_out[worldid, dofid + 1] = math.motion_cross(cvel, cdof[dofid + 1])
+ cdof_dot_out[worldid, dofid + 2] = math.motion_cross(cvel, cdof[dofid + 2])
+
+ cvel += cdof[dofid + 0] * qvel[dofid + 0]
+ cvel += cdof[dofid + 1] * qvel[dofid + 1]
+ cvel += cdof[dofid + 2] * qvel[dofid + 2]
+
+ dofid += 3
+ else:
+ cdof_dot_out[worldid, dofid] = math.motion_cross(cvel, cdof[dofid])
+ cvel += cdof[dofid] * qvel[dofid]
+
+ dofid += 1
+
+ return cvel, dofid
+
+
+# kernel_analyzer: on
+
+
@wp.kernel
def _comvel_branch(
# Model:
@@ -1975,38 +2155,16 @@ def _comvel_branch(
cvel_out[worldid, bodyid] = cvel
continue
- for j in range(jntid, jntid + jntnum):
- jnttype = jnt_type[j]
-
- if jnttype == JointType.FREE:
- cvel += cdof[dofid + 0] * qvel[dofid + 0]
- cvel += cdof[dofid + 1] * qvel[dofid + 1]
- cvel += cdof[dofid + 2] * qvel[dofid + 2]
-
- cdof_dot_out[worldid, dofid + 3] = math.motion_cross(cvel, cdof[dofid + 3])
- cdof_dot_out[worldid, dofid + 4] = math.motion_cross(cvel, cdof[dofid + 4])
- cdof_dot_out[worldid, dofid + 5] = math.motion_cross(cvel, cdof[dofid + 5])
-
- cvel += cdof[dofid + 3] * qvel[dofid + 3]
- cvel += cdof[dofid + 4] * qvel[dofid + 4]
- cvel += cdof[dofid + 5] * qvel[dofid + 5]
-
- dofid += 6
- elif jnttype == JointType.BALL:
- cdof_dot_out[worldid, dofid + 0] = math.motion_cross(cvel, cdof[dofid + 0])
- cdof_dot_out[worldid, dofid + 1] = math.motion_cross(cvel, cdof[dofid + 1])
- cdof_dot_out[worldid, dofid + 2] = math.motion_cross(cvel, cdof[dofid + 2])
-
- cvel += cdof[dofid + 0] * qvel[dofid + 0]
- cvel += cdof[dofid + 1] * qvel[dofid + 1]
- cvel += cdof[dofid + 2] * qvel[dofid + 2]
-
- dofid += 3
- else:
- cdof_dot_out[worldid, dofid] = math.motion_cross(cvel, cdof[dofid])
- cvel += cdof[dofid] * qvel[dofid]
-
- dofid += 1
+ # unrolled joint processing — avoids nested dynamic-range loop which
+ # produces incorrect gradients in warp's AD
+ if jntnum >= 1:
+ cvel, dofid = _process_joint_vel(cvel, dofid, jntid, worldid, jnt_type, qvel, cdof, cdof_dot_out)
+ if jntnum >= 2:
+ cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 1, worldid, jnt_type, qvel, cdof, cdof_dot_out)
+ if jntnum >= 3:
+ cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 2, worldid, jnt_type, qvel, cdof, cdof_dot_out)
+ if jntnum >= 4:
+ cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 3, worldid, jnt_type, qvel, cdof, cdof_dot_out)
cvel_out[worldid, bodyid] = cvel
@@ -2598,7 +2756,7 @@ def _transmission_body_moment_scale(
if ncon > 0:
actid = actuator_trntype_body_adr[trnbodyid]
rowadr = moment_rowadr_in[worldid, actid]
- actuator_moment_out[worldid, rowadr + dofid] /= -float(ncon)
+ actuator_moment_out[worldid, rowadr + dofid] = actuator_moment_out[worldid, rowadr + dofid] / -float(ncon)
@event_scope
@@ -2715,7 +2873,7 @@ def _solve_LD_sparse_qLDiag_mul(
out: wp.array2d(dtype=float),
):
worldid, dofid = wp.tid()
- out[worldid, dofid] *= D[worldid, dofid]
+ out[worldid, dofid] = out[worldid, dofid] * D[worldid, dofid]
@wp.kernel
@@ -2933,9 +3091,8 @@ def _subtree_vel_forward(
subtree_linvel_out[worldid, bodyid] = body_mass[body_mass_id, bodyid] * lin
dv = wp.transpose(ximat) @ ang
- dv[0] *= body_inertia[body_inertia_id, bodyid][0]
- dv[1] *= body_inertia[body_inertia_id, bodyid][1]
- dv[2] *= body_inertia[body_inertia_id, bodyid][2]
+ inertia = body_inertia[body_inertia_id, bodyid]
+ dv = wp.vec3(dv[0] * inertia[0], dv[1] * inertia[1], dv[2] * inertia[2])
subtree_angmom_out[worldid, bodyid] = ximat @ dv
subtree_bodyvel_out[worldid, bodyid] = wp.spatial_vector(ang, lin)
@@ -2957,7 +3114,9 @@ def _linear_momentum(
if bodyid:
pid = body_parentid[bodyid]
wp.atomic_add(subtree_linvel_out[worldid], pid, subtree_linvel_in[worldid, bodyid])
- subtree_linvel_out[worldid, bodyid] /= wp.max(MJ_MINVAL, body_subtreemass[worldid % body_subtreemass.shape[0], bodyid])
+ subtree_linvel_out[worldid, bodyid] = subtree_linvel_out[worldid, bodyid] / wp.max(
+ MJ_MINVAL, body_subtreemass[worldid % body_subtreemass.shape[0], bodyid]
+ )
@wp.kernel
@@ -3008,7 +3167,7 @@ def _angular_momentum(
# momentum wrt parent
dx = com - com_parent
dv = linvel - linvel_parent
- dv *= subtreemass
+ dv = dv * subtreemass
dL = wp.cross(dx, dv)
wp.atomic_add(subtree_angmom_out[worldid], pid, dL)