Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mujoco_urdf_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .loader import (
ControlMode,
EqualityConstraintCfg,
FrameQuatSensorCfg,
GyroSensorCfg,
URDFtoMuJoCoLoader,
Expand Down
71 changes: 71 additions & 0 deletions mujoco_urdf_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
add_framequat_sensor,
add_gyro_sensor,
add_camera_to_site,
add_equality_constraints_for_sites,
convert_hinge_to_ball_joints,
)
from mujoco_urdf_loader.urdf_fcn import (
Expand Down Expand Up @@ -51,6 +52,13 @@ class CameraCfg:
fovy: float
name: str

@dataclasses.dataclass
class EqualityConstraintCfg:
"""Configuration for a connect/weld equality constraint between two sites."""
site1: str
site2: str
constraint_type: str = "connect"

@dataclasses.dataclass
class URDFtoMuJoCoLoaderCfg:
controlled_joints: List[str]
Expand All @@ -62,6 +70,7 @@ class URDFtoMuJoCoLoaderCfg:
framequat_sensors_cfg: Union[None, List[Union[FrameQuatSensorCfg, Dict[str, Any]]]] = None
gyro_sensors_cfg: Union[None, List[Union[GyroSensorCfg, Dict[str, Any]]]] = None
cameras_cfg: Union[None, List[Union[CameraCfg, Dict[str, Any]]]] = None
equality_constraints_cfg: Union[None, List[Union[EqualityConstraintCfg, Dict[str, Any]]]] = None
ball_joint_damping: float = 0.0
ball_joint_armature: float = 0.0
ball_joint_frictionloss: float = 0.0
Expand Down Expand Up @@ -168,6 +177,7 @@ def load_urdf(urdf_path: str, mesh_path: str, cfg: URDFtoMuJoCoLoaderCfg):
all_missing_joints_as_sites=cfg.all_missing_joints_as_sites,
framequat_sensors_cfg=cfg.framequat_sensors_cfg,
gyro_sensors_cfg=cfg.gyro_sensors_cfg,
equality_constraints_cfg=cfg.equality_constraints_cfg,
)
else:
mjcf_cfg = cfg
Expand All @@ -189,6 +199,7 @@ def load_urdf(urdf_path: str, mesh_path: str, cfg: URDFtoMuJoCoLoaderCfg):
loader.add_framequat_sensors(cfg.framequat_sensors_cfg)
loader.add_gyro_sensors(cfg.gyro_sensors_cfg)
loader.add_cameras(cfg.cameras_cfg)
loader.add_equality_constraints(cfg.equality_constraints_cfg)
return loader

@staticmethod
Expand Down Expand Up @@ -306,6 +317,66 @@ def add_cameras(
fovy=normalized_cfg.fovy,
)

@staticmethod
def _normalize_equality_constraint_cfg(
eq_cfg: Union[EqualityConstraintCfg, Dict[str, Any]],
) -> EqualityConstraintCfg:
if isinstance(eq_cfg, EqualityConstraintCfg):
return eq_cfg

if not isinstance(eq_cfg, dict):
raise TypeError(
"Each equality constraint configuration must be an "
"EqualityConstraintCfg or a dict with keys site1 and site2."
)

site1 = eq_cfg.get("site1")
site2 = eq_cfg.get("site2")
constraint_type = eq_cfg.get("constraint_type", "connect")

if site1 is None or site2 is None:
raise ValueError(
"Each equality constraint configuration requires site1 and site2."
)

return EqualityConstraintCfg(
site1=site1, site2=site2, constraint_type=constraint_type,
)

def add_equality_constraints(
self,
equality_constraints_cfg: Union[
None, List[Union[EqualityConstraintCfg, Dict[str, Any]]]
] = None,
):
"""Add equality constraints (connect/weld) to the MJCF model.

Uses the existing ``add_equality_constraints_for_sites`` helper to
create ``<connect>`` or ``<weld>`` elements inside ``<equality>``.

Args:
equality_constraints_cfg: List of ``EqualityConstraintCfg``
dataclasses or dicts with keys ``site1``, ``site2``, and
optionally ``constraint_type`` (default ``"connect"``).
If ``None``, no constraints are added.
"""
if equality_constraints_cfg is None:
return

# Group by constraint_type so we can call the helper once per type
by_type: Dict[str, List[tuple]] = {}
for cfg in equality_constraints_cfg:
normalized = self._normalize_equality_constraint_cfg(cfg)
ctype = normalized.constraint_type
by_type.setdefault(ctype, []).append(
(normalized.site1, normalized.site2)
)

for constraint_type, site_pairs in by_type.items():
add_equality_constraints_for_sites(
self.mjcf, site_pairs, constraint_type=constraint_type,
)

