From 3afce7b604ceac289572f681248c2485673753c5 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Wed, 4 Feb 2026 16:37:21 +0000 Subject: [PATCH 1/2] nccd --- mujoco_warp/_src/collision_convex.py | 188 +++++++++++++++------------ mujoco_warp/_src/io.py | 30 +++++ mujoco_warp/_src/types.py | 2 + mujoco_warp/testspeed.py | 3 +- mujoco_warp/viewer.py | 3 +- 5 files changed, 143 insertions(+), 83 deletions(-) diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 7d01e50c4..aa41a23d5 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -197,6 +197,7 @@ def ccd_hfield_kernel_builder( geomtype2: int, gjk_iterations: int, epa_iterations: int, + geomgeomid: int, ): """Kernel builder for heightfield CCD collisions (no multiccd args).""" @@ -246,6 +247,7 @@ def ccd_hfield_kernel( pair_friction: wp.array2d(dtype=vec5), # Data in: naconmax_in: int, + naccdmax_in: int, geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), collision_pair_in: wp.array(dtype=wp.vec2i), @@ -264,6 +266,7 @@ def ccd_hfield_kernel( epa_index_in: wp.array2d(dtype=int), epa_map_in: wp.array2d(dtype=int), epa_horizon_in: wp.array2d(dtype=int), + nccd_in: wp.array(dtype=int), # Data out: nacon_out: wp.array(dtype=int), contact_dist_out: wp.array(dtype=float), @@ -280,18 +283,18 @@ def ccd_hfield_kernel( contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), ): - tid = wp.tid() - if tid >= ncollision_in[0]: + collisionid = wp.tid() + if collisionid >= ncollision_in[0]: return - geoms = collision_pair_in[tid] + geoms = collision_pair_in[collisionid] g1 = geoms[0] g2 = geoms[1] if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2: return - worldid = collision_worldid_in[tid] + worldid = collision_worldid_in[collisionid] # height field filter no_hf_collision, xmin, xmax, ymin, ymax, zmin, zmax = _hfield_filter( @@ -300,6 +303,11 @@ def ccd_hfield_kernel( if no_hf_collision: return + ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1) + if ccdid >= naccdmax_in: + wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid) + return + _, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( geom_condim, geom_priority, @@ -318,7 +326,7 @@ def ccd_hfield_kernel( pair_friction, collision_pair_in, collision_pairid_in, - tid, + collisionid, worldid, ) @@ -388,19 +396,19 @@ def ccd_hfield_kernel( geom2.margin = margin # EPA memory - epa_vert = epa_vert_in[tid] - epa_vert1 = epa_vert1_in[tid] - epa_vert2 = epa_vert2_in[tid] - epa_vert_index1 = epa_vert_index1_in[tid] - epa_vert_index2 = epa_vert_index2_in[tid] - epa_face = epa_face_in[tid] - epa_pr = epa_pr_in[tid] - epa_norm2 = epa_norm2_in[tid] - epa_index = epa_index_in[tid] - epa_map = epa_map_in[tid] - epa_horizon = epa_horizon_in[tid] - - collision_pairid = collision_pairid_in[tid] + epa_vert = epa_vert_in[ccdid] + epa_vert1 = epa_vert1_in[ccdid] + epa_vert2 = epa_vert2_in[ccdid] + epa_vert_index1 = epa_vert_index1_in[ccdid] + epa_vert_index2 = epa_vert_index2_in[ccdid] + epa_face = epa_face_in[ccdid] + epa_pr = epa_pr_in[ccdid] + epa_norm2 = epa_norm2_in[ccdid] + epa_index = epa_index_in[ccdid] + epa_map = epa_map_in[ccdid] + epa_horizon = epa_horizon_in[ccdid] + + collision_pairid = collision_pairid_in[collisionid] # process all prisms in subgrid count = int(0) @@ -709,6 +717,7 @@ def ccd_kernel_builder( gjk_iterations: int, epa_iterations: int, use_multiccd: bool, + geomgeomid: int, ): """Kernel builder for non-heightfield CCD collisions (no hfield args).""" @@ -746,7 +755,7 @@ def eval_ccd_write_contact( geom2: Geom, geoms: wp.vec2i, worldid: int, - tid: int, + ccdid: int, margin: float, gap: float, condim: int, @@ -795,17 +804,17 @@ def eval_ccd_write_contact( geomtype2, x1, x2, - epa_vert_in[tid], - epa_vert1_in[tid], - epa_vert2_in[tid], - epa_vert_index1_in[tid], - epa_vert_index2_in[tid], - epa_face_in[tid], - epa_pr_in[tid], - epa_norm2_in[tid], - epa_index_in[tid], - epa_map_in[tid], - epa_horizon_in[tid], + epa_vert_in[ccdid], + epa_vert1_in[ccdid], + epa_vert2_in[ccdid], + epa_vert_index1_in[ccdid], + epa_vert_index2_in[ccdid], + epa_face_in[ccdid], + epa_pr_in[ccdid], + epa_norm2_in[ccdid], + epa_index_in[ccdid], + epa_map_in[ccdid], + epa_horizon_in[ccdid], ) if dist >= 0.0 and pairid[1] == -1: @@ -822,22 +831,22 @@ def eval_ccd_write_contact( and (geomtype2 == GeomType.BOX or (geomtype2 == GeomType.MESH and geom2.mesh_polyadr > -1)) ): ncontact, witness1, witness2 = multicontact( - multiccd_polygon_in[tid], - multiccd_clipped_in[tid], - multiccd_pnormal_in[tid], - multiccd_pdist_in[tid], - multiccd_idx1_in[tid], - multiccd_idx2_in[tid], - multiccd_n1_in[tid], - multiccd_n2_in[tid], - multiccd_endvert_in[tid], - multiccd_face1_in[tid], - multiccd_face2_in[tid], - epa_vert1_in[tid], - epa_vert2_in[tid], - epa_vert_index1_in[tid], - epa_vert_index2_in[tid], - epa_face_in[tid, idx], + multiccd_polygon_in[ccdid], + multiccd_clipped_in[ccdid], + multiccd_pnormal_in[ccdid], + multiccd_pdist_in[ccdid], + multiccd_idx1_in[ccdid], + multiccd_idx2_in[ccdid], + multiccd_n1_in[ccdid], + multiccd_n2_in[ccdid], + multiccd_endvert_in[ccdid], + multiccd_face1_in[ccdid], + multiccd_face2_in[ccdid], + epa_vert1_in[ccdid], + epa_vert2_in[ccdid], + epa_vert_index1_in[ccdid], + epa_vert_index2_in[ccdid], + epa_face_in[ccdid, idx], w1, w2, geom1, @@ -932,6 +941,7 @@ def ccd_kernel( pair_friction: wp.array2d(dtype=vec5), # Data in: naconmax_in: int, + naccdmax_in: int, geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), collision_pair_in: wp.array(dtype=wp.vec2i), @@ -961,6 +971,7 @@ def ccd_kernel( multiccd_endvert_in: wp.array2d(dtype=wp.vec3), multiccd_face1_in: wp.array2d(dtype=wp.vec3), multiccd_face2_in: wp.array2d(dtype=wp.vec3), + nccd_in: wp.array(dtype=int), # Data out: nacon_out: wp.array(dtype=int), contact_dist_out: wp.array(dtype=float), @@ -977,18 +988,23 @@ def ccd_kernel( contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), ): - tid = wp.tid() - if tid >= ncollision_in[0]: + collisionid = wp.tid() + if collisionid >= ncollision_in[0]: return - geoms = collision_pair_in[tid] + geoms = collision_pair_in[collisionid] g1 = geoms[0] g2 = geoms[1] if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2: return - worldid = collision_worldid_in[tid] + ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1) + if ccdid >= naccdmax_in: + wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid) + return + + worldid = collision_worldid_in[collisionid] _, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( geom_condim, @@ -1008,7 +1024,7 @@ def ccd_kernel( pair_friction, collision_pair_in, collision_pairid_in, - tid, + collisionid, worldid, ) @@ -1066,7 +1082,7 @@ def ccd_kernel( geom2, geoms, worldid, - tid, + ccdid, margin, gap, condim, @@ -1077,7 +1093,7 @@ def ccd_kernel( geom1.pos, geom2.pos, 0, - collision_pairid_in[tid], + collision_pairid_in[collisionid], contact_dist_out, contact_pos_out, contact_frame_out, @@ -1144,10 +1160,11 @@ def convex_narrowphase(m: Model, d: Data): """ def _pair_count(p1: int, p2: int) -> int: - return m.geom_pair_type_count[upper_trid_index(len(GeomType), p1, p2)] + idx = upper_trid_index(len(GeomType), p1, p2) + return m.geom_pair_type_count[idx], idx # no convex collisions, early return - if not any(_pair_count(g[0].value, g[1].value) for g in _CONVEX_COLLISION_PAIRS): + if not any(_pair_count(g[0].value, g[1].value)[0] for g in _CONVEX_COLLISION_PAIRS): return epa_iterations = m.opt.ccd_iterations @@ -1157,28 +1174,31 @@ def _pair_count(p1: int, p2: int) -> int: nmaxpolygon = m.nmaxpolygon if use_multiccd else 0 nmaxmeshdeg = m.nmaxmeshdeg if use_multiccd else 0 + # ccd collider count + nccd = wp.zeros(len(GeomType) * (len(GeomType) + 1) // 2, dtype=int) + # epa_vert: vertices in EPA polytope in Minkowski space - epa_vert = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3) + epa_vert = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3) # epa_vert1: vertices in EPA polytope in geom 1 space - epa_vert1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3) + epa_vert1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3) # epa_vert2: vertices in EPA polytope in geom 2 space - epa_vert2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3) + epa_vert2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3) # epa_vert_index1: vertex indices in EPA polytope for geom 1 - epa_vert_index1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int) + epa_vert_index1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int) # epa_vert_index2: vertex indices in EPA polytope for geom 2 (naconmax, 5 + CCDiter) - epa_vert_index2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int) + epa_vert_index2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int) # epa_face: faces of polytope represented by three indices - epa_face = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3i) + epa_face = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3i) # epa_pr: projection of origin on polytope faces - epa_pr = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3) + epa_pr = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3) # epa_norm2: epa_pr * epa_pr - epa_norm2 = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float) + epa_norm2 = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float) # epa_index: index of face in polytope map - epa_index = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int) + epa_index = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int) # epa_map: status of faces in polytope - epa_map = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int) + epa_map = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int) # epa_horizon: index pair (i j) of edges on horizon - epa_horizon = wp.empty(shape=(d.naconmax, 2 * MJ_MAX_EPAHORIZON), dtype=int) + epa_horizon = wp.empty(shape=(d.naccdmax, 2 * MJ_MAX_EPAHORIZON), dtype=int) # Contact outputs contact_outputs = [ @@ -1202,9 +1222,10 @@ def _pair_count(p1: int, p2: int) -> int: for geom_pair in _HFIELD_COLLISION_PAIRS: g1 = geom_pair[0].value g2 = geom_pair[1].value - if _pair_count(g1, g2): + count, geomgeomid = _pair_count(g1, g2) + if count: wp.launch( - ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations), + ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, geomgeomid), dim=d.naconmax, inputs=[ m.opt.ccd_tolerance, @@ -1248,6 +1269,7 @@ def _pair_count(p1: int, p2: int) -> int: m.pair_gap, m.pair_friction, d.naconmax, + d.naccdmax, d.geom_xpos, d.geom_xmat, d.collision_pair, @@ -1265,41 +1287,43 @@ def _pair_count(p1: int, p2: int) -> int: epa_index, epa_map, epa_horizon, + nccd, ], outputs=contact_outputs, ) # Allocate multiccd arrays only for non-heightfield collisions # multiccd_polygon: clipped contact surface - multiccd_polygon = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3) + multiccd_polygon = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3) # multiccd_clipped: clipped contact surface (intermediate) - multiccd_clipped = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3) + multiccd_clipped = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3) # multiccd_pnormal: plane normal of clipping polygon - multiccd_pnormal = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3) + multiccd_pnormal = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3) # multiccd_pdist: plane distance of clipping polygon - multiccd_pdist = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=float) + multiccd_pdist = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=float) # multiccd_idx1: list of normal index candidates for Geom 1 - multiccd_idx1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int) + multiccd_idx1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int) # multiccd_idx2: list of normal index candidates for Geom 2 - multiccd_idx2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int) + multiccd_idx2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int) # multiccd_n1: list of normal candidates for Geom 1 - multiccd_n1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3) + multiccd_n1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3) # multiccd_n2: list of normal candidates for Geom 1 - multiccd_n2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3) + multiccd_n2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3) # multiccd_endvert: list of edge vertices candidates - multiccd_endvert = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3) + multiccd_endvert = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3) # multiccd_face1: contact face - multiccd_face1 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3) + multiccd_face1 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3) # multiccd_face2: contact face - multiccd_face2 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3) + multiccd_face2 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3) # Launch non-heightfield collision kernels (no hfield args, 78 args total) for geom_pair in _NON_HFIELD_COLLISION_PAIRS: g1 = geom_pair[0].value g2 = geom_pair[1].value - if _pair_count(g1, g2): + count, geomgeomid = _pair_count(g1, g2) + if count: wp.launch( - ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd), + ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd, geomgeomid), dim=d.naconmax, inputs=[ m.opt.ccd_tolerance, @@ -1336,6 +1360,7 @@ def _pair_count(p1: int, p2: int) -> int: m.pair_gap, m.pair_friction, d.naconmax, + d.naccdmax, d.geom_xpos, d.geom_xmat, d.collision_pair, @@ -1364,6 +1389,7 @@ def _pair_count(p1: int, p2: int) -> int: multiccd_endvert, multiccd_face1, multiccd_face2, + nccd, ], outputs=contact_outputs, ) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index e2ed2d238..ae51076ec 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -590,8 +590,10 @@ def make_data( mjm: mujoco.MjModel, nworld: int = 1, nconmax: Optional[int] = None, + nccdmax: Optional[int] = None, njmax: Optional[int] = None, naconmax: Optional[int] = None, + naccdmax: Optional[int] = None, ) -> types.Data: """Creates a data object on device. @@ -600,9 +602,11 @@ def make_data( nworld: Number of worlds. nconmax: Number of contacts to allocate per world. Contacts exist in large heterogeneous arrays: one world may have more than nconmax contacts. + nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax. njmax: Number of constraints to allocate per world. Constraint arrays are batched by world: no world may have more than njmax constraints. naconmax: Number of contacts to allocate for all worlds. Overrides nconmax. + naccdmax: Maximum number of CCD contacts. Defaults to naconmax. Returns: The data object containing the current state and output arrays (device). @@ -611,6 +615,9 @@ def make_data( if nconmax is None: nconmax = _default_nconmax(mjm) + if nccdmax is None: + nccdmax = nconmax + if njmax is None: njmax = _default_njmax(mjm) @@ -624,6 +631,13 @@ def make_data( elif naconmax < 0: raise ValueError("naconmax must be >= 0") + if naccdmax is None: + if nccdmax < 0: + raise ValueError("nccdmax must be >= 0") + naccdmax = nworld * nccdmax + elif naccdmax < 0: + raise ValueError("naccdmax must be >= 0") + if njmax < 0: raise ValueError("njmax must be >= 0") @@ -655,6 +669,7 @@ def make_data( "efc": efc, "nworld": nworld, "naconmax": naconmax, + "naccdmax": naccdmax, "njmax": njmax, "qM": None, "qLD": None, @@ -695,8 +710,10 @@ def put_data( mjd: mujoco.MjData, nworld: int = 1, nconmax: Optional[int] = None, + nccdmax: Optional[int] = None, njmax: Optional[int] = None, naconmax: Optional[int] = None, + naccdmax: Optional[int] = None, ) -> types.Data: """Moves data from host to a device. @@ -706,9 +723,11 @@ def put_data( nworld: The number of worlds. nconmax: Number of contacts to allocate per world. Contacts exist in large heterogenous arrays: one world may have more than nconmax contacts. + nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax. njmax: Number of constraints to allocate per world. Constraint arrays are batched by world: no world may have more than njmax constraints. naconmax: Number of contacts to allocate for all worlds. Overrides nconmax. + naccdmax: Maximum number of CCD contacts. Defaults to naconmax. Returns: The data object containing the current state and output arrays (device). @@ -720,6 +739,9 @@ def put_data( if nconmax is None: nconmax = _default_nconmax(mjm, mjd) + if nccdmax is None: + nccdmax = nconmax + if njmax is None: njmax = _default_njmax(mjm, mjd) @@ -735,6 +757,13 @@ def put_data( elif naconmax < mjd.ncon * nworld: raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})") + if naccdmax is None: + if nccdmax < 0: + raise ValueError("nccdmax must be >= 0") + naccdmax = nworld * nccdmax + elif naccdmax < 0: + raise ValueError("naccdmax must be >= 0") + if njmax < 0: raise ValueError("njmax must be >= 0") @@ -812,6 +841,7 @@ def put_data( "efc": efc, "nworld": nworld, "naconmax": naconmax, + "naccdmax": naccdmax, "njmax": njmax, # fields set after initialization: "solver_niter": None, diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 6ff16d703..036478253 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -1634,6 +1634,7 @@ class Data: warp only fields: nworld: number of worlds naconmax: maximum number of contacts (shared across all worlds) + naccdmax: Maximum number of CCD contacts for any collider (all worlds) njmax: maximum number of constraints per world nacon: number of detected contacts (across all worlds) (1,) ne_connect: number of equality connect constraints (nworld,) @@ -1730,6 +1731,7 @@ class Data: # warp only fields: nworld: int naconmax: int + naccdmax: int njmax: int nacon: array(1, int) ne_connect: array("nworld", int) diff --git a/mujoco_warp/testspeed.py b/mujoco_warp/testspeed.py index 53dd9cf3b..23e269b5a 100644 --- a/mujoco_warp/testspeed.py +++ b/mujoco_warp/testspeed.py @@ -48,6 +48,7 @@ _NWORLD = flags.DEFINE_integer("nworld", 8192, "number of parallel rollouts") _NCONMAX = flags.DEFINE_integer("nconmax", None, "override maximum number of contacts for all worlds") _NJMAX = flags.DEFINE_integer("njmax", None, "override maximum number of constraints per world") +_NCCDMAX = flags.DEFINE_integer("nccdmax", None, "override maximum number of CCD contacts per world") _OVERRIDE = flags.DEFINE_multi_string("override", [], "Model overrides (notation: foo.bar = baz)", short_name="o") _KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.") _CLEAR_KERNEL_CACHE = flags.DEFINE_bool("clear_kernel_cache", False, "clear kernel cache (to calculate full JIT time)") @@ -176,7 +177,7 @@ def _main(argv: Sequence[str]): f" solver: {solver} cone: {cone} iterations: {iterations} {ls_str}\n" f" integrator: {integrator} graph_conditional: {m.opt.graph_conditional}" ) - d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value) + d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value, nccdmax=_NCCDMAX.value) print(f"Data\n nworld: {d.nworld} naconmax: {d.naconmax} njmax: {d.njmax}\n") print(f"Rolling out {_NSTEP.value} steps at dt = {m.opt.timestep.numpy()[0]:.3f}...") diff --git a/mujoco_warp/viewer.py b/mujoco_warp/viewer.py index cd8680498..13d516a99 100644 --- a/mujoco_warp/viewer.py +++ b/mujoco_warp/viewer.py @@ -55,6 +55,7 @@ class EngineOptions(enum.IntEnum): _ENGINE = flags.DEFINE_enum_class("engine", EngineOptions.WARP, EngineOptions, "Simulation engine") _NCONMAX = flags.DEFINE_integer("nconmax", None, "Maximum number of contacts.") _NJMAX = flags.DEFINE_integer("njmax", None, "Maximum number of constraints per world.") +_NCCDMAX = flags.DEFINE_integer("nccdmax", None, "Maximum number of CCD contacts per world.") _OVERRIDE = flags.DEFINE_multi_string("override", [], "Model overrides (notation: foo.bar = baz)", short_name="o") _KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.") _DEVICE = flags.DEFINE_string("device", None, "override the default Warp device") @@ -152,7 +153,7 @@ def _main(argv: Sequence[str]) -> None: f" solver: {solver} cone: {cone} iterations: {iterations} {ls_str}\n" f" integrator: {integrator} graph_conditional: {m.opt.graph_conditional}" ) - d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value) + d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value, nccdmax=_NCCDMAX.value) print(f"Data\n nworld: {d.nworld} nconmax: {d.naconmax / d.nworld} njmax: {d.njmax}\n") graph = _compile_step(m, d) print(f"MuJoCo Warp simulating with dt = {m.opt.timestep.numpy()[0]:.3f}...") From 2693c8eee35fae506415b129e74648697633a63e Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Wed, 11 Feb 2026 12:55:46 +0000 Subject: [PATCH 2/2] address erikfrey pr comments --- mujoco_warp/_src/collision_convex.py | 8 +++--- mujoco_warp/_src/io.py | 38 ++++++++++++++++++---------- mujoco_warp/_src/io_test.py | 22 ++++++++++++++++ mujoco_warp/_src/types.py | 2 +- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 9e967fb22..781b7aa9e 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -1102,7 +1102,7 @@ def convex_narrowphase(m: Model, d: Data, ctx: CollisionContext, collision_table computations for non-existent pair types. """ - def _pair_count(p1: int, p2: int) -> int: + def _pair_count(p1: int, p2: int) -> Tuple[int, int]: idx = upper_trid_index(len(GeomType), p1, p2) return m.geom_pair_type_count[idx], idx @@ -1113,11 +1113,11 @@ def _pair_count(p1: int, p2: int) -> int: return # compute nmaxpolygon and nmaxmeshdeg given the geom pairs for the model - nboxbox = _pair_count(GeomType.BOX.value, GeomType.BOX.value)[0] + nboxbox, _ = _pair_count(GeomType.BOX.value, GeomType.BOX.value) if (GeomType.BOX, GeomType.BOX) not in collision_table: nboxbox = 0 - nboxmesh = _pair_count(GeomType.BOX.value, GeomType.MESH.value)[0] - nmeshmesh = _pair_count(GeomType.MESH.value, GeomType.MESH.value)[0] + nboxmesh, _ = _pair_count(GeomType.BOX.value, GeomType.MESH.value) + nmeshmesh, _ = _pair_count(GeomType.MESH.value, GeomType.MESH.value) epa_iterations = 16 if nboxbox == ncollision else m.opt.ccd_iterations diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index c00064e6f..383599eee 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -681,31 +681,36 @@ def make_data( if nconmax is None: nconmax = _default_nconmax(mjm) + if nconmax < 0: + raise ValueError("nconmax must be >= 0") + if nccdmax is None: nccdmax = nconmax + elif nccdmax < 0: + raise ValueError("nccdmax must be >= 0") + elif nccdmax > nconmax: + raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})") if njmax is None: njmax = _default_njmax(mjm) + if njmax < 0: + raise ValueError("njmax must be >= 0") + if nworld < 1: raise ValueError(f"nworld must be >= 1") if naconmax is None: - if nconmax < 0: - raise ValueError("nconmax must be >= 0") naconmax = nworld * nconmax elif naconmax < 0: raise ValueError("naconmax must be >= 0") if naccdmax is None: - if nccdmax < 0: - raise ValueError("nccdmax must be >= 0") naccdmax = nworld * nccdmax elif naccdmax < 0: raise ValueError("naccdmax must be >= 0") - - if njmax < 0: - raise ValueError("njmax must be >= 0") + elif naccdmax > naconmax: + raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})") sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int}) sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max() @@ -809,18 +814,26 @@ def put_data( if nconmax is None: nconmax = _default_nconmax(mjm, mjd) + if nconmax < 0: + raise ValueError("nconmax must be >= 0") + if nccdmax is None: nccdmax = nconmax + elif nccdmax < 0: + raise ValueError("nccdmax must be >= 0") + elif nccdmax > nconmax: + raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})") if njmax is None: njmax = _default_njmax(mjm, mjd) + if njmax < 0: + raise ValueError("njmax must be >= 0") + if nworld < 1: raise ValueError(f"nworld must be >= 1") if naconmax is None: - if nconmax < 0: - raise ValueError("nconmax must be >= 0") if mjd.ncon > nconmax: raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})") naconmax = nworld * nconmax @@ -828,14 +841,11 @@ def put_data( raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})") if naccdmax is None: - if nccdmax < 0: - raise ValueError("nccdmax must be >= 0") naccdmax = nworld * nccdmax elif naccdmax < 0: raise ValueError("naccdmax must be >= 0") - - if njmax < 0: - raise ValueError("njmax must be >= 0") + elif naccdmax > naconmax: + raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})") if mjd.nefc > njmax: raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})") diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index 4ecf590e5..fc5c85efe 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -438,6 +438,28 @@ def test_static_geom_collision_with_put_data(self): box_z = d.xpos.numpy()[0, 1, 2] # world 0, body 1 (box), z coordinate self.assertGreater(box_z, 0.4, msg=f"Box fell through ground plane (z={box_z}, should be > 0.4)") + def test_make_data_nccdmax_exceeds_nconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + with self.assertRaises(ValueError, msg="nccdmax.*nconmax"): + mjwarp.make_data(mjm, nconmax=16, nccdmax=17) + + def test_make_data_naccdmax_exceeds_naconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + with self.assertRaises(ValueError, msg="naccdmax.*naconmax"): + mjwarp.make_data(mjm, nconmax=16, naconmax=16, naccdmax=17) + + def test_put_data_nccdmax_exceeds_nconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + mjd = mujoco.MjData(mjm) + with self.assertRaises(ValueError, msg="nccdmax.*nconmax"): + mjwarp.put_data(mjm, mjd, nconmax=16, nccdmax=17) + + def test_put_data_naccdmax_exceeds_naconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + mjd = mujoco.MjData(mjm) + with self.assertRaises(ValueError, msg="naccdmax.*naconmax"): + mjwarp.put_data(mjm, mjd, nconmax=16, naconmax=16, naccdmax=17) + def test_noslip_solver(self): with self.assertRaises(NotImplementedError): test_data.fixture( diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 580c279c4..8d50fc8f8 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -1663,7 +1663,7 @@ class Data: warp only fields: nworld: number of worlds naconmax: maximum number of contacts (shared across all worlds) - naccdmax: Maximum number of CCD contacts for any collider (all worlds) + naccdmax: maximum number of contacts for CCD (all worlds) njmax: maximum number of constraints per world nacon: number of detected contacts (across all worlds) (1,) ncollision: collision count from broadphase (1,)