diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index ebc292774..781b7aa9e 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -155,6 +155,7 @@ def ccd_hfield_kernel_builder( geomtype2: int, gjk_iterations: int, epa_iterations: int, + geomgeomid: int, ): """Kernel builder for heightfield CCD collisions (no multiccd args).""" @@ -205,6 +206,7 @@ def ccd_hfield_kernel( geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), naconmax_in: int, + naccdmax_in: int, ncollision_in: wp.array(dtype=int), # In: collision_pair_in: wp.array(dtype=wp.vec2i), @@ -218,6 +220,7 @@ def ccd_hfield_kernel( epa_pr_in: wp.array2d(dtype=wp.vec3), epa_norm2_in: wp.array2d(dtype=float), epa_horizon_in: wp.array2d(dtype=int), + nccd_in: wp.array(dtype=int), # Data out: contact_dist_out: wp.array(dtype=float), contact_pos_out: wp.array(dtype=wp.vec3), @@ -234,18 +237,18 @@ def ccd_hfield_kernel( contact_geomcollisionid_out: wp.array(dtype=int), nacon_out: wp.array(dtype=int), ): - tid = wp.tid() - if tid >= ncollision_in[0]: + collisionid = wp.tid() + if collisionid >= ncollision_in[0]: return - geoms = collision_pair_in[tid] + geoms = collision_pair_in[collisionid] g1 = geoms[0] g2 = geoms[1] if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2: return - worldid = collision_worldid_in[tid] + worldid = collision_worldid_in[collisionid] # height field filter no_hf_collision, xmin, xmax, ymin, ymax, zmin, zmax = _hfield_filter( @@ -269,6 +272,11 @@ def ccd_hfield_kernel( if no_hf_collision: return + ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1) + if ccdid >= naccdmax_in: + wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid) + return + _, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( geom_condim, geom_priority, @@ -287,7 +295,7 @@ def ccd_hfield_kernel( pair_friction, collision_pair_in, collision_pairid_in, - tid, + collisionid, worldid, ) @@ -368,16 +376,16 @@ def ccd_hfield_kernel( geom2.margin = margin # EPA memory - epa_vert1 = epa_vert1_in[tid] - epa_vert2 = epa_vert2_in[tid] - epa_vert_index1 = epa_vert_index1_in[tid] - epa_vert_index2 = epa_vert_index2_in[tid] - epa_face = epa_face_in[tid] - epa_pr = epa_pr_in[tid] - epa_norm2 = epa_norm2_in[tid] - epa_horizon = epa_horizon_in[tid] + epa_vert1 = epa_vert1_in[ccdid] + epa_vert2 = epa_vert2_in[ccdid] + epa_vert_index1 = epa_vert_index1_in[ccdid] + epa_vert_index2 = epa_vert_index2_in[ccdid] + epa_face = epa_face_in[ccdid] + epa_pr = epa_pr_in[ccdid] + epa_norm2 = epa_norm2_in[ccdid] + epa_horizon = epa_horizon_in[ccdid] - collision_pairid = collision_pairid_in[tid] + collision_pairid = collision_pairid_in[collisionid] # process all prisms in subgrid count = int(0) @@ -692,6 +700,7 @@ def ccd_kernel_builder( gjk_iterations: int, epa_iterations: int, use_multiccd: bool, + geomgeomid: int, ): """Kernel builder for non-heightfield CCD collisions (no hfield args).""" @@ -725,7 +734,7 @@ def eval_ccd_write_contact( geom2: Geom, geoms: wp.vec2i, worldid: int, - tid: int, + ccdid: int, margin: float, gap: float, condim: int, @@ -773,14 +782,14 @@ def eval_ccd_write_contact( geomtype2, x1, x2, - epa_vert1_in[tid], - epa_vert2_in[tid], - epa_vert_index1_in[tid], - epa_vert_index2_in[tid], - epa_face_in[tid], - epa_pr_in[tid], - epa_norm2_in[tid], - epa_horizon_in[tid], + epa_vert1_in[ccdid], + epa_vert2_in[ccdid], + epa_vert_index1_in[ccdid], + epa_vert_index2_in[ccdid], + epa_face_in[ccdid], + epa_pr_in[ccdid], + epa_norm2_in[ccdid], + epa_horizon_in[ccdid], ) if dist >= 0.0 and pairid[1] == -1: @@ -802,22 +811,22 @@ def eval_ccd_write_contact( if multiccd_idx > -1: ncollision, witness1, witness2 = multicontact( - multiccd_polygon_in[tid], - multiccd_clipped_in[tid], - multiccd_pnormal_in[tid], - multiccd_pdist_in[tid], - multiccd_idx1_in[tid], - multiccd_idx2_in[tid], - multiccd_n1_in[tid], - multiccd_n2_in[tid], - multiccd_endvert_in[tid], - multiccd_face1_in[tid], - multiccd_face2_in[tid], - epa_vert1_in[tid], - epa_vert2_in[tid], - epa_vert_index1_in[tid], - epa_vert_index2_in[tid], - epa_face_in[tid, multiccd_idx], + multiccd_polygon_in[ccdid], + multiccd_clipped_in[ccdid], + multiccd_pnormal_in[ccdid], + multiccd_pdist_in[ccdid], + multiccd_idx1_in[ccdid], + multiccd_idx2_in[ccdid], + multiccd_n1_in[ccdid], + multiccd_n2_in[ccdid], + multiccd_endvert_in[ccdid], + multiccd_face1_in[ccdid], + multiccd_face2_in[ccdid], + epa_vert1_in[ccdid], + epa_vert2_in[ccdid], + epa_vert_index1_in[ccdid], + epa_vert_index2_in[ccdid], + epa_face_in[ccdid, multiccd_idx], w1, w2, geom1, @@ -914,6 +923,7 @@ def ccd_kernel( geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), naconmax_in: int, + naccdmax_in: int, ncollision_in: wp.array(dtype=int), # In: collision_pair_in: wp.array(dtype=wp.vec2i), @@ -938,6 +948,7 @@ def ccd_kernel( multiccd_endvert_in: wp.array2d(dtype=wp.vec3), multiccd_face1_in: wp.array2d(dtype=wp.vec3), multiccd_face2_in: wp.array2d(dtype=wp.vec3), + nccd_in: wp.array(dtype=int), # Data out: contact_dist_out: wp.array(dtype=float), contact_pos_out: wp.array(dtype=wp.vec3), @@ -954,18 +965,23 @@ def ccd_kernel( contact_geomcollisionid_out: wp.array(dtype=int), nacon_out: wp.array(dtype=int), ): - tid = wp.tid() - if tid >= ncollision_in[0]: + collisionid = wp.tid() + if collisionid >= ncollision_in[0]: return - geoms = collision_pair_in[tid] + geoms = collision_pair_in[collisionid] g1 = geoms[0] g2 = geoms[1] if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2: return - worldid = collision_worldid_in[tid] + ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1) + if ccdid >= naccdmax_in: + wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid) + return + + worldid = collision_worldid_in[collisionid] _, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( geom_condim, @@ -985,7 +1001,7 @@ def ccd_kernel( pair_friction, collision_pair_in, collision_pairid_in, - tid, + collisionid, worldid, ) @@ -1039,7 +1055,7 @@ def ccd_kernel( geom2, geoms, worldid, - tid, + ccdid, margin, gap, condim, @@ -1049,7 +1065,7 @@ def ccd_kernel( solimp, geom1.pos, geom2.pos, - collision_pairid_in[tid], + collision_pairid_in[collisionid], contact_dist_out, contact_pos_out, contact_frame_out, @@ -1086,21 +1102,22 @@ def convex_narrowphase(m: Model, d: Data, ctx: CollisionContext, collision_table computations for non-existent pair types. """ - def _pair_count(p1: int, p2: int) -> int: - return m.geom_pair_type_count[upper_trid_index(len(GeomType), p1, p2)] + def _pair_count(p1: int, p2: int) -> Tuple[int, int]: + idx = upper_trid_index(len(GeomType), p1, p2) + return m.geom_pair_type_count[idx], idx - ncollision = sum(_pair_count(g[0].value, g[1].value) for g in collision_table) + ncollision = sum(_pair_count(g[0].value, g[1].value)[0] for g in collision_table) # no convex collisions, early return if ncollision == 0: return # compute nmaxpolygon and nmaxmeshdeg given the geom pairs for the model - nboxbox = _pair_count(GeomType.BOX.value, GeomType.BOX.value) + nboxbox, _ = _pair_count(GeomType.BOX.value, GeomType.BOX.value) if (GeomType.BOX, GeomType.BOX) not in collision_table: nboxbox = 0 - nboxmesh = _pair_count(GeomType.BOX.value, GeomType.MESH.value) - nmeshmesh = _pair_count(GeomType.MESH.value, GeomType.MESH.value) + nboxmesh, _ = _pair_count(GeomType.BOX.value, GeomType.MESH.value) + nmeshmesh, _ = _pair_count(GeomType.MESH.value, GeomType.MESH.value) epa_iterations = 16 if nboxbox == ncollision else m.opt.ccd_iterations @@ -1117,22 +1134,25 @@ def _pair_count(p1: int, p2: int) -> int: nmaxpolygon = max(m.nmaxpolygon, minval) nmaxmeshdeg = max(m.nmaxmeshdeg, 3) + # ccd collider count + nccd = wp.zeros(len(GeomType) * (len(GeomType) + 1) // 2, dtype=int) + # epa_vert1: vertices in EPA polytope in geom 1 space - epa_vert1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3) + epa_vert1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3) # epa_vert2: vertices in EPA polytope in geom 2 space - epa_vert2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3) + epa_vert2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3) # epa_vert_index1: vertex indices in EPA polytope for geom 1 - epa_vert_index1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int) + epa_vert_index1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int) # epa_vert_index2: vertex indices in EPA polytope for geom 2 (naconmax, 5 + CCDiter) - epa_vert_index2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int) + epa_vert_index2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int) # epa_face: faces of polytope represented by three indices - epa_face = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int) + epa_face = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int) # epa_pr: projection of origin on polytope faces - epa_pr = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3) + epa_pr = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3) # epa_norm2: epa_pr * epa_pr - epa_norm2 = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float) + epa_norm2 = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float) # epa_horizon: index pair (i j) of edges on horizon - epa_horizon = wp.empty(shape=(d.naconmax, MJ_MAX_EPAHORIZON), dtype=int) + epa_horizon = wp.empty(shape=(d.naccdmax, MJ_MAX_EPAHORIZON), dtype=int) # Contact outputs contact_outputs = [ @@ -1156,9 +1176,10 @@ def _pair_count(p1: int, p2: int) -> int: for geom_pair in collision_table: g1 = geom_pair[0].value g2 = geom_pair[1].value - if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and _pair_count(g1, g2): + count, geomgeomid = _pair_count(g1, g2) + if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and count: wp.launch( - ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations), + ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, geomgeomid), dim=d.naconmax, inputs=[ m.opt.ccd_tolerance, @@ -1203,6 +1224,7 @@ def _pair_count(p1: int, p2: int) -> int: d.geom_xpos, d.geom_xmat, d.naconmax, + d.naccdmax, d.ncollision, ctx.collision_pair, ctx.collision_pairid, @@ -1215,41 +1237,43 @@ def _pair_count(p1: int, p2: int) -> int: epa_pr, epa_norm2, epa_horizon, + nccd, ], outputs=contact_outputs, ) # Allocate multiccd arrays only for non-heightfield collisions # multiccd_polygon: clipped contact surface - multiccd_polygon = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3) + multiccd_polygon = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3) # multiccd_clipped: clipped contact surface (intermediate) - multiccd_clipped = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3) + multiccd_clipped = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3) # multiccd_pnormal: plane normal of clipping polygon - multiccd_pnormal = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3) + multiccd_pnormal = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3) # multiccd_pdist: plane distance of clipping polygon - multiccd_pdist = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=float) + multiccd_pdist = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=float) # multiccd_idx1: list of normal index candidates for Geom 1 - multiccd_idx1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int) + multiccd_idx1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int) # multiccd_idx2: list of normal index candidates for Geom 2 - multiccd_idx2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int) + multiccd_idx2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int) # multiccd_n1: list of normal candidates for Geom 1 - multiccd_n1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3) + multiccd_n1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3) # multiccd_n2: list of normal candidates for Geom 1 - multiccd_n2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3) + multiccd_n2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3) # multiccd_endvert: list of edge vertices candidates - multiccd_endvert = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3) + multiccd_endvert = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3) # multiccd_face1: contact face - multiccd_face1 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3) + multiccd_face1 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3) # multiccd_face2: contact face - multiccd_face2 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3) + multiccd_face2 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3) # Launch non-heightfield collision kernels (no hfield args, 78 args total) for geom_pair in collision_table: g1 = geom_pair[0].value g2 = geom_pair[1].value - if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and _pair_count(g1, g2): + count, geomgeomid = _pair_count(g1, g2) + if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and count: wp.launch( - ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd), + ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd, geomgeomid), dim=d.naconmax, inputs=[ m.opt.ccd_tolerance, @@ -1288,6 +1312,7 @@ def _pair_count(p1: int, p2: int) -> int: d.geom_xpos, d.geom_xmat, d.naconmax, + d.naccdmax, d.ncollision, ctx.collision_pair, ctx.collision_pairid, @@ -1311,6 +1336,7 @@ def _pair_count(p1: int, p2: int) -> int: multiccd_endvert, multiccd_face1, multiccd_face2, + nccd, ], outputs=contact_outputs, ) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index b96ea5406..383599eee 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -656,8 +656,10 @@ def make_data( mjm: mujoco.MjModel, nworld: int = 1, nconmax: Optional[int] = None, + nccdmax: Optional[int] = None, njmax: Optional[int] = None, naconmax: Optional[int] = None, + naccdmax: Optional[int] = None, ) -> types.Data: """Creates a data object on device. @@ -666,9 +668,11 @@ def make_data( nworld: Number of worlds. nconmax: Number of contacts to allocate per world. Contacts exist in large heterogeneous arrays: one world may have more than nconmax contacts. + nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax. njmax: Number of constraints to allocate per world. Constraint arrays are batched by world: no world may have more than njmax constraints. naconmax: Number of contacts to allocate for all worlds. Overrides nconmax. + naccdmax: Maximum number of CCD contacts. Defaults to naconmax. Returns: The data object containing the current state and output arrays (device). @@ -677,21 +681,36 @@ def make_data( if nconmax is None: nconmax = _default_nconmax(mjm) + if nconmax < 0: + raise ValueError("nconmax must be >= 0") + + if nccdmax is None: + nccdmax = nconmax + elif nccdmax < 0: + raise ValueError("nccdmax must be >= 0") + elif nccdmax > nconmax: + raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})") + if njmax is None: njmax = _default_njmax(mjm) + if njmax < 0: + raise ValueError("njmax must be >= 0") + if nworld < 1: raise ValueError(f"nworld must be >= 1") if naconmax is None: - if nconmax < 0: - raise ValueError("nconmax must be >= 0") naconmax = nworld * nconmax elif naconmax < 0: raise ValueError("naconmax must be >= 0") - if njmax < 0: - raise ValueError("njmax must be >= 0") + if naccdmax is None: + naccdmax = nworld * nccdmax + elif naccdmax < 0: + raise ValueError("naccdmax must be >= 0") + elif naccdmax > naconmax: + raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})") sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int}) sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max() @@ -721,6 +740,7 @@ def make_data( "efc": efc, "nworld": nworld, "naconmax": naconmax, + "naccdmax": naccdmax, "njmax": njmax, "qM": None, "qLD": None, @@ -765,8 +785,10 @@ def put_data( mjd: mujoco.MjData, nworld: int = 1, nconmax: Optional[int] = None, + nccdmax: Optional[int] = None, njmax: Optional[int] = None, naconmax: Optional[int] = None, + naccdmax: Optional[int] = None, ) -> types.Data: """Moves data from host to a device. @@ -776,9 +798,11 @@ def put_data( nworld: The number of worlds. nconmax: Number of contacts to allocate per world. Contacts exist in large heterogenous arrays: one world may have more than nconmax contacts. + nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax. njmax: Number of constraints to allocate per world. Constraint arrays are batched by world: no world may have more than njmax constraints. naconmax: Number of contacts to allocate for all worlds. Overrides nconmax. + naccdmax: Maximum number of CCD contacts. Defaults to naconmax. Returns: The data object containing the current state and output arrays (device). @@ -790,23 +814,38 @@ def put_data( if nconmax is None: nconmax = _default_nconmax(mjm, mjd) + if nconmax < 0: + raise ValueError("nconmax must be >= 0") + + if nccdmax is None: + nccdmax = nconmax + elif nccdmax < 0: + raise ValueError("nccdmax must be >= 0") + elif nccdmax > nconmax: + raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})") + if njmax is None: njmax = _default_njmax(mjm, mjd) + if njmax < 0: + raise ValueError("njmax must be >= 0") + if nworld < 1: raise ValueError(f"nworld must be >= 1") if naconmax is None: - if nconmax < 0: - raise ValueError("nconmax must be >= 0") if mjd.ncon > nconmax: raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})") naconmax = nworld * nconmax elif naconmax < mjd.ncon * nworld: raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})") - if njmax < 0: - raise ValueError("njmax must be >= 0") + if naccdmax is None: + naccdmax = nworld * nccdmax + elif naccdmax < 0: + raise ValueError("naccdmax must be >= 0") + elif naccdmax > naconmax: + raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})") if mjd.nefc > njmax: raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})") @@ -882,6 +921,7 @@ def put_data( "efc": efc, "nworld": nworld, "naconmax": naconmax, + "naccdmax": naccdmax, "njmax": njmax, # fields set after initialization: "solver_niter": None, diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index 4ecf590e5..fc5c85efe 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -438,6 +438,28 @@ def test_static_geom_collision_with_put_data(self): box_z = d.xpos.numpy()[0, 1, 2] # world 0, body 1 (box), z coordinate self.assertGreater(box_z, 0.4, msg=f"Box fell through ground plane (z={box_z}, should be > 0.4)") + def test_make_data_nccdmax_exceeds_nconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + with self.assertRaises(ValueError, msg="nccdmax.*nconmax"): + mjwarp.make_data(mjm, nconmax=16, nccdmax=17) + + def test_make_data_naccdmax_exceeds_naconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + with self.assertRaises(ValueError, msg="naccdmax.*naconmax"): + mjwarp.make_data(mjm, nconmax=16, naconmax=16, naccdmax=17) + + def test_put_data_nccdmax_exceeds_nconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + mjd = mujoco.MjData(mjm) + with self.assertRaises(ValueError, msg="nccdmax.*nconmax"): + mjwarp.put_data(mjm, mjd, nconmax=16, nccdmax=17) + + def test_put_data_naccdmax_exceeds_naconmax(self): + mjm = mujoco.MjModel.from_xml_string("") + mjd = mujoco.MjData(mjm) + with self.assertRaises(ValueError, msg="naccdmax.*naconmax"): + mjwarp.put_data(mjm, mjd, nconmax=16, naconmax=16, naccdmax=17) + def test_noslip_solver(self): with self.assertRaises(NotImplementedError): test_data.fixture( diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 9bac275ac..8d50fc8f8 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -1663,6 +1663,7 @@ class Data: warp only fields: nworld: number of worlds naconmax: maximum number of contacts (shared across all worlds) + naccdmax: maximum number of contacts for CCD (all worlds) njmax: maximum number of constraints per world nacon: number of detected contacts (across all worlds) (1,) ncollision: collision count from broadphase (1,) @@ -1749,6 +1750,7 @@ class Data: # warp only fields: nworld: int naconmax: int + naccdmax: int njmax: int nacon: array(1, int) ncollision: array(1, int) diff --git a/mujoco_warp/testspeed.py b/mujoco_warp/testspeed.py index 48ac0bdf7..68fd4e3c2 100644 --- a/mujoco_warp/testspeed.py +++ b/mujoco_warp/testspeed.py @@ -54,6 +54,7 @@ _NWORLD = flags.DEFINE_integer("nworld", 8192, "number of parallel rollouts") _NCONMAX = flags.DEFINE_integer("nconmax", None, "override maximum number of contacts per world") _NJMAX = flags.DEFINE_integer("njmax", None, "override maximum number of constraints per world") +_NCCDMAX = flags.DEFINE_integer("nccdmax", None, "override maximum number of CCD contacts per world") _OVERRIDE = flags.DEFINE_multi_string("override", [], "Model overrides (notation: foo.bar = baz)", short_name="o") _KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.") _CLEAR_WARP_CACHE = flags.DEFINE_bool("clear_warp_cache", False, "clear warp caches (kernel, LTO, CUDA compute)") @@ -302,7 +303,7 @@ def _main(argv: Sequence[str]): override_model(mjm, _OVERRIDE.value) m = mjw.put_model(mjm) override_model(m, _OVERRIDE.value) - d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value) + d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value, nccdmax=_NCCDMAX.value) rc = None if "rc" in inspect.signature(_FUNCS[_FUNCTION.value]).parameters.keys(): rc = mjw.create_render_context( diff --git a/mujoco_warp/viewer.py b/mujoco_warp/viewer.py index f8647d712..6784718a8 100644 --- a/mujoco_warp/viewer.py +++ b/mujoco_warp/viewer.py @@ -56,6 +56,7 @@ class EngineOptions(enum.IntEnum): _ENGINE = flags.DEFINE_enum_class("engine", EngineOptions.WARP, EngineOptions, "Simulation engine") _NCONMAX = flags.DEFINE_integer("nconmax", None, "Maximum number of contacts.") _NJMAX = flags.DEFINE_integer("njmax", None, "Maximum number of constraints per world.") +_NCCDMAX = flags.DEFINE_integer("nccdmax", None, "Maximum number of CCD contacts per world.") _OVERRIDE = flags.DEFINE_multi_string("override", [], "Model overrides (notation: foo.bar = baz)", short_name="o") _KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.") _DEVICE = flags.DEFINE_string("device", None, "override the default Warp device") @@ -148,7 +149,7 @@ def _main(argv: Sequence[str]) -> None: override_model(mjm, _OVERRIDE.value) m = mjw.put_model(mjm) override_model(m, _OVERRIDE.value) - d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value) + d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value, nccdmax=_NCCDMAX.value) graph = _compile_step(m, d) if wp.get_device().is_cuda else None if graph is None: mjw.step(m, d) # warmup step