diff --git a/.gitignore b/.gitignore index 2440c4d..67d7537 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,9 @@ dmypy.json # pixi environments .pixi *.egg-info + +# Hatch-VCS version file +src/rod/_version.py + +# VS Code settings +.vscode/ diff --git a/README.md b/README.md index b8342eb..389ddcf 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ This will automatically install `sdformat` and `gz-tools`.
Using pixi -[`pixi`](https://pixi.sh) definetly provides the quickest way to start using ROD. You can run the tests by executing: +[`pixi`](https://pixi.sh) definitely provides the quickest way to start using ROD. You can run the tests by executing: ```bash pixi run test diff --git a/environment.yml b/environment.yml index bcbfa9a..c2b5b78 100644 --- a/environment.yml +++ b/environment.yml @@ -2,6 +2,7 @@ name: rod channels: - conda-forge dependencies: + # Core runtime dependencies - coloredlogs - mashumaro - numpy @@ -10,11 +11,22 @@ dependencies: - scipy - trimesh - xmltodict + + # Development and testing - black - isort + - pytest + - pytest-icdiff + + # Optional dependencies - pptree - idyntree - - pytest - robot_descriptions + + # Gazebo/SDF processing - libgz-tools2 - libsdformat13 + + # Python packaging + - hatchling + - hatch-vcs diff --git a/pixi.lock b/pixi.lock index 6f9010f..cff631e 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ad5c217ec910b6fb75d18e454db364169b66ecc039e8aa706a75ad269661e1df -size 100593 +oid sha256:1aaa0f98fa9b782f244f6cc6c3e0daab5b34a0f387430b620b5104e4e01ecfcd +size 92591 diff --git a/pyproject.toml b/pyproject.toml index a6308b0..187aa0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,9 @@ dependencies = [ "xmltodict", ] +[project.scripts] +rod = "rod.__main__:main" + [project.optional-dependencies] style = [ "black ~= 24.0", @@ -87,18 +90,19 @@ Tracker = "https://github.com/ami-iit/rod/issues" # =========== [build-system] -build-backend = "setuptools.build_meta" +build-backend = "hatchling.build" requires = [ - "setuptools>=64", - "setuptools-scm[toml]>=8", - "wheel", + "hatchling", "hatch-vcs" ] -[tool.setuptools] -package-dir = { "" = "src" } +[tool.hatch.version] +source = "vcs" + +[tool.hatch.build.hooks.vcs] +version-file = "src/rod/_version.py" -[tool.setuptools_scm] -local_scheme = "dirty-tag" +[tool.hatch.build.targets.wheel] +packages = ["src/rod"] # ================= # Style and testing @@ -158,7 +162,7 @@ ignore = [ "E731", # Do not assign a `lambda` expression, use a `def` "E741", # Ambiguous variable name "I001", # Import block is unsorted or unformatted - "RUF003", # Ambigous unicode character in comment + "RUF003", # Ambiguous unicode character in comment ] [tool.ruff.lint.per-file-ignores] diff --git a/setup.py b/setup.py deleted file mode 100644 index b908cbe..0000000 --- a/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -import setuptools - -setuptools.setup() diff --git a/src/rod/__init__.py b/src/rod/__init__.py index 0291906..07bb996 100644 --- a/src/rod/__init__.py +++ b/src/rod/__init__.py @@ -135,3 +135,8 @@ def check_compatible_sdformat(specification_version: str) -> None: check_compatible_sdformat(specification_version="1.10") del check_compatible_sdformat + +try: + from rod._version import __version__ +except ImportError: + __version__ = "unknown" diff --git a/src/rod/__main__.py b/src/rod/__main__.py index 5cece3f..5786420 100644 --- a/src/rod/__main__.py +++ b/src/rod/__main__.py @@ -1,19 +1,36 @@ import argparse -import importlib.metadata import logging import sys +from typing import NoReturn from rod import logging as rodlogging +try: + from rod._version import __version__ +except ImportError: + # Fallback for development installations + try: + import importlib.metadata -def main() -> None: + __version__ = importlib.metadata.version("rod") + except importlib.metadata.PackageNotFoundError: + __version__ = "unknown" + + +def main() -> NoReturn: """ Main function of the ROD command line interface. """ parser = argparse.ArgumentParser( prog="rod", - description="ROD: The ultimate Python tool for RObot Descriptions processing.", - usage="%(prog)s [options] file", + description="ROD: The ultimate Python tool for RObot Descriptions processing.\n" + "Load, parse, convert, and manipulate robot description files (SDF/URDF).", + epilog="Examples:\n" + " rod -f robot.urdf --show # Display robot model structure\n" + " rod -f robot.sdf -o robot.urdf # Convert SDF to URDF\n" + " rod -f robot.urdf -o robot.sdf # Convert URDF to SDF\n" + " rod --version # Show version information", + formatter_class=argparse.RawDescriptionHelpFormatter, ) # Version. @@ -21,7 +38,8 @@ def main() -> None: "-V", "--version", action="version", - version=f"%(prog)s {importlib.metadata.version('rod')}", + version=f"%(prog)s {__version__}", + help="Show version information and exit.", ) # Verbose output. @@ -29,7 +47,7 @@ def main() -> None: "-vv", "--verbose", action="store_true", - help="enable verbose output.", + help="Enable verbose output with detailed logging information.", ) # File to parse. @@ -37,7 +55,8 @@ def main() -> None: "-f", "--file", type=str, - help="path to the file to parse.", + metavar="PATH", + help="Path to the robot description file to parse (supports .sdf, .urdf).", required=False, ) @@ -46,7 +65,7 @@ def main() -> None: "-s", "--show", action="store_true", - help="show the robot model attributes.", + help="Display the parsed robot model structure and properties in a human-readable format.", ) # Option to output a URDF or SDF file. @@ -54,13 +73,13 @@ def main() -> None: "-o", "--output", type=str, - help="Output file path.", + metavar="PATH", + help="Output file path for format conversion (.urdf or .sdf extension required).", ) args = parser.parse_args() log_level = logging.DEBUG if args.verbose else logging.INFO - logging.basicConfig(level=log_level) from rod.urdf.exporter import UrdfExporter @@ -110,6 +129,9 @@ def main() -> None: rodlogging.exception(f"Error writing output file: {e}") sys.exit(1) + # Exit successfully if we reach here + sys.exit(0) + if __name__ == "__main__": main() diff --git a/src/rod/kinematics/kinematic_tree.py b/src/rod/kinematics/kinematic_tree.py index abbe68f..f4c887a 100644 --- a/src/rod/kinematics/kinematic_tree.py +++ b/src/rod/kinematics/kinematic_tree.py @@ -3,7 +3,6 @@ import copy import dataclasses import functools -from collections.abc import Sequence import numpy as np @@ -15,29 +14,26 @@ @dataclasses.dataclass(frozen=True) class KinematicTree(DirectedTree): model: rod.Model - joints: list[TreeEdge] = dataclasses.field(default_factory=list) frames: list[TreeFrame] = dataclasses.field(default_factory=list) def __post_init__(self): - # Initialize base class super().__post_init__() + self._assign_indices() + + def _assign_indices(self): + """Assign indices to frames and joints for fast access.""" - # Ge the index of the last link - last_node_idx = list(iter(self))[-1].index + last_node_idx = len(self) - 1 # Assign frames indices. The frames indexing continues the link indexing. - for frame_idx, node in enumerate(self.frames): - node.index = last_node_idx + 1 + frame_idx + for idx, frame in enumerate(self.frames): + frame.index = last_node_idx + 1 + idx # Assign joint indices. The joint index matches the index of its child link. for joint in self.joints: joint.index = self.links_dict[joint.child.name()].index - # Order the lists containing joints and frames - self.joints.sort(key=lambda j: j.index) - self.frames.sort(key=lambda f: f.index) - def link_names(self) -> list[str]: return [node.name() for node in self] @@ -49,262 +45,116 @@ def joint_names(self) -> list[str]: @staticmethod def build(model: rod.Model, is_top_level: bool = True) -> KinematicTree: - logging.debug(msg=f"Building kinematic tree of model '{model.name}'") - - if model.model is not None: - msg = "Model composition is not yet supported. Ignoring all sub-models." - logging.warning(msg=msg) - - # Copy the model and make reference frames explicit, i.e. add the pose element - # to all models, links, joints, frames. - # In this build method, we don't require any specific FrameConvention since - # converting a tree to a new convention would need to build the tree first. - model = copy.deepcopy(model) - model.resolve_frames(is_top_level=is_top_level, explicit_frames=True) - # Generally speaking, a rod.Model describes a DAG (directed acyclic graph). - # Since we do not yet support closed loops / parallel kinematic structures, - # we can restrict the family of supported rod.Models to those describing - # a tree, i.e. a DAG whose nodes have a single parent. - # The links of the model are the nodes of the tree, and joints its edges. - # The root of the tree is defined by the canonical link of the model. - # The model could also define additional elements called frames that are - # pseudo-nodes attached to either a node or another pseudo-node (another frame). - - # In our tree, links are the nodes and joints the edges. - # Create a dict mapping link names to tree nodes, for easy retrieval. - nodes_links_dict: dict[str, DirectedTreeNode] = { - # Add one node for each link of the model - **{link.name: DirectedTreeNode(_source=link) for link in model.links()}, - # Add special world node, that will become a frame later - TreeFrame.WORLD: DirectedTreeNode( - _source=rod.Link( - name=TreeFrame.WORLD, - pose=rod.Pose(relative_to=TreeFrame.WORLD), - ) - ), - } + logging.debug(f"Building kinematic tree for model '{model.name}'") + model = KinematicTree._prepare_model(model, is_top_level) + nodes_links_dict, nodes_frames_dict = KinematicTree._create_nodes(model) + edges_dict = KinematicTree._create_edges(model, nodes_links_dict) - # Get the canonical link of the model. - # The canonical link defines the implicit coordinate frame of the model, - # and by default the implicit __model__ frame is attached to it. - # Note: selecting the wrong canonical link would produce an invalid tree having - # unconnected links and joints, situation that would raise an error. - # Note: after building the tree from the rod.Model, it will be possible to change - # the canonical link (also known as base link), operation that performs a - # tree re-balancing that would produce a new model having its __model__ - # frame not attached to its canonical link. - root_node_name = model.get_canonical_link() - logging.debug(msg=f"Selecting '{root_node_name}' as canonical link") - assert root_node_name in nodes_links_dict, root_node_name - - # Furthermore, existing frames are extra elements that could be optionally - # attached to the kinematic tree (but by default they're not part of it). - # Create a dict mapping frame names to frame nodes, for easy retrieval. - nodes_frames_dict: dict[str, TreeFrame] = { - # Add a frame node for each frame in the model - **{frame.name: TreeFrame(_source=frame) for frame in model.frames()}, - # Add implicit frames used in the SDF specification (__model__). - # The following frames are attached to the first link found in the model - # description and never moved, so that all elements expressing their pose - # w.r.t. these frames always remain valid. - TreeFrame.MODEL: TreeFrame( - _source=rod.Frame( - name=TreeFrame.MODEL, - attached_to=root_node_name, - pose=model.pose, - ), - ), - } + return KinematicTree( + root=nodes_links_dict[model.get_canonical_link()], + joints=list(edges_dict.values()), + frames=list(nodes_frames_dict.values()), + model=model, + ) - # Check that links and frames have unique names - assert len( - set(list(nodes_links_dict.keys()) + list(nodes_frames_dict.keys())) - ) == (len(nodes_links_dict) + len(nodes_frames_dict)) + @staticmethod + def _prepare_model(model: rod.Model, is_top_level: bool) -> rod.Model: + """Prepare the model by resolving frames and ensuring tree structure.""" - # Use joints to connect nodes by defining their parent and children - for joint in model.joints(): - if joint.child == TreeFrame.WORLD: - msg = f"A joint cannot have '{TreeFrame.WORLD}' as child" - raise RuntimeError(msg) + model = copy.deepcopy(model) + model.resolve_frames(is_top_level=is_top_level, explicit_frames=True) - # Get the parent and child nodes of the joint - child_node = nodes_links_dict[joint.child] - parent_node = nodes_links_dict[joint.parent] + if model.model: + logging.warning("Model composition not supported. Ignoring sub-models.") + model.model = None - # Check that the dict is correct - assert child_node.name() == joint.child, (child_node.name(), joint.child) - assert parent_node.name() == joint.parent, ( - parent_node.name(), - joint.parent, - ) + return model - # Assign to each child node their parent - child_node.parent = parent_node + @staticmethod + def _create_nodes( + model: rod.Model, + ) -> tuple[dict[str, DirectedTreeNode], dict[str, TreeFrame]]: + """Create nodes for links and frames.""" - # Assign to each node their children, and make sure they are unique - if child_node.name() not in {n.name() for n in parent_node.children}: - parent_node.children.append(child_node) - - # Compute the tree traversal with BFS algorithm. - # If the model is fixed-base, the world node is not part of the tree and the - # joint connecting to world will be removed. - all_node_names_in_tree = [ - n.name() - for n in list( - KinematicTree.breadth_first_search( - root=nodes_links_dict[root_node_name] - ) - ) - ] - - # Get all the joints part of the kinematic tree ... - joints_in_tree_names = [ - j.name - for j in model.joints() - if {j.parent, j.child}.issubset(all_node_names_in_tree) - ] - joints_in_tree = [j for j in model.joints() if j.name in joints_in_tree_names] - - # ... and those that are not - joints_not_in_tree = [ - j for j in model.joints() if j.name not in joints_in_tree_names - ] - - # A valid rod.Model does not have any dangling link and any unconnected joints. - # Here we check that the rod.Model contains a valid tree representation. - found_num_extra_joints = len(joints_not_in_tree) - expected_num_extra_joints = 1 if model.is_fixed_base() else 0 - - if found_num_extra_joints != expected_num_extra_joints: - if model.is_fixed_base() and found_num_extra_joints == 0: - raise RuntimeError("Failed to find joint connecting the model to world") - - unexpected_joint_names = [j.name for j in joints_not_in_tree] - raise RuntimeError(f"Found unexpected joints: {unexpected_joint_names}") - - # Handle connection to world of fixed-base models - if model.is_fixed_base(): - assert len(joints_not_in_tree) == 1 - world_to_base_joint = joints_not_in_tree[0] - - # Create a temporary edge so that we can reuse the logic implemented for - # the link lumping process - world_to_base_edge = TreeEdge( - parent=nodes_links_dict[world_to_base_joint.parent], - child=nodes_links_dict[world_to_base_joint.child], - _source=world_to_base_joint, + nodes_links_dict = { + link.name: DirectedTreeNode(_source=link) for link in model.links() + } + nodes_links_dict[TreeFrame.WORLD] = DirectedTreeNode( + _source=rod.Link( + name=TreeFrame.WORLD, pose=rod.Pose(relative_to=TreeFrame.WORLD) ) + ) + nodes_frames_dict = { + frame.name: TreeFrame(_source=frame) for frame in model.frames() + } - # Produce new nodes and frames by removing the edge connecting base to world. - # One of the additional frame will be the world frame. - new_base_node, additional_frames = KinematicTree.remove_edge( - edge=world_to_base_edge, keep_parent=False + nodes_frames_dict[TreeFrame.MODEL] = TreeFrame( + _source=rod.Frame( + name=TreeFrame.MODEL, + attached_to=model.get_canonical_link(), + pose=model.pose, ) - assert any(f.name() == TreeFrame.WORLD for f in additional_frames) + ) - # Replace the former base node with the new base node - nodes_links_dict[new_base_node.name()] = new_base_node + return nodes_links_dict, nodes_frames_dict - # Add all the additional frames created by the edge removal process - nodes_frames_dict = { - **nodes_frames_dict, - **{f.name(): f for f in additional_frames}, - } + @staticmethod + def _create_edges( + model: rod.Model, nodes_links_dict: dict[str, DirectedTreeNode] + ) -> dict[str, TreeEdge]: + """Create edges (joints) connecting nodes.""" - # Remove the world node from the nodes dictionary since it was - # converted to frame and already added to the frames dictionary - world_node = nodes_links_dict.pop(TreeFrame.WORLD) - assert world_node is not None + edges_dict = {} - else: - # Remove the world node from the nodes dictionary since it's unconnected... - world_node = nodes_links_dict.pop(TreeFrame.WORLD) + for joint in model.joints(): - # ... and add it as an explicit frame attached to the root node - nodes_frames_dict[world_node.name()] = TreeFrame.from_node( - node=world_node, attached_to=nodes_links_dict[root_node_name] - ) + if joint.child == TreeFrame.WORLD: + raise RuntimeError(f"Joint cannot have '{TreeFrame.WORLD}' as child") - # Create an edge for all joints - edges_dict = { - joint.name: TreeEdge( - parent=nodes_links_dict[joint.parent], - child=nodes_links_dict[joint.child], - _source=joint, + parent_node = nodes_links_dict[joint.parent] + child_node = nodes_links_dict[joint.child] + child_node.parent = parent_node + parent_node.children.append(child_node) + edges_dict[joint.name] = TreeEdge( + parent=parent_node, child=child_node, _source=joint ) - for joint in joints_in_tree - } - - # Build the tree, it assigns indices upon construction - tree = KinematicTree( - root=nodes_links_dict[root_node_name], - joints=list(edges_dict.values()), - frames=list(nodes_frames_dict.values()), - model=model, - ) - return tree + return edges_dict @staticmethod def remove_edge( edge: TreeEdge, keep_parent: bool = True - ) -> tuple[DirectedTreeNode, Sequence[TreeFrame]]: - # Removed node: the node to remove. - # Replaced node: the node removed and replaced with the new node. - # New node: the new node that combines the removed and replaced nodes. - - if keep_parent: - # Lump child into parent. - # This is the default behavior. - removed_node = edge.child - replaced_node = edge.parent - new_node = dataclasses.replace(replaced_node) - - else: - # Lump parent into child. - # Can be useful when lumping the special world node into the base. - removed_node = edge.parent - replaced_node = edge.child - new_node = dataclasses.replace(replaced_node, parent=removed_node.parent) - - # Convert the removed edge to frame - removed_edge_as_frame = TreeFrame.from_edge(edge=edge, attached_to=new_node) + ) -> tuple[DirectedTreeNode, list[TreeFrame]]: + """Remove an edge and lump inertial properties.""" - # Convert the removed node as frame + removed_node, replaced_node = ( + (edge.child, edge.parent) if keep_parent else (edge.parent, edge.child) + ) + new_node = dataclasses.replace( + replaced_node, + parent=replaced_node.parent if keep_parent else removed_node.parent, + ) + removed_edge_as_frame = TreeFrame.from_edge(edge=edge, attached_to=new_node) removed_node_as_frame = TreeFrame.from_node( node=removed_node, attached_to=removed_edge_as_frame ) - - # Create a list with all new frames resulting from the edge removal process new_frames = [removed_node_as_frame, removed_edge_as_frame] - # Check if a link has non-trivial inertial parameters - def has_zero_inertial(link: rod.Link) -> bool: - if not isinstance(link, rod.Link): - return True - - if link.inertial is None: - return True - - return np.allclose(link.inertial.mass, 0.0) and np.allclose( - link.inertial.inertia, np.zeros(shape=(3, 3)) - ) - - # The new node has the same inertial parameters of the removed node if the - # removed node has zero inertial parameters. - # In this case, the new node is equivalent to the removed one and its name can - # be the same. We return the new node and the two new frames (the removed node - # -either parent or child- and the removed edge). - if has_zero_inertial(link=removed_node._source): + if KinematicTree._has_zero_inertial(removed_node._source): return new_node, new_frames - # ======================== - # Lump inertial properties - # ======================== - raise NotImplementedError("Inertial parameters lumping") + @staticmethod + def _has_zero_inertial(link: rod.Link) -> bool: + """Check if a link has zero inertial parameters.""" + if not isinstance(link, rod.Link) or link.inertial is None: + return True + return np.allclose(link.inertial.mass, 0.0) and np.allclose( + link.inertial.inertia, np.zeros((3, 3)) + ) + @functools.cached_property def links_dict(self) -> dict[str, DirectedTreeNode]: return self.nodes_dict diff --git a/src/rod/kinematics/tree_transforms.py b/src/rod/kinematics/tree_transforms.py index 49b9ef7..a4ff8d7 100644 --- a/src/rod/kinematics/tree_transforms.py +++ b/src/rod/kinematics/tree_transforms.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import dataclasses import numpy as np @@ -13,110 +12,102 @@ @dataclasses.dataclass class TreeTransforms: - kinematic_tree: KinematicTree = dataclasses.field( - default_factory=dataclasses.dataclass(init=False) - ) - _transform_cache: dict[str, npt.NDArray] = dataclasses.field(default_factory=dict) - @staticmethod - def build( - model: rod.Model, - is_top_level: bool = True, - ) -> TreeTransforms: - model = copy.deepcopy(model) + kinematic_tree: KinematicTree + _cache: dict[str, npt.NDArray] = dataclasses.field(default_factory=dict, init=False) + + @classmethod + def from_model(cls, model: rod.Model, is_top_level: bool = True) -> TreeTransforms: + return cls(KinematicTree.build(model=model, is_top_level=is_top_level)) - # Make sure that all elements have a pose attribute with explicit 'relative_to'. - model.resolve_frames(is_top_level=is_top_level, explicit_frames=True) + @staticmethod + def build(model: rod.Model, is_top_level: bool = True) -> TreeTransforms: + kinematic_tree = KinematicTree.build(model=model, is_top_level=is_top_level) # Build the kinematic tree and return the TreeTransforms object. - return TreeTransforms( - kinematic_tree=KinematicTree.build(model=model, is_top_level=is_top_level) - ) + return TreeTransforms(kinematic_tree=kinematic_tree) def transform(self, name: str) -> npt.NDArray: - if name in self._transform_cache: - return self._transform_cache[name] - - self._transform_cache[name] = self._compute_transform(name=name) - return self._transform_cache[name] - - def _compute_transform(self, name: str) -> npt.NDArray: - match name: - case TreeFrame.WORLD: - - return np.eye(4) - - case name if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}: - - relative_to = self.kinematic_tree.model.pose.relative_to - assert relative_to in {None, ""}, (relative_to, name) - return self.kinematic_tree.model.pose.transform() - - case name if name in self.kinematic_tree.joint_names(): - - edge = self.kinematic_tree.joints_dict[name] - assert edge.name() == name - - # Get the pose of the frame in which the node's pose is expressed - assert edge._source.pose.relative_to not in {"", None} - x_H_E = edge._source.pose.transform() - W_H_x = self.transform(name=edge._source.pose.relative_to) - - # Compute the world-to-node transform - # TODO: this assumes all joint positions to be 0 - W_H_E = W_H_x @ x_H_E - - return W_H_E - - case name if name in self.kinematic_tree.link_names(): - - element = self.kinematic_tree.links_dict[name] - - assert element.name() == name - assert element._source.pose.relative_to not in {"", None} - - # Get the pose of the frame in which the link's pose is expressed. - x_H_L = element._source.pose.transform() - W_H_x = self.transform(name=element._source.pose.relative_to) - - # Compute the world transform of the link. - W_H_L = W_H_x @ x_H_L - return W_H_L - - case name if name in self.kinematic_tree.frame_names(): - - element = self.kinematic_tree.frames_dict[name] - - assert element.name() == name - assert element._source.pose.relative_to not in {"", None} - - # Get the pose of the frame in which the frame's pose is expressed. - x_H_F = element._source.pose.transform() - W_H_x = self.transform(name=element._source.pose.relative_to) - - # Compute the world transform of the frame. - W_H_F = W_H_x @ x_H_F - return W_H_F - - case _: - raise ValueError(name) - - def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: - - world_H_name = self.transform(name=name) - world_H_relative_to = self.transform(name=relative_to) - - return TreeTransforms.inverse(world_H_relative_to) @ world_H_name + """Get world transform, computing and caching path to root as needed.""" + if name in self._cache: + return self._cache[name] + + # Build path from name to root, stopping at first cached element + path = [] + current = name + while current and current not in self._cache: + path.append(current) + current = self._get_parent(current) + + # Compute transforms from root down + base_transform = self._cache.get(current, np.eye(4)) + for element in reversed(path): + base_transform = base_transform @ self._get_local_transform(element) + self._cache[element] = base_transform + + return self._cache[name] + + def relative_transform(self, from_frame: str, to_frame: str) -> npt.NDArray: + """Transform from one frame to another.""" + return self.inverse(self.transform(from_frame)) @ self.transform(to_frame) + + def invalidate(self, name: str) -> None: + """Remove cached transform and all dependents.""" + to_remove = {key for key in self._cache if self._depends_on(key, name)} + for key in to_remove: + del self._cache[key] + + def clear_cache(self) -> None: + self._cache.clear() + + def _get_parent(self, name: str) -> str | None: + """Get parent frame for any element.""" + if name == TreeFrame.WORLD: + return None + + if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}: + parent = self.kinematic_tree.model.pose.relative_to + return TreeFrame.WORLD if parent in {None, ""} else parent + + # Search through all element types + for element_dict in [ + self.kinematic_tree.joints_dict, + self.kinematic_tree.links_dict, + self.kinematic_tree.frames_dict, + ]: + if name in element_dict: + parent = element_dict[name]._source.pose.relative_to + return TreeFrame.WORLD if parent in {"", None} else parent + + raise ValueError(f"Unknown element: {name}") + + def _get_local_transform(self, name: str) -> npt.NDArray: + """Get local transform for any element.""" + if name == TreeFrame.WORLD: + return np.eye(4) + + if name in {TreeFrame.MODEL, self.kinematic_tree.model.name}: + return self.kinematic_tree.model.pose.transform() + + # Search through all element types + for element_dict in [ + self.kinematic_tree.joints_dict, + self.kinematic_tree.links_dict, + self.kinematic_tree.frames_dict, + ]: + if name in element_dict: + return element_dict[name]._source.pose.transform() + + raise ValueError(f"Unknown element: {name}") + + def _depends_on(self, child: str, ancestor: str) -> bool: + """Check if child depends on ancestor in transform chain.""" + current = child + while current and current != ancestor: + current = self._get_parent(current) + return current == ancestor @staticmethod - def inverse(transform: npt.NDArray) -> npt.NDArray: - - R = transform[0:3, 0:3] - p = np.vstack(transform[0:3, 3]) - - return np.block( - [ - [R.T, -R.T @ p], - [0, 0, 0, 1], - ] - ) + def inverse(T: npt.NDArray) -> npt.NDArray: + R, p = T[:3, :3], T[:3, 3:4] + return np.block([[R.T, -R.T @ p], [np.zeros((1, 3)), np.ones((1, 1))]]) diff --git a/src/rod/sdf/common.py b/src/rod/sdf/common.py index 90b7132..f751536 100644 --- a/src/rod/sdf/common.py +++ b/src/rod/sdf/common.py @@ -4,7 +4,9 @@ from typing import Any import mashumaro +import numpy as np import numpy.typing as npt +from scipy.spatial.transform import Rotation as R from .element import Element @@ -71,8 +73,6 @@ def rpy(self) -> list[float]: return self.pose[3:6] def transform(self) -> npt.NDArray: - import numpy as np - from scipy.spatial.transform import Rotation as R # Transform Euler angles to DCM matrix. # The rpy sequence included in URDF and SDF implements the x-y-z Tait-Bryan diff --git a/src/rod/tree/directed_tree.py b/src/rod/tree/directed_tree.py index 423d718..f295abc 100644 --- a/src/rod/tree/directed_tree.py +++ b/src/rod/tree/directed_tree.py @@ -1,5 +1,6 @@ import dataclasses import functools +from collections import deque from collections.abc import Callable, Iterable, Sequence from typing import Any @@ -35,17 +36,13 @@ def breadth_first_search( root: DirectedTreeNode, sort_children: Callable[[Any], Any] | None = lambda node: node.name(), ) -> Iterable[DirectedTreeNode]: - queue = [root] - - # We assume that nodes have a unique name, and we mark a node as visited by - # storing its name. This assumption speeds up considerably object comparison. - visited = [] - visited.append(root.name) + queue = deque([root]) + visited = {root.name} yield root - while len(queue) > 0: - node = queue.pop(0) + while queue: + node = queue.popleft() # Note: sorting the nodes with their name so that the order of children # insertion does not matter when assigning the node index @@ -53,7 +50,7 @@ def breadth_first_search( if child.name in visited: continue - visited.append(child.name) + visited.add(child.name) queue.append(child) yield child @@ -71,23 +68,18 @@ def pretty_print(self) -> None: def __getitem__( self, key: int | slice | str ) -> DirectedTreeNode | list[DirectedTreeNode]: - # Get the nodes' dictionary (already inserted in order following BFS) - nodes_dict = self.nodes_dict - if isinstance(key, str): - if key not in nodes_dict.keys(): + if key not in self.nodes_dict: raise KeyError(key) - - return nodes_dict[key] + return self.nodes_dict[key] if isinstance(key, int): - if key > len(nodes_dict): + if key >= len(self): raise IndexError(key) - - return list(nodes_dict.values())[key] + return self.nodes[key] if isinstance(key, slice): - return list(nodes_dict.values())[key] + return self.nodes[key] raise TypeError(type(key).__name__) @@ -102,9 +94,9 @@ def __reversed__(self) -> Iterable[DirectedTreeNode]: def __contains__(self, item: str | DirectedTreeNode) -> bool: if isinstance(item, str): - return item in self.nodes_dict.keys() + return item in self.nodes_dict if isinstance(item, DirectedTreeNode): - return item.name() in self.nodes_dict.keys() + return item.name() in self.nodes_dict raise TypeError(type(item).__name__) diff --git a/src/rod/urdf/exporter.py b/src/rod/urdf/exporter.py index d7783c8..1600f4d 100644 --- a/src/rod/urdf/exporter.py +++ b/src/rod/urdf/exporter.py @@ -1,442 +1,297 @@ +from __future__ import annotations + import abc import copy import dataclasses -from typing import Any, ClassVar +import logging +from typing import ( + Any, + ClassVar, + TypeAlias, + TypeVar, +) import numpy as np import xmltodict import rod -from rod import logging +# Type aliases for better readability +FramesList: TypeAlias = list[dict[str, Any]] +JointsList: TypeAlias = list[dict[str, Any]] +PreserveJointsOption: TypeAlias = bool | list[str] +T = TypeVar("T") # Generic type for optional handling + +_ZERO_POSE = np.zeros(6) -@dataclasses.dataclass -class UrdfExporter(abc.ABC): - """Resources to convert an in-memory ROD model to URDF.""" - # The string to use for each indentation level. - indent: str = " " +def _fmt(values) -> str: + """Format a sequence of numbers as a space-separated string.""" + return " ".join(str(v) for v in values) - # Whether to include indentation and newlines in the output. - pretty: bool = False - # Whether to inject additional `` elements in the resulting URDF - # to preserve fixed joints in case of re-loading into sdformat. - # If a list of strings is passed, only the listed fixed joints will be preserved. - gazebo_preserve_fixed_joints: bool | list[str] = False +@dataclasses.dataclass +class UrdfExporter(abc.ABC): + """Resources to convert an in-memory ROD model to URDF with elegant Pythonic patterns.""" + + indent: str = " " # String to use for each indentation level + pretty: bool = False # Whether to include indentation and newlines + gazebo_preserve_fixed_joints: PreserveJointsOption = False # Joints to preserve - SupportedSdfJointTypes: ClassVar[set[str]] = { + # Class constants + SUPPORTED_JOINT_TYPES: ClassVar[set[str]] = { "revolute", "continuous", "prismatic", "fixed", } - DefaultMaterial: ClassVar[dict[str, Any]] = { + DEFAULT_MATERIAL: ClassVar[dict[str, Any]] = { "@name": "default_material", "color": { - "@rgba": " ".join(np.array([1, 1, 1, 1], dtype=str)), + "@rgba": _fmt([1, 1, 1, 1]), }, } - @staticmethod + def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str: + """Convert an in-memory SDF model to a URDF string. + + Args: + sdf: The SDF model parsed by ROD to convert. + + Returns: + The URDF string representing the converted SDF model. + """ + # Work with a copy to avoid modifying the original + sdf = copy.deepcopy(sdf) + + # Get the model (handle both Sdf and Model types) + model = self._extract_model(sdf) + logging.debug(f"Converting model '{model.name}' to URDF") + + # Prepare the model + self._prepare_model_for_conversion(model) + + # Process frames and get extra elements + extra_links, extra_joints = self._process_frames(model) + + # Handle fixed joints preservation + preserved_joints = self._get_preserved_fixed_joints(model) + + # Build and return the URDF + return self._build_urdf_string( + model, extra_links, extra_joints, preserved_joints + ) + + @classmethod def sdf_to_urdf_string( + cls, sdf: rod.Sdf | rod.Model, pretty: bool = False, indent: str = " ", - gazebo_preserve_fixed_joints: bool | list[str] = False, + gazebo_preserve_fixed_joints: PreserveJointsOption = False, ) -> str: + """Legacy method maintained for backward compatibility.""" - msg = "This method is deprecated, please use '{}' instead." - logging.warning(msg.format("UrdfExporter.to_urdf_string")) + logging.warning( + "This method is deprecated, please use 'UrdfExporter.to_urdf_string' instead." + ) - return UrdfExporter( + exporter = cls( pretty=pretty, indent=indent, gazebo_preserve_fixed_joints=gazebo_preserve_fixed_joints, - ).to_urdf_string(sdf=sdf) - - @staticmethod - def _get_urdf_joint_type(joint: rod.Joint) -> str: - """ - Get the URDF joint type, converting revolute joints with infinite limits to continuous. - - sdformat converts URDF continuous joints to SDF revolute joints with infinite limits, - so we need to convert them back to continuous when exporting to URDF. - """ - if ( - joint.type == "revolute" - and joint.axis is not None - and joint.axis.limit is not None - and (joint.axis.limit.lower is None or np.isinf(joint.axis.limit.lower)) - and (joint.axis.limit.upper is None or np.isinf(joint.axis.limit.upper)) - ): - return "continuous" - return joint.type + ) - @staticmethod - def _joint_to_urdf_dict(joint: rod.Joint) -> dict[str, Any]: - """ - Convert a ROD joint to a URDF joint dictionary. + return exporter.to_urdf_string(sdf) - Args: - joint: The ROD joint to convert. + def _extract_model(self, sdf: rod.Sdf | rod.Model) -> rod.Model: + """Extract the model from an SDF object or return the model directly.""" - Returns: - A dictionary representing the joint in URDF format. - """ - # Compute the corrected joint type once - urdf_joint_type = UrdfExporter._get_urdf_joint_type(joint) + if isinstance(sdf, rod.Model): + return sdf - return { - "@name": joint.name, - "@type": urdf_joint_type, - "origin": { - "@xyz": " ".join(map(str, joint.pose.xyz)), - "@rpy": " ".join(map(str, joint.pose.rpy)), - }, - "parent": {"@link": joint.parent}, - "child": {"@link": joint.child}, - **( - {"axis": {"@xyz": " ".join(map(str, joint.axis.xyz.xyz))}} - if joint.axis is not None - and joint.axis.xyz is not None - and urdf_joint_type != "fixed" - else {} - ), - # calibration: does not have any SDF corresponding element - **( - { - "dynamics": { - **( - {"@damping": joint.axis.dynamics.damping} - if joint.axis.dynamics.damping is not None - else {} - ), - **( - {"@friction": joint.axis.dynamics.friction} - if joint.axis.dynamics.friction is not None - else {} - ), - } - } - if joint.axis is not None - and joint.axis.dynamics is not None - and {joint.axis.dynamics.damping, joint.axis.dynamics.friction} - != {None} - and urdf_joint_type != "fixed" - else {} - ), - **( - { - "limit": { - **( - {"@effort": joint.axis.limit.effort} - if joint.axis.limit.effort is not None - else ( - {"@effort": np.finfo(np.float32).max} - if urdf_joint_type - in {"revolute", "prismatic", "continuous"} - else {} - ) - ), - **( - {"@velocity": joint.axis.limit.velocity} - if joint.axis.limit.velocity is not None - else ( - {"@velocity": np.finfo(np.float32).max} - if urdf_joint_type - in {"revolute", "prismatic", "continuous"} - else {} - ) - ), - **( - {"@lower": joint.axis.limit.lower} - if joint.axis.limit.lower is not None - and not np.isinf(joint.axis.limit.lower) - and urdf_joint_type in {"revolute", "prismatic"} - else {} - ), - **( - {"@upper": joint.axis.limit.upper} - if joint.axis.limit.upper is not None - and not np.isinf(joint.axis.limit.upper) - and urdf_joint_type in {"revolute", "prismatic"} - else {} - ), - }, - } - if joint.axis is not None - and joint.axis.limit is not None - and urdf_joint_type != "fixed" - else {} - ), - # mimic: does not have any SDF corresponding element - # safety_controller: does not have any SDF corresponding element - } + if len(sdf.models()) > 1: + raise RuntimeError("URDF only supports one robot element") - def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str: - """ - Convert an in-memory SDF model to a URDF string. + return sdf.models()[0] - Args: - sdf: The SDF model parsed by ROD to convert. + def _prepare_model_for_conversion(self, model: rod.Model) -> None: + """Prepare the model for conversion to URDF format.""" - Returns: - The URDF string representing the converted SDF model. - """ + # Remove all poses that could be assumed being implicit + model.resolve_frames(is_top_level=True, explicit_frames=False) - # Operate on a copy of the sdf object - sdf = copy.deepcopy(sdf) + # Handle sub-models (not supported in URDF) + if model.models(): + logging.warning(f"Ignoring unsupported sub-models of model '{model.name}'") + model.model = None - if isinstance(sdf, rod.Sdf) and len(sdf.models()) > 1: - raise RuntimeError("URDF only supports one robot element") + # Check model pose validity + self._validate_model_pose(model) - # Get the model - model = sdf if isinstance(sdf, rod.Model) else sdf.models()[0] - logging.debug(f"Converting model '{model.name}' to URDF") + # Process canonical link + self._process_canonical_link(model) - # Remove all poses that could be assumed being implicit - model.resolve_frames(is_top_level=True, explicit_frames=False) + # Convert all poses to use the URDF frames convention + model.switch_frame_convention( + frame_convention=rod.FrameConvention.Urdf, + explicit_frames=True, + attach_frames_to_links=True, + ) - # Model composition is not supported, ignoring sub-models - if len(model.models()) > 0: - msg = f"Ignoring unsupported sub-models of model '{model.name}'" - logging.warning(msg=msg) + # Clean up link poses (in URDF, links are attached to parent joint frames) + for link in model.links(): + if link.pose is not None and not np.allclose(link.pose.pose, _ZERO_POSE): + logging.warning(f"Ignoring non-trivial pose of link '{link.name}'") + link.pose = None - model.model = None + def _validate_model_pose(self, model: rod.Model) -> None: + """Validate the model pose for URDF compatibility.""" - # Check that the model pose has no reference frame (implicit frame is world) if model.pose is not None and model.pose.relative_to not in {"", None}: raise RuntimeError("Invalid model pose") - # If the model pose is not zero, warn that it will be ignored. - # In fact, the pose wrt world of the canonical link (base) will be used instead. if ( model.is_fixed_base() and model.pose is not None - and not np.allclose(model.pose.pose, np.zeros(6)) + and not np.allclose(model.pose.pose, _ZERO_POSE) ): logging.warning("Ignoring non-trivial pose of fixed-base model") model.pose = None - # Get the canonical link of the model - logging.debug(f"Detected '{model.get_canonical_link()}' as root link") - canonical_link: rod.Link = {l.name: l for l in model.links()}[ - model.get_canonical_link() - ] + def _process_canonical_link(self, model: rod.Model) -> None: + """Process the canonical link of the model.""" - # If the canonical link has a custom pose, notify that it will be ignored. - # In fact, it might happen that the canonical link has a custom pose w.r.t. - # the __model__ frame. In SDF, the __model__frame defines the default reference - # of a model, instead in URDF this reference is represented by the root link - # (that is, by definition, the SDF canonical link). + canonical_link_name = model.get_canonical_link() + logging.debug(f"Detected '{canonical_link_name}' as root link") + + canonical_link = next(l for l in model.links() if l.name == canonical_link_name) + + # Check if canonical link has a custom pose if ( not model.is_fixed_base() and canonical_link.pose is not None - and not np.allclose(canonical_link.pose.pose, np.zeros(6)) + and not np.allclose(canonical_link.pose.pose, _ZERO_POSE) ): - msg = "Ignoring non-trivial pose of canonical link '{name}'" - logging.warning(msg.format(name=canonical_link.name)) + logging.warning( + f"Ignoring non-trivial pose of canonical link '{canonical_link.name}'" + ) canonical_link.pose = None - # Convert all poses to use the Urdf frames convention. - # This process drastically simplifies extracting compatible kinematic transforms. - # Furthermore, it post-processes frames such that they get directly attached to - # a real link (instead of being attached to other frames). - model.switch_frame_convention( - frame_convention=rod.FrameConvention.Urdf, - explicit_frames=True, - attach_frames_to_links=True, - ) - - # ============================================ - # Convert SDF frames to URDF equivalent chains - # ============================================ + def _process_frames(self, model: rod.Model) -> tuple[FramesList, JointsList]: + """Convert SDF frames to URDF equivalent chains.""" - # Initialize the containers of extra links and joints - extra_links_from_frames: list[dict[str, Any]] = [] - extra_joints_from_frames: list[dict[str, Any]] = [] + extra_links = [] + extra_joints = [] - # Since URDF does not support plain frames as SDF, we convert all frames - # to (fixed_joint->dummy_link) sequences for frame in model.frames(): + dummy_link = self._create_dummy_link(frame) + new_joint = self._create_dummy_joint(frame, dummy_link["@name"]) - # New dummy link with same name of the frame - dummy_link = { - "@name": frame.name, - "inertial": { - "origin": { - "@xyz": "0 0 0", - "@rpy": "0 0 0", - }, - "mass": {"@value": 0.0}, - "inertia": { - "@ixx": 0.0, - "@ixy": 0.0, - "@ixz": 0.0, - "@iyy": 0.0, - "@iyz": 0.0, - "@izz": 0.0, - }, - }, - } + logging.debug( + f"Processing frame '{frame.name}': created new dummy chain " + f"{frame.attached_to}->({new_joint['@name']})->{dummy_link['@name']}" + ) - # Note: the pose of the frame in FrameConvention.Urdf already - # refers to the parent link, so we can directly use it. - assert frame.pose.relative_to == frame.attached_to - - # New joint connecting the link to which the frame is attached - # to the new dummy link. - new_joint = { - "@name": f"{frame.attached_to}_to_{dummy_link['@name']}", - "@type": "fixed", - "parent": {"@link": frame.attached_to}, - "child": {"@link": dummy_link["@name"]}, - "origin": { - "@xyz": " ".join(np.array(frame.pose.xyz, dtype=str)), - "@rpy": " ".join(np.array(frame.pose.rpy, dtype=str)), + extra_links.append(dummy_link) + extra_joints.append(new_joint) + + return extra_links, extra_joints + + def _create_dummy_link(self, frame: rod.Frame) -> dict[str, Any]: + """Create a dummy link for a frame.""" + + return { + "@name": frame.name, + "inertial": { + "origin": {"@xyz": "0 0 0", "@rpy": "0 0 0"}, + "mass": {"@value": 0.0}, + "inertia": { + "@ixx": 0.0, + "@ixy": 0.0, + "@ixz": 0.0, + "@iyy": 0.0, + "@iyz": 0.0, + "@izz": 0.0, }, - } + }, + } - logging.debug( - "Processing frame '{}': created new dummy chain {}->({})->{}".format( - frame.name, - frame.attached_to, - new_joint["@name"], - dummy_link["@name"], - ) - ) + def _create_dummy_joint( + self, frame: rod.Frame, dummy_link_name: str + ) -> dict[str, Any]: + """Create a dummy joint for a frame.""" - extra_links_from_frames.append(dummy_link) - extra_joints_from_frames.append(new_joint) + # The pose of the frame in FrameConvention.Urdf refers to the parent link + assert frame.pose.relative_to == frame.attached_to - # ===================== - # Preserve fixed joints - # ===================== + return { + "@name": f"{frame.attached_to}_to_{dummy_link_name}", + "@type": "fixed", + "parent": {"@link": frame.attached_to}, + "child": {"@link": dummy_link_name}, + "origin": { + "@xyz": _fmt(frame.pose.xyz), + "@rpy": _fmt(frame.pose.rpy), + }, + } - # This attribute could either be list of fixed joint names to preserve, - # or a boolean to preserve all fixed joints. - gazebo_preserve_fixed_joints = copy.copy(self.gazebo_preserve_fixed_joints) + def _get_preserved_fixed_joints(self, model: rod.Model) -> list[str]: + """Get the list of fixed joints to preserve.""" - # If it is a boolean, automatically populate the list with all fixed joints. - if gazebo_preserve_fixed_joints is True: - gazebo_preserve_fixed_joints = [ - j.name for j in model.joints() if j.type == "fixed" - ] + preserve_option = copy.copy(self.gazebo_preserve_fixed_joints) - if gazebo_preserve_fixed_joints is False: - gazebo_preserve_fixed_joints = [] + # Convert boolean option to list of joint names + if preserve_option is True: + preserve_option = [j.name for j in model.joints() if j.type == "fixed"] + elif preserve_option is False: + preserve_option = [] - assert isinstance(gazebo_preserve_fixed_joints, list) + assert isinstance(preserve_option, list) - # Check that all fixed joints to preserve are actually present in the model. - for fixed_joint_name in gazebo_preserve_fixed_joints: - logging.debug(f"Preserving fixed joint '{fixed_joint_name}'") - all_model_joint_names = {j.name for j in model.joints()} - if fixed_joint_name not in all_model_joint_names: - raise RuntimeError(f"Joint '{fixed_joint_name}' not found in the model") + # Validate that all specified joints exist + model_joint_names = {j.name for j in model.joints()} + for joint_name in preserve_option: + logging.debug(f"Preserving fixed joint '{joint_name}'") + if joint_name not in model_joint_names: + raise RuntimeError(f"Joint '{joint_name}' not found in the model") - # =================== - # Convert SDF to URDF - # =================== + return preserve_option - # In URDF, links are directly attached to the frame of their parent joint - for link in model.links(): - if link.pose is not None and not np.allclose(link.pose.pose, np.zeros(6)): - msg = "Ignoring non-trivial pose of link '{name}'" - logging.warning(msg.format(name=link.name)) - link.pose = None + def _build_urdf_string( + self, + model: rod.Model, + extra_links: FramesList, + extra_joints: JointsList, + preserved_joints: list[str], + ) -> str: + """Build the URDF string from the model and additional elements.""" - # Define the 'world' link used for fixed-base models + # Define world link for fixed-base models world_link = rod.Link(name="world") + world_link_dict = [world_link.to_dict()] if model.is_fixed_base() else [] - # Create a new dict in xmldict format with only the elements supported by URDF + # Create the URDF dictionary urdf_dict = { "robot": { - **{"@name": model.name}, - # http://wiki.ros.org/urdf/XML/link - "link": ([world_link.to_dict()] if model.is_fixed_base() else []) - + [ - { - "@name": l.name, - "inertial": { - "origin": { - "@xyz": " ".join(map(str, l.inertial.pose.xyz)), - "@rpy": " ".join(map(str, l.inertial.pose.rpy)), - }, - "mass": {"@value": l.inertial.mass}, - "inertia": { - "@ixx": l.inertial.inertia.ixx, - "@ixy": l.inertial.inertia.ixy, - "@ixz": l.inertial.inertia.ixz, - "@iyy": l.inertial.inertia.iyy, - "@iyz": l.inertial.inertia.iyz, - "@izz": l.inertial.inertia.izz, - }, - }, - "visual": [ - { - "@name": v.name, - "origin": { - "@xyz": " ".join(map(str, v.pose.xyz)), - "@rpy": " ".join(map(str, v.pose.rpy)), - }, - "geometry": UrdfExporter._rod_geometry_to_xmltodict( - geometry=v.geometry - ), - **( - { - "material": UrdfExporter._rod_material_to_xmltodict( - material=v.material - ) - } - if v.material is not None - else {} - ), - } - for v in l.visuals() - ], - "collision": [ - { - "@name": c.name, - "origin": { - "@xyz": " ".join(map(str, c.pose.xyz)), - "@rpy": " ".join(map(str, c.pose.rpy)), - }, - "geometry": UrdfExporter._rod_geometry_to_xmltodict( - geometry=c.geometry - ), - } - for c in l.collisions() - ], - } - for l in model.links() - ] - # Add the extra links resulting from the frame->dummy_link conversion - + extra_links_from_frames, - # http://wiki.ros.org/urdf/XML/joint - "joint": [ - UrdfExporter._joint_to_urdf_dict(j) - for j in model.joints() - if j.type in UrdfExporter.SupportedSdfJointTypes - ] - # Add the extra joints resulting from the frame->link conversion - + extra_joints_from_frames, - # Extra gazebo-related elements - # https://classic.gazebosim.org/tutorials?tut=ros_urdf - # https://github.com/gazebosim/sdformat/issues/199#issuecomment-622127508 + "@name": model.name, + "link": world_link_dict + + self._create_link_elements(model) + + extra_links, + "joint": self._create_joint_elements(model) + extra_joints, "gazebo": [ { "@reference": fixed_joint, "preserveFixedJoint": "true", "disableFixedJointLumping": "true", } - for fixed_joint in gazebo_preserve_fixed_joints + for fixed_joint in preserved_joints ], } } + # Convert to XML string return xmltodict.unparse( input_dict=urdf_dict, pretty=self.pretty, @@ -444,57 +299,222 @@ def to_urdf_string(self, sdf: rod.Sdf | rod.Model) -> str: short_empty_elements=True, ) - @staticmethod - def _rod_geometry_to_xmltodict(geometry: rod.Geometry) -> dict[str, Any]: + def _create_link_elements(self, model: rod.Model) -> list[dict[str, Any]]: + """Create the link elements for the URDF.""" + + return [ + { + "@name": link.name, + "inertial": self._create_inertial_element(link), + "visual": self._create_visual_elements(link), + "collision": self._create_collision_elements(link), + } + for link in model.links() + ] + + def _create_inertial_element(self, link: rod.Link) -> dict[str, Any]: + """Create an inertial element for a link.""" + return { - **( - {"box": {"@size": " ".join(np.array(geometry.box.size, dtype=str))}} - if geometry.box is not None - else {} - ), - **( - { - "cylinder": { - "@radius": geometry.cylinder.radius, - "@length": geometry.cylinder.length, - } - } - if geometry.cylinder is not None - else {} - ), - **( - {"sphere": {"@radius": geometry.sphere.radius}} - if geometry.sphere is not None - else {} - ), - **( - { - "mesh": { - "@filename": geometry.mesh.uri, - "@scale": " ".join(map(str, geometry.mesh.scale)), - } - } - if geometry.mesh is not None - else {} - ), + "origin": { + "@xyz": _fmt(link.inertial.pose.xyz), + "@rpy": _fmt(link.inertial.pose.rpy), + }, + "mass": {"@value": link.inertial.mass}, + "inertia": { + "@ixx": link.inertial.inertia.ixx, + "@ixy": link.inertial.inertia.ixy, + "@ixz": link.inertial.inertia.ixz, + "@iyy": link.inertial.inertia.iyy, + "@iyz": link.inertial.inertia.iyz, + "@izz": link.inertial.inertia.izz, + }, } + def _create_visual_elements(self, link: rod.Link) -> list[dict[str, Any]]: + """Create visual elements for a link.""" + + return [ + { + "@name": visual.name, + "origin": { + "@xyz": _fmt(visual.pose.xyz), + "@rpy": _fmt(visual.pose.rpy), + }, + "geometry": self._rod_geometry_to_xmltodict(visual.geometry), + **( + {"material": self._rod_material_to_xmltodict(visual.material)} + if visual.material is not None + else {} + ), + } + for visual in link.visuals() + ] + + def _create_collision_elements(self, link: rod.Link) -> list[dict[str, Any]]: + """Create collision elements for a link.""" + + return [ + { + "@name": collision.name, + "origin": { + "@xyz": _fmt(collision.pose.xyz), + "@rpy": _fmt(collision.pose.rpy), + }, + "geometry": self._rod_geometry_to_xmltodict(collision.geometry), + } + for collision in link.collisions() + ] + + def _create_joint_elements(self, model: rod.Model) -> list[dict[str, Any]]: + """Create the joint elements for the URDF.""" + + return [ + self._joint_to_dict(joint) + for joint in model.joints() + if joint.type in self.SUPPORTED_JOINT_TYPES + ] + @staticmethod - def _rod_material_to_xmltodict(material: rod.Material) -> dict[str, Any]: + def _get_urdf_joint_type(joint: rod.Joint) -> str: + """ + Get the URDF joint type, converting revolute joints with infinite limits to continuous. + + sdformat converts URDF continuous joints to SDF revolute joints with infinite limits, + so we need to convert them back to continuous when exporting to URDF. + """ + if ( + joint.type == "revolute" + and joint.axis is not None + and joint.axis.limit is not None + and (joint.axis.limit.lower is None or np.isinf(joint.axis.limit.lower)) + and (joint.axis.limit.upper is None or np.isinf(joint.axis.limit.upper)) + ): + return "continuous" + return joint.type + + def _joint_to_dict(self, joint: rod.Joint) -> dict[str, Any]: + """Convert a joint to a dictionary representation.""" + + urdf_joint_type = self._get_urdf_joint_type(joint) + + joint_dict = { + "@name": joint.name, + "@type": urdf_joint_type, + "origin": { + "@xyz": _fmt(joint.pose.xyz), + "@rpy": _fmt(joint.pose.rpy), + }, + "parent": {"@link": joint.parent}, + "child": {"@link": joint.child}, + } + + # Add axis if needed and not a fixed joint + if ( + joint.axis is not None + and joint.axis.xyz is not None + and urdf_joint_type != "fixed" + ): + joint_dict["axis"] = {"@xyz": _fmt(joint.axis.xyz.xyz)} + + # Add dynamics if available and not a fixed joint + if ( + joint.axis is not None + and joint.axis.dynamics is not None + and {joint.axis.dynamics.damping, joint.axis.dynamics.friction} != {None} + and urdf_joint_type != "fixed" + ): + + dynamics = {} + if joint.axis.dynamics.damping is not None: + dynamics["@damping"] = joint.axis.dynamics.damping + if joint.axis.dynamics.friction is not None: + dynamics["@friction"] = joint.axis.dynamics.friction + + if dynamics: + joint_dict["dynamics"] = dynamics + + # Add limits if needed and not a fixed joint + limit_dict = {} + + if ( + joint.axis is not None + and joint.axis.limit is not None + and urdf_joint_type in {"revolute", "prismatic", "continuous"} + ): + + # Effort and velocity get defaults if not specified + limit_dict["@effort"] = self._get_or_default( + joint.axis.limit.effort, np.finfo(np.float32).max + ) + limit_dict["@velocity"] = self._get_or_default( + joint.axis.limit.velocity, np.finfo(np.float32).max + ) + + # Lower and upper only for revolute and prismatic joints (not continuous) + if urdf_joint_type in {"revolute", "prismatic"}: + if joint.axis.limit.lower is not None and not np.isinf( + joint.axis.limit.lower + ): + limit_dict["@lower"] = joint.axis.limit.lower + if joint.axis.limit.upper is not None and not np.isinf( + joint.axis.limit.upper + ): + limit_dict["@upper"] = joint.axis.limit.upper + + if limit_dict: + joint_dict["limit"] = limit_dict + + return joint_dict + + @staticmethod + def _get_or_default(value: T | None, default: T) -> T: + """Return the value if not None, otherwise the default.""" + + return value if value is not None else default + + @staticmethod + def _rod_geometry_to_xmltodict(geometry: rod.Geometry) -> dict[str, Any]: + """Convert ROD geometry to XML dictionary format.""" + + result = {} + + if geometry.box is not None: + result["box"] = {"@size": _fmt(geometry.box.size)} + elif geometry.cylinder is not None: + result["cylinder"] = { + "@radius": geometry.cylinder.radius, + "@length": geometry.cylinder.length, + } + elif geometry.sphere is not None: + result["sphere"] = {"@radius": geometry.sphere.radius} + elif geometry.mesh is not None: + result["mesh"] = { + "@filename": geometry.mesh.uri, + "@scale": _fmt(geometry.mesh.scale), + } + + return result + + @classmethod + def _rod_material_to_xmltodict(cls, material: rod.Material) -> dict[str, Any]: + """Convert ROD material to XML dictionary format.""" + if material.script is not None: - msg = "Material scripts are not supported, returning default material" - logging.info(msg=msg) - return UrdfExporter.DefaultMaterial + logging.info( + "Material scripts are not supported, returning default material" + ) + return cls.DEFAULT_MATERIAL if material.diffuse is None: - msg = "Material diffuse color is not defined, returning default material" - logging.info(msg=msg) - return UrdfExporter.DefaultMaterial + logging.info( + "Material diffuse color is not defined, returning default material" + ) + return cls.DEFAULT_MATERIAL return { - "@name": f"color_{hash(' '.join(map(str, material.diffuse)))}", + "@name": f"color_{hash(_fmt(material.diffuse))}", "color": { - "@rgba": " ".join(map(str, material.diffuse)), + "@rgba": _fmt(material.diffuse), }, - # "texture": {"@filename": None}, # TODO } diff --git a/src/rod/utils/frame_convention.py b/src/rod/utils/frame_convention.py index 7a6a652..ea339ec 100644 --- a/src/rod/utils/frame_convention.py +++ b/src/rod/utils/frame_convention.py @@ -1,8 +1,7 @@ import enum -from collections import defaultdict import rod -from rod import logging +from rod.kinematics.tree_transforms import TreeTransforms class FrameConvention(enum.IntEnum): @@ -18,307 +17,188 @@ def switch_frame_convention( is_top_level: bool = True, attach_frames_to_links: bool = True, ) -> None: + """Switch the frame convention of a model.""" - # Resolve all implicit reference frames using Sdf convention + # Resolve all implicit reference frames model.resolve_frames(is_top_level=is_top_level, explicit_frames=True) - # ============================= - # Initialize forward kinematics - # ============================= - - from rod.kinematics.tree_transforms import TreeTransforms - - # Create the object to compute the kinematics of the tree. + # Initialize kinematics kin = TreeTransforms.build(model=model, is_top_level=is_top_level) - # ===================================== - # Update frames to be attached to links - # ===================================== - - # Update the //frame/attached_to attribute of all frames so that they are - # directly attached to links. + # Attach frames to links if requested if attach_frames_to_links: - for frame in model.frames(): - # Find the link to which the frame is attached to following recursively - # the //frame/attached_to attribute. - parent_link = find_parent_link_of_frame(frame=frame, model=model) - - # Compute the transform between the model and the frame. - model_H_frame = ( - kin.relative_transform( - relative_to="__model__", name=frame.pose.relative_to - ) - @ frame.pose.transform() - ) + _attach_frames_to_links(model, kin) - # Compute the transform between the parent link and the model. - parent_link_H_model = kin.relative_transform( - relative_to=parent_link, name="__model__" - ) + # Get frame mapping functions + frame_fn = _get_frame_function(frame_convention, model) - # Update the frame such that it is attached_to a link, populating the - # pose with the correct transform between the parent link and the frame. - frame.attached_to = parent_link - frame.pose = rod.Pose.from_transform( - relative_to=parent_link, - transform=parent_link_H_model @ model_H_frame, - ) + # Process all elements + _process_model_elements(model, kin, frame_fn, is_top_level) - # ============================================================= - # Define the default reference frames of the different elements - # ============================================================= - - match frame_convention: - case FrameConvention.World: - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "world" - reference_frame_frames = lambda f: "world" - reference_frame_joints = lambda j: "world" - reference_frame_visuals = lambda v: "world" - reference_frame_inertials = lambda i, parent_link: "world" - reference_frame_collisions = lambda c: "world" - reference_frame_link_canonical = "world" - - case FrameConvention.Model: - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "__model__" - reference_frame_frames = lambda f: "__model__" - reference_frame_joints = lambda j: "__model__" - reference_frame_visuals = lambda v: "__model__" - reference_frame_inertials = lambda i, parent_link: "__model__" - reference_frame_collisions = lambda c: "__model__" - reference_frame_link_canonical = "__model__" - - case FrameConvention.Sdf: - - visual_name_to_parent_link = { - visual_name: parent_link - for d in [ - {v.name: link for v in link.visuals()} for link in model.links() - ] - for visual_name, parent_link in d.items() - } - - collision_name_to_parent_link = { - collision_name: parent_link - for d in [ - {c.name: link for c in link.collisions()} for link in model.links() - ] - for collision_name, parent_link in d.items() - } - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: "__model__" - reference_frame_frames = lambda f: f.attached_to - reference_frame_joints = lambda j: joint.child - reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name - reference_frame_inertials = lambda i, parent_link: parent_link.name - reference_frame_collisions = lambda c: collision_name_to_parent_link[ - c.name - ].name - reference_frame_link_canonical = "__model__" - - case FrameConvention.Urdf: - - visual_name_to_parent_link = { - visual_name: parent_link - for d in [ - {v.name: link for v in link.visuals()} for link in model.links() - ] - for visual_name, parent_link in d.items() - } - - collision_name_to_parent_link = { - collision_name: parent_link - for d in [ - {c.name: link for c in link.collisions()} for link in model.links() - ] - for collision_name, parent_link in d.items() - } - - link_name_to_parent_joint_names = defaultdict(list) - - for j in model.joints(): - if j.child != model.get_canonical_link(): - link_name_to_parent_joint_names[j.child].append(j.name) - else: - # The pose of the canonical link is used to define the origin of - # the URDF joint connecting the world to the robot - assert model.is_fixed_base() - link_name_to_parent_joint_names[j.child].append("world") - - reference_frame_model = lambda m: "world" - reference_frame_links = lambda l: link_name_to_parent_joint_names[l.name][0] - reference_frame_frames = lambda f: f.attached_to - reference_frame_joints = lambda j: j.parent - reference_frame_visuals = lambda v: visual_name_to_parent_link[v.name].name - reference_frame_inertials = lambda i, parent_link: parent_link.name - reference_frame_collisions = lambda c: collision_name_to_parent_link[ - c.name - ].name - - if model.is_fixed_base(): - canonical_link = {l.name: l for l in model.links()}[ - model.get_canonical_link() - ] - reference_frame_link_canonical = reference_frame_links(l=canonical_link) - else: - reference_frame_link_canonical = "__model__" - - case _: - raise ValueError(frame_convention) - - # ========================================= - # Process the reference frames of the model - # ========================================= - - if is_top_level: - assert model.pose.relative_to in {"", None} - else: - # Adjust the reference frame of the sub-model - if model.pose.relative_to != reference_frame_model: - x_H_model = model.pose.transform() - target_H_x = kin.relative_transform( - relative_to=reference_frame_model(m=model), - name=model.pose.relative_to, - ) +def _attach_frames_to_links(model: rod.Model, kin) -> None: + """Attach all frames directly to links.""" + for frame in model.frames(): + parent_link = find_parent_link_of_frame(frame, model) - model.pose = rod.Pose.from_transform( - relative_to=reference_frame_model(m=model), - transform=target_H_x @ x_H_model, + model_H_frame = ( + kin.relative_transform( + from_frame="__model__", to_frame=frame.pose.relative_to ) - - # Adjust the reference frames of all sub-models - for sub_model in model.models(): - logging.info( - f"Model composition not yet supported, ignoring '{model.name}/{sub_model.name}'" + @ frame.pose.transform() ) - - # Adjust the reference frames of all joints - for joint in model.joints(): - x_H_joint = joint.pose.transform() - target_H_x = kin.relative_transform( - relative_to=reference_frame_joints(j=joint), - name=joint.pose.relative_to, + parent_link_H_model = kin.relative_transform( + from_frame=parent_link, to_frame="__model__" ) - joint.pose = rod.Pose.from_transform( - relative_to=reference_frame_joints(j=joint), - transform=target_H_x @ x_H_joint, + frame.attached_to = parent_link + frame.pose = rod.Pose.from_transform( + relative_to=parent_link, + transform=parent_link_H_model @ model_H_frame, ) - # Adjust the reference frames of all frames - for frame in model.frames(): - x_H_frame = frame.pose.transform() - target_H_x = kin.relative_transform( - relative_to=reference_frame_frames(f=frame), - name=frame.pose.relative_to, - ) - frame.pose = rod.Pose.from_transform( - relative_to=reference_frame_frames(f=frame), - transform=target_H_x @ x_H_frame, - ) +def _get_frame_function(convention: FrameConvention, model: rod.Model): + """Get frame mapping function for the convention.""" + + # Pre-compute mappings for SDF and URDF + if convention in (FrameConvention.Sdf, FrameConvention.Urdf): + links = model.links() + visual_map = { + v.name: link.name for link in links for v in link.visuals() + } + collision_map = { + c.name: link.name for link in links for c in link.collisions() + } + + if convention == FrameConvention.Urdf: + joint_map = {} + canonical = model.get_canonical_link() + for joint in model.joints(): + joint_map[joint.child] = ( + "world" + if joint.child == canonical and model.is_fixed_base() + else joint.name + ) + canonical_ref = joint_map.get(canonical, "__model__") + + def get_target_frame(element_type: str, element, parent_link=None): + dispatch = { + FrameConvention.World: { + None: "world", + }, + FrameConvention.Model: { + "model": "world", + None: "__model__", + }, + FrameConvention.Sdf: { + "model": "world", + "link": "__model__", + "frame": lambda e: e.attached_to, + "joint": lambda e: e.child, + "visual": lambda e: visual_map[e.name], + "collision": lambda e: collision_map[e.name], + "inertial": lambda e: parent_link.name, + "canonical": "__model__", + }, + FrameConvention.Urdf: { + "model": "world", + "link": lambda e: joint_map[e.name], + "frame": lambda e: e.attached_to, + "joint": lambda e: e.parent, + "visual": lambda e: visual_map[e.name], + "collision": lambda e: collision_map[e.name], + "inertial": lambda e: parent_link.name, + "canonical": canonical_ref, + }, + } + + table = dispatch[convention] + value = table.get(element_type, table.get(None)) + + return value(element) if callable(value) else value + + return get_target_frame + + +def _transform_pose(kin, pose, target_frame: str): + """Transform a pose to target frame.""" + target_H_current = kin.relative_transform( + from_frame=target_frame, to_frame=pose.relative_to + ) + return rod.Pose.from_transform( + relative_to=target_frame, + transform=target_H_current @ pose.transform(), + ) + + +def _process_model_elements( + model: rod.Model, kin, frame_fn, is_top_level: bool +) -> None: + """Process all model elements with the frame function.""" + canonical = model.get_canonical_link() - # Adjust the reference frames of all links - for link in model.links(): - relative_to = ( - reference_frame_links(l=link) - if link.name != model.get_canonical_link() - else reference_frame_link_canonical - ) + # Model pose for sub-models + if not is_top_level: + target = frame_fn("model", model) + if model.pose.relative_to != target: + model.pose = _transform_pose(kin, model.pose, target) + + # Process all elements + for joint in model.joints(): + joint.pose = _transform_pose(kin, joint.pose, frame_fn("joint", joint)) + for frame in model.frames(): + frame.pose = _transform_pose(kin, frame.pose, frame_fn("frame", frame)) + + for link in model.links(): # Link pose - x_H_link = link.pose.transform() - target_H_x = kin.relative_transform( - relative_to=relative_to, - name=link.pose.relative_to, - ) - link.pose = rod.Pose.from_transform( - relative_to=relative_to, - transform=target_H_x @ x_H_link, + target = ( + frame_fn("canonical", link) + if link.name == canonical + else frame_fn("link", link) ) + link.pose = _transform_pose(kin, link.pose, target) - # Inertial pose - x_H_inertial = link.inertial.pose.transform() - target_H_x = kin.relative_transform( - relative_to=reference_frame_inertials(i=link.inertial, parent_link=link), - name=link.inertial.pose.relative_to, - ) - link.inertial.pose = rod.Pose.from_transform( - relative_to=reference_frame_inertials(i=link.inertial, parent_link=link), - transform=target_H_x @ x_H_inertial, + # Link elements + link.inertial.pose = _transform_pose( + kin, link.inertial.pose, frame_fn("inertial", link.inertial, link) ) - # Visuals pose for visual in link.visuals(): - x_H_visual = visual.pose.transform() - target_H_x = kin.relative_transform( - relative_to=reference_frame_visuals(v=visual), - name=visual.pose.relative_to, - ) + visual.pose = _transform_pose(kin, visual.pose, frame_fn("visual", visual)) - visual.pose = rod.Pose.from_transform( - relative_to=reference_frame_visuals(v=visual), - transform=target_H_x @ x_H_visual, - ) - - # Collisions pose for collision in link.collisions(): - x_H_collision = collision.pose.transform() - target_H_x = kin.relative_transform( - relative_to=reference_frame_collisions(c=collision), - name=collision.pose.relative_to, - ) - - collision.pose = rod.Pose.from_transform( - relative_to=reference_frame_collisions(c=collision), - transform=target_H_x @ x_H_collision, + collision.pose = _transform_pose( + kin, collision.pose, frame_fn("collision", collision) ) def find_parent_link_of_frame(frame: rod.Frame, model: rod.Model) -> str: - - links_dict = {l.name: l for l in model.links()} - frames_dict = {f.name: f for f in model.frames()} - joints_dict = {j.name: j for j in model.joints()} - sub_models_dict = {m.name: m for m in model.models()} - - assert isinstance(frame, rod.Frame) - - match frame.attached_to: - case anchor if anchor in links_dict: - parent = links_dict[frame.attached_to] - - case anchor if anchor in frames_dict: - parent = frames_dict[frame.attached_to] - - case anchor if anchor in {model.name, "__model__"}: - return model.get_canonical_link() - - case anchor if anchor in joints_dict: - raise ValueError("Frames cannot be attached to joints") - - case anchor if anchor in sub_models_dict: - raise RuntimeError("Model composition not yet supported") - - case _: - raise RuntimeError( - f"Failed to find element with name '{frame.attached_to}'" - ) - - # At this point, the parent is either a link or another frame. - assert isinstance(parent, rod.Link | rod.Frame) - - match parent: - # If the parent is a link, can stop searching. - case parent if isinstance(parent, rod.Link): - return parent.name - - # If the parent is another frame, keep looking for the parent link. - case parent if isinstance(parent, rod.Frame): - return find_parent_link_of_frame(frame=parent, model=model) - - raise RuntimeError("This recursive function should never arrive here.") + """Find the parent link of a frame.""" + # Create lookup dicts + elements = { + **{l.name: l for l in model.links()}, + **{f.name: f for f in model.frames()}, + **{j.name: j for j in model.joints()}, + **{m.name: m for m in model.models()}, + } + + # Handle special cases + if frame.attached_to in {model.name, "__model__"}: + return model.get_canonical_link() + + if frame.attached_to not in elements: + raise RuntimeError(f"Element '{frame.attached_to}' not found") + + parent = elements[frame.attached_to] + + # Check parent type + if isinstance(parent, rod.Link): + return parent.name + elif isinstance(parent, rod.Frame): + return find_parent_link_of_frame(parent, model) # Recursive + elif isinstance(parent, rod.Joint): + raise ValueError("Frames cannot be attached to joints") + else: + raise RuntimeError("Model composition not yet supported") diff --git a/src/rod/utils/gazebo.py b/src/rod/utils/gazebo.py index 9cc8b8e..eac3c0d 100644 --- a/src/rod/utils/gazebo.py +++ b/src/rod/utils/gazebo.py @@ -1,3 +1,4 @@ +import functools import os import pathlib import shutil @@ -49,6 +50,12 @@ def has_gazebo() -> bool: except Exception: return False + @staticmethod + @functools.lru_cache(maxsize=128) + def _get_cached_processing(content: str) -> str: + """Internal cached processing method.""" + return GazeboHelper._process_with_sdformat_uncached(content) + @staticmethod def process_model_description_with_sdformat( model_description: str | pathlib.Path, @@ -82,32 +89,41 @@ def process_model_description_with_sdformat( model_description_string = model_description # ================================ - # Process the string with sdformat + # Use caching for repeated content # ================================ + return GazeboHelper._get_cached_processing(model_description_string) + + @staticmethod + def _process_with_sdformat_uncached(model_description_string: str) -> str: + """Internal method for actual SDF processing without caching.""" # Get the Gazebo Sim executable (raises exception if not found) gazebo_executable = GazeboHelper.get_gazebo_executable() - # Operate on a file stored in a temporary directory. - # This is necessary on windows because the file has to be closed before - # it can be processed by the sdformat executable. - # As soon as 3.12 will be the minimum supported version, we can use just - # NamedTemporaryFile with the new delete_on_close=False parameter. - with tempfile.TemporaryDirectory() as tmp: - - with tempfile.NamedTemporaryFile( - mode="w+", suffix=".xml", dir=tmp, delete=False - ) as fp: + fd, tmp_path = tempfile.mkstemp(suffix=".xml", prefix="rod_sdf_") + temp_file = pathlib.Path(tmp_path) - fp.write(model_description_string) - fp.close() + try: + # Write via the file descriptor returned by mkstemp to avoid a + # separate open() call (the fd is already open and exclusive). + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.write(model_description_string) + # Process with optimized subprocess call try: cp = subprocess.run( - [str(gazebo_executable), "sdf", "-p", fp.name], + [str(gazebo_executable), "sdf", "-p", str(temp_file)], text=True, capture_output=True, check=True, + bufsize=-1, # Use system default buffer size + env=dict( + os.environ, + **{ + "IGN_PARTITION": "rod_processing", # Avoid interference with running Gazebo + "GZ_PARTITION": "rod_processing", + }, + ), ) except subprocess.CalledProcessError as e: if e.returncode != 0: @@ -116,11 +132,31 @@ def process_model_description_with_sdformat( "Failed to process the input with sdformat" ) from e + finally: + temp_file.unlink(missing_ok=True) + # Get the resulting SDF string sdf_string = cp.stdout # There might be warnings in the output, so we remove them by finding the # first tag and ignoring everything before it - sdf_string = sdf_string[sdf_string.find(" list[str]: + """Process multiple SDF descriptions in batch for better performance.""" + results = [] + for desc in descriptions: + results.append(GazeboHelper.process_model_description_with_sdformat(desc)) + return results + + @classmethod + def clear_cache(cls) -> None: + """Clear the processing cache to free memory.""" + cls._get_cached_processing.cache_clear()