Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 171 additions & 6 deletions mujoco_warp/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,167 @@ def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) ->
return int(valid_sizes[np.searchsorted(valid_sizes, njmax)])


def _body_pair_nnz(mjm: mujoco.MjModel, body1: int, body2: int) -> int:
"""Returns the number of unique DOFs in the kinematic tree union of two bodies."""
body1 = mjm.body_weldid[body1]
body2 = mjm.body_weldid[body2]
da1 = mjm.body_dofadr[body1] + mjm.body_dofnum[body1] - 1
da2 = mjm.body_dofadr[body2] + mjm.body_dofnum[body2] - 1
nnz = 0
while da1 >= 0 or da2 >= 0:
da = max(da1, da2)
if da1 == da:
da1 = mjm.dof_parentid[da1]
if da2 == da:
da2 = mjm.dof_parentid[da2]
nnz += 1
return nnz


def _default_njmax_nnz(mjm: mujoco.MjModel, nconmax: int, njmax: int) -> int:
"""Returns a heuristic estimate for the number of non-zeros in the sparse constraint Jacobian.

Assumes all equality, friction, and limit constraints are active and computes
their non-zeros. For contacts, assumes njmax contact rows at the maximum
body-pair non-zeros from all enabled collision pairs.

Args:
mjm: The model containing kinematic and dynamic information (host).
nconmax: Maximum number of contacts per world.
njmax: Maximum number of constraint rows per world.

Returns:
Estimated number of non-zeros in the constraint Jacobian.
"""
total_nnz = 0

def _eq_bodies(i):
"""Returns body pair for equality constraint i."""
obj1id, obj2id = mjm.eq_obj1id[i], mjm.eq_obj2id[i]
if mjm.eq_objtype[i] == mujoco.mjtObj.mjOBJ_SITE:
return mjm.site_bodyid[obj1id], mjm.site_bodyid[obj2id]
return obj1id, obj2id

# equality constraints (assume all active)
for i in range(mjm.neq):
eq_type = mjm.eq_type[i]

if eq_type == mujoco.mjtEq.mjEQ_CONNECT:
total_nnz += 3 * _body_pair_nnz(mjm, *_eq_bodies(i))

elif eq_type == mujoco.mjtEq.mjEQ_WELD:
total_nnz += 6 * _body_pair_nnz(mjm, *_eq_bodies(i))

elif eq_type == mujoco.mjtEq.mjEQ_JOINT:
total_nnz += 2 if mjm.eq_obj2id[i] >= 0 else 1

elif eq_type == mujoco.mjtEq.mjEQ_TENDON:
obj1id = mjm.eq_obj1id[i]
obj2id = mjm.eq_obj2id[i]
rownnz1 = mjm.ten_J_rownnz[obj1id] if obj1id < mjm.ntendon else 0
if obj2id >= 0 and obj2id < mjm.ntendon:
rowadr1 = mjm.ten_J_rowadr[obj1id]
rowadr2 = mjm.ten_J_rowadr[obj2id]
rownnz2 = mjm.ten_J_rownnz[obj2id]
cols = set()
for j in range(rownnz1):
cols.add(mjm.ten_J_colind[rowadr1 + j])
for j in range(rownnz2):
cols.add(mjm.ten_J_colind[rowadr2 + j])
total_nnz += len(cols)
else:
total_nnz += rownnz1

elif eq_type == mujoco.mjtEq.mjEQ_FLEX:
obj1id = mjm.eq_obj1id[i]
if obj1id < mjm.nflex:
edge_start = mjm.flex_edgeadr[obj1id]
edge_count = mjm.flex_edgenum[obj1id]
for e in range(edge_count):
total_nnz += mjm.flexedge_J_rownnz[edge_start + e]

# friction constraints
total_nnz += (mjm.dof_frictionloss > 0).sum()
for i in range(mjm.ntendon):
if mjm.tendon_frictionloss[i] > 0:
total_nnz += mjm.ten_J_rownnz[i]

