Auto Differentiation Implementation 1/3#1226
Auto Differentiation Implementation 1/3#1226mar-yan24 wants to merge 2 commits intogoogle-deepmind:mainfrom
Conversation
|
@mar-yan24 thank you for this contribution! thanks for scoping this to the smooth dynamics for now. have you identified any key blockers for adding differentiation support to other parts of the code like the collision pipeline or constraint solver? |
|
@erikfrey @adenzler-nvidia @adenzler-nvidia @Kenny-Vilella |
|
Thanks for the review @thowell! Yea I scoped to smooth dynamics after doing a rough survey of the collision and solver code. Within the collision pipeline and constraint solver, there are some blockers I need to look into and do some testing on but these are what I mainly found. For the collision pipeline, the fundamental issue is that collision detection is a discrete geometric query, with several integer configuration vars ( For constraint solver I think I need to spend some more time looking into it, but from what ive seen, the biggest issues for I also don't mind taking a look at If it would help, I can also write a MD file for the larger changes I need to implement and a high-level roadmap for the whole implementation. |
|
Some tile operations have some limitation on the adjoint calculation, this includes:
But most of them should be OK. For |
47a5120 to
eb256a4
Compare
|
I think this should be ready for review, I did some testing in a personal suite and everything looks good. I already found out that multi-step tape AD is not supported properly in Warp so I can probably implement some workarounds for that. Also, you need I will probably request a PR draft soon for the constraint solver implementation which I have basically working just need to run some more tests. Let me know if anything looks sus and I'll take a look. |
|
Thanks for tackling this - I do have a comment about enable_backward in general. MuJoCo Warp is consumed as a library in most cases, so we need to make sure that enable_backwards is properly propagated from user code into the library, and can be toggled correctly. I realize this is currently broken, mostly due to the fact the differentiability does not work at all, so while enabling we should add proper testing to verify that any configuration coming from user code is respected. |
|
We are working on a slightly approximate collision pipeline differentiability solution for Newton, see newton-physics/newton#2164. Maybe some ideas can be applied to MjWarp as well. |
|
@nvtw Thank you for the input, I'll take a look at the newton implementation today and see if there are improvements that can be made for the collision gradients on my end. |
|
hi @mar-yan24 - what a cool PR! Thanks for taking this on. Just as a heads up, we're still gathering requirements from teams / potential partners on what are the active use cases for differentiable simulations, which in turn feeds into thinking through ideal API design and needed code changes. As a result, I think it's going to be a while yet before we're ready to give specific feedback on this PR and get it into a mergeable state. So I want to offer a couple of options in the meantime:
|
|
@erikfrey thanks for the input! Keeping this PR open might be a good call, but if anyone within the community wants to use a more polished version I do have my own fork which has a basic full pipeline working. I think having this PR open or another forum to keep a feedback loop might be nice especially if planned for future merging? I also don't know how much the fork might drift in the future from main but anyone is welcome to try out the AD implementation: https://github.com/mar-yan24/mujoco_warp/tree/mark/autodifferentiation3. If you or the rest of the team have any comments or insights feel free to just pop in and comment here! |
Autodifferentiation Support 1/3
Overview
So a bit ago I was sorta interested in implementing automatic differentiation into MJWarp cause I wanna do a project with diff contact geometry and I had some time on my hands so I decided to begin working on a personal implementation of AD support in MJWarp. I'll probably continue working on this over the month but I'm putting in a draft here to see if I can get some maintainer feedback and maybe discuss if there is still community desire for this as referenced in 'issue' #500.
Basically, these changes add reverse-mode AD support for the smooth dynamics pipeline of MuJoCo Warp. Most people should now be able to compute gradients of scalar loss functions with respect to
qpos,qvel,ctrl, and other state variables by recording awp.Tapeoverkinematics -> fwd_velocity -> fwd_actuation -> euler.The implementation follows a selective
enable_backwardstrategy: only the four modules that participate in the differentiable smooth-dynamics path have backward code generation enabled. All other modules (collision, constraint, solver, sensor, render, ray, etc.) remain atenable_backward: False. This should keep compilation time and binary size normal i think.Architecture
Selective backward generation
enable_backwardsmooth.pyTrueforward.pyTruepassive.pyTruederivative.pyTrueFalseWithin the enabled modules, tile kernels (
wp.launch_tiled) still have per-kernelenable_backward=Falseoverrides (smooth.py lines 1053/2825/2903, forward.py line 309) because cuSolverDx LTO compilation does not support adjoint generation.Kernel compilation time
@Kenny-Vilella raised a good concern in discussion #993 which was about enabling backward globally. As I mentioned earlier, one of the issues we wanna avoid is hella long compile. This selective approach basically generates adjoint kernels for only ~30 smooth-dynamics kernels out of 100+ total. Warp caches compiled kernels, so the cost is one-time per kernel signature.
New modules
grad.py- coordination layer:enable_grad(),disable_grad(),make_diff_data(),diff_step(),diff_forward(),SMOOTH_GRAD_FIELDS.adjoint.py- centralizes@wp.func_gradregistrations. Phase 1 provides a custom adjoint forquat_integrate(avoids gradient singularity at zero angular velocity).grad_test.py- AD test suite: kinematics, fwd_velocity, fwd_actuation, euler_step, quaternion integration, and utility tests.Summary
As a whole, my goal with this is to get some initial feedback and review from both the community and the maintainers on whether or not this project and implementation is feasible. Any words of advice and feedback is appreciated!