diff --git a/mujoco_urdf_loader/mjcf_fcn.py b/mujoco_urdf_loader/mjcf_fcn.py index f990109..a023ee6 100644 --- a/mujoco_urdf_loader/mjcf_fcn.py +++ b/mujoco_urdf_loader/mjcf_fcn.py @@ -86,6 +86,7 @@ def add_position_actuator( return mjcf + def add_torque_actuator( mjcf: ET.Element, joint: str, @@ -177,12 +178,12 @@ def add_joint_vel_sensor(mjcf: ET.Element, joint: str, name: str = None) -> ET.E def add_joint_eq( - mjcf: ET.Element, - joint1: str, - joint2: str, - name: str = None, - multiplier: float = 1.0, - offset: float = 0.0 + mjcf: ET.Element, + joint1: str, + joint2: str, + name: str = None, + multiplier: float = 1.0, + offset: float = 0.0, ) -> ET.Element: """Add a joint equality constraint between two joints. @@ -411,3 +412,67 @@ def add_sphere( geom.set("mass", f"{mass}") return mjcf + + +def add_equality_constraints_for_sites( + mjcf: ET.Element, site_pairs: List[tuple], constraint_type: str = "connect" +) -> ET.Element: + """ + Add equality constraints between pairs of sites in MJCF. + + Args: + mjcf (ET.Element): The MJCF file as ElementTree. + site_pairs (List[tuple]): List of tuples with (site1_name, site2_name) to connect. + constraint_type (str): Type of constraint - "connect" or "weld" (default: "connect"). + + Returns: + ET.Element: The modified MJCF file. + """ + # Find or create the equality element + equality = mjcf.find("equality") + if equality is None: + equality = ET.SubElement(mjcf, "equality") + + for site1, site2 in site_pairs: + # Verify both sites exist + site1_elem = mjcf.find(f".//site[@name='{site1}']") + site2_elem = mjcf.find(f".//site[@name='{site2}']") + + if site1_elem is None: + raise ValueError(f"Site {site1} not found in MJCF") + if site2_elem is None: + raise ValueError(f"Site {site2} not found in MJCF") + + # Create the equality constraint + if constraint_type == "connect": + # Connect constraint directly references sites (no anchor needed for sites) + constraint = ET.SubElement(equality, "connect") + constraint.set("site1", site1) + constraint.set("site2", site2) + elif constraint_type == "weld": + # Weld constraint references bodies + # Find parent bodies of the sites + body1 = None + body2 = None + for body in mjcf.findall(".//body"): + if body.find(f".//site[@name='{site1}']") is not None: + body1 = body.attrib.get("name") + if body.find(f".//site[@name='{site2}']") is not None: + body2 = body.attrib.get("name") + + if body1 is None or body2 is None: + raise ValueError( + f"Could not find parent bodies for sites {site1} and {site2}" + ) + + constraint = ET.SubElement(equality, "weld") + constraint.set("body1", body1) + constraint.set("body2", body2) + else: + raise ValueError(f"Unknown constraint type: {constraint_type}") + + print( + f"Created {constraint_type} equality constraint between {site1} and {site2}" + ) + + return mjcf