Skip to content

Improve gradient stability for dynamics computations#505

Open
flferretti wants to merge 4 commits intomainfrom
gradient-stability
Open

Improve gradient stability for dynamics computations#505
flferretti wants to merge 4 commits intomainfrom
gradient-stability

Conversation

@flferretti
Copy link
Copy Markdown
Collaborator

@flferretti flferretti commented Mar 24, 2026

This PR adds custom JVPs and safe math operations to speed up and improve the stability of autodiff, especially when using single-point precision. Moreover, it substitute the jnp.max in the contact detection with a ReLU function to improve the stability of contact dynamics.


📚 Documentation preview 📚: https://jaxsim--505.org.readthedocs.build//505/

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces gradient-stable math primitives (custom JVPs + “safe” operations) and applies them across contact and rigid-body dynamics codepaths to reduce autodiff discontinuities/instabilities, especially in float32.

Changes:

  • Add new gradient-safe math utilities (smooth_relu, safe_normalize, normalize_quaternion) and linear-algebra helpers (safe_inv, spd_solve, etc.) with custom JVPs.
  • Replace hard maximum/manual normalization patterns in contact dynamics, terrain normals, and quaternion handling with the new primitives.
  • Improve numeric robustness in a few dynamics computations by swapping raw inverses/divisions with guarded versions.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
src/jaxsim/math/safe.py New gradient-safe primitives with custom JVPs (smooth_relu, safe_normalize, normalize_quaternion, safe_divide).
src/jaxsim/math/linalg.py New linear-algebra helpers with custom JVPs (safe_inv, spd_solve, standard_solve, safe_lstsq).
src/jaxsim/math/__init__.py Expose new math helpers at the package level.
src/jaxsim/rbda/contacts/common.py Use smooth_relu for penetration depth to avoid gradient discontinuity at contact onset.
src/jaxsim/rbda/contacts/soft.py Use smooth_relu for non-negative normal force and safe_normalize for tangential direction.
src/jaxsim/rbda/contacts/relaxed_rigid.py Replace fragile divide/inv with safe_divide and safe_inv in regularization term.
src/jaxsim/rbda/mass_inverse.py Replace a raw matrix inverse with safe_inv.
src/jaxsim/terrain/terrain.py Switch terrain normal normalization to safe_normalize.
src/jaxsim/math/rotation.py Use safe_normalize for axis-angle conversion axis handling.
src/jaxsim/math/quaternion.py Normalize quaternions via normalize_quaternion before core ops.
src/jaxsim/api/integrators.py Use normalize_quaternion in integrators instead of manual epsilon-guards.
src/jaxsim/api/data.py Normalize stored/base quaternions via normalize_quaternion.
src/jaxsim/__init__.py Fix .upper() usage and adjust 32-bit precision warning message.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants