diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py
index ebc292774..781b7aa9e 100644
--- a/mujoco_warp/_src/collision_convex.py
+++ b/mujoco_warp/_src/collision_convex.py
@@ -155,6 +155,7 @@ def ccd_hfield_kernel_builder(
geomtype2: int,
gjk_iterations: int,
epa_iterations: int,
+ geomgeomid: int,
):
"""Kernel builder for heightfield CCD collisions (no multiccd args)."""
@@ -205,6 +206,7 @@ def ccd_hfield_kernel(
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
naconmax_in: int,
+ naccdmax_in: int,
ncollision_in: wp.array(dtype=int),
# In:
collision_pair_in: wp.array(dtype=wp.vec2i),
@@ -218,6 +220,7 @@ def ccd_hfield_kernel(
epa_pr_in: wp.array2d(dtype=wp.vec3),
epa_norm2_in: wp.array2d(dtype=float),
epa_horizon_in: wp.array2d(dtype=int),
+ nccd_in: wp.array(dtype=int),
# Data out:
contact_dist_out: wp.array(dtype=float),
contact_pos_out: wp.array(dtype=wp.vec3),
@@ -234,18 +237,18 @@ def ccd_hfield_kernel(
contact_geomcollisionid_out: wp.array(dtype=int),
nacon_out: wp.array(dtype=int),
):
- tid = wp.tid()
- if tid >= ncollision_in[0]:
+ collisionid = wp.tid()
+ if collisionid >= ncollision_in[0]:
return
- geoms = collision_pair_in[tid]
+ geoms = collision_pair_in[collisionid]
g1 = geoms[0]
g2 = geoms[1]
if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2:
return
- worldid = collision_worldid_in[tid]
+ worldid = collision_worldid_in[collisionid]
# height field filter
no_hf_collision, xmin, xmax, ymin, ymax, zmin, zmax = _hfield_filter(
@@ -269,6 +272,11 @@ def ccd_hfield_kernel(
if no_hf_collision:
return
+ ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1)
+ if ccdid >= naccdmax_in:
+ wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid)
+ return
+
_, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params(
geom_condim,
geom_priority,
@@ -287,7 +295,7 @@ def ccd_hfield_kernel(
pair_friction,
collision_pair_in,
collision_pairid_in,
- tid,
+ collisionid,
worldid,
)
@@ -368,16 +376,16 @@ def ccd_hfield_kernel(
geom2.margin = margin
# EPA memory
- epa_vert1 = epa_vert1_in[tid]
- epa_vert2 = epa_vert2_in[tid]
- epa_vert_index1 = epa_vert_index1_in[tid]
- epa_vert_index2 = epa_vert_index2_in[tid]
- epa_face = epa_face_in[tid]
- epa_pr = epa_pr_in[tid]
- epa_norm2 = epa_norm2_in[tid]
- epa_horizon = epa_horizon_in[tid]
+ epa_vert1 = epa_vert1_in[ccdid]
+ epa_vert2 = epa_vert2_in[ccdid]
+ epa_vert_index1 = epa_vert_index1_in[ccdid]
+ epa_vert_index2 = epa_vert_index2_in[ccdid]
+ epa_face = epa_face_in[ccdid]
+ epa_pr = epa_pr_in[ccdid]
+ epa_norm2 = epa_norm2_in[ccdid]
+ epa_horizon = epa_horizon_in[ccdid]
- collision_pairid = collision_pairid_in[tid]
+ collision_pairid = collision_pairid_in[collisionid]
# process all prisms in subgrid
count = int(0)
@@ -692,6 +700,7 @@ def ccd_kernel_builder(
gjk_iterations: int,
epa_iterations: int,
use_multiccd: bool,
+ geomgeomid: int,
):
"""Kernel builder for non-heightfield CCD collisions (no hfield args)."""
@@ -725,7 +734,7 @@ def eval_ccd_write_contact(
geom2: Geom,
geoms: wp.vec2i,
worldid: int,
- tid: int,
+ ccdid: int,
margin: float,
gap: float,
condim: int,
@@ -773,14 +782,14 @@ def eval_ccd_write_contact(
geomtype2,
x1,
x2,
- epa_vert1_in[tid],
- epa_vert2_in[tid],
- epa_vert_index1_in[tid],
- epa_vert_index2_in[tid],
- epa_face_in[tid],
- epa_pr_in[tid],
- epa_norm2_in[tid],
- epa_horizon_in[tid],
+ epa_vert1_in[ccdid],
+ epa_vert2_in[ccdid],
+ epa_vert_index1_in[ccdid],
+ epa_vert_index2_in[ccdid],
+ epa_face_in[ccdid],
+ epa_pr_in[ccdid],
+ epa_norm2_in[ccdid],
+ epa_horizon_in[ccdid],
)
if dist >= 0.0 and pairid[1] == -1:
@@ -802,22 +811,22 @@ def eval_ccd_write_contact(
if multiccd_idx > -1:
ncollision, witness1, witness2 = multicontact(
- multiccd_polygon_in[tid],
- multiccd_clipped_in[tid],
- multiccd_pnormal_in[tid],
- multiccd_pdist_in[tid],
- multiccd_idx1_in[tid],
- multiccd_idx2_in[tid],
- multiccd_n1_in[tid],
- multiccd_n2_in[tid],
- multiccd_endvert_in[tid],
- multiccd_face1_in[tid],
- multiccd_face2_in[tid],
- epa_vert1_in[tid],
- epa_vert2_in[tid],
- epa_vert_index1_in[tid],
- epa_vert_index2_in[tid],
- epa_face_in[tid, multiccd_idx],
+ multiccd_polygon_in[ccdid],
+ multiccd_clipped_in[ccdid],
+ multiccd_pnormal_in[ccdid],
+ multiccd_pdist_in[ccdid],
+ multiccd_idx1_in[ccdid],
+ multiccd_idx2_in[ccdid],
+ multiccd_n1_in[ccdid],
+ multiccd_n2_in[ccdid],
+ multiccd_endvert_in[ccdid],
+ multiccd_face1_in[ccdid],
+ multiccd_face2_in[ccdid],
+ epa_vert1_in[ccdid],
+ epa_vert2_in[ccdid],
+ epa_vert_index1_in[ccdid],
+ epa_vert_index2_in[ccdid],
+ epa_face_in[ccdid, multiccd_idx],
w1,
w2,
geom1,
@@ -914,6 +923,7 @@ def ccd_kernel(
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
naconmax_in: int,
+ naccdmax_in: int,
ncollision_in: wp.array(dtype=int),
# In:
collision_pair_in: wp.array(dtype=wp.vec2i),
@@ -938,6 +948,7 @@ def ccd_kernel(
multiccd_endvert_in: wp.array2d(dtype=wp.vec3),
multiccd_face1_in: wp.array2d(dtype=wp.vec3),
multiccd_face2_in: wp.array2d(dtype=wp.vec3),
+ nccd_in: wp.array(dtype=int),
# Data out:
contact_dist_out: wp.array(dtype=float),
contact_pos_out: wp.array(dtype=wp.vec3),
@@ -954,18 +965,23 @@ def ccd_kernel(
contact_geomcollisionid_out: wp.array(dtype=int),
nacon_out: wp.array(dtype=int),
):
- tid = wp.tid()
- if tid >= ncollision_in[0]:
+ collisionid = wp.tid()
+ if collisionid >= ncollision_in[0]:
return
- geoms = collision_pair_in[tid]
+ geoms = collision_pair_in[collisionid]
g1 = geoms[0]
g2 = geoms[1]
if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2:
return
- worldid = collision_worldid_in[tid]
+ ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1)
+ if ccdid >= naccdmax_in:
+ wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid)
+ return
+
+ worldid = collision_worldid_in[collisionid]
_, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params(
geom_condim,
@@ -985,7 +1001,7 @@ def ccd_kernel(
pair_friction,
collision_pair_in,
collision_pairid_in,
- tid,
+ collisionid,
worldid,
)
@@ -1039,7 +1055,7 @@ def ccd_kernel(
geom2,
geoms,
worldid,
- tid,
+ ccdid,
margin,
gap,
condim,
@@ -1049,7 +1065,7 @@ def ccd_kernel(
solimp,
geom1.pos,
geom2.pos,
- collision_pairid_in[tid],
+ collision_pairid_in[collisionid],
contact_dist_out,
contact_pos_out,
contact_frame_out,
@@ -1086,21 +1102,22 @@ def convex_narrowphase(m: Model, d: Data, ctx: CollisionContext, collision_table
computations for non-existent pair types.
"""
- def _pair_count(p1: int, p2: int) -> int:
- return m.geom_pair_type_count[upper_trid_index(len(GeomType), p1, p2)]
+ def _pair_count(p1: int, p2: int) -> Tuple[int, int]:
+ idx = upper_trid_index(len(GeomType), p1, p2)
+ return m.geom_pair_type_count[idx], idx
- ncollision = sum(_pair_count(g[0].value, g[1].value) for g in collision_table)
+ ncollision = sum(_pair_count(g[0].value, g[1].value)[0] for g in collision_table)
# no convex collisions, early return
if ncollision == 0:
return
# compute nmaxpolygon and nmaxmeshdeg given the geom pairs for the model
- nboxbox = _pair_count(GeomType.BOX.value, GeomType.BOX.value)
+ nboxbox, _ = _pair_count(GeomType.BOX.value, GeomType.BOX.value)
if (GeomType.BOX, GeomType.BOX) not in collision_table:
nboxbox = 0
- nboxmesh = _pair_count(GeomType.BOX.value, GeomType.MESH.value)
- nmeshmesh = _pair_count(GeomType.MESH.value, GeomType.MESH.value)
+ nboxmesh, _ = _pair_count(GeomType.BOX.value, GeomType.MESH.value)
+ nmeshmesh, _ = _pair_count(GeomType.MESH.value, GeomType.MESH.value)
epa_iterations = 16 if nboxbox == ncollision else m.opt.ccd_iterations
@@ -1117,22 +1134,25 @@ def _pair_count(p1: int, p2: int) -> int:
nmaxpolygon = max(m.nmaxpolygon, minval)
nmaxmeshdeg = max(m.nmaxmeshdeg, 3)
+ # ccd collider count
+ nccd = wp.zeros(len(GeomType) * (len(GeomType) + 1) // 2, dtype=int)
+
# epa_vert1: vertices in EPA polytope in geom 1 space
- epa_vert1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3)
+ epa_vert1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3)
# epa_vert2: vertices in EPA polytope in geom 2 space
- epa_vert2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3)
+ epa_vert2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3)
# epa_vert_index1: vertex indices in EPA polytope for geom 1
- epa_vert_index1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int)
+ epa_vert_index1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int)
# epa_vert_index2: vertex indices in EPA polytope for geom 2 (naconmax, 5 + CCDiter)
- epa_vert_index2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int)
+ epa_vert_index2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int)
# epa_face: faces of polytope represented by three indices
- epa_face = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int)
+ epa_face = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int)
# epa_pr: projection of origin on polytope faces
- epa_pr = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3)
+ epa_pr = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3)
# epa_norm2: epa_pr * epa_pr
- epa_norm2 = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float)
+ epa_norm2 = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float)
# epa_horizon: index pair (i j) of edges on horizon
- epa_horizon = wp.empty(shape=(d.naconmax, MJ_MAX_EPAHORIZON), dtype=int)
+ epa_horizon = wp.empty(shape=(d.naccdmax, MJ_MAX_EPAHORIZON), dtype=int)
# Contact outputs
contact_outputs = [
@@ -1156,9 +1176,10 @@ def _pair_count(p1: int, p2: int) -> int:
for geom_pair in collision_table:
g1 = geom_pair[0].value
g2 = geom_pair[1].value
- if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and _pair_count(g1, g2):
+ count, geomgeomid = _pair_count(g1, g2)
+ if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and count:
wp.launch(
- ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations),
+ ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, geomgeomid),
dim=d.naconmax,
inputs=[
m.opt.ccd_tolerance,
@@ -1203,6 +1224,7 @@ def _pair_count(p1: int, p2: int) -> int:
d.geom_xpos,
d.geom_xmat,
d.naconmax,
+ d.naccdmax,
d.ncollision,
ctx.collision_pair,
ctx.collision_pairid,
@@ -1215,41 +1237,43 @@ def _pair_count(p1: int, p2: int) -> int:
epa_pr,
epa_norm2,
epa_horizon,
+ nccd,
],
outputs=contact_outputs,
)
# Allocate multiccd arrays only for non-heightfield collisions
# multiccd_polygon: clipped contact surface
- multiccd_polygon = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3)
+ multiccd_polygon = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3)
# multiccd_clipped: clipped contact surface (intermediate)
- multiccd_clipped = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3)
+ multiccd_clipped = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3)
# multiccd_pnormal: plane normal of clipping polygon
- multiccd_pnormal = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3)
+ multiccd_pnormal = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3)
# multiccd_pdist: plane distance of clipping polygon
- multiccd_pdist = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=float)
+ multiccd_pdist = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=float)
# multiccd_idx1: list of normal index candidates for Geom 1
- multiccd_idx1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int)
+ multiccd_idx1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int)
# multiccd_idx2: list of normal index candidates for Geom 2
- multiccd_idx2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int)
+ multiccd_idx2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int)
# multiccd_n1: list of normal candidates for Geom 1
- multiccd_n1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3)
+ multiccd_n1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3)
# multiccd_n2: list of normal candidates for Geom 1
- multiccd_n2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3)
+ multiccd_n2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3)
# multiccd_endvert: list of edge vertices candidates
- multiccd_endvert = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3)
+ multiccd_endvert = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3)
# multiccd_face1: contact face
- multiccd_face1 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3)
+ multiccd_face1 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3)
# multiccd_face2: contact face
- multiccd_face2 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3)
+ multiccd_face2 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3)
# Launch non-heightfield collision kernels (no hfield args, 78 args total)
for geom_pair in collision_table:
g1 = geom_pair[0].value
g2 = geom_pair[1].value
- if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and _pair_count(g1, g2):
+ count, geomgeomid = _pair_count(g1, g2)
+ if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and count:
wp.launch(
- ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd),
+ ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd, geomgeomid),
dim=d.naconmax,
inputs=[
m.opt.ccd_tolerance,
@@ -1288,6 +1312,7 @@ def _pair_count(p1: int, p2: int) -> int:
d.geom_xpos,
d.geom_xmat,
d.naconmax,
+ d.naccdmax,
d.ncollision,
ctx.collision_pair,
ctx.collision_pairid,
@@ -1311,6 +1336,7 @@ def _pair_count(p1: int, p2: int) -> int:
multiccd_endvert,
multiccd_face1,
multiccd_face2,
+ nccd,
],
outputs=contact_outputs,
)
diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py
index b96ea5406..383599eee 100644
--- a/mujoco_warp/_src/io.py
+++ b/mujoco_warp/_src/io.py
@@ -656,8 +656,10 @@ def make_data(
mjm: mujoco.MjModel,
nworld: int = 1,
nconmax: Optional[int] = None,
+ nccdmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
+ naccdmax: Optional[int] = None,
) -> types.Data:
"""Creates a data object on device.
@@ -666,9 +668,11 @@ def make_data(
nworld: Number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogeneous arrays: one world may have more than nconmax contacts.
+ nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
+ naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
Returns:
The data object containing the current state and output arrays (device).
@@ -677,21 +681,36 @@ def make_data(
if nconmax is None:
nconmax = _default_nconmax(mjm)
+ if nconmax < 0:
+ raise ValueError("nconmax must be >= 0")
+
+ if nccdmax is None:
+ nccdmax = nconmax
+ elif nccdmax < 0:
+ raise ValueError("nccdmax must be >= 0")
+ elif nccdmax > nconmax:
+ raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
+
if njmax is None:
njmax = _default_njmax(mjm)
+ if njmax < 0:
+ raise ValueError("njmax must be >= 0")
+
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
- if nconmax < 0:
- raise ValueError("nconmax must be >= 0")
naconmax = nworld * nconmax
elif naconmax < 0:
raise ValueError("naconmax must be >= 0")
- if njmax < 0:
- raise ValueError("njmax must be >= 0")
+ if naccdmax is None:
+ naccdmax = nworld * nccdmax
+ elif naccdmax < 0:
+ raise ValueError("naccdmax must be >= 0")
+ elif naccdmax > naconmax:
+ raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
@@ -721,6 +740,7 @@ def make_data(
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
+ "naccdmax": naccdmax,
"njmax": njmax,
"qM": None,
"qLD": None,
@@ -765,8 +785,10 @@ def put_data(
mjd: mujoco.MjData,
nworld: int = 1,
nconmax: Optional[int] = None,
+ nccdmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
+ naccdmax: Optional[int] = None,
) -> types.Data:
"""Moves data from host to a device.
@@ -776,9 +798,11 @@ def put_data(
nworld: The number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogenous arrays: one world may have more than nconmax contacts.
+ nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
+ naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
Returns:
The data object containing the current state and output arrays (device).
@@ -790,23 +814,38 @@ def put_data(
if nconmax is None:
nconmax = _default_nconmax(mjm, mjd)
+ if nconmax < 0:
+ raise ValueError("nconmax must be >= 0")
+
+ if nccdmax is None:
+ nccdmax = nconmax
+ elif nccdmax < 0:
+ raise ValueError("nccdmax must be >= 0")
+ elif nccdmax > nconmax:
+ raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
+
if njmax is None:
njmax = _default_njmax(mjm, mjd)
+ if njmax < 0:
+ raise ValueError("njmax must be >= 0")
+
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
- if nconmax < 0:
- raise ValueError("nconmax must be >= 0")
if mjd.ncon > nconmax:
raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")
naconmax = nworld * nconmax
elif naconmax < mjd.ncon * nworld:
raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})")
- if njmax < 0:
- raise ValueError("njmax must be >= 0")
+ if naccdmax is None:
+ naccdmax = nworld * nccdmax
+ elif naccdmax < 0:
+ raise ValueError("naccdmax must be >= 0")
+ elif naccdmax > naconmax:
+ raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
if mjd.nefc > njmax:
raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")
@@ -882,6 +921,7 @@ def put_data(
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
+ "naccdmax": naccdmax,
"njmax": njmax,
# fields set after initialization:
"solver_niter": None,
diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py
index 4ecf590e5..fc5c85efe 100644
--- a/mujoco_warp/_src/io_test.py
+++ b/mujoco_warp/_src/io_test.py
@@ -438,6 +438,28 @@ def test_static_geom_collision_with_put_data(self):
box_z = d.xpos.numpy()[0, 1, 2] # world 0, body 1 (box), z coordinate
self.assertGreater(box_z, 0.4, msg=f"Box fell through ground plane (z={box_z}, should be > 0.4)")
+ def test_make_data_nccdmax_exceeds_nconmax(self):
+ mjm = mujoco.MjModel.from_xml_string("")
+ with self.assertRaises(ValueError, msg="nccdmax.*nconmax"):
+ mjwarp.make_data(mjm, nconmax=16, nccdmax=17)
+
+ def test_make_data_naccdmax_exceeds_naconmax(self):
+ mjm = mujoco.MjModel.from_xml_string("")
+ with self.assertRaises(ValueError, msg="naccdmax.*naconmax"):
+ mjwarp.make_data(mjm, nconmax=16, naconmax=16, naccdmax=17)
+
+ def test_put_data_nccdmax_exceeds_nconmax(self):
+ mjm = mujoco.MjModel.from_xml_string("")
+ mjd = mujoco.MjData(mjm)
+ with self.assertRaises(ValueError, msg="nccdmax.*nconmax"):
+ mjwarp.put_data(mjm, mjd, nconmax=16, nccdmax=17)
+
+ def test_put_data_naccdmax_exceeds_naconmax(self):
+ mjm = mujoco.MjModel.from_xml_string("")
+ mjd = mujoco.MjData(mjm)
+ with self.assertRaises(ValueError, msg="naccdmax.*naconmax"):
+ mjwarp.put_data(mjm, mjd, nconmax=16, naconmax=16, naccdmax=17)
+
def test_noslip_solver(self):
with self.assertRaises(NotImplementedError):
test_data.fixture(
diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py
index 9bac275ac..8d50fc8f8 100644
--- a/mujoco_warp/_src/types.py
+++ b/mujoco_warp/_src/types.py
@@ -1663,6 +1663,7 @@ class Data:
warp only fields:
nworld: number of worlds
naconmax: maximum number of contacts (shared across all worlds)
+ naccdmax: maximum number of contacts for CCD (all worlds)
njmax: maximum number of constraints per world
nacon: number of detected contacts (across all worlds) (1,)
ncollision: collision count from broadphase (1,)
@@ -1749,6 +1750,7 @@ class Data:
# warp only fields:
nworld: int
naconmax: int
+ naccdmax: int
njmax: int
nacon: array(1, int)
ncollision: array(1, int)
diff --git a/mujoco_warp/testspeed.py b/mujoco_warp/testspeed.py
index 48ac0bdf7..68fd4e3c2 100644
--- a/mujoco_warp/testspeed.py
+++ b/mujoco_warp/testspeed.py
@@ -54,6 +54,7 @@
_NWORLD = flags.DEFINE_integer("nworld", 8192, "number of parallel rollouts")
_NCONMAX = flags.DEFINE_integer("nconmax", None, "override maximum number of contacts per world")
_NJMAX = flags.DEFINE_integer("njmax", None, "override maximum number of constraints per world")
+_NCCDMAX = flags.DEFINE_integer("nccdmax", None, "override maximum number of CCD contacts per world")
_OVERRIDE = flags.DEFINE_multi_string("override", [], "Model overrides (notation: foo.bar = baz)", short_name="o")
_KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.")
_CLEAR_WARP_CACHE = flags.DEFINE_bool("clear_warp_cache", False, "clear warp caches (kernel, LTO, CUDA compute)")
@@ -302,7 +303,7 @@ def _main(argv: Sequence[str]):
override_model(mjm, _OVERRIDE.value)
m = mjw.put_model(mjm)
override_model(m, _OVERRIDE.value)
- d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value)
+ d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value, nccdmax=_NCCDMAX.value)
rc = None
if "rc" in inspect.signature(_FUNCS[_FUNCTION.value]).parameters.keys():
rc = mjw.create_render_context(
diff --git a/mujoco_warp/viewer.py b/mujoco_warp/viewer.py
index f8647d712..6784718a8 100644
--- a/mujoco_warp/viewer.py
+++ b/mujoco_warp/viewer.py
@@ -56,6 +56,7 @@ class EngineOptions(enum.IntEnum):
_ENGINE = flags.DEFINE_enum_class("engine", EngineOptions.WARP, EngineOptions, "Simulation engine")
_NCONMAX = flags.DEFINE_integer("nconmax", None, "Maximum number of contacts.")
_NJMAX = flags.DEFINE_integer("njmax", None, "Maximum number of constraints per world.")
+_NCCDMAX = flags.DEFINE_integer("nccdmax", None, "Maximum number of CCD contacts per world.")
_OVERRIDE = flags.DEFINE_multi_string("override", [], "Model overrides (notation: foo.bar = baz)", short_name="o")
_KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.")
_DEVICE = flags.DEFINE_string("device", None, "override the default Warp device")
@@ -148,7 +149,7 @@ def _main(argv: Sequence[str]) -> None:
override_model(mjm, _OVERRIDE.value)
m = mjw.put_model(mjm)
override_model(m, _OVERRIDE.value)
- d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value)
+ d = mjw.put_data(mjm, mjd, nconmax=_NCONMAX.value, njmax=_NJMAX.value, nccdmax=_NCCDMAX.value)
graph = _compile_step(m, d) if wp.get_device().is_cuda else None
if graph is None:
mjw.step(m, d) # warmup step