Skip to content

Auto Differentiation Implementation 1/3#1226

Open
mar-yan24 wants to merge 2 commits intogoogle-deepmind:mainfrom
mar-yan24:mark/autodifferentiation
Open

Auto Differentiation Implementation 1/3#1226
mar-yan24 wants to merge 2 commits intogoogle-deepmind:mainfrom
mar-yan24:mark/autodifferentiation

Conversation

@mar-yan24
Copy link

Autodifferentiation Support 1/3

Overview

So a bit ago I was sorta interested in implementing automatic differentiation into MJWarp cause I wanna do a project with diff contact geometry and I had some time on my hands so I decided to begin working on a personal implementation of AD support in MJWarp. I'll probably continue working on this over the month but I'm putting in a draft here to see if I can get some maintainer feedback and maybe discuss if there is still community desire for this as referenced in 'issue' #500.

Basically, these changes add reverse-mode AD support for the smooth dynamics pipeline of MuJoCo Warp. Most people should now be able to compute gradients of scalar loss functions with respect to qpos, qvel, ctrl, and other state variables by recording a wp.Tape over kinematics -> fwd_velocity -> fwd_actuation -> euler.

The implementation follows a selective enable_backward strategy: only the four modules that participate in the differentiable smooth-dynamics path have backward code generation enabled. All other modules (collision, constraint, solver, sensor, render, ray, etc.) remain at enable_backward: False. This should keep compilation time and binary size normal i think.

Architecture

Selective backward generation

Module enable_backward Notes
smooth.py True kinematics, crb, rne, com_vel, etc.
forward.py True fwd_velocity, fwd_actuation, euler/rk4
passive.py True passive forces (spring, damper, fluid)
derivative.py True analytical derivatives (qDeriv)
All others (13+) False collision, constraint, solver, sensor, ...

Within the enabled modules, tile kernels (wp.launch_tiled) still have per-kernel enable_backward=False overrides (smooth.py lines 1053/2825/2903, forward.py line 309) because cuSolverDx LTO compilation does not support adjoint generation.

Kernel compilation time

@Kenny-Vilella raised a good concern in discussion #993 which was about enabling backward globally. As I mentioned earlier, one of the issues we wanna avoid is hella long compile. This selective approach basically generates adjoint kernels for only ~30 smooth-dynamics kernels out of 100+ total. Warp caches compiled kernels, so the cost is one-time per kernel signature.

New modules

  • grad.py - coordination layer: enable_grad(), disable_grad(), make_diff_data(), diff_step(), diff_forward(), SMOOTH_GRAD_FIELDS.
  • adjoint.py - centralizes @wp.func_grad registrations. Phase 1 provides a custom adjoint for quat_integrate (avoids gradient singularity at zero angular velocity).
  • grad_test.py - AD test suite: kinematics, fwd_velocity, fwd_actuation, euler_step, quaternion integration, and utility tests.

Summary

As a whole, my goal with this is to get some initial feedback and review from both the community and the maintainers on whether or not this project and implementation is feasible. Any words of advice and feedback is appreciated!

@thowell thowell linked an issue Mar 16, 2026 that may be closed by this pull request
@thowell thowell self-requested a review March 16, 2026 09:53
@thowell
Copy link
Collaborator

thowell commented Mar 16, 2026

@mar-yan24 thank you for this contribution!

thanks for scoping this to the smooth dynamics for now. have you identified any key blockers for adding differentiation support to other parts of the code like the collision pipeline or constraint solver?

@thowell
Copy link
Collaborator

thowell commented Mar 16, 2026

@erikfrey @adenzler-nvidia
how do we want to think about an api for differentiation?

@adenzler-nvidia @Kenny-Vilella
what are the performance implications for utilizing wp.clone? are these calls something we should consider guarding with wp.static?
what considerations should be made for tile operations and differentiability?

@mar-yan24
Copy link
Author

mar-yan24 commented Mar 16, 2026

Thanks for the review @thowell! Yea I scoped to smooth dynamics after doing a rough survey of the collision and solver code. Within the collision pipeline and constraint solver, there are some blockers I need to look into and do some testing on but these are what I mainly found.

