Skip to content

add implicit differentiation for constraint solver 2/3#1

Open
mar-yan24 wants to merge 1 commit intomark/autodifferentiationfrom
mark/autodifferentiation2
Open

add implicit differentiation for constraint solver 2/3#1
mar-yan24 wants to merge 1 commit intomark/autodifferentiationfrom
mark/autodifferentiation2

Conversation

@mar-yan24
Copy link
Owner

Add Implicit Differentiation for Constraint Solver (2/3)

This PR is phase 2/3 for the AD implementation in mjwarp. It depends on PR google-deepmind#1226 merging/passing first.

TLDR

Phase 2 enables reverse-mode automatic differentiation through MuJoCo Warp's constraint solver. Solver uses iterative convergence + discontinuous constraint activation and tile Cholesky operations that cannot be directly differentiated by Warp's tape. Instead, apply implicit differentiation to bridge gradient gap.

The main implementations were:

  • Implicit differentiation through Newton solver via stored Hessian
  • Custom Cholesky-based adjoint solve
  • Identity fallback for CG/unconstrained cases

Retained Solver State

The solver's Hessian and Cholesky factor are retained on Data for reuse during the backward pass:

  • Data.solver_h: Hessian matrix H = M + J^T * diag(D_active) * J
  • Data.solver_hfactor: Cholesky factor of H (nv > 32 blocked path only)
  • Data.solver_Jaref: Reference force from last Newton iteration

These are allocated in io.py:make_data() and aliased into SolverContext during the forward solve.

Implicit Differentiation Adjoint

The tape.record_func is the main Warp mechanism focused on. It records a callable into the tape's execution list. When running tape.backward(), callables execute at their recorded pos in the reversed list.

Tape Order

Forward:

[smooth dynamics] -> [solver.solve] -> [RECORD_FUNC] -> [sensor_acc] -> [euler]

Backward:

[euler adj] -> [sensor_acc adj] -> [RECORD_FUNC fires] -> [solver adj SKIPPED] -> [smooth adj]

When the callable executes:

  1. d.qacc.grad already has dL/dqacc from integration backward
  2. Callable solves H * v = dL/dqacc
  3. Writes d.qacc_smooth.grad = M * v
  4. fwd_acceleration backward reads d.qacc_smooth.grad and propagates

Limitations

There are several known limitations as mentioned in the original PR comment. But this specific PR assumes a fixed active set. Implicit diff assumes small perturbations dont change which constraints are active.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant