diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py
index e6a10bf38..2cace8a10 100644
--- a/mujoco_warp/__init__.py
+++ b/mujoco_warp/__init__.py
@@ -99,3 +99,7 @@
from mujoco_warp._src.types import State as State
from mujoco_warp._src.types import Statistic as Statistic
from mujoco_warp._src.types import TrnType as TrnType
+from mujoco_warp._src.types import WarningType as WarningType
+from mujoco_warp._src.warning import check_warnings as check_warnings
+from mujoco_warp._src.warning import clear_warnings as clear_warnings
+from mujoco_warp._src.warning import get_warnings as get_warnings
diff --git a/mujoco_warp/_src/benchmark.py b/mujoco_warp/_src/benchmark.py
index eaadae0fa..31cf1e314 100644
--- a/mujoco_warp/_src/benchmark.py
+++ b/mujoco_warp/_src/benchmark.py
@@ -25,6 +25,7 @@
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import Model
from mujoco_warp._src.util_misc import halton
+from mujoco_warp._src.warning import check_warnings
def _sum(stack1, stack2):
@@ -145,6 +146,9 @@ def benchmark(
wp.synchronize()
run_end = time.perf_counter()
+ # Check and emit any overflow warnings, then clear flags
+ check_warnings(d, clear=True)
+
time_vec[i] = run_end - run_beg
if trace:
trace = _sum(trace, tracer.trace())
diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py
index e3e33acc6..c141696fe 100644
--- a/mujoco_warp/_src/collision_convex.py
+++ b/mujoco_warp/_src/collision_convex.py
@@ -34,6 +34,7 @@
from mujoco_warp._src.types import EnableBit
from mujoco_warp._src.types import GeomType
from mujoco_warp._src.types import Model
+from mujoco_warp._src.types import WarningType
from mujoco_warp._src.types import mat43
from mujoco_warp._src.types import mat63
from mujoco_warp._src.types import vec5
@@ -154,6 +155,7 @@ def ccd_hfield_kernel_builder(
geomtype2: int,
gjk_iterations: int,
epa_iterations: int,
+ warning_printf: bool,
):
"""Kernel builder for heightfield CCD collisions (no multiccd args)."""
@@ -232,6 +234,9 @@ def ccd_hfield_kernel(
contact_type_out: wp.array(dtype=int),
contact_geomcollisionid_out: wp.array(dtype=int),
nacon_out: wp.array(dtype=int),
+ # Warning output:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
):
tid = wp.tid()
if tid >= ncollision_in[0]:
@@ -403,10 +408,13 @@ def ccd_hfield_kernel(
# add both triangles from this cell
for i in range(2):
if count >= MJ_MAXCONPAIR:
- wp.printf(
- "height field collision overflow, number of collisions >= %u - please adjust resolution: \n decrease the number of hfield rows/cols or modify size of colliding geom\n",
- MJ_MAXCONPAIR,
- )
+ if wp.static(warning_printf):
+ wp.printf(
+ "height field collision overflow, number of collisions >= %u - please adjust resolution: \n decrease the number of hfield rows/cols or modify size of colliding geom\n",
+ MJ_MAXCONPAIR,
+ )
+ wp.atomic_max(warning_out, int(WarningType.HFIELD_OVERFLOW), 1)
+ wp.atomic_max(warning_info_out, int(WarningType.HFIELD_OVERFLOW), 0, MJ_MAXCONPAIR)
continue
# add vert
@@ -453,6 +461,9 @@ def ccd_hfield_kernel(
epa_pr,
epa_norm2,
epa_horizon,
+ wp.static(warning_printf),
+ warning_out,
+ warning_info_out,
)
if ncontact == 0:
@@ -692,6 +703,7 @@ def ccd_kernel_builder(
gjk_iterations: int,
epa_iterations: int,
use_multiccd: bool,
+ warning_printf: bool,
):
"""Kernel builder for non-heightfield CCD collisions (no hfield args)."""
@@ -751,6 +763,9 @@ def eval_ccd_write_contact(
contact_type_out: wp.array(dtype=int),
contact_geomcollisionid_out: wp.array(dtype=int),
nacon_out: wp.array(dtype=int),
+ # Warning output:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
) -> int:
points = mat43()
witness1 = mat43()
@@ -781,6 +796,9 @@ def eval_ccd_write_contact(
epa_pr_in[tid],
epa_norm2_in[tid],
epa_horizon_in[tid],
+ wp.static(warning_printf),
+ warning_out,
+ warning_info_out,
)
if dist >= 0.0 and pairid[1] == -1:
@@ -953,6 +971,9 @@ def ccd_kernel(
contact_type_out: wp.array(dtype=int),
contact_geomcollisionid_out: wp.array(dtype=int),
nacon_out: wp.array(dtype=int),
+ # Warning output:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
):
tid = wp.tid()
if tid >= ncollision_in[0]:
@@ -1064,6 +1085,8 @@ def ccd_kernel(
contact_type_out,
contact_geomcollisionid_out,
nacon_out,
+ warning_out,
+ warning_info_out,
)
return ccd_kernel
@@ -1150,6 +1173,8 @@ def _pair_count(p1: int, p2: int) -> int:
d.contact.type,
d.contact.geomcollisionid,
d.nacon,
+ d.warning,
+ d.warning_info,
]
# Launch heightfield collision kernels (no multiccd args, 72 args total)
@@ -1158,7 +1183,7 @@ def _pair_count(p1: int, p2: int) -> int:
g2 = geom_pair[1].value
if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and _pair_count(g1, g2):
wp.launch(
- ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations),
+ ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, m.opt.warning_printf),
dim=d.naconmax,
inputs=[
m.opt.ccd_tolerance,
@@ -1249,7 +1274,7 @@ def _pair_count(p1: int, p2: int) -> int:
g2 = geom_pair[1].value
if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and _pair_count(g1, g2):
wp.launch(
- ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd),
+ ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd, m.opt.warning_printf),
dim=d.naconmax,
inputs=[
m.opt.ccd_tolerance,
diff --git a/mujoco_warp/_src/collision_gjk.py b/mujoco_warp/_src/collision_gjk.py
index 96b550ec7..3c172268d 100644
--- a/mujoco_warp/_src/collision_gjk.py
+++ b/mujoco_warp/_src/collision_gjk.py
@@ -20,6 +20,7 @@
from mujoco_warp._src.collision_primitive import Geom
from mujoco_warp._src.types import GeomType
+from mujoco_warp._src.types import WarningType
from mujoco_warp._src.types import mat43
from mujoco_warp._src.types import mat63
@@ -571,6 +572,10 @@ def gjk(
geomtype2: int,
cutoff: float,
is_discrete: bool,
+ warning_printf: bool,
+ # Data out:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
) -> GJKResult:
"""Find distance within a tolerance between two geoms."""
cutoff2 = cutoff * cutoff
@@ -667,7 +672,10 @@ def gjk(
cnt += 1
if cnt == gjk_iterations:
- wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations)
+ if warning_printf:
+ wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations)
+ wp.atomic_max(warning_out, int(WarningType.GJK_ITERATIONS), 1)
+ wp.atomic_max(warning_info_out, int(WarningType.GJK_ITERATIONS), 0, gjk_iterations)
result = GJKResult()
@@ -1218,6 +1226,10 @@ def _epa(
geomtype1: int,
geomtype2: int,
is_discrete: bool,
+ warning_printf: bool,
+ # Data out:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
) -> Tuple[float, wp.vec3, wp.vec3, int]:
"""Recover penetration data from two geoms in contact given an initial polytope."""
upper = FLOAT_MAX
@@ -1288,7 +1300,10 @@ def _epa(
pt.nhorizon = _add_edge(pt, face[1], face[2])
pt.nhorizon = _add_edge(pt, face[2], face[0])
if pt.nhorizon == -1:
- wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0])
+ if warning_printf:
+ wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0])
+ wp.atomic_max(warning_out, int(WarningType.EPA_HORIZON), 1)
+ wp.atomic_max(warning_info_out, int(WarningType.EPA_HORIZON), 0, pt.horizon.shape[0])
idx = -1
break
@@ -1305,7 +1320,10 @@ def _epa(
pt.nhorizon = _add_edge(pt, face[1], face[2])
pt.nhorizon = _add_edge(pt, face[2], face[0])
if pt.nhorizon == -1:
- wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0])
+ if warning_printf:
+ wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0])
+ wp.atomic_max(warning_out, int(WarningType.EPA_HORIZON), 1)
+ wp.atomic_max(warning_info_out, int(WarningType.EPA_HORIZON), 0, pt.horizon.shape[0])
idx = -1
break
@@ -1333,7 +1351,10 @@ def _epa(
cnt += 1
if cnt == epa_iterations:
- wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations)
+ if warning_printf:
+ wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations)
+ wp.atomic_max(warning_out, int(WarningType.GJK_ITERATIONS), 1)
+ wp.atomic_max(warning_info_out, int(WarningType.GJK_ITERATIONS), 0, gjk_iterations)
# return from valid face
if idx > -1:
@@ -2232,6 +2253,10 @@ def ccd(
face_pr: wp.array(dtype=wp.vec3),
face_norm2: wp.array(dtype=float),
horizon: wp.array(dtype=int),
+ warning_printf: bool,
+ # Data out:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
) -> Tuple[float, int, wp.vec3, wp.vec3, int]:
"""General convex collision detection via GJK/EPA."""
full_margin1 = 0.0
@@ -2257,7 +2282,21 @@ def ccd(
# special handling for sphere and capsule (shrink to point and line respectively)
if size1 + size2 > 0.0:
cutoff += full_margin1 + full_margin2
- result = gjk(tolerance, gjk_iterations, geom1, geom2, x_1, x_2, geomtype1, geomtype2, cutoff, is_discrete)
+ result = gjk(
+ tolerance,
+ gjk_iterations,
+ geom1,
+ geom2,
+ x_1,
+ x_2,
+ geomtype1,
+ geomtype2,
+ cutoff,
+ is_discrete,
+ warning_printf,
+ warning_out,
+ warning_info_out,
+ )
# shallow penetration, inflate contact
if result.dist > tolerance:
@@ -2273,7 +2312,21 @@ def ccd(
geom2.size = wp.vec3(size2, geom2.size[1], geom2.size[2])
cutoff -= full_margin1 + full_margin2
- result = gjk(tolerance, gjk_iterations, geom1, geom2, x_1, x_2, geomtype1, geomtype2, cutoff, is_discrete)
+ result = gjk(
+ tolerance,
+ gjk_iterations,
+ geom1,
+ geom2,
+ x_1,
+ x_2,
+ geomtype1,
+ geomtype2,
+ cutoff,
+ is_discrete,
+ warning_printf,
+ warning_out,
+ warning_info_out,
+ )
# no penetration depth to recover
if result.dist > tolerance or result.dim < 2:
@@ -2349,7 +2402,20 @@ def ccd(
if pt.status:
return result.dist, 1, result.x1, result.x2, -1
- dist, x1, x2, idx = _epa(tolerance, gjk_iterations, epa_iterations, pt, geom1, geom2, geomtype1, geomtype2, is_discrete)
+ dist, x1, x2, idx = _epa(
+ tolerance,
+ gjk_iterations,
+ epa_iterations,
+ pt,
+ geom1,
+ geom2,
+ geomtype1,
+ geomtype2,
+ is_discrete,
+ warning_printf,
+ warning_out,
+ warning_info_out,
+ )
if idx == -1:
return FLOAT_MAX, 0, wp.vec3(), wp.vec3(), -1
diff --git a/mujoco_warp/_src/collision_gjk_test.py b/mujoco_warp/_src/collision_gjk_test.py
index 1e61de65a..d641cdf58 100644
--- a/mujoco_warp/_src/collision_gjk_test.py
+++ b/mujoco_warp/_src/collision_gjk_test.py
@@ -26,6 +26,7 @@
from mujoco_warp._src.collision_primitive import Geom
from mujoco_warp._src.types import MJ_MAX_EPAFACES
from mujoco_warp._src.types import MJ_MAX_EPAHORIZON
+from mujoco_warp._src.types import NUM_WARNINGS
def _geom_dist(
@@ -108,6 +109,9 @@ def _ccd_kernel(
endvert: wp.array(dtype=wp.vec3),
face1: wp.array(dtype=wp.vec3),
face2: wp.array(dtype=wp.vec3),
+ # Data out:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
# Out:
dist_out: wp.array(dtype=float),
ncon_out: wp.array(dtype=int),
@@ -200,6 +204,9 @@ def _ccd_kernel(
face_pr,
face_norm2,
horizon,
+ False, # warning_printf
+ warning_out,
+ warning_info_out,
)
if wp.static(multiccd):
@@ -236,6 +243,8 @@ def _ccd_kernel(
dist_out = wp.array(shape=(1,), dtype=float)
ncon_out = wp.array(shape=(1,), dtype=int)
pos_out = wp.array(shape=(2,), dtype=wp.vec3)
+ warning_out = wp.zeros(NUM_WARNINGS, dtype=int)
+ warning_info_out = wp.zeros((NUM_WARNINGS, 2), dtype=int)
wp.launch(
_ccd_kernel,
dim=1,
@@ -280,6 +289,8 @@ def _ccd_kernel(
multiccd_endvert,
multiccd_face1,
multiccd_face2,
+ warning_out,
+ warning_info_out,
],
outputs=[
dist_out,
diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py
index 032070be8..5e71a203a 100644
--- a/mujoco_warp/_src/forward.py
+++ b/mujoco_warp/_src/forward.py
@@ -39,6 +39,7 @@
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import TileSet
from mujoco_warp._src.types import TrnType
+from mujoco_warp._src.types import WarningType
from mujoco_warp._src.types import vec10f
from mujoco_warp._src.warp_util import cache_kernel
from mujoco_warp._src.warp_util import event_scope
@@ -190,37 +191,56 @@ def _next_activation(
act_out[worldid, actid] = act
-@wp.kernel
-def _next_time(
- # Model:
- opt_timestep: wp.array(dtype=float),
- # Data in:
- nefc_in: wp.array(dtype=int),
- time_in: wp.array(dtype=float),
- nworld_in: int,
- naconmax_in: int,
- njmax_in: int,
- nacon_in: wp.array(dtype=int),
- ncollision_in: wp.array(dtype=int),
- # Data out:
- time_out: wp.array(dtype=float),
-):
- worldid = wp.tid()
- time_out[worldid] = time_in[worldid] + opt_timestep[worldid % opt_timestep.shape[0]]
- nefc = nefc_in[worldid]
-
- if nefc > njmax_in:
- wp.printf("nefc overflow - please increase njmax to %u\n", nefc)
-
- if worldid == 0:
- ncollision = ncollision_in[0]
- if ncollision > naconmax_in:
- nconmax = int(wp.ceil(float(ncollision) / float(nworld_in)))
- wp.printf("broadphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, ncollision)
+@cache_kernel
+def _next_time(enable_printf: bool):
+ """Creates _next_time kernel with optional printf for warnings."""
- if nacon_in[0] > naconmax_in:
- nconmax = int(wp.ceil(float(nacon_in[0]) / float(nworld_in)))
- wp.printf("narrowphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, nacon_in[0])
+ @wp.kernel(module="unique")
+ def next_time(
+ # Model:
+ opt_timestep: wp.array(dtype=float),
+ # Data in:
+ nefc_in: wp.array(dtype=int),
+ time_in: wp.array(dtype=float),
+ nworld_in: int,
+ naconmax_in: int,
+ njmax_in: int,
+ nacon_in: wp.array(dtype=int),
+ ncollision_in: wp.array(dtype=int),
+ # Data out:
+ time_out: wp.array(dtype=float),
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
+ ):
+ worldid = wp.tid()
+ time_out[worldid] = time_in[worldid] + opt_timestep[worldid % opt_timestep.shape[0]]
+ nefc = nefc_in[worldid]
+
+ if nefc > njmax_in:
+ if wp.static(enable_printf):
+ wp.printf("nefc overflow - please increase njmax to %u\n", nefc)
+ wp.atomic_max(warning_out, int(WarningType.NEFC_OVERFLOW), 1)
+ wp.atomic_max(warning_info_out[int(WarningType.NEFC_OVERFLOW)], 0, nefc)
+
+ if worldid == 0:
+ ncollision = ncollision_in[0]
+ if ncollision > naconmax_in:
+ nconmax = int(wp.ceil(float(ncollision) / float(nworld_in)))
+ if wp.static(enable_printf):
+ wp.printf("broadphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, ncollision)
+ wp.atomic_max(warning_out, int(WarningType.BROADPHASE_OVERFLOW), 1)
+ wp.atomic_max(warning_info_out[int(WarningType.BROADPHASE_OVERFLOW)], 0, nconmax)
+ wp.atomic_max(warning_info_out[int(WarningType.BROADPHASE_OVERFLOW)], 1, ncollision)
+
+ if nacon_in[0] > naconmax_in:
+ nconmax = int(wp.ceil(float(nacon_in[0]) / float(nworld_in)))
+ if wp.static(enable_printf):
+ wp.printf("narrowphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, nacon_in[0])
+ wp.atomic_max(warning_out, int(WarningType.NARROWPHASE_OVERFLOW), 1)
+ wp.atomic_max(warning_info_out[int(WarningType.NARROWPHASE_OVERFLOW)], 0, nconmax)
+ wp.atomic_max(warning_info_out[int(WarningType.NARROWPHASE_OVERFLOW)], 1, nacon_in[0])
+
+ return next_time
def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None):
@@ -263,10 +283,10 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
)
wp.launch(
- _next_time,
+ _next_time(m.opt.warning_printf),
dim=d.nworld,
inputs=[m.opt.timestep, d.nefc, d.time, d.nworld, d.naconmax, d.njmax, d.nacon, d.ncollision],
- outputs=[d.time],
+ outputs=[d.time, d.warning, d.warning_info],
)
wp.copy(d.qacc_warmstart, d.qacc)
diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py
index 624aa2664..7bc591a52 100644
--- a/mujoco_warp/_src/io.py
+++ b/mujoco_warp/_src/io.py
@@ -207,6 +207,9 @@ def _check_friction(name: str, id_: int, condim: int, friction, checks):
else:
opt.contact_sensor_maxmatch = 64
+ # warning_printf: emit overflow warnings via printf (default True)
+ opt.warning_printf = True
+
# place opt on device
for f in dataclasses.fields(types.Option):
if isinstance(f.type, wp.array):
@@ -691,7 +694,10 @@ def make_data(
if njmax < 0:
raise ValueError("njmax must be >= 0")
- sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
+ sizes = dict(
+ {"*": 1, "NUM_WARNINGS": types.NUM_WARNINGS},
+ **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int},
+ )
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
@@ -809,7 +815,10 @@ def put_data(
if mjd.nefc > njmax:
raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")
- sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
+ sizes = dict(
+ {"*": 1, "NUM_WARNINGS": types.NUM_WARNINGS},
+ **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int},
+ )
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
@@ -890,9 +899,14 @@ def put_data(
"flexedge_J": None,
"nacon": None,
}
+ # Skip fields that shouldn't be copied from MuJoCo data
+ skip_mjd_fields = {"warning", "warning_info"}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
+ if f.name in skip_mjd_fields:
+ d_kwargs[f.name] = _create_array(None, f.type, sizes)
+ continue
val = getattr(mjd, f.name, None)
if val is not None:
shape = val.shape if hasattr(val, "shape") else ()
diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py
index 3cdab0466..13c80e88e 100644
--- a/mujoco_warp/_src/sensor.py
+++ b/mujoco_warp/_src/sensor.py
@@ -36,6 +36,7 @@
from mujoco_warp._src.types import ObjType
from mujoco_warp._src.types import SensorType
from mujoco_warp._src.types import TrnType
+from mujoco_warp._src.types import WarningType
from mujoco_warp._src.types import vec5
from mujoco_warp._src.types import vec6
from mujoco_warp._src.types import vec8
@@ -2207,139 +2208,154 @@ def _check_match(body_parentid: wp.array(dtype=int), body: int, geom: int, objty
return False
-@wp.kernel
-def _contact_match(
- # Model:
- opt_cone: int,
- opt_contact_sensor_maxmatch: int,
- body_parentid: wp.array(dtype=int),
- geom_bodyid: wp.array(dtype=int),
- site_type: wp.array(dtype=int),
- site_size: wp.array(dtype=wp.vec3),
- sensor_objtype: wp.array(dtype=int),
- sensor_objid: wp.array(dtype=int),
- sensor_reftype: wp.array(dtype=int),
- sensor_refid: wp.array(dtype=int),
- sensor_intprm: wp.array2d(dtype=int),
- sensor_contact_adr: wp.array(dtype=int),
- # Data in:
- site_xpos_in: wp.array2d(dtype=wp.vec3),
- site_xmat_in: wp.array2d(dtype=wp.mat33),
- contact_dist_in: wp.array(dtype=float),
- contact_pos_in: wp.array(dtype=wp.vec3),
- contact_frame_in: wp.array(dtype=wp.mat33),
- contact_friction_in: wp.array(dtype=vec5),
- contact_dim_in: wp.array(dtype=int),
- contact_geom_in: wp.array(dtype=wp.vec2i),
- contact_efc_address_in: wp.array2d(dtype=int),
- contact_worldid_in: wp.array(dtype=int),
- contact_type_in: wp.array(dtype=int),
- efc_force_in: wp.array2d(dtype=float),
- njmax_in: int,
- nacon_in: wp.array(dtype=int),
- # Out:
- sensor_contact_nmatch_out: wp.array2d(dtype=int),
- sensor_contact_matchid_out: wp.array3d(dtype=int),
- sensor_contact_criteria_out: wp.array3d(dtype=float),
- sensor_contact_direction_out: wp.array3d(dtype=float),
-):
- contactsensorid, contactid = wp.tid()
- sensorid = sensor_contact_adr[contactsensorid]
+@cache_kernel
+def _contact_match(enable_printf: bool):
+ """Creates _contact_match kernel with optional printf for warnings."""
- if contactid >= nacon_in[0]:
- return
+ @wp.kernel(module="unique")
+ def contact_match(
+ # Model:
+ opt_cone: int,
+ opt_contact_sensor_maxmatch: int,
+ body_parentid: wp.array(dtype=int),
+ geom_bodyid: wp.array(dtype=int),
+ site_type: wp.array(dtype=int),
+ site_size: wp.array(dtype=wp.vec3),
+ sensor_objtype: wp.array(dtype=int),
+ sensor_objid: wp.array(dtype=int),
+ sensor_reftype: wp.array(dtype=int),
+ sensor_refid: wp.array(dtype=int),
+ sensor_intprm: wp.array2d(dtype=int),
+ sensor_contact_adr: wp.array(dtype=int),
+ # Data in:
+ site_xpos_in: wp.array2d(dtype=wp.vec3),
+ site_xmat_in: wp.array2d(dtype=wp.mat33),
+ contact_dist_in: wp.array(dtype=float),
+ contact_pos_in: wp.array(dtype=wp.vec3),
+ contact_frame_in: wp.array(dtype=wp.mat33),
+ contact_friction_in: wp.array(dtype=vec5),
+ contact_dim_in: wp.array(dtype=int),
+ contact_geom_in: wp.array(dtype=wp.vec2i),
+ contact_efc_address_in: wp.array2d(dtype=int),
+ contact_worldid_in: wp.array(dtype=int),
+ contact_type_in: wp.array(dtype=int),
+ efc_force_in: wp.array2d(dtype=float),
+ njmax_in: int,
+ nacon_in: wp.array(dtype=int),
+ # Data out:
+ warning_out: wp.array(dtype=int),
+ warning_info_out: wp.array2d(dtype=int),
+ # Out:
+ sensor_contact_nmatch_out: wp.array2d(dtype=int),
+ sensor_contact_matchid_out: wp.array3d(dtype=int),
+ sensor_contact_criteria_out: wp.array3d(dtype=float),
+ sensor_contact_direction_out: wp.array3d(dtype=float),
+ ):
+ contactsensorid, contactid = wp.tid()
+ sensorid = sensor_contact_adr[contactsensorid]
- if not contact_type_in[contactid] & ContactType.CONSTRAINT:
- return
+ if contactid >= nacon_in[0]:
+ return
- # sensor information
- objtype = sensor_objtype[sensorid]
- objid = sensor_objid[sensorid]
- reftype = sensor_reftype[sensorid]
- refid = sensor_refid[sensorid]
- reduce = sensor_intprm[sensorid, 1]
+ if not contact_type_in[contactid] & ContactType.CONSTRAINT:
+ return
- worldid = contact_worldid_in[contactid]
+ # sensor information
+ objtype = sensor_objtype[sensorid]
+ objid = sensor_objid[sensorid]
+ reftype = sensor_reftype[sensorid]
+ refid = sensor_refid[sensorid]
+ reduce = sensor_intprm[sensorid, 1]
- # site filter
- if objtype == ObjType.SITE:
- if not inside_geom(
- site_xpos_in[worldid, objid], site_xmat_in[worldid, objid], site_size[objid], site_type[objid], contact_pos_in[contactid]
- ):
- return
+ worldid = contact_worldid_in[contactid]
+
+ # site filter
+ if objtype == ObjType.SITE:
+ if not inside_geom(
+ site_xpos_in[worldid, objid],
+ site_xmat_in[worldid, objid],
+ site_size[objid],
+ site_type[objid],
+ contact_pos_in[contactid],
+ ):
+ return
- # unknown-unknown match
- if objtype == ObjType.UNKNOWN and reftype == ObjType.UNKNOWN:
- dir = 1.0
- else:
- # contact information
- geom = contact_geom_in[contactid]
- geom1 = geom[0]
- geom2 = geom[1]
- body1 = geom_bodyid[geom1]
- body2 = geom_bodyid[geom2]
-
- # check match of sensor objects with contact objects
- match11 = _check_match(body_parentid, body1, geom1, objtype, objid)
- match12 = _check_match(body_parentid, body2, geom2, objtype, objid)
- match21 = _check_match(body_parentid, body1, geom1, reftype, refid)
- match22 = _check_match(body_parentid, body2, geom2, reftype, refid)
-
- # if a sensor object is specified, it must be involved in the contact
- if not match11 and not match12:
- return
- if not match21 and not match22:
+ # unknown-unknown match
+ if objtype == ObjType.UNKNOWN and reftype == ObjType.UNKNOWN:
+ dir = 1.0
+ else:
+ # contact information
+ geom = contact_geom_in[contactid]
+ geom1 = geom[0]
+ geom2 = geom[1]
+ body1 = geom_bodyid[geom1]
+ body2 = geom_bodyid[geom2]
+
+ # check match of sensor objects with contact objects
+ match11 = _check_match(body_parentid, body1, geom1, objtype, objid)
+ match12 = _check_match(body_parentid, body2, geom2, objtype, objid)
+ match21 = _check_match(body_parentid, body1, geom1, reftype, refid)
+ match22 = _check_match(body_parentid, body2, geom2, reftype, refid)
+
+ # if a sensor object is specified, it must be involved in the contact
+ if not match11 and not match12:
+ return
+ if not match21 and not match22:
+ return
+
+ # determine direction
+ dir = 1.0
+ if objtype != ObjType.UNKNOWN and reftype != ObjType.UNKNOWN:
+ # both obj1 and obj2 specified: direction depends on order
+ order_regular = match11 and match22
+ order_reverse = match12 and match21
+ if not order_regular and not order_reverse:
+ return
+ if order_reverse and not order_regular:
+ dir = -1.0
+ elif objtype != ObjType.UNKNOWN:
+ if not match11:
+ dir = -1.0
+ elif reftype != ObjType.UNKNOWN:
+ if not match22:
+ dir = -1.0
+
+ contactmatchid = wp.atomic_add(sensor_contact_nmatch_out[worldid], contactsensorid, 1)
+
+ if contactmatchid >= opt_contact_sensor_maxmatch:
+ if wp.static(enable_printf):
+ wp.printf("contact match overflow: please increase Option.contact_sensor_maxmatch to %u\n", contactmatchid)
+ wp.atomic_max(warning_out, int(WarningType.CONTACT_MATCH_OVERFLOW), 1)
+ wp.atomic_max(warning_info_out[int(WarningType.CONTACT_MATCH_OVERFLOW)], 0, contactmatchid)
return
- # determine direction
- dir = 1.0
- if objtype != ObjType.UNKNOWN and reftype != ObjType.UNKNOWN:
- # both obj1 and obj2 specified: direction depends on order
- order_regular = match11 and match22
- order_reverse = match12 and match21
- if not order_regular and not order_reverse:
- return
- if order_reverse and not order_regular:
- dir = -1.0
- elif objtype != ObjType.UNKNOWN:
- if not match11:
- dir = -1.0
- elif reftype != ObjType.UNKNOWN:
- if not match22:
- dir = -1.0
-
- contactmatchid = wp.atomic_add(sensor_contact_nmatch_out[worldid], contactsensorid, 1)
-
- if contactmatchid >= opt_contact_sensor_maxmatch:
- # TODO(team): alternative to wp.printf for reporting overflow?
- wp.printf("contact match overflow: please increase Option.contact_sensor_maxmatch to %u\n", contactmatchid)
- return
+ sensor_contact_matchid_out[worldid, contactsensorid, contactmatchid] = contactid
+
+ if reduce == 1: # mindist
+ sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = contact_dist_in[contactid]
+ elif reduce == 2: # maxforce
+ contact_force = support.contact_force_fn(
+ opt_cone,
+ contact_frame_in,
+ contact_friction_in,
+ contact_dim_in,
+ contact_efc_address_in,
+ efc_force_in,
+ njmax_in,
+ nacon_in,
+ worldid,
+ contactid,
+ False,
+ )
+ force_magnitude = (
+ contact_force[0] * contact_force[0] + contact_force[1] * contact_force[1] + contact_force[2] * contact_force[2]
+ )
+ sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = -force_magnitude
- sensor_contact_matchid_out[worldid, contactsensorid, contactmatchid] = contactid
-
- if reduce == 1: # mindist
- sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = contact_dist_in[contactid]
- elif reduce == 2: # maxforce
- contact_force = support.contact_force_fn(
- opt_cone,
- contact_frame_in,
- contact_friction_in,
- contact_dim_in,
- contact_efc_address_in,
- efc_force_in,
- njmax_in,
- nacon_in,
- worldid,
- contactid,
- False,
- )
- force_magnitude = (
- contact_force[0] * contact_force[0] + contact_force[1] * contact_force[1] + contact_force[2] * contact_force[2]
- )
- sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = -force_magnitude
+ # contact direction
+ sensor_contact_direction_out[worldid, contactsensorid, contactmatchid] = dir
- # contact direction
- sensor_contact_direction_out[worldid, contactsensorid, contactmatchid] = dir
+ return contact_match
@cache_kernel
@@ -2463,7 +2479,7 @@ def sensor_acc(m: Model, d: Data):
sensor_contact_criteria.fill_(1.0e32)
wp.launch(
- _contact_match,
+ _contact_match(m.opt.warning_printf),
dim=(m.sensor_contact_adr.size, d.naconmax),
inputs=[
m.opt.cone,
@@ -2493,7 +2509,14 @@ def sensor_acc(m: Model, d: Data):
d.njmax,
d.nacon,
],
- outputs=[sensor_contact_nmatch, sensor_contact_matchid, sensor_contact_criteria, sensor_contact_direction],
+ outputs=[
+ d.warning,
+ d.warning_info,
+ sensor_contact_nmatch,
+ sensor_contact_matchid,
+ sensor_contact_criteria,
+ sensor_contact_direction,
+ ],
)
# sorting
diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py
index 92c6b726d..8867acadf 100644
--- a/mujoco_warp/_src/smooth.py
+++ b/mujoco_warp/_src/smooth.py
@@ -1985,11 +1985,9 @@ def _transmission(
actuator_length_out[worldid, actid] = wp.dot(axis_angle, gearaxis)
for i in range(3):
actuator_moment_out[worldid, actid, vadr + i] = gearaxis[i]
- elif jnt_typ == JointType.SLIDE or jnt_typ == JointType.HINGE:
+ else: # SLIDE or HINGE
actuator_length_out[worldid, actid] = qpos[qadr] * gear[0]
actuator_moment_out[worldid, actid, vadr] = gear[0]
- else:
- wp.printf("unrecognized joint type")
elif trntype == TrnType.SLIDERCRANK:
# get data
trnid = actuator_trnid[actid]
@@ -2194,8 +2192,6 @@ def _transmission(
moment += wp.dot(jacrdif, wrench_rotation)
actuator_moment_out[worldid, actid, i] = moment
- else:
- wp.printf("unhandled transmission type %d\n", trntype)
@wp.kernel
diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py
index 66632d843..52c47f307 100644
--- a/mujoco_warp/_src/types.py
+++ b/mujoco_warp/_src/types.py
@@ -66,6 +66,31 @@ class BlockDim:
qderiv_actuator_dense: int = 32
+class WarningType(enum.IntEnum):
+ """Warning types for kernel-side overflow detection.
+
+ Attributes:
+ NEFC_OVERFLOW: constraint count exceeded njmax
+ BROADPHASE_OVERFLOW: broadphase collision count exceeded naconmax
+ NARROWPHASE_OVERFLOW: narrowphase contact count exceeded naconmax
+ CONTACT_MATCH_OVERFLOW: contact sensor match count exceeded maxmatch
+ GJK_ITERATIONS: GJK algorithm did not converge within iteration limit
+ EPA_HORIZON: EPA horizon buffer overflow
+ HFIELD_OVERFLOW: heightfield collision count exceeded MJ_MAXCONPAIR
+ """
+
+ NEFC_OVERFLOW = 0
+ BROADPHASE_OVERFLOW = 1
+ NARROWPHASE_OVERFLOW = 2
+ CONTACT_MATCH_OVERFLOW = 3
+ GJK_ITERATIONS = 4
+ EPA_HORIZON = 5
+ HFIELD_OVERFLOW = 6
+
+
+NUM_WARNINGS = len(WarningType)
+
+
class BroadphaseType(enum.IntEnum):
"""Type of broadphase algorithm.
@@ -690,6 +715,9 @@ class Option:
zeros out the contacts at each step)
contact_sensor_maxmatch: max number of contacts considered by contact sensor matching criteria
contacts matched after this value is exceded will be ignored
+ warning_printf: if True, emit overflow warnings via printf to stdout in addition to setting
+ warning flags. Default True. Set to False and use check_warnings(d) for
+ programmatic access to warnings without stdout output.
"""
timestep: array("*", float)
@@ -720,6 +748,7 @@ class Option:
graph_conditional: bool
run_collision_detection: bool
contact_sensor_maxmatch: int
+ warning_printf: bool
@dataclasses.dataclass
@@ -1647,6 +1676,8 @@ class Data:
collision_pairid: ids from broadphase (naconmax, 2)
collision_worldid: collision world ids from broadphase (naconmax,)
ncollision: collision count from broadphase (1,)
+ warning: warning flags (accumulated across steps) (NUM_WARNINGS,)
+ warning_info: warning info (suggested values) (NUM_WARNINGS, 2)
"""
solver_niter: array("nworld", int)
@@ -1738,3 +1769,7 @@ class Data:
collision_pairid: array("naconmax", wp.vec2i)
collision_worldid: array("naconmax", int)
ncollision: array(1, int)
+
+ # warp only: warning flags (accumulated across steps, checked on host)
+ warning: array("NUM_WARNINGS", int) # flag per warning type
+ warning_info: array("NUM_WARNINGS", 2, int) # suggested values
diff --git a/mujoco_warp/_src/types_test.py b/mujoco_warp/_src/types_test.py
index 0d1306a3a..1f6d4a6b4 100644
--- a/mujoco_warp/_src/types_test.py
+++ b/mujoco_warp/_src/types_test.py
@@ -45,6 +45,8 @@ def test_field_order(self, mj_class, mjw_class):
# TODO(team): remove this reordering after MjData._all_fields order is fixed
# there's a bug in _all_fields where solver_niter is in the wrong place
mj_fields.insert(0, mj_fields.pop(mj_fields.index("solver_niter")))
+ # MuJoCo's "warning" field is different from our warp-only warning system
+ mj_fields.remove("warning")
mj_set, mjw_set = set(mj_fields), set(mjw_fields)
diff --git a/mujoco_warp/_src/warning.py b/mujoco_warp/_src/warning.py
new file mode 100644
index 000000000..686f71e46
--- /dev/null
+++ b/mujoco_warp/_src/warning.py
@@ -0,0 +1,94 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Warning utilities for kernel-side overflow detection."""
+
+import sys
+from typing import List
+
+from . import types
+
+_WARNING_MESSAGES = {
+ types.WarningType.NEFC_OVERFLOW: "nefc overflow - increase njmax to {0}",
+ types.WarningType.BROADPHASE_OVERFLOW: ("broadphase overflow - increase nconmax to {0} or naconmax to {1}"),
+ types.WarningType.NARROWPHASE_OVERFLOW: ("narrowphase overflow - increase nconmax to {0} or naconmax to {1}"),
+ types.WarningType.CONTACT_MATCH_OVERFLOW: ("contact match overflow - increase Option.contact_sensor_maxmatch to {0}"),
+ types.WarningType.GJK_ITERATIONS: ("GJK did not converge - increase opt.ccd_iterations (currently {0})"),
+ types.WarningType.EPA_HORIZON: "EPA horizon overflow - horizon size {0} insufficient",
+ types.WarningType.HFIELD_OVERFLOW: (
+ "heightfield collision overflow - decrease hfield rows/cols or modify colliding geom size (limit {0})"
+ ),
+}
+
+
+def check_warnings(d: types.Data, clear: bool = True) -> List[str]:
+ """Check warning flags and emit to stderr.
+
+ This function reads the warning flags set by kernels and emits appropriate
+ warning messages to stderr. Warning flags accumulate across simulation steps
+ (using atomic_max), so this should be called after graph execution completes.
+
+ Args:
+ d: The Data object containing warning flags.
+ clear: Whether to clear warning flags after checking. Default True.
+
+ Returns:
+ List of warning message strings that were emitted.
+ """
+ flags = d.warning.numpy()
+ info = d.warning_info.numpy()
+
+ emitted = []
+ for wtype in types.WarningType:
+ if flags[wtype]:
+ msg = _WARNING_MESSAGES[wtype].format(info[wtype, 0], info[wtype, 1])
+ print(f"Warning: {msg}", file=sys.stderr)
+ emitted.append(msg)
+
+ if clear:
+ d.warning.zero_()
+ d.warning_info.zero_()
+
+ return emitted
+
+
+def get_warnings(d: types.Data) -> List[str]:
+ """Get warning messages without emitting or clearing.
+
+ Args:
+ d: The Data object containing warning flags.
+
+ Returns:
+ List of warning message strings.
+ """
+ flags = d.warning.numpy()
+ info = d.warning_info.numpy()
+
+ messages = []
+ for wtype in types.WarningType:
+ if flags[wtype]:
+ msg = _WARNING_MESSAGES[wtype].format(info[wtype, 0], info[wtype, 1])
+ messages.append(msg)
+
+ return messages
+
+
+def clear_warnings(d: types.Data) -> None:
+ """Clear all warning flags.
+
+ Args:
+ d: The Data object containing warning flags.
+ """
+ d.warning.zero_()
+ d.warning_info.zero_()
diff --git a/mujoco_warp/_src/warning_test.py b/mujoco_warp/_src/warning_test.py
new file mode 100644
index 000000000..e9422442e
--- /dev/null
+++ b/mujoco_warp/_src/warning_test.py
@@ -0,0 +1,356 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for warning system."""
+
+import warnings
+
+import mujoco
+import numpy as np
+import warp as wp
+from absl.testing import absltest
+
+import mujoco_warp as mjw
+from mujoco_warp._src import types
+
+
+class WarningTest(absltest.TestCase):
+ def test_warning_arrays_initialized(self):
+ """Tests that warning arrays are properly initialized."""
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+ """)
+ mjd = mujoco.MjData(mjm)
+
+ d = mjw.put_data(mjm, mjd)
+
+ # Check shapes
+ self.assertEqual(d.warning.shape, (types.NUM_WARNINGS,))
+ self.assertEqual(d.warning_info.shape, (types.NUM_WARNINGS, 2))
+
+ # Check initial values are zero
+ np.testing.assert_array_equal(d.warning.numpy(), np.zeros(types.NUM_WARNINGS, dtype=np.int32))
+ np.testing.assert_array_equal(d.warning_info.numpy(), np.zeros((types.NUM_WARNINGS, 2), dtype=np.int32))
+
+ def test_check_warnings_no_warnings(self):
+ """Tests check_warnings returns empty list when no warnings."""
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+ """)
+ mjd = mujoco.MjData(mjm)
+
+ m = mjw.put_model(mjm)
+ d = mjw.put_data(mjm, mjd)
+
+ # Run a step - should not trigger any warnings
+ mjw.step(m, d)
+ wp.synchronize()
+
+ # Check warnings
+ result = mjw.get_warnings(d)
+ self.assertEqual(result, [])
+
+ def test_nefc_overflow_warning(self):
+ """Tests that nefc overflow sets warning flag correctly."""
+ # Sphere close to ground with large timestep - contacts quickly
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+ """)
+ mjd = mujoco.MjData(mjm)
+ mujoco.mj_forward(mjm, mjd)
+
+ # No initial contacts
+ self.assertEqual(mjd.ncon, 0, "Test setup: should have no initial contacts")
+
+ m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # disable printf in tests
+ # Set njmax very low to trigger overflow when sphere hits ground
+ d = mjw.put_data(mjm, mjd, njmax=1)
+
+ # Run steps until sphere falls and creates contact (~10 steps at 0.01 timestep)
+ for _ in range(20):
+ mjw.step(m, d)
+ wp.synchronize()
+
+ # Check warning flag is set
+ warning_flags = d.warning.numpy()
+ self.assertEqual(warning_flags[types.WarningType.NEFC_OVERFLOW], 1)
+
+ # Check warning info contains suggested value (should be 4 for a single contact with friction)
+ warning_info = d.warning_info.numpy()
+ self.assertGreater(warning_info[types.WarningType.NEFC_OVERFLOW, 0], 1)
+
+ # Check get_warnings returns the message
+ result = mjw.get_warnings(d)
+ self.assertEqual(len(result), 1)
+ self.assertIn("nefc overflow", result[0])
+
+ def test_check_warnings_clears_flags(self):
+ """Tests that check_warnings clears flags by default."""
+ # Sphere close to ground with large timestep
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+ """)
+ mjd = mujoco.MjData(mjm)
+ mujoco.mj_forward(mjm, mjd)
+
+ m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # disable printf in tests
+ d = mjw.put_data(mjm, mjd, njmax=1)
+
+ # Run until contact
+ for _ in range(20):
+ mjw.step(m, d)
+ wp.synchronize()
+
+ # First check_warnings should return warnings and clear
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter("always")
+ result1 = mjw.check_warnings(d, clear=True)
+
+ # Second call should return empty (flags cleared)
+ result2 = mjw.get_warnings(d)
+
+ self.assertGreater(len(result1), 0)
+ self.assertEqual(len(result2), 0)
+
+ def test_clear_warnings(self):
+ """Tests clear_warnings utility."""
+ # Sphere close to ground with large timestep
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+ """)
+ mjd = mujoco.MjData(mjm)
+ mujoco.mj_forward(mjm, mjd)
+
+ m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # disable printf in tests
+ d = mjw.put_data(mjm, mjd, njmax=1)
+
+ # Run until contact
+ for _ in range(20):
+ mjw.step(m, d)
+ wp.synchronize()
+
+ # Verify warning is set
+ self.assertGreater(len(mjw.get_warnings(d)), 0)
+
+ # Clear warnings
+ mjw.clear_warnings(d)
+
+ # Verify cleared
+ self.assertEqual(len(mjw.get_warnings(d)), 0)
+ np.testing.assert_array_equal(d.warning.numpy(), np.zeros(types.NUM_WARNINGS, dtype=np.int32))
+
+ @absltest.skipIf(not wp.get_device().is_cuda, "Skipping test that requires GPU.")
+ def test_multi_step_graph_captures_mid_graph_warning(self):
+ """Tests that a multi-step graph captures warnings even if they occur mid-graph.
+
+ Captures 20 steps in one graph. The sphere starts close to ground and hits it
+ around step 5-10. The warning from that step should be captured and readable
+ after the graph completes.
+ """
+ # Sphere close to ground - will contact around step 5-10
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+ """)
+ mjd = mujoco.MjData(mjm)
+ mujoco.mj_forward(mjm, mjd)
+
+ # No initial contact
+ self.assertEqual(mjd.ncon, 0, "Test setup: should have no initial contacts")
+
+ m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # disable printf in tests
+ d = mjw.put_data(mjm, mjd, njmax=1)
+
+ # Clear warnings
+ mjw.clear_warnings(d)
+
+ # Capture 20 steps as a single graph - warning should occur around step 5-10
+ nsteps = 20
+ with wp.ScopedCapture() as capture:
+ for _ in range(nsteps):
+ mjw.step(m, d)
+ graph = capture.graph
+
+ # Run graph once - warning happens mid-graph
+ wp.capture_launch(graph)
+ wp.synchronize()
+
+ # Check that warning was captured from the step where contact happened
+ warning_flags = d.warning.numpy()
+ self.assertEqual(
+ warning_flags[types.WarningType.NEFC_OVERFLOW], 1, f"Expected nefc overflow warning, got flags: {warning_flags}"
+ )
+
+ # Verify warning info shows the correct suggested value
+ warning_info = d.warning_info.numpy()
+ self.assertEqual(
+ warning_info[types.WarningType.NEFC_OVERFLOW, 0],
+ 4,
+ f"Expected nefc=4 in warning info, got: {warning_info[types.WarningType.NEFC_OVERFLOW]}",
+ )
+
+ # Get warnings - should include the overflow message
+ result = mjw.get_warnings(d)
+ self.assertEqual(len(result), 1)
+ self.assertIn("nefc overflow", result[0])
+ self.assertIn("4", result[0])
+
+ @absltest.skipIf(not wp.get_device().is_cuda, "Skipping test that requires GPU.")
+ def test_single_step_graph_warns_only_when_event_occurs(self):
+ """Tests that a single-step graph only reports warning on the step it happens.
+
+ Captures 1 step as a graph. Runs it multiple times. Verifies that:
+ 1. Before contact: no warnings generated
+ 2. After contact: warnings generated
+ 3. After clearing and running more: warnings still generated (overflow persists)
+ """
+ # Sphere close to ground - will contact around launch 5-10
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+
+
+ """)
+ mjd = mujoco.MjData(mjm)
+ mujoco.mj_forward(mjm, mjd)
+
+ m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # disable printf in tests
+ d = mjw.put_data(mjm, mjd, njmax=1)
+
+ # Clear warnings
+ mjw.clear_warnings(d)
+
+ # Capture single step
+ with wp.ScopedCapture() as capture:
+ mjw.step(m, d)
+ graph = capture.graph
+
+ # Run a few launches before contact happens (sphere is falling)
+ for _ in range(3):
+ wp.capture_launch(graph)
+ wp.synchronize()
+
+ # Check: no warnings yet (sphere hasn't hit ground)
+ warning_flags_before = d.warning.numpy().copy()
+ self.assertEqual(warning_flags_before[types.WarningType.NEFC_OVERFLOW], 0, "Should have no warning before contact")
+
+ # Clear and run more launches until contact happens
+ mjw.clear_warnings(d)
+ for _ in range(15):
+ wp.capture_launch(graph)
+ wp.synchronize()
+
+ # Check: warning should now be set (contact occurred)
+ warning_flags_after = d.warning.numpy()
+ self.assertEqual(warning_flags_after[types.WarningType.NEFC_OVERFLOW], 1, "Should have warning after contact")
+
+ # Clear and run more - warning should still be generated (overflow persists each step)
+ mjw.clear_warnings(d)
+ for _ in range(5):
+ wp.capture_launch(graph)
+ wp.synchronize()
+
+ warning_flags_final = d.warning.numpy()
+ self.assertEqual(
+ warning_flags_final[types.WarningType.NEFC_OVERFLOW], 1, "Warning should still be generated (overflow persists)"
+ )
+
+ # Check warning message
+ result = mjw.get_warnings(d)
+ self.assertEqual(len(result), 1)
+ self.assertIn("nefc overflow", result[0])
+
+ def test_warning_printf_option(self):
+ """Tests that warning_printf option is available and defaults to True."""
+ mjm = mujoco.MjModel.from_xml_string("""
+
+
+
+
+
+
+
+
+ """)
+
+ m = mjw.put_model(mjm)
+
+ # Default should be True (printf enabled)
+ self.assertTrue(m.opt.warning_printf)
+
+
+if __name__ == "__main__":
+ wp.init()
+ absltest.main()
diff --git a/mujoco_warp/testspeed.py b/mujoco_warp/testspeed.py
index a3d03c536..3d799b7b4 100644
--- a/mujoco_warp/testspeed.py
+++ b/mujoco_warp/testspeed.py
@@ -286,6 +286,7 @@ def _main(argv: Sequence[str]):
with wp.ScopedDevice(_DEVICE.value):
override_model(mjm, _OVERRIDE.value)
m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # use check_warnings instead
override_model(m, _OVERRIDE.value)
d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value)
if _FORMAT.value == "human":
diff --git a/mujoco_warp/viewer.py b/mujoco_warp/viewer.py
index 0a66ed3c6..f23d4f66f 100644
--- a/mujoco_warp/viewer.py
+++ b/mujoco_warp/viewer.py
@@ -43,6 +43,7 @@
from mujoco_warp._src.io import find_keys
from mujoco_warp._src.io import make_trajectory
from mujoco_warp._src.io import override_model
+from mujoco_warp._src.warning import check_warnings
class EngineOptions(enum.IntEnum):
@@ -147,6 +148,7 @@ def _main(argv: Sequence[str]) -> None:
with wp.ScopedDevice(_DEVICE.value):
override_model(mjm, _OVERRIDE.value)
m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # use check_warnings instead
override_model(m, _OVERRIDE.value)
broadphase, filter = mjw.BroadphaseType(m.opt.broadphase).name, mjw.BroadphaseFilter(m.opt.broadphase_filter).name
solver, cone = mjw.SolverType(m.opt.solver).name, mjw.ConeType(m.opt.cone).name
@@ -189,15 +191,18 @@ def _main(argv: Sequence[str]) -> None:
if mjm.opt != opt:
opt = copy.copy(mjm.opt)
m = mjw.put_model(mjm)
+ m.opt.warning_printf = False # use check_warnings instead
graph = _compile_step(m, d)
if _VIEWER_GLOBAL_STATE["running"]:
wp.capture_launch(graph)
wp.synchronize()
+ check_warnings(d)
elif _VIEWER_GLOBAL_STATE["step_once"]:
_VIEWER_GLOBAL_STATE["step_once"] = False
wp.capture_launch(graph)
wp.synchronize()
+ check_warnings(d)
mjw.get_data_into(mjd, mjm, d)