diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index e6a10bf38..2cace8a10 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -99,3 +99,7 @@ from mujoco_warp._src.types import State as State from mujoco_warp._src.types import Statistic as Statistic from mujoco_warp._src.types import TrnType as TrnType +from mujoco_warp._src.types import WarningType as WarningType +from mujoco_warp._src.warning import check_warnings as check_warnings +from mujoco_warp._src.warning import clear_warnings as clear_warnings +from mujoco_warp._src.warning import get_warnings as get_warnings diff --git a/mujoco_warp/_src/benchmark.py b/mujoco_warp/_src/benchmark.py index eaadae0fa..31cf1e314 100644 --- a/mujoco_warp/_src/benchmark.py +++ b/mujoco_warp/_src/benchmark.py @@ -25,6 +25,7 @@ from mujoco_warp._src.types import Data from mujoco_warp._src.types import Model from mujoco_warp._src.util_misc import halton +from mujoco_warp._src.warning import check_warnings def _sum(stack1, stack2): @@ -145,6 +146,9 @@ def benchmark( wp.synchronize() run_end = time.perf_counter() + # Check and emit any overflow warnings, then clear flags + check_warnings(d, clear=True) + time_vec[i] = run_end - run_beg if trace: trace = _sum(trace, tracer.trace()) diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index e3e33acc6..c141696fe 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -34,6 +34,7 @@ from mujoco_warp._src.types import EnableBit from mujoco_warp._src.types import GeomType from mujoco_warp._src.types import Model +from mujoco_warp._src.types import WarningType from mujoco_warp._src.types import mat43 from mujoco_warp._src.types import mat63 from mujoco_warp._src.types import vec5 @@ -154,6 +155,7 @@ def ccd_hfield_kernel_builder( geomtype2: int, gjk_iterations: int, epa_iterations: int, + warning_printf: bool, ): """Kernel builder for heightfield CCD collisions (no multiccd args).""" @@ -232,6 +234,9 @@ def ccd_hfield_kernel( contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), nacon_out: wp.array(dtype=int), + # Warning output: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), ): tid = wp.tid() if tid >= ncollision_in[0]: @@ -403,10 +408,13 @@ def ccd_hfield_kernel( # add both triangles from this cell for i in range(2): if count >= MJ_MAXCONPAIR: - wp.printf( - "height field collision overflow, number of collisions >= %u - please adjust resolution: \n decrease the number of hfield rows/cols or modify size of colliding geom\n", - MJ_MAXCONPAIR, - ) + if wp.static(warning_printf): + wp.printf( + "height field collision overflow, number of collisions >= %u - please adjust resolution: \n decrease the number of hfield rows/cols or modify size of colliding geom\n", + MJ_MAXCONPAIR, + ) + wp.atomic_max(warning_out, int(WarningType.HFIELD_OVERFLOW), 1) + wp.atomic_max(warning_info_out, int(WarningType.HFIELD_OVERFLOW), 0, MJ_MAXCONPAIR) continue # add vert @@ -453,6 +461,9 @@ def ccd_hfield_kernel( epa_pr, epa_norm2, epa_horizon, + wp.static(warning_printf), + warning_out, + warning_info_out, ) if ncontact == 0: @@ -692,6 +703,7 @@ def ccd_kernel_builder( gjk_iterations: int, epa_iterations: int, use_multiccd: bool, + warning_printf: bool, ): """Kernel builder for non-heightfield CCD collisions (no hfield args).""" @@ -751,6 +763,9 @@ def eval_ccd_write_contact( contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), nacon_out: wp.array(dtype=int), + # Warning output: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), ) -> int: points = mat43() witness1 = mat43() @@ -781,6 +796,9 @@ def eval_ccd_write_contact( epa_pr_in[tid], epa_norm2_in[tid], epa_horizon_in[tid], + wp.static(warning_printf), + warning_out, + warning_info_out, ) if dist >= 0.0 and pairid[1] == -1: @@ -953,6 +971,9 @@ def ccd_kernel( contact_type_out: wp.array(dtype=int), contact_geomcollisionid_out: wp.array(dtype=int), nacon_out: wp.array(dtype=int), + # Warning output: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), ): tid = wp.tid() if tid >= ncollision_in[0]: @@ -1064,6 +1085,8 @@ def ccd_kernel( contact_type_out, contact_geomcollisionid_out, nacon_out, + warning_out, + warning_info_out, ) return ccd_kernel @@ -1150,6 +1173,8 @@ def _pair_count(p1: int, p2: int) -> int: d.contact.type, d.contact.geomcollisionid, d.nacon, + d.warning, + d.warning_info, ] # Launch heightfield collision kernels (no multiccd args, 72 args total) @@ -1158,7 +1183,7 @@ def _pair_count(p1: int, p2: int) -> int: g2 = geom_pair[1].value if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and _pair_count(g1, g2): wp.launch( - ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations), + ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, m.opt.warning_printf), dim=d.naconmax, inputs=[ m.opt.ccd_tolerance, @@ -1249,7 +1274,7 @@ def _pair_count(p1: int, p2: int) -> int: g2 = geom_pair[1].value if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and _pair_count(g1, g2): wp.launch( - ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd), + ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd, m.opt.warning_printf), dim=d.naconmax, inputs=[ m.opt.ccd_tolerance, diff --git a/mujoco_warp/_src/collision_gjk.py b/mujoco_warp/_src/collision_gjk.py index 96b550ec7..3c172268d 100644 --- a/mujoco_warp/_src/collision_gjk.py +++ b/mujoco_warp/_src/collision_gjk.py @@ -20,6 +20,7 @@ from mujoco_warp._src.collision_primitive import Geom from mujoco_warp._src.types import GeomType +from mujoco_warp._src.types import WarningType from mujoco_warp._src.types import mat43 from mujoco_warp._src.types import mat63 @@ -571,6 +572,10 @@ def gjk( geomtype2: int, cutoff: float, is_discrete: bool, + warning_printf: bool, + # Data out: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), ) -> GJKResult: """Find distance within a tolerance between two geoms.""" cutoff2 = cutoff * cutoff @@ -667,7 +672,10 @@ def gjk( cnt += 1 if cnt == gjk_iterations: - wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations) + if warning_printf: + wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations) + wp.atomic_max(warning_out, int(WarningType.GJK_ITERATIONS), 1) + wp.atomic_max(warning_info_out, int(WarningType.GJK_ITERATIONS), 0, gjk_iterations) result = GJKResult() @@ -1218,6 +1226,10 @@ def _epa( geomtype1: int, geomtype2: int, is_discrete: bool, + warning_printf: bool, + # Data out: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), ) -> Tuple[float, wp.vec3, wp.vec3, int]: """Recover penetration data from two geoms in contact given an initial polytope.""" upper = FLOAT_MAX @@ -1288,7 +1300,10 @@ def _epa( pt.nhorizon = _add_edge(pt, face[1], face[2]) pt.nhorizon = _add_edge(pt, face[2], face[0]) if pt.nhorizon == -1: - wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0]) + if warning_printf: + wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0]) + wp.atomic_max(warning_out, int(WarningType.EPA_HORIZON), 1) + wp.atomic_max(warning_info_out, int(WarningType.EPA_HORIZON), 0, pt.horizon.shape[0]) idx = -1 break @@ -1305,7 +1320,10 @@ def _epa( pt.nhorizon = _add_edge(pt, face[1], face[2]) pt.nhorizon = _add_edge(pt, face[2], face[0]) if pt.nhorizon == -1: - wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0]) + if warning_printf: + wp.printf("Warning: EPA horizon = %d isn't large enough.\n", pt.horizon.shape[0]) + wp.atomic_max(warning_out, int(WarningType.EPA_HORIZON), 1) + wp.atomic_max(warning_info_out, int(WarningType.EPA_HORIZON), 0, pt.horizon.shape[0]) idx = -1 break @@ -1333,7 +1351,10 @@ def _epa( cnt += 1 if cnt == epa_iterations: - wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations) + if warning_printf: + wp.printf("Warning: opt.ccd_iterations, currently set to %d, needs to be increased.\n", gjk_iterations) + wp.atomic_max(warning_out, int(WarningType.GJK_ITERATIONS), 1) + wp.atomic_max(warning_info_out, int(WarningType.GJK_ITERATIONS), 0, gjk_iterations) # return from valid face if idx > -1: @@ -2232,6 +2253,10 @@ def ccd( face_pr: wp.array(dtype=wp.vec3), face_norm2: wp.array(dtype=float), horizon: wp.array(dtype=int), + warning_printf: bool, + # Data out: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), ) -> Tuple[float, int, wp.vec3, wp.vec3, int]: """General convex collision detection via GJK/EPA.""" full_margin1 = 0.0 @@ -2257,7 +2282,21 @@ def ccd( # special handling for sphere and capsule (shrink to point and line respectively) if size1 + size2 > 0.0: cutoff += full_margin1 + full_margin2 - result = gjk(tolerance, gjk_iterations, geom1, geom2, x_1, x_2, geomtype1, geomtype2, cutoff, is_discrete) + result = gjk( + tolerance, + gjk_iterations, + geom1, + geom2, + x_1, + x_2, + geomtype1, + geomtype2, + cutoff, + is_discrete, + warning_printf, + warning_out, + warning_info_out, + ) # shallow penetration, inflate contact if result.dist > tolerance: @@ -2273,7 +2312,21 @@ def ccd( geom2.size = wp.vec3(size2, geom2.size[1], geom2.size[2]) cutoff -= full_margin1 + full_margin2 - result = gjk(tolerance, gjk_iterations, geom1, geom2, x_1, x_2, geomtype1, geomtype2, cutoff, is_discrete) + result = gjk( + tolerance, + gjk_iterations, + geom1, + geom2, + x_1, + x_2, + geomtype1, + geomtype2, + cutoff, + is_discrete, + warning_printf, + warning_out, + warning_info_out, + ) # no penetration depth to recover if result.dist > tolerance or result.dim < 2: @@ -2349,7 +2402,20 @@ def ccd( if pt.status: return result.dist, 1, result.x1, result.x2, -1 - dist, x1, x2, idx = _epa(tolerance, gjk_iterations, epa_iterations, pt, geom1, geom2, geomtype1, geomtype2, is_discrete) + dist, x1, x2, idx = _epa( + tolerance, + gjk_iterations, + epa_iterations, + pt, + geom1, + geom2, + geomtype1, + geomtype2, + is_discrete, + warning_printf, + warning_out, + warning_info_out, + ) if idx == -1: return FLOAT_MAX, 0, wp.vec3(), wp.vec3(), -1 diff --git a/mujoco_warp/_src/collision_gjk_test.py b/mujoco_warp/_src/collision_gjk_test.py index 1e61de65a..d641cdf58 100644 --- a/mujoco_warp/_src/collision_gjk_test.py +++ b/mujoco_warp/_src/collision_gjk_test.py @@ -26,6 +26,7 @@ from mujoco_warp._src.collision_primitive import Geom from mujoco_warp._src.types import MJ_MAX_EPAFACES from mujoco_warp._src.types import MJ_MAX_EPAHORIZON +from mujoco_warp._src.types import NUM_WARNINGS def _geom_dist( @@ -108,6 +109,9 @@ def _ccd_kernel( endvert: wp.array(dtype=wp.vec3), face1: wp.array(dtype=wp.vec3), face2: wp.array(dtype=wp.vec3), + # Data out: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), # Out: dist_out: wp.array(dtype=float), ncon_out: wp.array(dtype=int), @@ -200,6 +204,9 @@ def _ccd_kernel( face_pr, face_norm2, horizon, + False, # warning_printf + warning_out, + warning_info_out, ) if wp.static(multiccd): @@ -236,6 +243,8 @@ def _ccd_kernel( dist_out = wp.array(shape=(1,), dtype=float) ncon_out = wp.array(shape=(1,), dtype=int) pos_out = wp.array(shape=(2,), dtype=wp.vec3) + warning_out = wp.zeros(NUM_WARNINGS, dtype=int) + warning_info_out = wp.zeros((NUM_WARNINGS, 2), dtype=int) wp.launch( _ccd_kernel, dim=1, @@ -280,6 +289,8 @@ def _ccd_kernel( multiccd_endvert, multiccd_face1, multiccd_face2, + warning_out, + warning_info_out, ], outputs=[ dist_out, diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 032070be8..5e71a203a 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -39,6 +39,7 @@ from mujoco_warp._src.types import Model from mujoco_warp._src.types import TileSet from mujoco_warp._src.types import TrnType +from mujoco_warp._src.types import WarningType from mujoco_warp._src.types import vec10f from mujoco_warp._src.warp_util import cache_kernel from mujoco_warp._src.warp_util import event_scope @@ -190,37 +191,56 @@ def _next_activation( act_out[worldid, actid] = act -@wp.kernel -def _next_time( - # Model: - opt_timestep: wp.array(dtype=float), - # Data in: - nefc_in: wp.array(dtype=int), - time_in: wp.array(dtype=float), - nworld_in: int, - naconmax_in: int, - njmax_in: int, - nacon_in: wp.array(dtype=int), - ncollision_in: wp.array(dtype=int), - # Data out: - time_out: wp.array(dtype=float), -): - worldid = wp.tid() - time_out[worldid] = time_in[worldid] + opt_timestep[worldid % opt_timestep.shape[0]] - nefc = nefc_in[worldid] - - if nefc > njmax_in: - wp.printf("nefc overflow - please increase njmax to %u\n", nefc) - - if worldid == 0: - ncollision = ncollision_in[0] - if ncollision > naconmax_in: - nconmax = int(wp.ceil(float(ncollision) / float(nworld_in))) - wp.printf("broadphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, ncollision) +@cache_kernel +def _next_time(enable_printf: bool): + """Creates _next_time kernel with optional printf for warnings.""" - if nacon_in[0] > naconmax_in: - nconmax = int(wp.ceil(float(nacon_in[0]) / float(nworld_in))) - wp.printf("narrowphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, nacon_in[0]) + @wp.kernel(module="unique") + def next_time( + # Model: + opt_timestep: wp.array(dtype=float), + # Data in: + nefc_in: wp.array(dtype=int), + time_in: wp.array(dtype=float), + nworld_in: int, + naconmax_in: int, + njmax_in: int, + nacon_in: wp.array(dtype=int), + ncollision_in: wp.array(dtype=int), + # Data out: + time_out: wp.array(dtype=float), + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), + ): + worldid = wp.tid() + time_out[worldid] = time_in[worldid] + opt_timestep[worldid % opt_timestep.shape[0]] + nefc = nefc_in[worldid] + + if nefc > njmax_in: + if wp.static(enable_printf): + wp.printf("nefc overflow - please increase njmax to %u\n", nefc) + wp.atomic_max(warning_out, int(WarningType.NEFC_OVERFLOW), 1) + wp.atomic_max(warning_info_out[int(WarningType.NEFC_OVERFLOW)], 0, nefc) + + if worldid == 0: + ncollision = ncollision_in[0] + if ncollision > naconmax_in: + nconmax = int(wp.ceil(float(ncollision) / float(nworld_in))) + if wp.static(enable_printf): + wp.printf("broadphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, ncollision) + wp.atomic_max(warning_out, int(WarningType.BROADPHASE_OVERFLOW), 1) + wp.atomic_max(warning_info_out[int(WarningType.BROADPHASE_OVERFLOW)], 0, nconmax) + wp.atomic_max(warning_info_out[int(WarningType.BROADPHASE_OVERFLOW)], 1, ncollision) + + if nacon_in[0] > naconmax_in: + nconmax = int(wp.ceil(float(nacon_in[0]) / float(nworld_in))) + if wp.static(enable_printf): + wp.printf("narrowphase overflow - please increase nconmax to %u or naconmax to %u\n", nconmax, nacon_in[0]) + wp.atomic_max(warning_out, int(WarningType.NARROWPHASE_OVERFLOW), 1) + wp.atomic_max(warning_info_out[int(WarningType.NARROWPHASE_OVERFLOW)], 0, nconmax) + wp.atomic_max(warning_info_out[int(WarningType.NARROWPHASE_OVERFLOW)], 1, nacon_in[0]) + + return next_time def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None): @@ -263,10 +283,10 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) ) wp.launch( - _next_time, + _next_time(m.opt.warning_printf), dim=d.nworld, inputs=[m.opt.timestep, d.nefc, d.time, d.nworld, d.naconmax, d.njmax, d.nacon, d.ncollision], - outputs=[d.time], + outputs=[d.time, d.warning, d.warning_info], ) wp.copy(d.qacc_warmstart, d.qacc) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 624aa2664..7bc591a52 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -207,6 +207,9 @@ def _check_friction(name: str, id_: int, condim: int, friction, checks): else: opt.contact_sensor_maxmatch = 64 + # warning_printf: emit overflow warnings via printf (default True) + opt.warning_printf = True + # place opt on device for f in dataclasses.fields(types.Option): if isinstance(f.type, wp.array): @@ -691,7 +694,10 @@ def make_data( if njmax < 0: raise ValueError("njmax must be >= 0") - sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int}) + sizes = dict( + {"*": 1, "NUM_WARNINGS": types.NUM_WARNINGS}, + **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int}, + ) sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max() sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1)) tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE @@ -809,7 +815,10 @@ def put_data( if mjd.nefc > njmax: raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})") - sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int}) + sizes = dict( + {"*": 1, "NUM_WARNINGS": types.NUM_WARNINGS}, + **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int}, + ) sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max() sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1)) tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE @@ -890,9 +899,14 @@ def put_data( "flexedge_J": None, "nacon": None, } + # Skip fields that shouldn't be copied from MuJoCo data + skip_mjd_fields = {"warning", "warning_info"} for f in dataclasses.fields(types.Data): if f.name in d_kwargs: continue + if f.name in skip_mjd_fields: + d_kwargs[f.name] = _create_array(None, f.type, sizes) + continue val = getattr(mjd, f.name, None) if val is not None: shape = val.shape if hasattr(val, "shape") else () diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 3cdab0466..13c80e88e 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -36,6 +36,7 @@ from mujoco_warp._src.types import ObjType from mujoco_warp._src.types import SensorType from mujoco_warp._src.types import TrnType +from mujoco_warp._src.types import WarningType from mujoco_warp._src.types import vec5 from mujoco_warp._src.types import vec6 from mujoco_warp._src.types import vec8 @@ -2207,139 +2208,154 @@ def _check_match(body_parentid: wp.array(dtype=int), body: int, geom: int, objty return False -@wp.kernel -def _contact_match( - # Model: - opt_cone: int, - opt_contact_sensor_maxmatch: int, - body_parentid: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - site_type: wp.array(dtype=int), - site_size: wp.array(dtype=wp.vec3), - sensor_objtype: wp.array(dtype=int), - sensor_objid: wp.array(dtype=int), - sensor_reftype: wp.array(dtype=int), - sensor_refid: wp.array(dtype=int), - sensor_intprm: wp.array2d(dtype=int), - sensor_contact_adr: wp.array(dtype=int), - # Data in: - site_xpos_in: wp.array2d(dtype=wp.vec3), - site_xmat_in: wp.array2d(dtype=wp.mat33), - contact_dist_in: wp.array(dtype=float), - contact_pos_in: wp.array(dtype=wp.vec3), - contact_frame_in: wp.array(dtype=wp.mat33), - contact_friction_in: wp.array(dtype=vec5), - contact_dim_in: wp.array(dtype=int), - contact_geom_in: wp.array(dtype=wp.vec2i), - contact_efc_address_in: wp.array2d(dtype=int), - contact_worldid_in: wp.array(dtype=int), - contact_type_in: wp.array(dtype=int), - efc_force_in: wp.array2d(dtype=float), - njmax_in: int, - nacon_in: wp.array(dtype=int), - # Out: - sensor_contact_nmatch_out: wp.array2d(dtype=int), - sensor_contact_matchid_out: wp.array3d(dtype=int), - sensor_contact_criteria_out: wp.array3d(dtype=float), - sensor_contact_direction_out: wp.array3d(dtype=float), -): - contactsensorid, contactid = wp.tid() - sensorid = sensor_contact_adr[contactsensorid] +@cache_kernel +def _contact_match(enable_printf: bool): + """Creates _contact_match kernel with optional printf for warnings.""" - if contactid >= nacon_in[0]: - return + @wp.kernel(module="unique") + def contact_match( + # Model: + opt_cone: int, + opt_contact_sensor_maxmatch: int, + body_parentid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + site_type: wp.array(dtype=int), + site_size: wp.array(dtype=wp.vec3), + sensor_objtype: wp.array(dtype=int), + sensor_objid: wp.array(dtype=int), + sensor_reftype: wp.array(dtype=int), + sensor_refid: wp.array(dtype=int), + sensor_intprm: wp.array2d(dtype=int), + sensor_contact_adr: wp.array(dtype=int), + # Data in: + site_xpos_in: wp.array2d(dtype=wp.vec3), + site_xmat_in: wp.array2d(dtype=wp.mat33), + contact_dist_in: wp.array(dtype=float), + contact_pos_in: wp.array(dtype=wp.vec3), + contact_frame_in: wp.array(dtype=wp.mat33), + contact_friction_in: wp.array(dtype=vec5), + contact_dim_in: wp.array(dtype=int), + contact_geom_in: wp.array(dtype=wp.vec2i), + contact_efc_address_in: wp.array2d(dtype=int), + contact_worldid_in: wp.array(dtype=int), + contact_type_in: wp.array(dtype=int), + efc_force_in: wp.array2d(dtype=float), + njmax_in: int, + nacon_in: wp.array(dtype=int), + # Data out: + warning_out: wp.array(dtype=int), + warning_info_out: wp.array2d(dtype=int), + # Out: + sensor_contact_nmatch_out: wp.array2d(dtype=int), + sensor_contact_matchid_out: wp.array3d(dtype=int), + sensor_contact_criteria_out: wp.array3d(dtype=float), + sensor_contact_direction_out: wp.array3d(dtype=float), + ): + contactsensorid, contactid = wp.tid() + sensorid = sensor_contact_adr[contactsensorid] - if not contact_type_in[contactid] & ContactType.CONSTRAINT: - return + if contactid >= nacon_in[0]: + return - # sensor information - objtype = sensor_objtype[sensorid] - objid = sensor_objid[sensorid] - reftype = sensor_reftype[sensorid] - refid = sensor_refid[sensorid] - reduce = sensor_intprm[sensorid, 1] + if not contact_type_in[contactid] & ContactType.CONSTRAINT: + return - worldid = contact_worldid_in[contactid] + # sensor information + objtype = sensor_objtype[sensorid] + objid = sensor_objid[sensorid] + reftype = sensor_reftype[sensorid] + refid = sensor_refid[sensorid] + reduce = sensor_intprm[sensorid, 1] - # site filter - if objtype == ObjType.SITE: - if not inside_geom( - site_xpos_in[worldid, objid], site_xmat_in[worldid, objid], site_size[objid], site_type[objid], contact_pos_in[contactid] - ): - return + worldid = contact_worldid_in[contactid] + + # site filter + if objtype == ObjType.SITE: + if not inside_geom( + site_xpos_in[worldid, objid], + site_xmat_in[worldid, objid], + site_size[objid], + site_type[objid], + contact_pos_in[contactid], + ): + return - # unknown-unknown match - if objtype == ObjType.UNKNOWN and reftype == ObjType.UNKNOWN: - dir = 1.0 - else: - # contact information - geom = contact_geom_in[contactid] - geom1 = geom[0] - geom2 = geom[1] - body1 = geom_bodyid[geom1] - body2 = geom_bodyid[geom2] - - # check match of sensor objects with contact objects - match11 = _check_match(body_parentid, body1, geom1, objtype, objid) - match12 = _check_match(body_parentid, body2, geom2, objtype, objid) - match21 = _check_match(body_parentid, body1, geom1, reftype, refid) - match22 = _check_match(body_parentid, body2, geom2, reftype, refid) - - # if a sensor object is specified, it must be involved in the contact - if not match11 and not match12: - return - if not match21 and not match22: + # unknown-unknown match + if objtype == ObjType.UNKNOWN and reftype == ObjType.UNKNOWN: + dir = 1.0 + else: + # contact information + geom = contact_geom_in[contactid] + geom1 = geom[0] + geom2 = geom[1] + body1 = geom_bodyid[geom1] + body2 = geom_bodyid[geom2] + + # check match of sensor objects with contact objects + match11 = _check_match(body_parentid, body1, geom1, objtype, objid) + match12 = _check_match(body_parentid, body2, geom2, objtype, objid) + match21 = _check_match(body_parentid, body1, geom1, reftype, refid) + match22 = _check_match(body_parentid, body2, geom2, reftype, refid) + + # if a sensor object is specified, it must be involved in the contact + if not match11 and not match12: + return + if not match21 and not match22: + return + + # determine direction + dir = 1.0 + if objtype != ObjType.UNKNOWN and reftype != ObjType.UNKNOWN: + # both obj1 and obj2 specified: direction depends on order + order_regular = match11 and match22 + order_reverse = match12 and match21 + if not order_regular and not order_reverse: + return + if order_reverse and not order_regular: + dir = -1.0 + elif objtype != ObjType.UNKNOWN: + if not match11: + dir = -1.0 + elif reftype != ObjType.UNKNOWN: + if not match22: + dir = -1.0 + + contactmatchid = wp.atomic_add(sensor_contact_nmatch_out[worldid], contactsensorid, 1) + + if contactmatchid >= opt_contact_sensor_maxmatch: + if wp.static(enable_printf): + wp.printf("contact match overflow: please increase Option.contact_sensor_maxmatch to %u\n", contactmatchid) + wp.atomic_max(warning_out, int(WarningType.CONTACT_MATCH_OVERFLOW), 1) + wp.atomic_max(warning_info_out[int(WarningType.CONTACT_MATCH_OVERFLOW)], 0, contactmatchid) return - # determine direction - dir = 1.0 - if objtype != ObjType.UNKNOWN and reftype != ObjType.UNKNOWN: - # both obj1 and obj2 specified: direction depends on order - order_regular = match11 and match22 - order_reverse = match12 and match21 - if not order_regular and not order_reverse: - return - if order_reverse and not order_regular: - dir = -1.0 - elif objtype != ObjType.UNKNOWN: - if not match11: - dir = -1.0 - elif reftype != ObjType.UNKNOWN: - if not match22: - dir = -1.0 - - contactmatchid = wp.atomic_add(sensor_contact_nmatch_out[worldid], contactsensorid, 1) - - if contactmatchid >= opt_contact_sensor_maxmatch: - # TODO(team): alternative to wp.printf for reporting overflow? - wp.printf("contact match overflow: please increase Option.contact_sensor_maxmatch to %u\n", contactmatchid) - return + sensor_contact_matchid_out[worldid, contactsensorid, contactmatchid] = contactid + + if reduce == 1: # mindist + sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = contact_dist_in[contactid] + elif reduce == 2: # maxforce + contact_force = support.contact_force_fn( + opt_cone, + contact_frame_in, + contact_friction_in, + contact_dim_in, + contact_efc_address_in, + efc_force_in, + njmax_in, + nacon_in, + worldid, + contactid, + False, + ) + force_magnitude = ( + contact_force[0] * contact_force[0] + contact_force[1] * contact_force[1] + contact_force[2] * contact_force[2] + ) + sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = -force_magnitude - sensor_contact_matchid_out[worldid, contactsensorid, contactmatchid] = contactid - - if reduce == 1: # mindist - sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = contact_dist_in[contactid] - elif reduce == 2: # maxforce - contact_force = support.contact_force_fn( - opt_cone, - contact_frame_in, - contact_friction_in, - contact_dim_in, - contact_efc_address_in, - efc_force_in, - njmax_in, - nacon_in, - worldid, - contactid, - False, - ) - force_magnitude = ( - contact_force[0] * contact_force[0] + contact_force[1] * contact_force[1] + contact_force[2] * contact_force[2] - ) - sensor_contact_criteria_out[worldid, contactsensorid, contactmatchid] = -force_magnitude + # contact direction + sensor_contact_direction_out[worldid, contactsensorid, contactmatchid] = dir - # contact direction - sensor_contact_direction_out[worldid, contactsensorid, contactmatchid] = dir + return contact_match @cache_kernel @@ -2463,7 +2479,7 @@ def sensor_acc(m: Model, d: Data): sensor_contact_criteria.fill_(1.0e32) wp.launch( - _contact_match, + _contact_match(m.opt.warning_printf), dim=(m.sensor_contact_adr.size, d.naconmax), inputs=[ m.opt.cone, @@ -2493,7 +2509,14 @@ def sensor_acc(m: Model, d: Data): d.njmax, d.nacon, ], - outputs=[sensor_contact_nmatch, sensor_contact_matchid, sensor_contact_criteria, sensor_contact_direction], + outputs=[ + d.warning, + d.warning_info, + sensor_contact_nmatch, + sensor_contact_matchid, + sensor_contact_criteria, + sensor_contact_direction, + ], ) # sorting diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 92c6b726d..8867acadf 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -1985,11 +1985,9 @@ def _transmission( actuator_length_out[worldid, actid] = wp.dot(axis_angle, gearaxis) for i in range(3): actuator_moment_out[worldid, actid, vadr + i] = gearaxis[i] - elif jnt_typ == JointType.SLIDE or jnt_typ == JointType.HINGE: + else: # SLIDE or HINGE actuator_length_out[worldid, actid] = qpos[qadr] * gear[0] actuator_moment_out[worldid, actid, vadr] = gear[0] - else: - wp.printf("unrecognized joint type") elif trntype == TrnType.SLIDERCRANK: # get data trnid = actuator_trnid[actid] @@ -2194,8 +2192,6 @@ def _transmission( moment += wp.dot(jacrdif, wrench_rotation) actuator_moment_out[worldid, actid, i] = moment - else: - wp.printf("unhandled transmission type %d\n", trntype) @wp.kernel diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 66632d843..52c47f307 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -66,6 +66,31 @@ class BlockDim: qderiv_actuator_dense: int = 32 +class WarningType(enum.IntEnum): + """Warning types for kernel-side overflow detection. + + Attributes: + NEFC_OVERFLOW: constraint count exceeded njmax + BROADPHASE_OVERFLOW: broadphase collision count exceeded naconmax + NARROWPHASE_OVERFLOW: narrowphase contact count exceeded naconmax + CONTACT_MATCH_OVERFLOW: contact sensor match count exceeded maxmatch + GJK_ITERATIONS: GJK algorithm did not converge within iteration limit + EPA_HORIZON: EPA horizon buffer overflow + HFIELD_OVERFLOW: heightfield collision count exceeded MJ_MAXCONPAIR + """ + + NEFC_OVERFLOW = 0 + BROADPHASE_OVERFLOW = 1 + NARROWPHASE_OVERFLOW = 2 + CONTACT_MATCH_OVERFLOW = 3 + GJK_ITERATIONS = 4 + EPA_HORIZON = 5 + HFIELD_OVERFLOW = 6 + + +NUM_WARNINGS = len(WarningType) + + class BroadphaseType(enum.IntEnum): """Type of broadphase algorithm. @@ -690,6 +715,9 @@ class Option: zeros out the contacts at each step) contact_sensor_maxmatch: max number of contacts considered by contact sensor matching criteria contacts matched after this value is exceded will be ignored + warning_printf: if True, emit overflow warnings via printf to stdout in addition to setting + warning flags. Default True. Set to False and use check_warnings(d) for + programmatic access to warnings without stdout output. """ timestep: array("*", float) @@ -720,6 +748,7 @@ class Option: graph_conditional: bool run_collision_detection: bool contact_sensor_maxmatch: int + warning_printf: bool @dataclasses.dataclass @@ -1647,6 +1676,8 @@ class Data: collision_pairid: ids from broadphase (naconmax, 2) collision_worldid: collision world ids from broadphase (naconmax,) ncollision: collision count from broadphase (1,) + warning: warning flags (accumulated across steps) (NUM_WARNINGS,) + warning_info: warning info (suggested values) (NUM_WARNINGS, 2) """ solver_niter: array("nworld", int) @@ -1738,3 +1769,7 @@ class Data: collision_pairid: array("naconmax", wp.vec2i) collision_worldid: array("naconmax", int) ncollision: array(1, int) + + # warp only: warning flags (accumulated across steps, checked on host) + warning: array("NUM_WARNINGS", int) # flag per warning type + warning_info: array("NUM_WARNINGS", 2, int) # suggested values diff --git a/mujoco_warp/_src/types_test.py b/mujoco_warp/_src/types_test.py index 0d1306a3a..1f6d4a6b4 100644 --- a/mujoco_warp/_src/types_test.py +++ b/mujoco_warp/_src/types_test.py @@ -45,6 +45,8 @@ def test_field_order(self, mj_class, mjw_class): # TODO(team): remove this reordering after MjData._all_fields order is fixed # there's a bug in _all_fields where solver_niter is in the wrong place mj_fields.insert(0, mj_fields.pop(mj_fields.index("solver_niter"))) + # MuJoCo's "warning" field is different from our warp-only warning system + mj_fields.remove("warning") mj_set, mjw_set = set(mj_fields), set(mjw_fields) diff --git a/mujoco_warp/_src/warning.py b/mujoco_warp/_src/warning.py new file mode 100644 index 000000000..686f71e46 --- /dev/null +++ b/mujoco_warp/_src/warning.py @@ -0,0 +1,94 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Warning utilities for kernel-side overflow detection.""" + +import sys +from typing import List + +from . import types + +_WARNING_MESSAGES = { + types.WarningType.NEFC_OVERFLOW: "nefc overflow - increase njmax to {0}", + types.WarningType.BROADPHASE_OVERFLOW: ("broadphase overflow - increase nconmax to {0} or naconmax to {1}"), + types.WarningType.NARROWPHASE_OVERFLOW: ("narrowphase overflow - increase nconmax to {0} or naconmax to {1}"), + types.WarningType.CONTACT_MATCH_OVERFLOW: ("contact match overflow - increase Option.contact_sensor_maxmatch to {0}"), + types.WarningType.GJK_ITERATIONS: ("GJK did not converge - increase opt.ccd_iterations (currently {0})"), + types.WarningType.EPA_HORIZON: "EPA horizon overflow - horizon size {0} insufficient", + types.WarningType.HFIELD_OVERFLOW: ( + "heightfield collision overflow - decrease hfield rows/cols or modify colliding geom size (limit {0})" + ), +} + + +def check_warnings(d: types.Data, clear: bool = True) -> List[str]: + """Check warning flags and emit to stderr. + + This function reads the warning flags set by kernels and emits appropriate + warning messages to stderr. Warning flags accumulate across simulation steps + (using atomic_max), so this should be called after graph execution completes. + + Args: + d: The Data object containing warning flags. + clear: Whether to clear warning flags after checking. Default True. + + Returns: + List of warning message strings that were emitted. + """ + flags = d.warning.numpy() + info = d.warning_info.numpy() + + emitted = [] + for wtype in types.WarningType: + if flags[wtype]: + msg = _WARNING_MESSAGES[wtype].format(info[wtype, 0], info[wtype, 1]) + print(f"Warning: {msg}", file=sys.stderr) + emitted.append(msg) + + if clear: + d.warning.zero_() + d.warning_info.zero_() + + return emitted + + +def get_warnings(d: types.Data) -> List[str]: + """Get warning messages without emitting or clearing. + + Args: + d: The Data object containing warning flags. + + Returns: + List of warning message strings. + """ + flags = d.warning.numpy() + info = d.warning_info.numpy() + + messages = [] + for wtype in types.WarningType: + if flags[wtype]: + msg = _WARNING_MESSAGES[wtype].format(info[wtype, 0], info[wtype, 1]) + messages.append(msg) + + return messages + + +def clear_warnings(d: types.Data) -> None: + """Clear all warning flags. + + Args: + d: The Data object containing warning flags. + """ + d.warning.zero_() + d.warning_info.zero_() diff --git a/mujoco_warp/_src/warning_test.py b/mujoco_warp/_src/warning_test.py new file mode 100644 index 000000000..e9422442e --- /dev/null +++ b/mujoco_warp/_src/warning_test.py @@ -0,0 +1,356 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for warning system.""" + +import warnings + +import mujoco +import numpy as np +import warp as wp +from absl.testing import absltest + +import mujoco_warp as mjw +from mujoco_warp._src import types + + +class WarningTest(absltest.TestCase): + def test_warning_arrays_initialized(self): + """Tests that warning arrays are properly initialized.""" + mjm = mujoco.MjModel.from_xml_string(""" + + + + + + + + + """) + mjd = mujoco.MjData(mjm) + + d = mjw.put_data(mjm, mjd) + + # Check shapes + self.assertEqual(d.warning.shape, (types.NUM_WARNINGS,)) + self.assertEqual(d.warning_info.shape, (types.NUM_WARNINGS, 2)) + + # Check initial values are zero + np.testing.assert_array_equal(d.warning.numpy(), np.zeros(types.NUM_WARNINGS, dtype=np.int32)) + np.testing.assert_array_equal(d.warning_info.numpy(), np.zeros((types.NUM_WARNINGS, 2), dtype=np.int32)) + + def test_check_warnings_no_warnings(self): + """Tests check_warnings returns empty list when no warnings.""" + mjm = mujoco.MjModel.from_xml_string(""" + + + + + + + + + """) + mjd = mujoco.MjData(mjm) + + m = mjw.put_model(mjm) + d = mjw.put_data(mjm, mjd) + + # Run a step - should not trigger any warnings + mjw.step(m, d) + wp.synchronize() + + # Check warnings + result = mjw.get_warnings(d) + self.assertEqual(result, []) + + def test_nefc_overflow_warning(self): + """Tests that nefc overflow sets warning flag correctly.""" + # Sphere close to ground with large timestep - contacts quickly + mjm = mujoco.MjModel.from_xml_string(""" + + + """) + mjd = mujoco.MjData(mjm) + mujoco.mj_forward(mjm, mjd) + + # No initial contacts + self.assertEqual(mjd.ncon, 0, "Test setup: should have no initial contacts") + + m = mjw.put_model(mjm) + m.opt.warning_printf = False # disable printf in tests + # Set njmax very low to trigger overflow when sphere hits ground + d = mjw.put_data(mjm, mjd, njmax=1) + + # Run steps until sphere falls and creates contact (~10 steps at 0.01 timestep) + for _ in range(20): + mjw.step(m, d) + wp.synchronize() + + # Check warning flag is set + warning_flags = d.warning.numpy() + self.assertEqual(warning_flags[types.WarningType.NEFC_OVERFLOW], 1) + + # Check warning info contains suggested value (should be 4 for a single contact with friction) + warning_info = d.warning_info.numpy() + self.assertGreater(warning_info[types.WarningType.NEFC_OVERFLOW, 0], 1) + + # Check get_warnings returns the message + result = mjw.get_warnings(d) + self.assertEqual(len(result), 1) + self.assertIn("nefc overflow", result[0]) + + def test_check_warnings_clears_flags(self): + """Tests that check_warnings clears flags by default.""" + # Sphere close to ground with large timestep + mjm = mujoco.MjModel.from_xml_string(""" + + + """) + mjd = mujoco.MjData(mjm) + mujoco.mj_forward(mjm, mjd) + + m = mjw.put_model(mjm) + m.opt.warning_printf = False # disable printf in tests + d = mjw.put_data(mjm, mjd, njmax=1) + + # Run until contact + for _ in range(20): + mjw.step(m, d) + wp.synchronize() + + # First check_warnings should return warnings and clear + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + result1 = mjw.check_warnings(d, clear=True) + + # Second call should return empty (flags cleared) + result2 = mjw.get_warnings(d) + + self.assertGreater(len(result1), 0) + self.assertEqual(len(result2), 0) + + def test_clear_warnings(self): + """Tests clear_warnings utility.""" + # Sphere close to ground with large timestep + mjm = mujoco.MjModel.from_xml_string(""" + + + """) + mjd = mujoco.MjData(mjm) + mujoco.mj_forward(mjm, mjd) + + m = mjw.put_model(mjm) + m.opt.warning_printf = False # disable printf in tests + d = mjw.put_data(mjm, mjd, njmax=1) + + # Run until contact + for _ in range(20): + mjw.step(m, d) + wp.synchronize() + + # Verify warning is set + self.assertGreater(len(mjw.get_warnings(d)), 0) + + # Clear warnings + mjw.clear_warnings(d) + + # Verify cleared + self.assertEqual(len(mjw.get_warnings(d)), 0) + np.testing.assert_array_equal(d.warning.numpy(), np.zeros(types.NUM_WARNINGS, dtype=np.int32)) + + @absltest.skipIf(not wp.get_device().is_cuda, "Skipping test that requires GPU.") + def test_multi_step_graph_captures_mid_graph_warning(self): + """Tests that a multi-step graph captures warnings even if they occur mid-graph. + + Captures 20 steps in one graph. The sphere starts close to ground and hits it + around step 5-10. The warning from that step should be captured and readable + after the graph completes. + """ + # Sphere close to ground - will contact around step 5-10 + mjm = mujoco.MjModel.from_xml_string(""" + + + """) + mjd = mujoco.MjData(mjm) + mujoco.mj_forward(mjm, mjd) + + # No initial contact + self.assertEqual(mjd.ncon, 0, "Test setup: should have no initial contacts") + + m = mjw.put_model(mjm) + m.opt.warning_printf = False # disable printf in tests + d = mjw.put_data(mjm, mjd, njmax=1) + + # Clear warnings + mjw.clear_warnings(d) + + # Capture 20 steps as a single graph - warning should occur around step 5-10 + nsteps = 20 + with wp.ScopedCapture() as capture: + for _ in range(nsteps): + mjw.step(m, d) + graph = capture.graph + + # Run graph once - warning happens mid-graph + wp.capture_launch(graph) + wp.synchronize() + + # Check that warning was captured from the step where contact happened + warning_flags = d.warning.numpy() + self.assertEqual( + warning_flags[types.WarningType.NEFC_OVERFLOW], 1, f"Expected nefc overflow warning, got flags: {warning_flags}" + ) + + # Verify warning info shows the correct suggested value + warning_info = d.warning_info.numpy() + self.assertEqual( + warning_info[types.WarningType.NEFC_OVERFLOW, 0], + 4, + f"Expected nefc=4 in warning info, got: {warning_info[types.WarningType.NEFC_OVERFLOW]}", + ) + + # Get warnings - should include the overflow message + result = mjw.get_warnings(d) + self.assertEqual(len(result), 1) + self.assertIn("nefc overflow", result[0]) + self.assertIn("4", result[0]) + + @absltest.skipIf(not wp.get_device().is_cuda, "Skipping test that requires GPU.") + def test_single_step_graph_warns_only_when_event_occurs(self): + """Tests that a single-step graph only reports warning on the step it happens. + + Captures 1 step as a graph. Runs it multiple times. Verifies that: + 1. Before contact: no warnings generated + 2. After contact: warnings generated + 3. After clearing and running more: warnings still generated (overflow persists) + """ + # Sphere close to ground - will contact around launch 5-10 + mjm = mujoco.MjModel.from_xml_string(""" + + + """) + mjd = mujoco.MjData(mjm) + mujoco.mj_forward(mjm, mjd) + + m = mjw.put_model(mjm) + m.opt.warning_printf = False # disable printf in tests + d = mjw.put_data(mjm, mjd, njmax=1) + + # Clear warnings + mjw.clear_warnings(d) + + # Capture single step + with wp.ScopedCapture() as capture: + mjw.step(m, d) + graph = capture.graph + + # Run a few launches before contact happens (sphere is falling) + for _ in range(3): + wp.capture_launch(graph) + wp.synchronize() + + # Check: no warnings yet (sphere hasn't hit ground) + warning_flags_before = d.warning.numpy().copy() + self.assertEqual(warning_flags_before[types.WarningType.NEFC_OVERFLOW], 0, "Should have no warning before contact") + + # Clear and run more launches until contact happens + mjw.clear_warnings(d) + for _ in range(15): + wp.capture_launch(graph) + wp.synchronize() + + # Check: warning should now be set (contact occurred) + warning_flags_after = d.warning.numpy() + self.assertEqual(warning_flags_after[types.WarningType.NEFC_OVERFLOW], 1, "Should have warning after contact") + + # Clear and run more - warning should still be generated (overflow persists each step) + mjw.clear_warnings(d) + for _ in range(5): + wp.capture_launch(graph) + wp.synchronize() + + warning_flags_final = d.warning.numpy() + self.assertEqual( + warning_flags_final[types.WarningType.NEFC_OVERFLOW], 1, "Warning should still be generated (overflow persists)" + ) + + # Check warning message + result = mjw.get_warnings(d) + self.assertEqual(len(result), 1) + self.assertIn("nefc overflow", result[0]) + + def test_warning_printf_option(self): + """Tests that warning_printf option is available and defaults to True.""" + mjm = mujoco.MjModel.from_xml_string(""" + + + + + + + + + """) + + m = mjw.put_model(mjm) + + # Default should be True (printf enabled) + self.assertTrue(m.opt.warning_printf) + + +if __name__ == "__main__": + wp.init() + absltest.main() diff --git a/mujoco_warp/testspeed.py b/mujoco_warp/testspeed.py index a3d03c536..3d799b7b4 100644 --- a/mujoco_warp/testspeed.py +++ b/mujoco_warp/testspeed.py @@ -286,6 +286,7 @@ def _main(argv: Sequence[str]): with wp.ScopedDevice(_DEVICE.value): override_model(mjm, _OVERRIDE.value) m = mjw.put_model(mjm) + m.opt.warning_printf = False # use check_warnings instead override_model(m, _OVERRIDE.value) d = mjw.put_data(mjm, mjd, nworld=_NWORLD.value, nconmax=_NCONMAX.value, njmax=_NJMAX.value) if _FORMAT.value == "human": diff --git a/mujoco_warp/viewer.py b/mujoco_warp/viewer.py index 0a66ed3c6..f23d4f66f 100644 --- a/mujoco_warp/viewer.py +++ b/mujoco_warp/viewer.py @@ -43,6 +43,7 @@ from mujoco_warp._src.io import find_keys from mujoco_warp._src.io import make_trajectory from mujoco_warp._src.io import override_model +from mujoco_warp._src.warning import check_warnings class EngineOptions(enum.IntEnum): @@ -147,6 +148,7 @@ def _main(argv: Sequence[str]) -> None: with wp.ScopedDevice(_DEVICE.value): override_model(mjm, _OVERRIDE.value) m = mjw.put_model(mjm) + m.opt.warning_printf = False # use check_warnings instead override_model(m, _OVERRIDE.value) broadphase, filter = mjw.BroadphaseType(m.opt.broadphase).name, mjw.BroadphaseFilter(m.opt.broadphase_filter).name solver, cone = mjw.SolverType(m.opt.solver).name, mjw.ConeType(m.opt.cone).name @@ -189,15 +191,18 @@ def _main(argv: Sequence[str]) -> None: if mjm.opt != opt: opt = copy.copy(mjm.opt) m = mjw.put_model(mjm) + m.opt.warning_printf = False # use check_warnings instead graph = _compile_step(m, d) if _VIEWER_GLOBAL_STATE["running"]: wp.capture_launch(graph) wp.synchronize() + check_warnings(d) elif _VIEWER_GLOBAL_STATE["step_once"]: _VIEWER_GLOBAL_STATE["step_once"] = False wp.capture_launch(graph) wp.synchronize() + check_warnings(d) mjw.get_data_into(mjd, mjm, d)