Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def build_from_model_description(
considered_joints: Sequence[str] | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,
parametrized_links: tuple[str, ...] | None = None,
) -> JaxSimModel:
"""
Build a Model object from a model description.
Expand Down Expand Up @@ -170,6 +171,8 @@ def build_from_model_description(
constraints:
An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
Note that constraints can be used only with RelaxedRigidContacts.
parametrized_links:
The optional list of links to be parametrized. If None, all links are parametrized.

Returns:
The built Model object.
Expand Down Expand Up @@ -202,6 +205,7 @@ def build_from_model_description(
integrator=integrator,
gravity=-gravity,
constraints=constraints,
parametrized_links=parametrized_links,
)

# Store the origin of the model, in case downstream logic needs it.
Expand All @@ -212,7 +216,9 @@ def build_from_model_description(
# TODO: move the building of the metadata to KinDynParameters.build()
# and use the model_description instead of model.built_from.
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata()
model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata(
parametrized_links=parametrized_links
)

return model

Expand All @@ -230,6 +236,7 @@ def build(
integrator: IntegratorType | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,
parametrized_links: tuple[str, ...] | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new parameter is not used inside this function. Are we missing to pass it somewhere here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sorry, I added this line with the intention of moving the instantiation of hardware link metadata in the build method rather than inside the build_from_model_description. If you agree I can add a commit for that in this PR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes thanks!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened an issue at https://github.com/ami-iit/element_differentiable_simulators_for_codesign/issues/30. We will address this in the future as it requires deeper investigation. I'll leave this discussion unresolved

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect thanks!

) -> JaxSimModel:
"""
Build a Model object from an intermediate model description.
Expand All @@ -254,6 +261,8 @@ def build(
gravity: The gravity constant.
constraints:
An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
parametrized_links:
The optional list of links to be parametrized. If None, all links are parametrized.

Returns:
The built Model object.
Expand Down Expand Up @@ -320,10 +329,17 @@ def build(

return model

def compute_hw_link_metadata(self) -> HwLinkMetadata:
def compute_hw_link_metadata(
self, parametrized_links: tuple[str, ...] | None = None
) -> HwLinkMetadata:
"""
Compute the parametric metadata of the links in the model.

Args:
parametrized_links:
An optional tuple of link names to be parametrized. If None,
all links will be parametrized.

Returns:
An instance of HwLinkMetadata containing the metadata of all links.
"""
Expand Down Expand Up @@ -370,19 +386,20 @@ def compute_hw_link_metadata(self) -> HwLinkMetadata:
L_H_pre_masks = []
L_H_pre = []

# Process each link
# Process each link, only parametrizing those in parametrized_links if provided
for link_description in ordered_links:
link_name = link_description.name

if link_name not in self.link_names():
logging.debug(
f"Skipping link '{link_name}' for hardware parametrization as it is not part of the JaxSim model."
)

if link_name not in rod_links_dict:
logging.debug(
f"Skipping link '{link_name}' for hardware parametrization as it is not part of the ROD model."
)
if parametrized_links is not None and link_name not in parametrized_links:
# Mark as unsupported for non-parametrized links
shapes.append(LinkParametrizableShape.Unsupported)
geoms.append([0, 0, 0])
densities.append(0.0)
L_H_Gs.append(jnp.eye(4))
L_H_vises.append(jnp.eye(4))
L_H_pre_masks.append([0] * self.number_of_joints())
L_H_pre.append([jnp.eye(4)] * self.number_of_joints())
continue

rod_link = rod_links_dict.get(link_name)
link_index = int(js.link.name_to_idx(model=self, link_name=link_name))
Expand Down
21 changes: 21 additions & 0 deletions tests/test_api_model_hw_parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,27 @@ def test_unsupported_link_cases():
err_msg="Sphere radius must match the first visual",
)

# Test selective parametrization: only 'supported_link' and 'double_visual_link' should be parametrized
selective_model = js.model.JaxSimModel.build_from_model_description(
multi_link_urdf, is_urdf=True, parametrized_links=("double_visual_link")
)
selective_metadata = selective_model.kin_dyn_parameters.hw_link_metadata

# Check that only the selected links are parametrized
link_indices = {name: idx for idx, name in enumerate(selective_model.link_names())}
assert (
selective_metadata.link_shape[link_indices["supported_link"]]
== LinkParametrizableShape.Unsupported
), "Selected supported_link should be parametrized as Box"
assert (
selective_metadata.link_shape[link_indices["double_visual_link"]]
== LinkParametrizableShape.Sphere
), "Selected double_visual_link should be parametrized as Sphere"
assert (
selective_metadata.link_shape[link_indices["unsupported_link"]]
== LinkParametrizableShape.Unsupported
), "Non-selected unsupported_link should be marked as Unsupported"


def test_export_continuous_joint_handling():
"""
Expand Down