Improve gradient stability for dynamics computations#505
Open
flferretti wants to merge 4 commits intomainfrom
Open
Improve gradient stability for dynamics computations#505flferretti wants to merge 4 commits intomainfrom
flferretti wants to merge 4 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces gradient-stable math primitives (custom JVPs + “safe” operations) and applies them across contact and rigid-body dynamics codepaths to reduce autodiff discontinuities/instabilities, especially in float32.
Changes:
- Add new gradient-safe math utilities (
smooth_relu,safe_normalize,normalize_quaternion) and linear-algebra helpers (safe_inv,spd_solve, etc.) with custom JVPs. - Replace hard
maximum/manual normalization patterns in contact dynamics, terrain normals, and quaternion handling with the new primitives. - Improve numeric robustness in a few dynamics computations by swapping raw inverses/divisions with guarded versions.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
src/jaxsim/math/safe.py |
New gradient-safe primitives with custom JVPs (smooth_relu, safe_normalize, normalize_quaternion, safe_divide). |
src/jaxsim/math/linalg.py |
New linear-algebra helpers with custom JVPs (safe_inv, spd_solve, standard_solve, safe_lstsq). |
src/jaxsim/math/__init__.py |
Expose new math helpers at the package level. |
src/jaxsim/rbda/contacts/common.py |
Use smooth_relu for penetration depth to avoid gradient discontinuity at contact onset. |
src/jaxsim/rbda/contacts/soft.py |
Use smooth_relu for non-negative normal force and safe_normalize for tangential direction. |
src/jaxsim/rbda/contacts/relaxed_rigid.py |
Replace fragile divide/inv with safe_divide and safe_inv in regularization term. |
src/jaxsim/rbda/mass_inverse.py |
Replace a raw matrix inverse with safe_inv. |
src/jaxsim/terrain/terrain.py |
Switch terrain normal normalization to safe_normalize. |
src/jaxsim/math/rotation.py |
Use safe_normalize for axis-angle conversion axis handling. |
src/jaxsim/math/quaternion.py |
Normalize quaternions via normalize_quaternion before core ops. |
src/jaxsim/api/integrators.py |
Use normalize_quaternion in integrators instead of manual epsilon-guards. |
src/jaxsim/api/data.py |
Normalize stored/base quaternions via normalize_quaternion. |
src/jaxsim/__init__.py |
Fix .upper() usage and adjust 32-bit precision warning message. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
5dde61f to
41d25ae
Compare
41d25ae to
ed8f0f5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds custom JVPs and safe math operations to speed up and improve the stability of autodiff, especially when using single-point precision. Moreover, it substitute the
jnp.maxin the contact detection with a ReLU function to improve the stability of contact dynamics.📚 Documentation preview 📚: https://jaxsim--505.org.readthedocs.build//505/