Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mujoco_warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
from mujoco_warp._src.forward import rungekutta4 as rungekutta4
from mujoco_warp._src.forward import step1 as step1
from mujoco_warp._src.forward import step2 as step2
from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS
from mujoco_warp._src.grad import diff_forward as diff_forward
from mujoco_warp._src.grad import diff_step as diff_step
from mujoco_warp._src.grad import disable_grad as disable_grad
from mujoco_warp._src.grad import enable_grad as enable_grad
from mujoco_warp._src.grad import make_diff_data as make_diff_data
from mujoco_warp._src.inverse import inverse as inverse
from mujoco_warp._src.io import create_render_context as create_render_context
from mujoco_warp._src.io import get_data_into as get_data_into
Expand Down
107 changes: 107 additions & 0 deletions mujoco_warp/_src/adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""custom adjoint definitions for MuJoCo Warp autodifferentiation.

This module centralizes all ``@wp.func_grad`` registrations. It must be
imported before any tape recording so that custom gradients are registered
with Warp's AD system.

Import this module via ``grad.py`` dont import it directly
"""

import warp as wp

from mujoco_warp._src import math


@wp.func_grad(math.quat_integrate)
def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat):
"""Custom adjoint avoiding gradient singularity at |v|=0."""
EPS = float(1e-10)
norm_v = wp.length(v)
norm_v_sq = norm_v * norm_v
half_angle = dt * norm_v * 0.5

# sinc-safe rotation quaternion construction
if norm_v > EPS:
s_over_nv = wp.sin(half_angle) / norm_v # sin(dt|v|/2) / |v|
c = wp.cos(half_angle)
# d(s_over_nv)/dv_j = ds_coeff * v_j
ds_coeff = (c * dt * 0.5 - s_over_nv) / norm_v_sq
else:
s_over_nv = dt * 0.5
c = 1.0
# Taylor limit: (c*dt/2 - s_over_nv) / |v|^2 -> -dt^3/24
ds_coeff = -dt * dt * dt / 24.0

q_rot = wp.quat(
c,
s_over_nv * v[0],
s_over_nv * v[1],
s_over_nv * v[2],
)

# recompute forward intermediates
q_len = wp.length(q)
q_inv_len = 1.0 / wp.max(q_len, EPS)
q_n = wp.quat(
q[0] * q_inv_len,
q[1] * q_inv_len,
q[2] * q_inv_len,
q[3] * q_inv_len,
)

q_res = math.mul_quat(q_n, q_rot)
res_len = wp.length(q_res)
res_inv = 1.0 / wp.max(res_len, EPS)

# result = normalize(q_res)
# adj_q_res_k = adj_ret_k / |q_res| - q_res_k * dot(adj_ret, q_res) / |q_res|^3
dot_ar = adj_ret[0] * q_res[0] + adj_ret[1] * q_res[1] + adj_ret[2] * q_res[2] + adj_ret[3] * q_res[3]
res_inv3 = res_inv * res_inv * res_inv
adj_qr = wp.quat(
adj_ret[0] * res_inv - q_res[0] * dot_ar * res_inv3,
adj_ret[1] * res_inv - q_res[1] * dot_ar * res_inv3,
adj_ret[2] * res_inv - q_res[2] * dot_ar * res_inv3,
adj_ret[3] * res_inv - q_res[3] * dot_ar * res_inv3,
)

# q_res = mul_quat(q_n, q_rot)
# adj_q_n = mul_quat(adj_qr, conj(q_rot))
# adj_q_rot = mul_quat(conj(q_n), adj_qr)
q_rot_conj = wp.quat(q_rot[0], -q_rot[1], -q_rot[2], -q_rot[3])
adj_qn = math.mul_quat(adj_qr, q_rot_conj)

q_n_conj = wp.quat(q_n[0], -q_n[1], -q_n[2], -q_n[3])
adj_q_rot = math.mul_quat(q_n_conj, adj_qr)

# q_rot = (c, s_over_nv * v)
# d(c)/dv_j = -s_over_nv * dt/2 * v_j
# d(s_over_nv * v_i)/dv_j = ds_coeff * v_j * v_i + s_over_nv * delta_ij
sv_dot = adj_q_rot[1] * v[0] + adj_q_rot[2] * v[1] + adj_q_rot[3] * v[2]
common = -s_over_nv * dt * 0.5 * adj_q_rot[0] + ds_coeff * sv_dot
adj_v_val = wp.vec3(
common * v[0] + s_over_nv * adj_q_rot[1],
common * v[1] + s_over_nv * adj_q_rot[2],
common * v[2] + s_over_nv * adj_q_rot[3],
)

# adj_dt from q_rot dependency on dt
# d(c)/d(dt) = -sin(half_angle) * norm_v / 2
# d(s_over_nv * v_i)/dt = (c / 2) * v_i
adj_dt_val = adj_q_rot[0] * (-wp.sin(half_angle) * norm_v * 0.5)
adj_dt_val += sv_dot * c * 0.5

# q_n = normalize(q)
# adj_q_k = adj_qn_k / |q| - q_k * dot(adj_qn, q) / |q|^3
dot_aqn = adj_qn[0] * q[0] + adj_qn[1] * q[1] + adj_qn[2] * q[2] + adj_qn[3] * q[3]
q_inv_len3 = q_inv_len * q_inv_len * q_inv_len
adj_q_val = wp.quat(
adj_qn[0] * q_inv_len - q[0] * dot_aqn * q_inv_len3,
adj_qn[1] * q_inv_len - q[1] * dot_aqn * q_inv_len3,
adj_qn[2] * q_inv_len - q[2] * dot_aqn * q_inv_len3,
adj_qn[3] * q_inv_len - q[3] * dot_aqn * q_inv_len3,
)

# accumulate adjoints
wp.adjoint[q] += adj_q_val
wp.adjoint[v] += adj_v_val
wp.adjoint[dt] += adj_dt_val
2 changes: 1 addition & 1 deletion mujoco_warp/_src/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from mujoco_warp._src.types import vec10f
from mujoco_warp._src.warp_util import event_scope

wp.set_module_options({"enable_backward": False})
wp.set_module_options({"enable_backward": True})


@wp.kernel
Expand Down
22 changes: 15 additions & 7 deletions mujoco_warp/_src/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from mujoco_warp._src.warp_util import cache_kernel
from mujoco_warp._src.warp_util import event_scope

wp.set_module_options({"enable_backward": False})
wp.set_module_options({"enable_backward": True})


@wp.kernel
Expand Down Expand Up @@ -214,6 +214,12 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
"""Advance state and time given activation derivatives and acceleration."""
# TODO(team): can we assume static timesteps?

# Clone arrays used as both input and output so that Warp's tape retains the
# original values for correct reverse-mode AD.
act_in = wp.clone(d.act)
qvel_prev = wp.clone(d.qvel)
qpos_prev = wp.clone(d.qpos)

# advance activations
wp.launch(
_next_activation,
Expand All @@ -226,7 +232,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
m.actuator_actlimited,
m.actuator_dynprm,
m.actuator_actrange,
d.act,
act_in,
d.act_dot,
1.0,
True,
Expand All @@ -237,7 +243,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
wp.launch(
_next_velocity,
dim=(d.nworld, m.nv),
inputs=[m.opt.timestep, d.qvel, qacc, 1.0],
inputs=[m.opt.timestep, qvel_prev, qacc, 1.0],
outputs=[d.qvel],
)

Expand All @@ -247,7 +253,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
wp.launch(
_next_position,
dim=(d.nworld, m.njnt),
inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, qvel_in, 1.0],
inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, qpos_prev, qvel_in, 1.0],
outputs=[d.qpos],
)

Expand Down Expand Up @@ -774,9 +780,9 @@ def _tendon_actuator_force_clamp(
actfrcrange = tendon_actfrcrange[worldid % tendon_actfrcrange.shape[0], tenid]

if ten_actfrc < actfrcrange[0]:
actuator_force_out[worldid, actid] *= actfrcrange[0] / ten_actfrc
actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[0] / ten_actfrc)
elif ten_actfrc > actfrcrange[1]:
actuator_force_out[worldid, actid] *= actfrcrange[1] / ten_actfrc
actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[1] / ten_actfrc)


@wp.kernel
Expand Down Expand Up @@ -911,6 +917,8 @@ def fwd_actuation(m: Model, d: Data):
],
outputs=[d.qfrc_actuator],
)
# clone to break input/output aliasing for correct AD
qfrc_actuator_in = wp.clone(d.qfrc_actuator)
wp.launch(
_qfrc_actuator_gravcomp_limits,
dim=(d.nworld, m.nv),
Expand All @@ -921,7 +929,7 @@ def fwd_actuation(m: Model, d: Data):
m.jnt_actfrcrange,
m.dof_jntid,
d.qfrc_gravcomp,
d.qfrc_actuator,
qfrc_actuator_in,
],
outputs=[d.qfrc_actuator],
)
Expand Down
150 changes: 150 additions & 0 deletions mujoco_warp/_src/grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Autodifferentiation coordination for MuJoCo Warp.

This module provides utilities for enabling Warp's tape-based reverse-mode
automatic differentiation through the MuJoCo Warp physics pipeline.

Usage::

import mujoco_warp as mjw

d = mjw.make_diff_data(mjm) # Data with gradient tracking
tape = wp.Tape()
with tape:
mjw.step(m, d)
wp.launch(loss_kernel, dim=1, inputs=[d.xpos, target, loss])
tape.backward(loss=loss)
grad_ctrl = d.ctrl.grad
"""

