From fd6da8c5548ace3f3651999222a69afb773f603a Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Tue, 17 Mar 2026 21:14:16 +0000 Subject: [PATCH] heuristic for estimating the number of non-zeros in constraint_jacobian --- mujoco_warp/_src/io.py | 177 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 171 insertions(+), 6 deletions(-) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 2e95e868f..eea85d43b 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -670,6 +670,167 @@ def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> return int(valid_sizes[np.searchsorted(valid_sizes, njmax)]) +def _body_pair_nnz(mjm: mujoco.MjModel, body1: int, body2: int) -> int: + """Returns the number of unique DOFs in the kinematic tree union of two bodies.""" + body1 = mjm.body_weldid[body1] + body2 = mjm.body_weldid[body2] + da1 = mjm.body_dofadr[body1] + mjm.body_dofnum[body1] - 1 + da2 = mjm.body_dofadr[body2] + mjm.body_dofnum[body2] - 1 + nnz = 0 + while da1 >= 0 or da2 >= 0: + da = max(da1, da2) + if da1 == da: + da1 = mjm.dof_parentid[da1] + if da2 == da: + da2 = mjm.dof_parentid[da2] + nnz += 1 + return nnz + + +def _default_njmax_nnz(mjm: mujoco.MjModel, nconmax: int, njmax: int) -> int: + """Returns a heuristic estimate for the number of non-zeros in the sparse constraint Jacobian. + + Assumes all equality, friction, and limit constraints are active and computes + their non-zeros. For contacts, assumes njmax contact rows at the maximum + body-pair non-zeros from all enabled collision pairs. + + Args: + mjm: The model containing kinematic and dynamic information (host). + nconmax: Maximum number of contacts per world. + njmax: Maximum number of constraint rows per world. + + Returns: + Estimated number of non-zeros in the constraint Jacobian. + """ + total_nnz = 0 + + def _eq_bodies(i): + """Returns body pair for equality constraint i.""" + obj1id, obj2id = mjm.eq_obj1id[i], mjm.eq_obj2id[i] + if mjm.eq_objtype[i] == mujoco.mjtObj.mjOBJ_SITE: + return mjm.site_bodyid[obj1id], mjm.site_bodyid[obj2id] + return obj1id, obj2id + + # equality constraints (assume all active) + for i in range(mjm.neq): + eq_type = mjm.eq_type[i] + + if eq_type == mujoco.mjtEq.mjEQ_CONNECT: + total_nnz += 3 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_WELD: + total_nnz += 6 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_JOINT: + total_nnz += 2 if mjm.eq_obj2id[i] >= 0 else 1 + + elif eq_type == mujoco.mjtEq.mjEQ_TENDON: + obj1id = mjm.eq_obj1id[i] + obj2id = mjm.eq_obj2id[i] + rownnz1 = mjm.ten_J_rownnz[obj1id] if obj1id < mjm.ntendon else 0 + if obj2id >= 0 and obj2id < mjm.ntendon: + rowadr1 = mjm.ten_J_rowadr[obj1id] + rowadr2 = mjm.ten_J_rowadr[obj2id] + rownnz2 = mjm.ten_J_rownnz[obj2id] + cols = set() + for j in range(rownnz1): + cols.add(mjm.ten_J_colind[rowadr1 + j]) + for j in range(rownnz2): + cols.add(mjm.ten_J_colind[rowadr2 + j]) + total_nnz += len(cols) + else: + total_nnz += rownnz1 + + elif eq_type == mujoco.mjtEq.mjEQ_FLEX: + obj1id = mjm.eq_obj1id[i] + if obj1id < mjm.nflex: + edge_start = mjm.flex_edgeadr[obj1id] + edge_count = mjm.flex_edgenum[obj1id] + for e in range(edge_count): + total_nnz += mjm.flexedge_J_rownnz[edge_start + e] + + # friction constraints + total_nnz += (mjm.dof_frictionloss > 0).sum() + for i in range(mjm.ntendon): + if mjm.tendon_frictionloss[i] > 0: + total_nnz += mjm.ten_J_rownnz[i] + + # limit constraints (assume all active) + for i in range(mjm.njnt): + if mjm.jnt_limited[i]: + jnt_type = mjm.jnt_type[i] + if jnt_type == mujoco.mjtJoint.mjJNT_BALL: + total_nnz += 3 + elif jnt_type in (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE): + total_nnz += 1 + for i in range(mjm.ntendon): + if mjm.tendon_limited[i]: + total_nnz += mjm.ten_J_rownnz[i] + + # contact constraints: njmax rows at max body-pair non-zeros + max_contact_nnz = 0 + + # contact pairs + for i in range(mjm.npair): + g1, g2 = mjm.pair_geom1[i], mjm.pair_geom2[i] + b1, b2 = mjm.geom_bodyid[g1], mjm.geom_bodyid[g2] + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, b1, b2)) + + # filter geom-geom pairs (unique body pairs, filtered) + body_pair_seen = set() + for i in range(mjm.ngeom): + bi = mjm.geom_bodyid[i] + cti, cai = mjm.geom_contype[i], mjm.geom_conaffinity[i] + for j in range(i + 1, mjm.ngeom): + bj = mjm.geom_bodyid[j] + if bi == bj: + continue + if mjm.body_weldid[bi] == 0 and mjm.body_weldid[bj] == 0: + continue + bp = (min(bi, bj), max(bi, bj)) + if bp in body_pair_seen: + continue + ctj, caj = mjm.geom_contype[j], mjm.geom_conaffinity[j] + if not ((cti & caj) or (ctj & cai)): + continue + body_pair_seen.add(bp) + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, bi, bj)) + + # flex vertex contacts + for fi in range(mjm.nflex): + fct = mjm.flex_contype[fi] + fca = mjm.flex_conaffinity[fi] + + vert_start = mjm.flex_vertadr[fi] + vert_count = mjm.flex_vertnum[fi] + flex_bodies = {mjm.flex_vertbodyid[vert_start + v] for v in range(vert_count)} + + geom_bodies = set() + for g in range(mjm.ngeom): + ct, ca = mjm.geom_contype[g], mjm.geom_conaffinity[g] + if (fct & ca) or (ct & fca): + geom_bodies.add(mjm.geom_bodyid[g]) + + for fb in flex_bodies: + for gb in geom_bodies: + if fb != gb: + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, fb, gb)) + + # flex self-collision + if mjm.flex_selfcollide[fi]: + flex_body_list = sorted(flex_bodies) + for idx1 in range(len(flex_body_list)): + for idx2 in range(idx1 + 1, len(flex_body_list)): + max_contact_nnz = max( + max_contact_nnz, + _body_pair_nnz(mjm, flex_body_list[idx1], flex_body_list[idx2]), + ) + + total_nnz += njmax * max_contact_nnz + + return int(min(max(total_nnz, 1), njmax * mjm.nv)) + + def _resolve_batch_size(na: int | None, n: int | None, nworld: int, default: int) -> int: if na is not None: return na @@ -748,9 +909,11 @@ def make_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax - # TODO(team): heuristic for constraint Jacobian number of non-zeros - if njmax_nnz is None or not SPARSE_CONSTRAINT_JACOBIAN: - njmax_nnz = njmax * mjm.nv + if njmax_nnz is None: + if SPARSE_CONSTRAINT_JACOBIAN: + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)}) contact.efc_address = wp.array(np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int), dtype=int) @@ -918,9 +1081,11 @@ def put_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax - # TODO(team): heuristic for constraint Jacobian number of non-zeros - if njmax_nnz is None or not SPARSE_CONSTRAINT_JACOBIAN: - njmax_nnz = njmax * mjm.nv + if njmax_nnz is None: + if SPARSE_CONSTRAINT_JACOBIAN: + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv # ensure static geom positions are computed # TODO: remove once MjData creation semantics are fixed