For the collision pipeline, the fundamental issue is that collision detection is a discrete geometric query, with several integer configuration vars (cltype/clface/clcorner) controlling branching with no gradient, and atomic counters make the contact count data-dependent. There are also some algorithmic non-differentiabilities, which im pretty sure is not fixable just by enabling backward (enable_backward) on the existing code. Probably gotta bypass the discrete pipeline with smooth distance proxies and custom adjoins in adjoint.py.

For constraint solver I think I need to spend some more time looking into it, but from what ive seen, the biggest issues for enable_backward permission are wp.capture_while (runs until all worlds converse, so number of iterations varies per world -> tape needs fixed computation graph to replay backwards), constraint activation (constraint active depends on pos < 0, the discontinuity needs to be avoided, maybe add small perturbation to qpos?), and wp.tile_cholesky (does not support LTO adjoin generation, no flow for grad).

I also don't mind taking a look at wp.clone and tile operations for performance optimization. I don't really know as much on performance optimization but im a student I have a decent amount of time on my hands lol.

If it would help, I can also write a MD file for the larger changes I need to implement and a high-level roadmap for the whole implementation.

@Kenny-Vilella
Copy link
Collaborator

Some tile operations have some limitation on the adjoint calculation, this includes:

  • wp.tile_cholesky
  • wp.tile_matmul
  • wp.tile_*_solve

But most of them should be OK.

For wp.clone, it's a memory operation so it may be expensive depending on the size of the array.
One thing to consider is whether it is faster to spin up a kernel, but it's not a major issue.

@mar-yan24 mar-yan24 force-pushed the mark/autodifferentiation branch from 47a5120 to eb256a4 Compare March 17, 2026 03:52
@mar-yan24 mar-yan24 marked this pull request as ready for review March 20, 2026 03:43
@mar-yan24
Copy link
Author

I think this should be ready for review, I did some testing in a personal suite and everything looks good. I already found out that multi-step tape AD is not supported properly in Warp so I can probably implement some workarounds for that. Also, you need sm_70+ version to run on GPU cause Tile Cholesky needs cuSolverDx.

I will probably request a PR draft soon for the constraint solver implementation which I have basically working just need to run some more tests. Let me know if anything looks sus and I'll take a look.

@adenzler-nvidia
Copy link
Collaborator

Thanks for tackling this - I do have a comment about enable_backward in general. MuJoCo Warp is consumed as a library in most cases, so we need to make sure that enable_backwards is properly propagated from user code into the library, and can be toggled correctly. I realize this is currently broken, mostly due to the fact the differentiability does not work at all, so while enabling we should add proper testing to verify that any configuration coming from user code is respected.

@nvtw
Copy link
Collaborator

nvtw commented Mar 23, 2026

We are working on a slightly approximate collision pipeline differentiability solution for Newton, see newton-physics/newton#2164. Maybe some ideas can be applied to MjWarp as well.

@mar-yan24
Copy link
Author

@nvtw Thank you for the input, I'll take a look at the newton implementation today and see if there are improvements that can be made for the collision gradients on my end.

@erikfrey
Copy link
Collaborator

hi @mar-yan24 - what a cool PR! Thanks for taking this on.

Just as a heads up, we're still gathering requirements from teams / potential partners on what are the active use cases for differentiable simulations, which in turn feeds into thinking through ideal API design and needed code changes.

As a result, I think it's going to be a while yet before we're ready to give specific feedback on this PR and get it into a mergeable state. So I want to offer a couple of options in the meantime:

  • If you'd like to keep this PR open, we can come back to it from time to time as we collect new requirements or get feedback from the community
  • If you have a specific research aim in mind and you're gung ho to complete this engineering, you could also consider maintaining a fork of mjwarp that supports autodiff, and we would be happy to refer folks to it who are interested in autodifferentiation in the meantime.

@mar-yan24
Copy link
Author

@erikfrey thanks for the input! Keeping this PR open might be a good call, but if anyone within the community wants to use a more polished version I do have my own fork which has a basic full pipeline working. I think having this PR open or another forum to keep a feedback loop might be nice especially if planned for future merging? I also don't know how much the fork might drift in the future from main but anyone is welcome to try out the AD implementation: https://github.com/mar-yan24/mujoco_warp/tree/mark/autodifferentiation3.

If you or the rest of the team have any comments or insights feel free to just pop in and comment here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Differentiability

6 participants