diff --git a/environment.yml b/environment.yml index 1af32f750..57d41c004 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,7 @@ dependencies: - jaxlib >= 0.4.26 - jaxlie >= 1.3.0 - jax-dataclasses >= 1.4.0 - - optax == 0.2.3 + - optax >= 0.2.3 - pptree - qpax - rod >= 0.3.3 diff --git a/pyproject.toml b/pyproject.toml index e7d0f5f1d..d2740bf68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "jaxlie >= 1.3.0", "jax_dataclasses >= 1.4.0", "pptree", - "optax == 0.2.3", + "optax >= 0.2.3", "qpax", "rod >= 0.3.3", "typing_extensions ; python_version < '3.12'", diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index f5610bf73..5e20ada54 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -182,10 +182,10 @@ class RelaxedRigidContacts(common.ContactModel): """Relaxed rigid contacts model.""" _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( - default=("tol", "maxiter", "memory_size"), kw_only=True + default=("tol", "maxiter", "memory_size", "scale_init_precond"), kw_only=True ) _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( - default=(1e-6, 50, 10), kw_only=True + default=(1e-6, 50, 10, False), kw_only=True ) @property