from typing import Callable, Optional, Sequence

import warp as wp

from mujoco_warp._src import adjoint as _adjoint # noqa: F401 (register custom adjoints)
from mujoco_warp._src import io
from mujoco_warp._src.forward import forward
from mujoco_warp._src.forward import step
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import Model

SMOOTH_GRAD_FIELDS: tuple = (
# primary state, user-controlled inputs
"qpos",
"qvel",
"ctrl",
"act",
"mocap_pos",
"mocap_quat",
"xfrc_applied",
"qfrc_applied",
# position-dependent outputs
"xpos",
"xquat",
"xmat",
"xipos",
"ximat",
"xanchor",
"xaxis",
"geom_xpos",
"geom_xmat",
"site_xpos",
"site_xmat",
"subtree_com",
"cinert",
"crb",
"cdof",
# Velocity-dependent outputs
"cdof_dot",
"cvel",
"subtree_linvel",
"subtree_angmom",
"actuator_velocity",
"ten_velocity",
# body-level intermediate quantities
"cacc",
"cfrc_int",
"cfrc_ext",
# force/acceleration outputs
"qfrc_bias",
"qfrc_spring",
"qfrc_damper",
"qfrc_gravcomp",
"qfrc_fluid",
"qfrc_passive",
"qfrc_actuator",
"qfrc_smooth",
"qacc",
"qacc_smooth",
"actuator_force",
"act_dot",
# inertia matrix
"qM",
"qLD",
"qLDiagInv",
# Tendon
"ten_J",
"ten_length",
# actuator
"actuator_length",
"actuator_moment",
# sensor
"sensordata",
)


