diff --git a/src/mjlab/sim/sim.py b/src/mjlab/sim/sim.py index d1a8f9a59..5a300d8c1 100644 --- a/src/mjlab/sim/sim.py +++ b/src/mjlab/sim/sim.py @@ -18,6 +18,9 @@ ModelBridge = WarpBridge DataBridge = WarpBridge +# mujoco_warp uses 'wp.capture_while' which strictly requires Driver 12.4+ +_MIN_DRIVER_FOR_CONDITIONAL_GRAPHS = 12.4 + _JACOBIAN_MAP = { "auto": mujoco.mjtJacobian.mjJAC_AUTO, "dense": mujoco.mjtJacobian.mjJAC_DENSE, @@ -128,9 +131,18 @@ def __init__( self._model_bridge = WarpBridge(self._wp_model, nworld=self.num_envs) self._data_bridge = WarpBridge(self._wp_data) - self.use_cuda_graph = self.wp_device.is_cuda and wp.is_mempool_enabled( - self.wp_device + driver_ver = wp.context.runtime.driver_version + driver_ver = float(f"{driver_ver[0]}.{driver_ver[1]}") + self.use_cuda_graph = ( + self.wp_device.is_cuda + and wp.is_mempool_enabled(self.wp_device) + and driver_ver >= _MIN_DRIVER_FOR_CONDITIONAL_GRAPHS ) + + if not self.use_cuda_graph: + print(f"[WARNING] Disabling CUDA Graphs. Current Driver {driver_ver} < 12.4.") + print(" mujoco_warp solver requires 12.4+ for graph loops.") + self.create_graph() self.nan_guard = NanGuard(cfg.nan_guard, self.num_envs, self._mj_model)