Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0177118
wp.clone ad protection
mar-yan24 Mar 17, 2026
8d87902
initial solver ad implementation
mar-yan24 Mar 17, 2026
0b1f649
full constraint solver implicit diff implementation
mar-yan24 Mar 18, 2026
c14961d
initial implementation of smooth contact ad
mar-yan24 Mar 20, 2026
3a0c83f
fixed docstring order
mar-yan24 Mar 20, 2026
c290bdb
actearly derivative (#1143)
thowell Mar 20, 2026
7ffd2bc
Fix multi flex indexing (#1249)
StafaH Mar 20, 2026
80e146c
Optimization: qderiv_actuator_passive_actuation (#1243)
Kenny-Vilella Mar 21, 2026
4839fc7
Fix benchmark path typo in help usage examples (#1254)
shi-eric Mar 22, 2026
04a8657
0.5 * gauss_cost (#1251)
thowell Mar 23, 2026
138c30e
Merge branch 'google-deepmind:main' into mark/autodifferentiation3
mar-yan24 Mar 23, 2026
a25140d
Flex rendering improvements (#1250)
StafaH Mar 24, 2026
573d406
constraint jacobian update for io.put_data (#1255)
thowell Mar 24, 2026
04fc3e1
heuristic for estimating the number of non-zeros in constraint_jacobi…
thowell Mar 24, 2026
36ab664
Merge branch 'google-deepmind:main' into mark/autodifferentiation3
mar-yan24 Mar 24, 2026
9b0af62
fix freejoint zerograd bug and add enable_backward test
mar-yan24 Mar 24, 2026
4e79dc6
add smooth contact autodifferentation and freejoin zerograd fix
mar-yan24 Mar 27, 2026
21032d4
fix merge
mar-yan24 Mar 28, 2026
ef7d754
add diagnostic instrumentation to adjont.py
mar-yan24 Mar 29, 2026
a57223e
fix integrator gradient chain
mar-yan24 Mar 29, 2026
1998f9c
add integrator grad path tests
mar-yan24 Mar 29, 2026
800fd66
mass matrix inverse solve
mar-yan24 Mar 29, 2026
e7e46dd
workaround for tape all
mar-yan24 Apr 1, 2026
a812bfb
more intermediate list
mar-yan24 Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion contrib/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Usage: mjwarp-render <mjcf XML path> [flags]

Example:
mjwarp-render benchmark/humanoid/humanoid.xml --nworld=1 --cam=0 --width=512 --height=512
mjwarp-render benchmarks/humanoid/humanoid.xml --nworld=1 --cam=0 --width=512 --height=512
"""

import sys
Expand All @@ -42,6 +42,7 @@
_HEIGHT = flags.DEFINE_integer("height", 512, "render height (pixels)")
_RENDER_RGB = flags.DEFINE_bool("rgb", True, "render RGB image")
_RENDER_DEPTH = flags.DEFINE_bool("depth", True, "render depth image")
_RENDER_SEG = flags.DEFINE_bool("seg", False, "render segmentation image")
_USE_TEXTURES = flags.DEFINE_bool("textures", True, "use textures")
_USE_SHADOWS = flags.DEFINE_bool("shadows", False, "use shadows")
_DEVICE = flags.DEFINE_string("device", None, "override the default Warp device")
Expand Down Expand Up @@ -207,6 +208,7 @@ def _main(argv: Sequence[str]):
(render_width, render_height),
_RENDER_RGB.value,
_RENDER_DEPTH.value,
_RENDER_SEG.value,
_USE_TEXTURES.value,
_USE_SHADOWS.value,
enabled_geom_groups=[0, 1, 2],
Expand Down
1 change: 1 addition & 0 deletions mujoco_warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from mujoco_warp._src.forward import rungekutta4 as rungekutta4
from mujoco_warp._src.forward import step1 as step1
from mujoco_warp._src.forward import step2 as step2
from mujoco_warp._src.grad import COLLISION_GRAD_FIELDS as COLLISION_GRAD_FIELDS
from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS
from mujoco_warp._src.grad import SOLVER_GRAD_FIELDS as SOLVER_GRAD_FIELDS
from mujoco_warp._src.grad import diff_forward as diff_forward
Expand Down
226 changes: 201 additions & 25 deletions mujoco_warp/_src/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Import this module via ``grad.py`` dont import it directly
"""

import os

import warp as wp

from mujoco_warp._src import math
Expand All @@ -15,6 +17,110 @@
from mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func
from mujoco_warp._src.warp_util import cache_kernel

# ---------------------------------------------------------------------------
# Phase 3: efc-level gradient kernels for collision chain
# ---------------------------------------------------------------------------


@wp.kernel
def _efc_J_grad_kernel(
# In:
v: wp.array2d(dtype=float),
efc_force: wp.array2d(dtype=float),
nefc: wp.array(dtype=int),
nv: int,
njmax: int,
# Out:
efc_J_grad_out: wp.array3d(dtype=float),
):
"""Compute adj_efc_J[i, j] = v[j] * efc_force[i].

From KKT: F(qacc) = M*qacc - qfrc_smooth - J^T*f = 0
The derivative of J^T*f w.r.t. J[i,j] is f[i] * delta, and the
adjoint vector v gives the sensitivity: adj_J[i,j] = v[j] * f[i].
"""
worldid, efcid, dofid = wp.tid()
if efcid < nefc[worldid] and dofid < nv:
efc_J_grad_out[worldid, efcid, dofid] = v[worldid, dofid] * efc_force[worldid, efcid]


@wp.kernel
def _efc_pos_grad_kernel(
# In:
efc_aref_grad: wp.array2d(dtype=float),
contact_solref: wp.array(dtype=wp.vec2),
contact_solimp: wp.array(dtype=types.vec5),
contact_includemargin: wp.array(dtype=float),
contact_dist: wp.array(dtype=float),
contact_efc_address: wp.array2d(dtype=int),
contact_worldid: wp.array(dtype=int),
contact_type: wp.array(dtype=int),
nacon: wp.array(dtype=int),
opt_timestep: wp.array(dtype=float),
opt_disableflags: int,
# Out:
efc_pos_grad_out: wp.array2d(dtype=float),
):
"""Compute adj_efc_pos from adj_efc_aref.

From efc_aref = -k * imp * pos - b * vel, d(aref)/d(pos) = -k*imp.
So adj_efc_pos = adj_efc_aref * (-k * imp).
We iterate over contacts and their first dimension (normal direction).
"""
conid = wp.tid()
if conid >= nacon[0]:
return
if not (contact_type[conid] & 1): # ContactType.CONSTRAINT
return

efcid = contact_efc_address[conid, 0]
if efcid < 0:
return

worldid = contact_worldid[conid]
timestep = opt_timestep[worldid % opt_timestep.shape[0]]

solref = contact_solref[conid]
solimp = contact_solimp[conid]
includemargin = contact_includemargin[conid]
pos_val = contact_dist[conid] - includemargin

# Recompute k and imp (same as _efc_row)
timeconst = solref[0]
dampratio = solref[1]
dmin = solimp[0]
dmax = solimp[1]
width = solimp[2]
mid = solimp[3]
power = solimp[4]

if not (opt_disableflags & types.DisableBit.REFSAFE):
timeconst = wp.max(timeconst, 2.0 * timestep)

dmin = wp.clamp(dmin, types.MJ_MINIMP, types.MJ_MAXIMP)
dmax = wp.clamp(dmax, types.MJ_MINIMP, types.MJ_MAXIMP)
width = wp.max(types.MJ_MINVAL, width)
mid = wp.clamp(mid, types.MJ_MINIMP, types.MJ_MAXIMP)
power = wp.max(1.0, power)

dmax_sq = dmax * dmax
k = 1.0 / (dmax_sq * timeconst * timeconst * dampratio * dampratio)
k = wp.where(solref[0] <= 0.0, -solref[0] / dmax_sq, k)

imp_x = wp.abs(pos_val) / width
imp_a = (1.0 / wp.pow(mid, power - 1.0)) * wp.pow(imp_x, power)
imp_b = 1.0 - (1.0 / wp.pow(1.0 - mid, power - 1.0)) * wp.pow(1.0 - imp_x, power)
imp_y = wp.where(imp_x < mid, imp_a, imp_b)
imp = dmin + imp_y * (dmax - dmin)
imp = wp.clamp(imp, dmin, dmax)
imp = wp.where(imp_x > 1.0, dmax, imp)

# d(aref)/d(pos) = -k * imp
daref_dpos = -k * imp

adj_aref = efc_aref_grad[worldid, efcid]
efc_pos_grad_out[worldid, efcid] = adj_aref * daref_dpos


@wp.func_grad(math.quat_integrate)
def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat):
Expand Down Expand Up @@ -129,6 +235,17 @@ def _copy_grad_kernel(
dst[worldid, dofid] = src[worldid, dofid]


@wp.kernel
def _accumulate_grad_kernel(
# In:
src: wp.array2d(dtype=float),
# Out:
dst: wp.array2d(dtype=float),
):
worldid, dofid = wp.tid()
dst[worldid, dofid] = dst[worldid, dofid] + src[worldid, dofid]


@cache_kernel
def _adjoint_cholesky_tile(nv: int):
@wp.kernel(module="unique", enable_backward=False)
Expand Down Expand Up @@ -182,9 +299,7 @@ def kernel(
out: wp.array3d(dtype=float),
):
worldid = wp.tid()
wp.static(create_blocked_cholesky_func(tile_size))(
H[worldid], nv_runtime, hfactor_tmp[worldid]
)
wp.static(create_blocked_cholesky_func(tile_size))(H[worldid], nv_runtime, hfactor_tmp[worldid])
wp.static(create_blocked_cholesky_solve_func(tile_size, matrix_size))(
hfactor_tmp[worldid], b[worldid], nv_runtime, out[worldid]
)
Expand Down Expand Up @@ -219,9 +334,7 @@ def _solve_hessian_system(m: types.Model, d: types.Data, b, out):
if d.solver_hfactor.shape[1] > 0:
# Solve-only using stored Cholesky factor
wp.launch_tiled(
_adjoint_cholesky_blocked(
types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad
),
_adjoint_cholesky_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad),
dim=d.nworld,
inputs=[d.solver_hfactor, b_3d, m.nv],
outputs=[out_3d],
Expand All @@ -237,53 +350,116 @@ def _solve_hessian_system(m: types.Model, d: types.Data, b, out):
inputs=[m.nv],
outputs=[d.solver_h],
)
hfactor_tmp = wp.zeros(
(d.nworld, m.nv_pad, m.nv_pad), dtype=float
)
hfactor_tmp = wp.zeros((d.nworld, m.nv_pad, m.nv_pad), dtype=float)
wp.launch_tiled(
_adjoint_cholesky_full_blocked(
types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad
),
_adjoint_cholesky_full_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad),
dim=d.nworld,
inputs=[d.solver_h, b_3d, m.nv, hfactor_tmp],
outputs=[out_3d],
block_dim=m.block_dim.update_gradient_cholesky_blocked,
)


def solver_implicit_adjoint(m: types.Model, d: types.Data):
def solver_implicit_adjoint(m: types.Model, d: types.Data, qacc_array=None, qacc_smooth_ref=None):
"""Implicit differentiation adjoint for constraint solver.

Called during tape backward. Reads d.qacc.grad (set by downstream),
solves H*v = adj_qacc, writes d.qacc_smooth.grad = M*v.
Called during tape backward. Reads qacc_array.grad (set by downstream
integrator adjoint), solves H*v = adj_qacc, accumulates into
qacc_smooth_ref.grad += M*v.

Args:
qacc_array: The array whose .grad contains the incoming adjoint.
Defaults to d.qacc when called from diff_forward().
Integrators pass their local qacc array when it differs
from d.qacc (e.g. euler with implicit damping).
qacc_smooth_ref: The qacc_smooth array whose .grad receives the
accumulated adjoint. Captured at record time for
correct gradient isolation when intermediate arrays
are cloned between substeps. Defaults to d.qacc_smooth.
"""
nv = m.nv
if nv == 0:
return

if qacc_array is None:
qacc_array = d.qacc

if qacc_smooth_ref is None:
qacc_smooth_ref = d.qacc_smooth

adj_qacc = qacc_array.grad
if adj_qacc is None:
return

if os.environ.get("MJW_DEBUG_ADJOINT") == "1":
import torch

adj_norm = wp.to_torch(adj_qacc).norm().item()
print(f"[adjoint] |adj_qacc|={adj_norm:.6e}, njmax={d.njmax}")

if d.njmax == 0:
# Solver was identity (qacc = qacc_smooth), copy adjoint through
# Solver was identity (qacc = qacc_smooth), accumulate adjoint through
wp.launch(
_copy_grad_kernel,
_accumulate_grad_kernel,
dim=(d.nworld, nv),
inputs=[d.qacc.grad],
outputs=[d.qacc_smooth.grad],
inputs=[adj_qacc],
outputs=[qacc_smooth_ref.grad],
)
return

if m.opt.solver != types.SolverType.NEWTON:
# CG solver: no Hessian stored, fall back to identity
wp.launch(
_copy_grad_kernel,
_accumulate_grad_kernel,
dim=(d.nworld, nv),
inputs=[d.qacc.grad],
outputs=[d.qacc_smooth.grad],
inputs=[adj_qacc],
outputs=[qacc_smooth_ref.grad],
)
return

# Solve H * v = adj_qacc
v = wp.zeros((d.nworld, m.nv_pad), dtype=float)
_solve_hessian_system(m, d, d.qacc.grad, v)
_solve_hessian_system(m, d, adj_qacc, v)

# adj_qacc_smooth += M * v (accumulate, not overwrite)
tmp = wp.zeros((d.nworld, m.nv_pad), dtype=float)
support.mul_m(m, d, tmp, v)
wp.launch(
_accumulate_grad_kernel,
dim=(d.nworld, nv),
inputs=[tmp],
outputs=[qacc_smooth_ref.grad],
)

# adj_qacc_smooth = M * v
support.mul_m(m, d, d.qacc_smooth.grad, v)
# Phase 3: compute efc-level gradients for collision chain
if d.njmax > 0:
efc_J = d.efc.J
if hasattr(efc_J, "grad") and efc_J.grad is not None:
wp.launch(
_efc_J_grad_kernel,
dim=(d.nworld, d.njmax_pad, m.nv_pad),
inputs=[v, d.efc.force, d.nefc, m.nv, d.njmax],
outputs=[efc_J.grad],
)

efc_aref = d.efc.aref
efc_pos = d.efc.pos
if hasattr(efc_aref, "grad") and efc_aref.grad is not None and hasattr(efc_pos, "grad") and efc_pos.grad is not None:
wp.launch(
_efc_pos_grad_kernel,
dim=d.naconmax,
inputs=[
efc_aref.grad,
d.contact.solref,
d.contact.solimp,
d.contact.includemargin,
d.contact.dist,
d.contact.efc_address,
d.contact.worldid,
d.contact.type,
d.nacon,
m.opt.timestep,
m.opt.disableflags,
],
outputs=[efc_pos.grad],
)
Loading