Skip to content
Closed
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/mjlab/sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
ModelBridge = WarpBridge
DataBridge = WarpBridge

# mujoco_warp uses 'wp.capture_while' which strictly requires Driver 12.4+
_MIN_DRIVER_FOR_CONDITIONAL_GRAPHS = 12.4

_JACOBIAN_MAP = {
"auto": mujoco.mjtJacobian.mjJAC_AUTO,
"dense": mujoco.mjtJacobian.mjJAC_DENSE,
Expand Down Expand Up @@ -128,9 +131,22 @@ def __init__(
self._model_bridge = WarpBridge(self._wp_model, nworld=self.num_envs)
self._data_bridge = WarpBridge(self._wp_data)

self.use_cuda_graph = self.wp_device.is_cuda and wp.is_mempool_enabled(
self.wp_device
driver_ver = wp.context.runtime.driver_version
driver_ver = float(f"{driver_ver[0]}.{driver_ver[1]}")
self.use_cuda_graph = (
self.wp_device.is_cuda
and wp.is_mempool_enabled(self.wp_device)
and driver_ver >= _MIN_DRIVER_FOR_CONDITIONAL_GRAPHS
)

if self.use_cuda_graph:
print("Warming up CUDA kernels...")
mjwarp.step(self.wp_model, self.wp_data)
wp.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot confirm that this step is unnecessary, as no dedicated setup is available to validate the scenario. In theory, if the CUDA driver runtime accepts CUDA graph capture but does not support lazy module loading, preloading all modules before enabling graph capture would prevent a potential crash; however, it is unclear whether any released CUDA driver versions actually exhibit this behavior.
Empirically, the CUDA versions that satisfy the initial runtime check also appear to support lazy module loading, suggesting that this potential failure mode is already implicitly covered.
Tests performed on the available machines indicate that the system functions correctly without the warm-up step; therefore, it can be removed, with the understanding that a separate pull request can be opened in the future if needed. The changes will be pushed accordingly.

else:
print(f"[WARNING] Disabling CUDA Graphs. Current Driver {driver_ver} < 12.4.")
print(" mujoco_warp solver requires 12.4+ for graph loops.")

self.create_graph()

self.nan_guard = NanGuard(cfg.nan_guard, self.num_envs, self._mj_model)
Expand Down