diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 878159e52..3f0d523fe 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -141,6 +141,7 @@ def build_from_model_description( considered_joints: Sequence[str] | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None, + parametrized_links: tuple[str, ...] | None = None, ) -> JaxSimModel: """ Build a Model object from a model description. @@ -170,6 +171,8 @@ def build_from_model_description( constraints: An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. Note that constraints can be used only with RelaxedRigidContacts. + parametrized_links: + The optional list of links to be parametrized. If None, all links are parametrized. Returns: The built Model object. @@ -202,6 +205,7 @@ def build_from_model_description( integrator=integrator, gravity=-gravity, constraints=constraints, + parametrized_links=parametrized_links, ) # Store the origin of the model, in case downstream logic needs it. @@ -212,7 +216,9 @@ def build_from_model_description( # TODO: move the building of the metadata to KinDynParameters.build() # and use the model_description instead of model.built_from. with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata() + model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata( + parametrized_links=parametrized_links + ) return model @@ -230,6 +236,7 @@ def build( integrator: IntegratorType | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None, + parametrized_links: tuple[str, ...] | None = None, ) -> JaxSimModel: """ Build a Model object from an intermediate model description. @@ -254,6 +261,8 @@ def build( gravity: The gravity constant. constraints: An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. + parametrized_links: + The optional list of links to be parametrized. If None, all links are parametrized. Returns: The built Model object. @@ -320,10 +329,17 @@ def build( return model - def compute_hw_link_metadata(self) -> HwLinkMetadata: + def compute_hw_link_metadata( + self, parametrized_links: tuple[str, ...] | None = None + ) -> HwLinkMetadata: """ Compute the parametric metadata of the links in the model. + Args: + parametrized_links: + An optional tuple of link names to be parametrized. If None, + all links will be parametrized. + Returns: An instance of HwLinkMetadata containing the metadata of all links. """ @@ -370,19 +386,20 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata: L_H_pre_masks = [] L_H_pre = [] - # Process each link + # Process each link, only parametrizing those in parametrized_links if provided for link_description in ordered_links: link_name = link_description.name - if link_name not in self.link_names(): - logging.debug( - f"Skipping link '{link_name}' for hardware parametrization as it is not part of the JaxSim model." - ) - - if link_name not in rod_links_dict: - logging.debug( - f"Skipping link '{link_name}' for hardware parametrization as it is not part of the ROD model." - ) + if parametrized_links is not None and link_name not in parametrized_links: + # Mark as unsupported for non-parametrized links + shapes.append(LinkParametrizableShape.Unsupported) + geoms.append([0, 0, 0]) + densities.append(0.0) + L_H_Gs.append(jnp.eye(4)) + L_H_vises.append(jnp.eye(4)) + L_H_pre_masks.append([0] * self.number_of_joints()) + L_H_pre.append([jnp.eye(4)] * self.number_of_joints()) + continue rod_link = rod_links_dict.get(link_name) link_index = int(js.link.name_to_idx(model=self, link_name=link_name)) diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 945715283..07b34db66 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -656,6 +656,27 @@ def test_unsupported_link_cases(): err_msg="Sphere radius must match the first visual", ) + # Test selective parametrization: only 'supported_link' and 'double_visual_link' should be parametrized + selective_model = js.model.JaxSimModel.build_from_model_description( + multi_link_urdf, is_urdf=True, parametrized_links=("double_visual_link") + ) + selective_metadata = selective_model.kin_dyn_parameters.hw_link_metadata + + # Check that only the selected links are parametrized + link_indices = {name: idx for idx, name in enumerate(selective_model.link_names())} + assert ( + selective_metadata.link_shape[link_indices["supported_link"]] + == LinkParametrizableShape.Unsupported + ), "Selected supported_link should be parametrized as Box" + assert ( + selective_metadata.link_shape[link_indices["double_visual_link"]] + == LinkParametrizableShape.Sphere + ), "Selected double_visual_link should be parametrized as Sphere" + assert ( + selective_metadata.link_shape[link_indices["unsupported_link"]] + == LinkParametrizableShape.Unsupported + ), "Non-selected unsupported_link should be marked as Unsupported" + def test_export_continuous_joint_handling(): """