@staticmethod
def get_missing_joint_sites(
robot_urdf: ET.Element,
Expand Down
76 changes: 70 additions & 6 deletions mujoco_urdf_loader/mjcf_fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def add_position_actuator(

return mjcf


def add_torque_actuator(
mjcf: ET.Element,
joint: str,
Expand Down Expand Up @@ -281,12 +282,12 @@ def add_framequat_sensor(mjcf: ET.Element, objname: str, objtype: str = 'site',


def add_joint_eq(
mjcf: ET.Element,
joint1: str,
joint2: str,
name: str = None,
multiplier: float = 1.0,
offset: float = 0.0
mjcf: ET.Element,
joint1: str,
joint2: str,
name: str = None,
multiplier: float = 1.0,
offset: float = 0.0,
) -> ET.Element:
"""Add a joint equality constraint between two joints.

Expand Down Expand Up @@ -575,3 +576,66 @@ def convert_hinge_to_ball_joints(
joint_elem.set("armature", str(armature))
joint_elem.set("frictionloss", str(frictionloss))
return mjcf

def add_equality_constraints_for_sites(
mjcf: ET.Element, site_pairs: List[tuple], constraint_type: str = "connect"
) -> ET.Element:
"""
Add equality constraints between pairs of sites in MJCF.

Args:
mjcf (ET.Element): The MJCF file as ElementTree.
site_pairs (List[tuple]): List of tuples with (site1_name, site2_name) to connect.
constraint_type (str): Type of constraint - "connect" or "weld" (default: "connect").

Returns:
ET.Element: The modified MJCF file.
"""
# Find or create the equality element
equality = mjcf.find("equality")
if equality is None:
equality = ET.SubElement(mjcf, "equality")

for site1, site2 in site_pairs:
# Verify both sites exist
site1_elem = mjcf.find(f".//site[@name='{site1}']")
site2_elem = mjcf.find(f".//site[@name='{site2}']")

if site1_elem is None:
raise ValueError(f"Site {site1} not found in MJCF")
if site2_elem is None:
raise ValueError(f"Site {site2} not found in MJCF")

# Create the equality constraint
if constraint_type == "connect":
# Connect constraint directly references sites (no anchor needed for sites)
constraint = ET.SubElement(equality, "connect")
constraint.set("site1", site1)
constraint.set("site2", site2)
elif constraint_type == "weld":
# Weld constraint references bodies
# Find parent bodies of the sites
body1 = None
body2 = None
for body in mjcf.findall(".//body"):
if body.find(f".//site[@name='{site1}']") is not None:
body1 = body.attrib.get("name")
if body.find(f".//site[@name='{site2}']") is not None:
body2 = body.attrib.get("name")

if body1 is None or body2 is None:
raise ValueError(
f"Could not find parent bodies for sites {site1} and {site2}"
)

constraint = ET.SubElement(equality, "weld")
constraint.set("body1", body1)
constraint.set("body2", body2)
else:
raise ValueError(f"Unknown constraint type: {constraint_type}")

print(
f"Created {constraint_type} equality constraint between {site1} and {site2}"
)

return mjcf
117 changes: 117 additions & 0 deletions tests/test_loader_equality_constraints_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import xml.etree.ElementTree as ET

import pytest

from mujoco_urdf_loader.loader import (
EqualityConstraintCfg,
URDFtoMuJoCoLoader,
URDFtoMuJoCoLoaderCfg,
)


def _make_empty_mjcf() -> ET.Element:
return ET.fromstring(
"""
<mujoco model="test_model">
<worldbody>
<body name="body_a">
<site name="site_a" pos="0 0 0" quat="1 0 0 0"/>
</body>
<body name="body_b">
<site name="site_b" pos="0.1 0.2 0.3" quat="1 0 0 0"/>
</body>
</worldbody>
</mujoco>
"""
)


def test_add_equality_constraints_none_keeps_model_unchanged():
loader = URDFtoMuJoCoLoader(
_make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[])
)

loader.add_equality_constraints(None)

assert loader.mjcf.find(".//equality") is None


def test_add_equality_constraints_accepts_list_of_dataclasses():
loader = URDFtoMuJoCoLoader(
_make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[])
)

loader.add_equality_constraints(
[EqualityConstraintCfg(site1="site_a", site2="site_b")]
)

connect = loader.mjcf.find(".//equality/connect")
assert connect is not None
assert connect.attrib["site1"] == "site_a"
assert connect.attrib["site2"] == "site_b"


def test_add_equality_constraints_accepts_dict():
loader = URDFtoMuJoCoLoader(
_make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[])
)

loader.add_equality_constraints(
[{"site1": "site_a", "site2": "site_b"}]
)

connect = loader.mjcf.find(".//equality/connect")
assert connect is not None
assert connect.attrib["site1"] == "site_a"
assert connect.attrib["site2"] == "site_b"


def test_add_equality_constraints_multiple():
loader = URDFtoMuJoCoLoader(
_make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[])
)

loader.add_equality_constraints(
[
EqualityConstraintCfg(site1="site_a", site2="site_b"),
EqualityConstraintCfg(site1="site_b", site2="site_a"),
]
)

connects = loader.mjcf.findall(".//equality/connect")
assert len(connects) == 2
assert connects[0].attrib["site1"] == "site_a"
assert connects[1].attrib["site1"] == "site_b"


def test_add_equality_constraints_weld_type():
loader = URDFtoMuJoCoLoader(
_make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[])
)

loader.add_equality_constraints(
[EqualityConstraintCfg(site1="site_a", site2="site_b", constraint_type="weld")]
)

weld = loader.mjcf.find(".//equality/weld")
assert weld is not None
assert weld.attrib["body1"] == "body_a"
assert weld.attrib["body2"] == "body_b"


def test_add_equality_constraints_raises_on_missing_fields():
loader = URDFtoMuJoCoLoader(
_make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[])
)

with pytest.raises(ValueError):
loader.add_equality_constraints([{"site1": "site_a"}])


def test_add_equality_constraints_raises_on_invalid_type():
loader = URDFtoMuJoCoLoader(
_make_empty_mjcf(), URDFtoMuJoCoLoaderCfg(controlled_joints=[])
)

with pytest.raises(TypeError):
loader.add_equality_constraints(["not_a_valid_config"])
Loading