From 38b6f20f6cdad3b535b3e9d0955df303465bb260 Mon Sep 17 00:00:00 2001 From: sergiacosta Date: Fri, 12 Dec 2025 13:23:44 +0100 Subject: [PATCH 1/6] added logic to create graph depending on driver cuda version --- src/mjlab/sim/sim.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/mjlab/sim/sim.py b/src/mjlab/sim/sim.py index d1a8f9a59..1241e666b 100644 --- a/src/mjlab/sim/sim.py +++ b/src/mjlab/sim/sim.py @@ -127,10 +127,16 @@ def __init__( self._model_bridge = WarpBridge(self._wp_model, nworld=self.num_envs) self._data_bridge = WarpBridge(self._wp_data) + + import ctypes + libcuda = ctypes.CDLL("libcuda.so") + version = ctypes.c_int() + libcuda.cuDriverGetVersion(ctypes.byref(version)) + driver_cuda_version = float(f"{version.value // 1000}.{(version.value % 1000) // 10}") self.use_cuda_graph = self.wp_device.is_cuda and wp.is_mempool_enabled( self.wp_device - ) + ) if driver_cuda_version >= 12.4 else False self.create_graph() self.nan_guard = NanGuard(cfg.nan_guard, self.num_envs, self._mj_model) From 81a746787adaf46139eb0dacb837a3529a6858df Mon Sep 17 00:00:00 2001 From: SergiMuac Date: Fri, 12 Dec 2025 13:51:29 +0100 Subject: [PATCH 2/6] code formated --- src/mjlab/sim/sim.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/mjlab/sim/sim.py b/src/mjlab/sim/sim.py index 1241e666b..968a0f0eb 100644 --- a/src/mjlab/sim/sim.py +++ b/src/mjlab/sim/sim.py @@ -127,16 +127,21 @@ def __init__( self._model_bridge = WarpBridge(self._wp_model, nworld=self.num_envs) self._data_bridge = WarpBridge(self._wp_data) - + import ctypes + libcuda = ctypes.CDLL("libcuda.so") version = ctypes.c_int() libcuda.cuDriverGetVersion(ctypes.byref(version)) - driver_cuda_version = float(f"{version.value // 1000}.{(version.value % 1000) // 10}") - - self.use_cuda_graph = self.wp_device.is_cuda and wp.is_mempool_enabled( - self.wp_device - ) if driver_cuda_version >= 12.4 else False + driver_cuda_version = float( + f"{version.value // 1000}.{(version.value % 1000) // 10}" + ) + + self.use_cuda_graph = ( + self.wp_device.is_cuda and wp.is_mempool_enabled(self.wp_device) + if driver_cuda_version >= 12.4 + else False + ) self.create_graph() self.nan_guard = NanGuard(cfg.nan_guard, self.num_envs, self._mj_model) From 201e009e7a6e6515fdb18bc9762b0ca52f849570 Mon Sep 17 00:00:00 2001 From: SergiMuac Date: Fri, 12 Dec 2025 14:58:19 +0100 Subject: [PATCH 3/6] added try/except logic to handle if cuda driver is not installed --- src/mjlab/sim/sim.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/mjlab/sim/sim.py b/src/mjlab/sim/sim.py index 968a0f0eb..50c832a9f 100644 --- a/src/mjlab/sim/sim.py +++ b/src/mjlab/sim/sim.py @@ -130,18 +130,28 @@ def __init__( import ctypes - libcuda = ctypes.CDLL("libcuda.so") - version = ctypes.c_int() - libcuda.cuDriverGetVersion(ctypes.byref(version)) - driver_cuda_version = float( - f"{version.value // 1000}.{(version.value % 1000) // 10}" - ) - - self.use_cuda_graph = ( - self.wp_device.is_cuda and wp.is_mempool_enabled(self.wp_device) - if driver_cuda_version >= 12.4 - else False - ) + try: + libcuda = ctypes.CDLL("libcuda.so") + except OSError: + print("[ERROR] Unable to find libcuda.so.") + libcuda = None + + if libcuda is not None: + version = ctypes.c_int() + libcuda.cuDriverGetVersion(ctypes.byref(version)) + driver_cuda_version = float( + f"{version.value // 1000}.{(version.value % 1000) // 10}" + ) + + self.use_cuda_graph = ( + self.wp_device.is_cuda and wp.is_mempool_enabled(self.wp_device) + if driver_cuda_version >= 12.4 + else False + ) + else: + print("[WARNING] CUDA driver not available, disabling CUDA graphs.") + self.use_cuda_graph = False + self.create_graph() self.nan_guard = NanGuard(cfg.nan_guard, self.num_envs, self._mj_model) From b93de62e9daa737991e7ec7be396b1f4409aa523 Mon Sep 17 00:00:00 2001 From: SergiMuac Date: Fri, 19 Dec 2025 14:05:18 +0100 Subject: [PATCH 4/6] improved cuda version checking --- src/mjlab/sim/sim.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/src/mjlab/sim/sim.py b/src/mjlab/sim/sim.py index 50c832a9f..3f25d1f0e 100644 --- a/src/mjlab/sim/sim.py +++ b/src/mjlab/sim/sim.py @@ -18,6 +18,9 @@ ModelBridge = WarpBridge DataBridge = WarpBridge +# mujoco_warp uses 'wp.capture_while' which strictly requires Driver 12.4+ +_MIN_DRIVER_FOR_CONDITIONAL_GRAPHS = 12.4 + _JACOBIAN_MAP = { "auto": mujoco.mjtJacobian.mjJAC_AUTO, "dense": mujoco.mjtJacobian.mjJAC_DENSE, @@ -128,26 +131,14 @@ def __init__( self._model_bridge = WarpBridge(self._wp_model, nworld=self.num_envs) self._data_bridge = WarpBridge(self._wp_data) - import ctypes - - try: - libcuda = ctypes.CDLL("libcuda.so") - except OSError: - print("[ERROR] Unable to find libcuda.so.") - libcuda = None + driver_ver = wp.context.runtime.driver_version + driver_ver = float(f"{driver_ver[0]}.{driver_ver[1]}") + self.use_cuda_graph = ( + self.wp_device.is_cuda + and wp.is_mempool_enabled(self.wp_device) + and driver_ver >= _MIN_DRIVER_FOR_CONDITIONAL_GRAPHS + ) - if libcuda is not None: - version = ctypes.c_int() - libcuda.cuDriverGetVersion(ctypes.byref(version)) - driver_cuda_version = float( - f"{version.value // 1000}.{(version.value % 1000) // 10}" - ) - - self.use_cuda_graph = ( - self.wp_device.is_cuda and wp.is_mempool_enabled(self.wp_device) - if driver_cuda_version >= 12.4 - else False - ) else: print("[WARNING] CUDA driver not available, disabling CUDA graphs.") self.use_cuda_graph = False From 89aa3eb6983b39882d0c73824b01a190e33d69c0 Mon Sep 17 00:00:00 2001 From: SergiMuac Date: Fri, 19 Dec 2025 14:05:37 +0100 Subject: [PATCH 5/6] added kernel modules warmup --- src/mjlab/sim/sim.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mjlab/sim/sim.py b/src/mjlab/sim/sim.py index 3f25d1f0e..0d957269b 100644 --- a/src/mjlab/sim/sim.py +++ b/src/mjlab/sim/sim.py @@ -139,9 +139,13 @@ def __init__( and driver_ver >= _MIN_DRIVER_FOR_CONDITIONAL_GRAPHS ) + if self.use_cuda_graph: + print("Warming up CUDA kernels...") + mjwarp.step(self.wp_model, self.wp_data) + wp.synchronize() else: - print("[WARNING] CUDA driver not available, disabling CUDA graphs.") - self.use_cuda_graph = False + print(f"[WARNING] Disabling CUDA Graphs. Current Driver {driver_ver} < 12.4.") + print(" mujoco_warp solver requires 12.4+ for graph loops.") self.create_graph() From aafd1e68c721e07931dc189bb958f4cc658916bd Mon Sep 17 00:00:00 2001 From: SergiMuac Date: Mon, 22 Dec 2025 11:58:49 +0100 Subject: [PATCH 6/6] removed warm up step --- src/mjlab/sim/sim.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/mjlab/sim/sim.py b/src/mjlab/sim/sim.py index 0d957269b..5a300d8c1 100644 --- a/src/mjlab/sim/sim.py +++ b/src/mjlab/sim/sim.py @@ -139,11 +139,7 @@ def __init__( and driver_ver >= _MIN_DRIVER_FOR_CONDITIONAL_GRAPHS ) - if self.use_cuda_graph: - print("Warming up CUDA kernels...") - mjwarp.step(self.wp_model, self.wp_data) - wp.synchronize() - else: + if not self.use_cuda_graph: print(f"[WARNING] Disabling CUDA Graphs. Current Driver {driver_ver} < 12.4.") print(" mujoco_warp solver requires 12.4+ for graph loops.")