def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None:
"""Enables gradient tracking on Data arrays."""
if fields is None:
fields = SMOOTH_GRAD_FIELDS
for name in fields:
arr = getattr(d, name, None)
if arr is not None and isinstance(arr, wp.array):
arr.requires_grad = True


def disable_grad(d: Data) -> None:
"""Disables gradient tracking on all Data arrays."""
for name in SMOOTH_GRAD_FIELDS:
arr = getattr(d, name, None)
if arr is not None and isinstance(arr, wp.array):
arr.requires_grad = False


def make_diff_data(
mjm,
nworld: int = 1,
grad_fields: Optional[Sequence[str]] = None,
**kwargs,
) -> Data:
"""Creates a Data object with gradient tracking enabled."""
d = io.make_data(mjm, nworld=nworld, **kwargs)
enable_grad(d, fields=grad_fields)
return d


def diff_step(
m: Model,
d: Data,
loss_fn: Callable[[Model, Data], wp.array],
) -> wp.Tape:
"""Runs a differentiable physics step."""
tape = wp.Tape()
with tape:
step(m, d)
loss = loss_fn(m, d)
tape.backward(loss=loss)
return tape


def diff_forward(
m: Model,
d: Data,
loss_fn: Callable[[Model, Data], wp.array],
) -> wp.Tape:
"""Runs differentiable forward dynamics (no integration)."""
tape = wp.Tape()
with tape:
forward(m, d)
loss = loss_fn(m, d)
tape.backward(loss=loss)
return tape
Loading