# limit constraints (assume all active)
for i in range(mjm.njnt):
if mjm.jnt_limited[i]:
jnt_type = mjm.jnt_type[i]
if jnt_type == mujoco.mjtJoint.mjJNT_BALL:
total_nnz += 3
elif jnt_type in (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE):
total_nnz += 1
for i in range(mjm.ntendon):
if mjm.tendon_limited[i]:
total_nnz += mjm.ten_J_rownnz[i]

# contact constraints: njmax rows at max body-pair non-zeros
max_contact_nnz = 0

# contact pairs
for i in range(mjm.npair):
g1, g2 = mjm.pair_geom1[i], mjm.pair_geom2[i]
b1, b2 = mjm.geom_bodyid[g1], mjm.geom_bodyid[g2]
max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, b1, b2))

# filter geom-geom pairs (unique body pairs, filtered)
body_pair_seen = set()
for i in range(mjm.ngeom):
bi = mjm.geom_bodyid[i]
cti, cai = mjm.geom_contype[i], mjm.geom_conaffinity[i]
for j in range(i + 1, mjm.ngeom):
bj = mjm.geom_bodyid[j]
if bi == bj:
continue
if mjm.body_weldid[bi] == 0 and mjm.body_weldid[bj] == 0:
continue
bp = (min(bi, bj), max(bi, bj))
if bp in body_pair_seen:
continue
ctj, caj = mjm.geom_contype[j], mjm.geom_conaffinity[j]
if not ((cti & caj) or (ctj & cai)):
continue
body_pair_seen.add(bp)
max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, bi, bj))

# flex vertex contacts
for fi in range(mjm.nflex):
fct = mjm.flex_contype[fi]
fca = mjm.flex_conaffinity[fi]

vert_start = mjm.flex_vertadr[fi]
vert_count = mjm.flex_vertnum[fi]
flex_bodies = {mjm.flex_vertbodyid[vert_start + v] for v in range(vert_count)}

geom_bodies = set()
for g in range(mjm.ngeom):
ct, ca = mjm.geom_contype[g], mjm.geom_conaffinity[g]
if (fct & ca) or (ct & fca):
geom_bodies.add(mjm.geom_bodyid[g])

for fb in flex_bodies:
for gb in geom_bodies:
if fb != gb:
max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, fb, gb))

# flex self-collision
if mjm.flex_selfcollide[fi]:
flex_body_list = sorted(flex_bodies)
for idx1 in range(len(flex_body_list)):
for idx2 in range(idx1 + 1, len(flex_body_list)):
max_contact_nnz = max(
max_contact_nnz,
_body_pair_nnz(mjm, flex_body_list[idx1], flex_body_list[idx2]),
)

total_nnz += njmax * max_contact_nnz

return int(min(max(total_nnz, 1), njmax * mjm.nv))


def _resolve_batch_size(na: int | None, n: int | None, nworld: int, default: int) -> int:
if na is not None:
return na
Expand Down Expand Up @@ -747,9 +908,11 @@ def make_data(
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax

# TODO(team): heuristic for constraint Jacobian number of non-zeros
if njmax_nnz is None or not is_sparse(mjm):
njmax_nnz = njmax * mjm.nv
if njmax_nnz is None:
if is_sparse(mjm):
njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax)
else:
njmax_nnz = njmax * mjm.nv

contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)})
contact.efc_address = wp.array(np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int), dtype=int)
Expand Down Expand Up @@ -917,9 +1080,11 @@ def put_data(
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax

# TODO(team): heuristic for constraint Jacobian number of non-zeros
if njmax_nnz is None or not is_sparse(mjm):
njmax_nnz = njmax * mjm.nv
if njmax_nnz is None:
if is_sparse(mjm):
njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax)
else:
njmax_nnz = njmax * mjm.nv

# ensure static geom positions are computed
# TODO: remove once MjData creation semantics are fixed
Expand Down
Loading