diff --git a/.gitignore b/.gitignore
index 10f31fb1..60db3b91 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,7 @@
*.egg-info/
env/
.history
+.vscode
# Python byte-compiled / optimized / DLL files
__pycache__/
diff --git a/AUTHORS b/AUTHORS
new file mode 100644
index 00000000..aff29ad2
--- /dev/null
+++ b/AUTHORS
@@ -0,0 +1,10 @@
+# This is the list of significant contributors to mjWarp.
+#
+# This does not necessarily list everyone who has contributed code,
+# especially since many employees of one corporation may be contributing.
+# To see the full list of contributors, see the revision history in
+# source control.
+
+Google LLC
+NVIDIA Corporation
+
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..e85d9451
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,25 @@
+# How to Contribute
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution,
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+## Community Guidelines
+
+This project follows [Google's Open Source Community
+Guidelines](https://opensource.google/conduct/).
diff --git a/README.md b/README.md
index b7b4f65f..5345dced 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-# mjWarp
+# MJWarp
-MuJoCo implemented in Warp.
+MuJoCo implemented in NVIDIA Warp.
# Installing for development
@@ -32,47 +32,62 @@ Should print out something like `XX passed in XX.XXs` at the end!
Benchmark as follows:
```bash
-mjx-testspeed --function=forward --is_sparse=True --mjcf=humanoid/humanoid.xml --batch_size=8192
+mjwarp-testspeed --function=step --mjcf=humanoid/humanoid.xml --batch_size=8192
```
-Some relevant benchmarks on an NVIDIA GeForce RTX 4090:
+To get a full trace of the physics steps (e.g. timings of the subcomponents) run the following:
-## forward steps / sec (smooth dynamics only)
-
-27 dofs per humanoid, 8k batch size.
-
-| Num Humanoids | MJX | mjWarp dense | mjWarp sparse |
-| ----------------| -------| ------------- | -------------- |
-| 1 | 7.9M | 15.6M | 13.7M |
-| 2 | 2.6M | 7.4M | 7.8M |
-| 3 | 2.2M | 4.6M | 5.3M |
-| 4 | 1.5M | 3.3M | 4.1M |
-| 5 | 1.1M | ❌ | 3.2M |
-
-# Ideas for what to try next
-
-## 1. Unroll steps
-
-In the Pure JAX benchmark, we can tell JAX to unroll some number of FK steps (in the benchmarks above, `unroll=4`). This has a big impact on performance. If we change `unroll` from 4 to 1, pure JAX performance at 8k batch drops from 50M to 33M steps/sec.
-
-Is there some way that we can improve Warp performance in the same way? If I know ahead of time that I am going to call FK in a loop 1000 times, can I somehow inject unroll primitives?
-
-## 2. Different levels of parallelism
-
-The current approach parallelizes over body kinematic tree depth. We could go either direction: remove body parallism (fewer kernel launches), or parallelize over joints instead (more launches, more parallelism).
-
-## 3. Tiling
+```bash
+mjwarp-testspeed --function=step --mjcf=humanoid/humanoid.xml --batch_size=8192 --event_trace=True
+```
-It looks like a thing! Should we use it? Will it help?
+`humanoid.xml` has been carefully optimized for MJX in the following ways:
-## 4. Quaternions
+* Newton solver iterations are capped at 1, linesearch iterations capped at 4
+* Only foot<>floor collisions are turned on, producing at most 8 contact points
+* Adding a damping term in the Euler integrator (which invokes another `factor_m` and `solve_m`) is disabled
-Why oh why did Warp make quaternions x,y,z,w? In order to be obstinate I wrote my own quaternion math. Is this slower than using the Warp quaternion primitives?
+By comparing MJWarp to MJX on this model, we are comparing MJWarp to the very best that MJX can do.
-## 5. `wp.static`
+For many (most) MuJoCo models, particularly ones that haven't been carefully tuned, MJX will
+do much worse.
-Haven't tried this at all - curious to see if it helps.
+## physics steps / sec
-## 6. Other stuff?
+NVIDIA GeForce RTX 4090, 27 dofs, ncon=8, 8k batch size.
-Should I be playing with `block_dim`? Is my method for timing OK or did I misunderstand how `wp.synchronize` works? Is there something about allocating that I should be aware of? What am I not thinking of?
+```
+Summary for 8192 parallel rollouts
+
+ Total JIT time: 0.82 s
+ Total simulation time: 2.98 s
+ Total steps per second: 2,753,173
+ Total realtime factor: 13,765.87 x
+ Total time per step: 363.22 ns
+
+Event trace:
+
+step: 361.41 (MJX: 316.58 ns)
+ forward: 359.15
+ fwd_position: 52.58
+ kinematics: 19.36 (MJX: 16.45 ns)
+ com_pos: 7.80 (MJX: 12.37 ns)
+ crb: 12.44 (MJX: 27.91 ns)
+ factor_m: 6.40 (MJX: 27.48 ns)
+ collision: 4.07 (MJX: 1.23 ns)
+ make_constraint: 6.32 (MJX: 42.39 ns)
+ transmission: 1.30 (MJX: 3.54 ns)
+ fwd_velocity: 26.52
+ com_vel: 8.44 (MJX: 9.38 ns)
+ passive: 1.06 (MJX: 3.22 ns)
+ rne: 10.96 (MJX: 16.75 ns)
+ fwd_actuation: 2.74 (MJX: 3.93 ns)
+ fwd_acceleration: 11.90
+ xfrc_accumulate: 3.83 (MJX: 6.81 ns)
+ solve_m: 6.92 (MJX: 8.88 ns)
+ solve: 264.38 (MJX: 153.29 ns)
+ mul_m: 5.93
+ _linesearch_iterative: 43.15
+ mul_m: 3.66
+ euler: 1.74 (MJX: 3.78 ns)
+```
diff --git a/contrib/README.md b/contrib/README.md
new file mode 100644
index 00000000..9f273766
--- /dev/null
+++ b/contrib/README.md
@@ -0,0 +1,5 @@
+# Contrib
+
+Contrib is a home for experiments, helper scripts and other tools that are not officially part of MJWarp.
+
+The contents of this directory are subject to change with no notice.
\ No newline at end of file
diff --git a/contrib/apptronik_apollo_locomotion.ipynb b/contrib/apptronik_apollo_locomotion.ipynb
new file mode 100644
index 00000000..79b16f4f
--- /dev/null
+++ b/contrib/apptronik_apollo_locomotion.ipynb
@@ -0,0 +1,791 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "j2cz906V7d0X"
+ },
+ "source": [
+ "# Training Apptronik Apollo using MuJoCo Warp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XpAzX-oCJcSJ"
+ },
+ "outputs": [],
+ "source": [
+ "# you may need some extra deps for this colab:\n",
+ "# pip install \"jax[cuda12_local]\"\n",
+ "# pip install playground\n",
+ "# pip install matplotlib\n",
+ "# if you want to run this on your local machine you can do like so:\n",
+ "# pip install jupyter\n",
+ "# jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8888 --NotebookApp.port_retries=0 --no-browser\n",
+ "\n",
+ "import dataclasses\n",
+ "import datetime\n",
+ "import functools\n",
+ "import os\n",
+ "import time\n",
+ "from typing import Any, Dict, Optional, Union\n",
+ "\n",
+ "import jax\n",
+ "import mediapy as media\n",
+ "import mujoco\n",
+ "import numpy as np\n",
+ "import warp as wp\n",
+ "from etils import epath\n",
+ "from jax import numpy as jp\n",
+ "from ml_collections import config_dict\n",
+ "from mujoco import mjx\n",
+ "from mujoco_playground._src import mjx_env\n",
+ "from mujoco_playground._src import reward\n",
+ "from mujoco_playground._src.dm_control_suite import common\n",
+ "from warp.jax_experimental.ffi import jax_callable\n",
+ "\n",
+ "import mujoco_warp as mjwarp\n",
+ "from mujoco_warp._src.warp_util import kernel_copy\n",
+ "\n",
+ "# this ensures JAX embeds Warp kernels into its own computation graph:\n",
+ "os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_graph_min_graph_size=1\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-BVYPgjk7xcm"
+ },
+ "outputs": [],
+ "source": [
+ "# We'll grab the Apptronik model from MuJoCo Menagerie, then remove some\n",
+ "# MJX-specific changes that exist in the XML that MJWarp doesn't need\n",
+ "# (such as explicit contacts, really tight ls_iterations etc.)\n",
+ "\n",
+ "mjx_env.ensure_menagerie_exists()\n",
+ "\n",
+ "contrib_xml_dir = epath.resource_path(\"mujoco_warp\").parent / \"contrib/xml\"\n",
+ "apptronik_dir = mjx_env.EXTERNAL_DEPS_PATH / \"mujoco_menagerie/apptronik_apollo/\"\n",
+ "\n",
+ "! cp {contrib_xml_dir / 'apptronik_apollo.xml'} {apptronik_dir}\n",
+ "! cp {contrib_xml_dir / 'scene.xml'} {apptronik_dir}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ixAOK05EkQNg"
+ },
+ "outputs": [],
+ "source": [
+ "# MJWarp is not yet fully conncected to JAX. For now we will use the function\n",
+ "# `mjwarp_step` below, which we wrap in Warp's handy `jax_callable` function\n",
+ "# that converts Warp kernels to JAX operations.\n",
+ "#\n",
+ "# After we build a proper JAX wrapper for MJWarp, this code will disappear.\n",
+ "\n",
+ "NWORLD = 8192\n",
+ "NCONMAX = 81920\n",
+ "NJMAX = NCONMAX * 4\n",
+ "\n",
+ "xml_path = apptronik_dir / \"scene.xml\"\n",
+ "mjm = mujoco.MjModel.from_xml_path(xml_path.as_posix())\n",
+ "mjm.opt.iterations = 5\n",
+ "mjm.opt.ls_iterations = 10\n",
+ "mjd = mujoco.MjData(mjm)\n",
+ "mujoco.mj_resetDataKeyframe(mjm, mjd, 0)\n",
+ "mujoco.mj_forward(mjm, mjd)\n",
+ "m = mjwarp.put_model(mjm)\n",
+ "d = mjwarp.put_data(mjm, mjd, nworld=NWORLD, nconmax=NCONMAX, njmax=NJMAX)\n",
+ "\n",
+ "\n",
+ "def mjwarp_step(\n",
+ " ctrl: wp.array(dtype=wp.float32, ndim=2),\n",
+ " qpos_in: wp.array(dtype=wp.float32, ndim=2),\n",
+ " qvel_in: wp.array(dtype=wp.float32, ndim=2),\n",
+ " qacc_warmstart_in: wp.array(dtype=wp.float32, ndim=2),\n",
+ " qpos_out: wp.array(dtype=wp.float32, ndim=2),\n",
+ " qvel_out: wp.array(dtype=wp.float32, ndim=2),\n",
+ " xpos_out: wp.array(dtype=wp.vec3, ndim=2),\n",
+ " xmat_out: wp.array(dtype=wp.mat33, ndim=2),\n",
+ " qacc_warmstart_out: wp.array(dtype=wp.float32, ndim=2),\n",
+ " subtree_com_out: wp.array(dtype=wp.vec3, ndim=2),\n",
+ " cvel_out: wp.array(dtype=wp.spatial_vector, ndim=2),\n",
+ " site_xpos_out: wp.array(dtype=wp.vec3, ndim=2),\n",
+ "):\n",
+ " kernel_copy(d.ctrl, ctrl)\n",
+ " kernel_copy(d.qpos, qpos_in)\n",
+ " kernel_copy(d.qvel, qvel_in)\n",
+ " kernel_copy(d.qacc_warmstart, qacc_warmstart_in)\n",
+ "\n",
+ " # TODO(team): remove this hard coding substeps\n",
+ " # ctrl_dt / sim_dt == 4\n",
+ " for i in range(4):\n",
+ " mjwarp.step(m, d)\n",
+ " kernel_copy(qpos_out, d.qpos)\n",
+ " kernel_copy(qvel_out, d.qvel)\n",
+ " kernel_copy(xpos_out, d.xpos)\n",
+ " kernel_copy(xmat_out, d.xmat)\n",
+ " kernel_copy(qacc_warmstart_out, d.qacc_warmstart)\n",
+ " kernel_copy(subtree_com_out, d.subtree_com)\n",
+ " kernel_copy(cvel_out, d.cvel)\n",
+ " kernel_copy(site_xpos_out, d.site_xpos)\n",
+ "\n",
+ "\n",
+ "jax_mjwarp_step = jax_callable(\n",
+ " mjwarp_step,\n",
+ " num_outputs=8,\n",
+ " output_dims={\n",
+ " \"qpos_out\": (NWORLD, mjm.nq),\n",
+ " \"qvel_out\": (NWORLD, mjm.nv),\n",
+ " \"xpos_out\": (NWORLD, mjm.nbody, 3),\n",
+ " \"xmat_out\": (NWORLD, mjm.nbody, 3, 3),\n",
+ " \"qacc_warmstart_out\": (NWORLD, mjm.nv),\n",
+ " \"subtree_com_out\": (NWORLD, mjm.nbody, 3),\n",
+ " \"cvel_out\": (NWORLD, mjm.nbody, 6),\n",
+ " \"site_xpos_out\": (NWORLD, mjm.nsite, 3),\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "# the functions below allow us to call MJWarp step inside jax vmap:\n",
+ "\n",
+ "\n",
+ "@jax.custom_batching.custom_vmap\n",
+ "def step(d: mjx.Data):\n",
+ " return d\n",
+ "\n",
+ "\n",
+ "@step.def_vmap\n",
+ "def step_vmap_rule(axis_size, in_batched, d: mjx.Data):\n",
+ " if in_batched[0].ctrl:\n",
+ " assert d.ctrl.shape[0] == axis_size\n",
+ " else:\n",
+ " d = d.replace(ctrl=jp.tile(d.ctrl, (axis_size, 1)))\n",
+ " params = {f.name: None for f in dataclasses.fields(mjx.Data)}\n",
+ " params[\"ctrl\"] = True\n",
+ " params[\"qpos\"] = True\n",
+ " params[\"qvel\"] = True\n",
+ " params[\"xpos\"] = True\n",
+ " params[\"xmat\"] = True\n",
+ " params[\"qacc_warmstart\"] = True\n",
+ " params[\"subtree_com\"] = True\n",
+ " params[\"cvel\"] = True\n",
+ " params[\"site_xpos\"] = True\n",
+ " out_batched = mjx.Data(**params)\n",
+ "\n",
+ " qpos, qvel, xpos, xmat, qacc_warmstart, subtree_com, cvel, site_xpos = (\n",
+ " jax_mjwarp_step(d.ctrl, d.qpos, d.qvel, d.qacc_warmstart)\n",
+ " )\n",
+ " d = d.replace(\n",
+ " qpos=qpos,\n",
+ " qvel=qvel,\n",
+ " xpos=xpos,\n",
+ " xmat=xmat,\n",
+ " qacc_warmstart=qacc_warmstart,\n",
+ " subtree_com=subtree_com,\n",
+ " cvel=cvel,\n",
+ " site_xpos=site_xpos,\n",
+ " )\n",
+ " return d, out_batched\n",
+ "\n",
+ "\n",
+ "def init(qpos, ctrl) -> mjx.Data:\n",
+ " init_params = {f.name: None for f in dataclasses.fields(mjx.Data)}\n",
+ " init_params[\"qpos\"] = qpos\n",
+ " init_params[\"ctrl\"] = ctrl\n",
+ " init_params[\"qvel\"] = jp.zeros(m.nv)\n",
+ " init_params[\"xpos\"] = jp.zeros((m.nbody, 3))\n",
+ " init_params[\"xmat\"] = jp.tile(jp.zeros((3, 3)), (m.nbody, 1, 1))\n",
+ " init_params[\"qacc_warmstart\"] = jp.array(mjd.qacc_warmstart)\n",
+ " init_params[\"subtree_com\"] = jp.zeros((m.nbody, 3))\n",
+ " init_params[\"cvel\"] = jp.zeros((m.nbody, 6))\n",
+ " init_params[\"site_xpos\"] = jp.zeros((m.nsite, 3))\n",
+ " return mjx.Data(**init_params)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AaYQk2lu7Bus"
+ },
+ "outputs": [],
+ "source": [
+ "# An environment for training an Apptronik Apollo to walk.\n",
+ "# This is the same format as environments in MuJoCo Playground.\n",
+ "\n",
+ "\n",
+ "def default_config() -> config_dict.ConfigDict:\n",
+ " return config_dict.create(\n",
+ " ctrl_dt=0.02,\n",
+ " sim_dt=0.005,\n",
+ " episode_length=1000,\n",
+ " action_repeat=1,\n",
+ " action_scale=0.5,\n",
+ " soft_joint_pos_limit_factor=0.95,\n",
+ " reward_config=config_dict.create(\n",
+ " scales=config_dict.create(\n",
+ " # Tracking related rewards.\n",
+ " tracking_lin_vel=1.0,\n",
+ " tracking_ang_vel=0.75,\n",
+ " # Base related rewards.\n",
+ " ang_vel_xy=-0.15,\n",
+ " orientation=-2.0,\n",
+ " # Energy related rewards.\n",
+ " action_rate=0.0,\n",
+ " # Feet related rewards.\n",
+ " feet_air_time=2.0,\n",
+ " feet_slip=-0.25,\n",
+ " feet_phase=1.0,\n",
+ " # Other rewards.\n",
+ " termination=-5.0,\n",
+ " # Pose related rewards.\n",
+ " joint_deviation_knee=-0.1,\n",
+ " joint_deviation_hip=-0.25,\n",
+ " dof_pos_limits=-1.0,\n",
+ " pose=-0.1,\n",
+ " ),\n",
+ " tracking_sigma=0.25,\n",
+ " max_foot_height=0.15,\n",
+ " base_height_target=0.5,\n",
+ " ),\n",
+ " lin_vel_x=[1.0, 1.0],\n",
+ " lin_vel_y=[0.0, 0.0],\n",
+ " ang_vel_yaw=[0.0, 0.0],\n",
+ " )\n",
+ "\n",
+ "\n",
+ "class Joystick(mjx_env.MjxEnv):\n",
+ " \"\"\"Track a joystick command.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " config: config_dict.ConfigDict = default_config(),\n",
+ " config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,\n",
+ " ):\n",
+ " super().__init__(config, config_overrides)\n",
+ " self._post_init()\n",
+ "\n",
+ " def _post_init(self) -> None:\n",
+ " self._init_q = jp.array(mjm.keyframe(\"stand\").qpos)\n",
+ " self._default_pose = self._init_q[7:]\n",
+ "\n",
+ " # Note: First joint is freejoint.\n",
+ " self._lowers, self._uppers = self.mj_model.jnt_range[1:].T\n",
+ " c = (self._lowers + self._uppers) / 2\n",
+ " r = self._uppers - self._lowers\n",
+ " self._soft_lowers = c - 0.5 * r * self._config.soft_joint_pos_limit_factor\n",
+ " self._soft_uppers = c + 0.5 * r * self._config.soft_joint_pos_limit_factor\n",
+ "\n",
+ " hip_joints = [\"l_hip_ie\", \"l_hip_aa\", \"r_hip_ie\", \"r_hip_aa\"]\n",
+ " hip_indices = [mjm.joint(j).qposadr - 7 for j in hip_joints]\n",
+ " self._hip_indices = jp.array(hip_indices)\n",
+ "\n",
+ " knee_joints = [\"l_knee_fe\", \"r_knee_fe\"]\n",
+ " knee_indices = [mjm.joint(j).qposadr - 7 for j in knee_joints]\n",
+ " self._knee_indices = jp.array(knee_indices)\n",
+ "\n",
+ " self._head_body_id = mjm.body(\"neck_pitch_link\").id\n",
+ " self._torso_id = mjm.body(\"torso_link\").id\n",
+ "\n",
+ " feet_sites = [\"l_foot_fr\", \"l_foot_br\", \"l_foot_fl\", \"l_foot_bl\"]\n",
+ " feet_sites += [\"r_foot_fr\", \"r_foot_br\", \"r_foot_fl\", \"r_foot_bl\"]\n",
+ " feet_site_ids = [mjm.site(s).id for s in feet_sites]\n",
+ " self._feet_site_id = jp.array(feet_site_ids)\n",
+ " self._feet_contact_z = 0.003\n",
+ "\n",
+ " self._floor_geom_id = mjm.geom(\"floor\").id\n",
+ "\n",
+ " def reset(self, rng: jax.Array) -> mjx_env.State:\n",
+ " qpos = self._init_q\n",
+ "\n",
+ " data = init(qpos=qpos, ctrl=qpos[7:])\n",
+ "\n",
+ " # Phase, freq=U(1.0, 1.5)\n",
+ " rng, key = jax.random.split(rng)\n",
+ " gait_freq = jax.random.uniform(key, (1,), minval=1.25, maxval=1.5)\n",
+ " phase_dt = 2 * jp.pi * self.dt * gait_freq\n",
+ " phase = jp.array([0, jp.pi])\n",
+ "\n",
+ " rng, cmd_rng = jax.random.split(rng)\n",
+ " cmd = self.sample_command(cmd_rng)\n",
+ "\n",
+ " info = {\n",
+ " \"rng\": rng,\n",
+ " \"step\": 0,\n",
+ " \"command\": cmd,\n",
+ " \"last_act\": jp.zeros(mjm.nu),\n",
+ " \"last_last_act\": jp.zeros(mjm.nu),\n",
+ " \"motor_targets\": jp.zeros(mjm.nu),\n",
+ " \"feet_air_time\": jp.zeros(2),\n",
+ " \"last_contact\": jp.zeros(2, dtype=bool),\n",
+ " # Phase related.\n",
+ " \"phase_dt\": phase_dt,\n",
+ " \"phase\": phase,\n",
+ " }\n",
+ "\n",
+ " metrics = {}\n",
+ " for k in self._config.reward_config.scales.keys():\n",
+ " metrics[f\"reward/{k}\"] = jp.zeros(())\n",
+ "\n",
+ " contact = data.site_xpos[self._feet_site_id] < self._feet_contact_z\n",
+ " contact = jp.array([contact[0:4].any(), contact[4:8].any()])\n",
+ " obs = self._get_obs(data, info, contact)\n",
+ " reward, done = jp.zeros(2)\n",
+ " return mjx_env.State(data, obs, reward, done, metrics, info)\n",
+ "\n",
+ " def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:\n",
+ " state.info[\"rng\"], _ = jax.random.split(state.info[\"rng\"], 2)\n",
+ "\n",
+ " ctrl = self._default_pose + action * self._config.action_scale\n",
+ " data = state.data\n",
+ " data = data.replace(ctrl=ctrl)\n",
+ " data = step(data)\n",
+ " state.info[\"motor_targets\"] = ctrl\n",
+ "\n",
+ " contact = data.site_xpos[self._feet_site_id, 2] < self._feet_contact_z\n",
+ " contact = jp.array([contact[0:4].any(), contact[4:8].any()])\n",
+ " contact_filt = contact | state.info[\"last_contact\"]\n",
+ " first_contact = (state.info[\"feet_air_time\"] > 0.0) * contact_filt\n",
+ " state.info[\"feet_air_time\"] += self.dt\n",
+ "\n",
+ " obs = self._get_obs(data, state.info, contact)\n",
+ " done = self._get_termination(data)\n",
+ "\n",
+ " rewards = self._get_reward(\n",
+ " data, action, state.info, state.metrics, done, first_contact, contact\n",
+ " )\n",
+ " rewards = {k: v * self._config.reward_config.scales[k] for k, v in rewards.items()}\n",
+ " reward = sum(rewards.values()) * self.dt\n",
+ "\n",
+ " state.info[\"step\"] += 1\n",
+ " phase_tp1 = state.info[\"phase\"] + state.info[\"phase_dt\"]\n",
+ " state.info[\"phase\"] = jp.fmod(phase_tp1 + jp.pi, 2 * jp.pi) - jp.pi\n",
+ " state.info[\"last_last_act\"] = state.info[\"last_act\"]\n",
+ " state.info[\"last_act\"] = action\n",
+ " state.info[\"rng\"], cmd_rng = jax.random.split(state.info[\"rng\"])\n",
+ " state.info[\"command\"] = jp.where(\n",
+ " state.info[\"step\"] > 500,\n",
+ " self.sample_command(cmd_rng),\n",
+ " state.info[\"command\"],\n",
+ " )\n",
+ " state.info[\"step\"] = jp.where(\n",
+ " done | (state.info[\"step\"] > 500),\n",
+ " 0,\n",
+ " state.info[\"step\"],\n",
+ " )\n",
+ " state.info[\"feet_air_time\"] *= ~contact\n",
+ " state.info[\"last_contact\"] = contact\n",
+ " for k, v in rewards.items():\n",
+ " state.metrics[f\"reward/{k}\"] = v\n",
+ "\n",
+ " done = done.astype(reward.dtype)\n",
+ " state = state.replace(data=data, obs=obs, reward=reward, done=done)\n",
+ " return state\n",
+ "\n",
+ " def _get_termination(self, data: mjx.Data) -> jax.Array:\n",
+ " fall_termination = data.xpos[self._head_body_id, 2] < 1.0\n",
+ " return fall_termination\n",
+ "\n",
+ " def _get_obs(\n",
+ " self, data: mjx.Data, info: dict[str, Any], contact: jax.Array\n",
+ " ) -> mjx_env.Observation:\n",
+ " cos = jp.cos(info[\"phase\"])\n",
+ " sin = jp.sin(info[\"phase\"])\n",
+ " phase = jp.concatenate([cos, sin])\n",
+ "\n",
+ " return jp.hstack(\n",
+ " [\n",
+ " data.qpos,\n",
+ " data.qvel,\n",
+ " data.cvel.ravel(),\n",
+ " data.xpos.ravel(),\n",
+ " data.xmat.ravel(),\n",
+ " phase,\n",
+ " info[\"command\"],\n",
+ " info[\"last_act\"],\n",
+ " info[\"feet_air_time\"],\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " def _get_reward(\n",
+ " self,\n",
+ " data: mjx.Data,\n",
+ " action: jax.Array,\n",
+ " info: dict[str, Any],\n",
+ " metrics: dict[str, Any],\n",
+ " done: jax.Array,\n",
+ " first_contact: jax.Array,\n",
+ " contact: jax.Array,\n",
+ " ) -> dict[str, jax.Array]:\n",
+ " del metrics # Unused.\n",
+ " return {\n",
+ " # Tracking rewards.\n",
+ " \"tracking_lin_vel\": self._reward_tracking_lin_vel(\n",
+ " info[\"command\"], self._get_global_linvel(data, self._torso_id)\n",
+ " ),\n",
+ " \"tracking_ang_vel\": self._reward_tracking_ang_vel(\n",
+ " info[\"command\"], self._get_global_angvel(data, self._torso_id)\n",
+ " ),\n",
+ " # Base-related rewards.\n",
+ " \"ang_vel_xy\": self._cost_ang_vel_xy(\n",
+ " self._get_global_angvel(data, self._torso_id)\n",
+ " ),\n",
+ " \"orientation\": self._cost_orientation(self._get_z_frame(data, self._torso_id)),\n",
+ " # Energy related rewards.\n",
+ " \"action_rate\": self._cost_action_rate(\n",
+ " action, info[\"last_act\"], info[\"last_last_act\"]\n",
+ " ),\n",
+ " # Feet related rewards.\n",
+ " \"feet_slip\": self._cost_feet_slip(data, contact, info),\n",
+ " \"feet_air_time\": self._reward_feet_air_time(\n",
+ " info[\"feet_air_time\"], first_contact, info[\"command\"]\n",
+ " ),\n",
+ " \"feet_phase\": self._reward_feet_phase(\n",
+ " data,\n",
+ " info[\"phase\"],\n",
+ " self._config.reward_config.max_foot_height,\n",
+ " info[\"command\"],\n",
+ " ),\n",
+ " # Pose related rewards.\n",
+ " \"joint_deviation_hip\": self._cost_joint_deviation_hip(\n",
+ " data.qpos[7:], info[\"command\"]\n",
+ " ),\n",
+ " \"joint_deviation_knee\": self._cost_joint_deviation_knee(data.qpos[7:]),\n",
+ " \"dof_pos_limits\": self._cost_joint_pos_limits(data.qpos[7:]),\n",
+ " \"pose\": self._cost_pose(data.qpos[7:]),\n",
+ " # Other rewards.\n",
+ " \"termination\": self._cost_termination(done),\n",
+ " }\n",
+ "\n",
+ " def _get_global_angvel(self, data: mjx.Data, bodyid: int):\n",
+ " return data.cvel[bodyid, :3]\n",
+ "\n",
+ " def _get_global_linvel(self, data: mjx.Data, bodyid: int):\n",
+ " offset = data.xpos[bodyid] - data.subtree_com[mjm.body_rootid[bodyid]]\n",
+ " xang = data.cvel[bodyid, :3]\n",
+ " xvel = data.cvel[bodyid, 3:] + jp.cross(offset, xang)\n",
+ " return xvel\n",
+ "\n",
+ " def _get_z_frame(self, data: mjx.Data, bodyid: int):\n",
+ " return data.xmat[bodyid, :, 2]\n",
+ "\n",
+ " # Tracking rewards.\n",
+ "\n",
+ " def _reward_tracking_lin_vel(\n",
+ " self,\n",
+ " commands: jax.Array,\n",
+ " local_vel: jax.Array,\n",
+ " ) -> jax.Array:\n",
+ " lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))\n",
+ " return jp.exp(-lin_vel_error / self._config.reward_config.tracking_sigma)\n",
+ "\n",
+ " def _reward_tracking_ang_vel(\n",
+ " self,\n",
+ " commands: jax.Array,\n",
+ " ang_vel: jax.Array,\n",
+ " ) -> jax.Array:\n",
+ " ang_vel_error = jp.square(commands[2] - ang_vel[2])\n",
+ " return jp.exp(-ang_vel_error / self._config.reward_config.tracking_sigma)\n",
+ "\n",
+ " # Base-related rewards.\n",
+ "\n",
+ " def _cost_ang_vel_xy(self, global_angvel_torso: jax.Array) -> jax.Array:\n",
+ " return jp.sum(jp.square(global_angvel_torso[:2]))\n",
+ "\n",
+ " def _cost_orientation(self, torso_zaxis: jax.Array) -> jax.Array:\n",
+ " return jp.sum(jp.square(torso_zaxis - jp.array([0.0, 0.0, 1.0])))\n",
+ "\n",
+ " def _cost_base_height(self, base_height: jax.Array) -> jax.Array:\n",
+ " return jp.square(base_height - self._config.reward_config.base_height_target)\n",
+ "\n",
+ " # Energy related rewards.\n",
+ "\n",
+ " def _cost_action_rate(\n",
+ " self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array\n",
+ " ) -> jax.Array:\n",
+ " del last_last_act # Unused.\n",
+ " return jp.sum(jp.square(act - last_act))\n",
+ "\n",
+ " # Feet related rewards.\n",
+ "\n",
+ " def _cost_feet_slip(\n",
+ " self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]\n",
+ " ) -> jax.Array:\n",
+ " del info # Unused.\n",
+ " body_vel = self._get_global_linvel(data, self._torso_id)[:2]\n",
+ " reward = jp.sum(jp.linalg.norm(body_vel, axis=-1) * contact)\n",
+ " return reward\n",
+ "\n",
+ " def _reward_feet_air_time(\n",
+ " self,\n",
+ " air_time: jax.Array,\n",
+ " first_contact: jax.Array,\n",
+ " commands: jax.Array,\n",
+ " threshold_min: float = 0.2,\n",
+ " threshold_max: float = 0.5,\n",
+ " ) -> jax.Array:\n",
+ " del commands # Unused.\n",
+ " air_time = (air_time - threshold_min) * first_contact\n",
+ " air_time = jp.clip(air_time, max=threshold_max - threshold_min)\n",
+ " reward = jp.sum(air_time)\n",
+ " return reward\n",
+ "\n",
+ " def get_rz(\n",
+ " phi: Union[jax.Array, float], swing_height: Union[jax.Array, float] = 0.08\n",
+ " ) -> jax.Array:\n",
+ " def cubic_bezier_interpolation(y_start, y_end, x):\n",
+ " y_diff = y_end - y_start\n",
+ " bezier = x**3 + 3 * (x**2 * (1 - x))\n",
+ " return y_start + y_diff * bezier\n",
+ "\n",
+ " x = (phi + jp.pi) / (2 * jp.pi)\n",
+ " stance = cubic_bezier_interpolation(0, swing_height, 2 * x)\n",
+ " swing = cubic_bezier_interpolation(swing_height, 0, 2 * x - 1)\n",
+ " return jp.where(x <= 0.5, stance, swing)\n",
+ "\n",
+ " def _reward_feet_phase(\n",
+ " self,\n",
+ " data: mjx.Data,\n",
+ " phase: jax.Array,\n",
+ " foot_height: jax.Array,\n",
+ " command: jax.Array,\n",
+ " ) -> jax.Array:\n",
+ " # Reward for tracking the desired foot height.\n",
+ " foot_pos = data.site_xpos[self._feet_site_id]\n",
+ " foot_pos = jp.array(\n",
+ " [jp.mean(foot_pos[0:4], axis=0), jp.mean(foot_pos[4:8], axis=0)]\n",
+ " )\n",
+ " foot_z = foot_pos[..., -1]\n",
+ " rz = Joystick.get_rz(phase, swing_height=foot_height)\n",
+ " error = jp.sum(jp.square(foot_z - rz))\n",
+ " reward = jp.exp(-error / 0.01)\n",
+ " body_linvel = self._get_global_linvel(data, self._torso_id)[:2]\n",
+ " body_angvel = self._get_global_angvel(data, self._torso_id)[2]\n",
+ " linvel_mask = jp.logical_or(\n",
+ " jp.linalg.norm(body_linvel) > 0.1,\n",
+ " jp.abs(body_angvel) > 0.1,\n",
+ " )\n",
+ " mask = jp.logical_or(linvel_mask, jp.linalg.norm(command) > 0.01)\n",
+ " reward *= mask\n",
+ " return reward\n",
+ "\n",
+ " # Pose-related rewards.\n",
+ "\n",
+ " def _cost_joint_deviation_hip(self, qpos: jax.Array, cmd: jax.Array) -> jax.Array:\n",
+ " error = qpos[self._hip_indices] - self._default_pose[self._hip_indices]\n",
+ " # Allow roll deviation when lateral velocity is high.\n",
+ " weight = jp.where(\n",
+ " cmd[1] > 0.1,\n",
+ " jp.array([0.0, 1.0, 0.0, 1.0]),\n",
+ " jp.array([1.0, 1.0, 1.0, 1.0]),\n",
+ " )\n",
+ " cost = jp.sum(jp.abs(error) * weight)\n",
+ " return cost\n",
+ "\n",
+ " def _cost_joint_deviation_knee(self, qpos: jax.Array) -> jax.Array:\n",
+ " error = qpos[self._knee_indices] - self._default_pose[self._knee_indices]\n",
+ " return jp.sum(jp.abs(error))\n",
+ "\n",
+ " def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array:\n",
+ " out_of_limits = -jp.clip(qpos - self._soft_lowers, None, 0.0)\n",
+ " out_of_limits += jp.clip(qpos - self._soft_uppers, 0.0, None)\n",
+ " return jp.sum(out_of_limits)\n",
+ "\n",
+ " def _cost_pose(self, qpos: jax.Array) -> jax.Array:\n",
+ " return jp.sum(jp.square(qpos - self._default_pose))\n",
+ "\n",
+ " # Other rewards.\n",
+ "\n",
+ " def _cost_termination(self, done: jax.Array) -> jax.Array:\n",
+ " return done\n",
+ "\n",
+ " def sample_command(self, rng: jax.Array) -> jax.Array:\n",
+ " rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)\n",
+ "\n",
+ " lin_vel_x = jax.random.uniform(\n",
+ " rng1, minval=self._config.lin_vel_x[0], maxval=self._config.lin_vel_x[1]\n",
+ " )\n",
+ " lin_vel_y = jax.random.uniform(\n",
+ " rng2, minval=self._config.lin_vel_y[0], maxval=self._config.lin_vel_y[1]\n",
+ " )\n",
+ " ang_vel_yaw = jax.random.uniform(\n",
+ " rng3,\n",
+ " minval=self._config.ang_vel_yaw[0],\n",
+ " maxval=self._config.ang_vel_yaw[1],\n",
+ " )\n",
+ "\n",
+ " return jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw])\n",
+ "\n",
+ " @property\n",
+ " def xml_path(self) -> str:\n",
+ " return xml_path.as_posix()\n",
+ "\n",
+ " @property\n",
+ " def action_size(self) -> int:\n",
+ " return mjm.nu\n",
+ "\n",
+ " @property\n",
+ " def mj_model(self) -> mujoco.MjModel:\n",
+ " return mjm\n",
+ "\n",
+ " @property\n",
+ " def mjx_model(self) -> mjx.Model:\n",
+ " return None # unused"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "aoWmBHXiU-HW"
+ },
+ "outputs": [],
+ "source": [
+ "# Train the environment using Brax PPO\n",
+ "\n",
+ "from IPython.display import HTML, clear_output\n",
+ "from brax.training.agents.ppo import networks as ppo_networks\n",
+ "from brax.training.agents.ppo import train as ppo\n",
+ "from matplotlib import pyplot as plt\n",
+ "from mujoco_playground import wrapper\n",
+ "from mujoco_playground.config import locomotion_params\n",
+ "\n",
+ "ppo_params = config_dict.create(\n",
+ " num_timesteps=50_000_000,\n",
+ " num_evals=5,\n",
+ " reward_scaling=1.0,\n",
+ " clipping_epsilon=0.2,\n",
+ " episode_length=1000,\n",
+ " normalize_observations=True,\n",
+ " action_repeat=1,\n",
+ " unroll_length=20,\n",
+ " num_minibatches=32,\n",
+ " num_updates_per_batch=4,\n",
+ " discounting=0.97,\n",
+ " learning_rate=3e-4,\n",
+ " entropy_cost=0.005,\n",
+ " num_envs=NWORLD,\n",
+ " num_eval_envs=NWORLD,\n",
+ " batch_size=256,\n",
+ " max_grad_norm=1.0,\n",
+ " # network_factory=config_dict.create(\n",
+ " # policy_hidden_layer_sizes=(512, 256, 128),\n",
+ " # value_hidden_layer_sizes=(512, 256, 128),\n",
+ " # ),\n",
+ ")\n",
+ "\n",
+ "x_data, y_data, y_dataerr = [], [], []\n",
+ "times = [datetime.datetime.now()]\n",
+ "\n",
+ "\n",
+ "def progress(num_steps, metrics):\n",
+ " times.append(datetime.datetime.now())\n",
+ " x_data.append(num_steps)\n",
+ " y_data.append(metrics[\"eval/episode_reward\"])\n",
+ " y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
+ "\n",
+ " plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
+ " plt.ylim([0, 30])\n",
+ " plt.xlabel(\"# environment steps\")\n",
+ " plt.ylabel(\"reward per episode\")\n",
+ " plt.title(f\"y={y_data[-1]:.3f}\")\n",
+ " plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
+ "\n",
+ " display(plt.gcf())\n",
+ " clear_output(wait=True)\n",
+ "\n",
+ "\n",
+ "ppo_training_params = dict(ppo_params)\n",
+ "network_factory = ppo_networks.make_ppo_networks\n",
+ "if \"network_factory\" in ppo_params:\n",
+ " del ppo_training_params[\"network_factory\"]\n",
+ " network_factory = functools.partial(\n",
+ " ppo_networks.make_ppo_networks, **ppo_params.network_factory\n",
+ " )\n",
+ "\n",
+ "train_fn = functools.partial(\n",
+ " ppo.train,\n",
+ " **dict(ppo_training_params),\n",
+ " network_factory=network_factory,\n",
+ " progress_fn=progress,\n",
+ ")\n",
+ "\n",
+ "env = Joystick()\n",
+ "\n",
+ "make_inference_fn, params, metrics = train_fn(\n",
+ " environment=env,\n",
+ " wrap_env_fn=wrapper.wrap_for_brax_training,\n",
+ ")\n",
+ "\n",
+ "print(f\"time to jit: {times[1] - times[0]}\")\n",
+ "print(f\"time to train: {times[-1] - times[1]}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gGRhELoRjtQs"
+ },
+ "outputs": [],
+ "source": [
+ "rng = jax.random.PRNGKey(0)\n",
+ "\n",
+ "jit_reset = jax.jit(env.reset)\n",
+ "\n",
+ "\n",
+ "def unroll(state):\n",
+ " inference_fn = make_inference_fn(params, deterministic=True)\n",
+ " rng = jax.random.PRNGKey(0)\n",
+ "\n",
+ " def single_step(state, _):\n",
+ " action, _ = inference_fn(state.obs, rng)\n",
+ " action = jp.tile(action, (NWORLD, 1))\n",
+ " state = jax.tree.map(lambda x: jp.tile(x, (NWORLD,) + (1,) * len(x.shape)), state)\n",
+ " state = jax.vmap(env.step)(state, action)\n",
+ " state = jax.tree.map(lambda x: x[0], state)\n",
+ "\n",
+ " return state, state\n",
+ "\n",
+ " _, states = jax.lax.scan(single_step, state, length=1000)\n",
+ "\n",
+ " return states\n",
+ "\n",
+ "\n",
+ "rollout = jax.jit(unroll)(jit_reset(rng))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_kj9fCiDn6j6"
+ },
+ "outputs": [],
+ "source": [
+ "rollout_arr = [jax.tree.map(lambda x, i=i: x[i], rollout) for i in range(400)]\n",
+ "frames = env.render(rollout_arr, camera=\"track\", width=640, height=480)\n",
+ "media.show_video(frames, fps=1.0 / env.dt)"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/contrib/jax_unroll.py b/contrib/jax_unroll.py
new file mode 100644
index 00000000..ec6db438
--- /dev/null
+++ b/contrib/jax_unroll.py
@@ -0,0 +1,86 @@
+# before running this, you may need to install JAX, most likely
+# backed by your local cuda install:
+#
+# pip install --upgrade "jax[cuda12_local]"
+
+import os
+import time
+
+import jax
+import mujoco
+import numpy as np
+import warp as wp
+from etils import epath
+from jax import numpy as jp
+from warp.jax_experimental.ffi import jax_callable
+
+import mujoco_warp as mjwarp
+from mujoco_warp._src.warp_util import kernel_copy
+
+os.environ["XLA_FLAGS"] = "--xla_gpu_graph_min_graph_size=1"
+
+NWORLDS = 8192
+UNROLL_LENGTH = 1000
+
+wp.clear_kernel_cache()
+
+path = epath.resource_path("mujoco_warp") / "test_data" / "humanoid/humanoid.xml"
+mjm = mujoco.MjModel.from_xml_path(path.as_posix())
+mjm.opt.iterations = 1
+mjm.opt.ls_iterations = 4
+mjd = mujoco.MjData(mjm)
+# give the system a little kick to ensure we have non-identity rotations
+mjd.qvel = np.random.uniform(-0.01, 0.01, mjm.nv)
+mujoco.mj_step(mjm, mjd, 3) # let dynamics get state significantly non-zero
+mujoco.mj_forward(mjm, mjd)
+m = mjwarp.put_model(mjm)
+d = mjwarp.put_data(mjm, mjd, nworld=NWORLDS, nconmax=131012, njmax=131012 * 4)
+
+
+def warp_step(
+ qpos_in: wp.array(dtype=wp.float32, ndim=2),
+ qvel_in: wp.array(dtype=wp.float32, ndim=2),
+ qpos_out: wp.array(dtype=wp.float32, ndim=2),
+ qvel_out: wp.array(dtype=wp.float32, ndim=2),
+):
+ kernel_copy(d.qpos, qpos_in)
+ kernel_copy(d.qvel, qvel_in)
+ mjwarp.step(m, d)
+ kernel_copy(qpos_out, d.qpos)
+ kernel_copy(qvel_out, d.qvel)
+
+
+warp_step_fn = jax_callable(
+ warp_step,
+ num_outputs=2,
+ output_dims={"qpos_out": (NWORLDS, mjm.nq), "qvel_out": (NWORLDS, mjm.nv)},
+)
+
+jax_qpos = jp.tile(jp.array(m.qpos0), (8192, 1))
+jax_qvel = jp.zeros((8192, m.nv))
+
+
+def unroll(qpos, qvel):
+ def step(carry, _):
+ qpos, qvel = carry
+ qpos, qvel = warp_step_fn(qpos, qvel)
+ return (qpos, qvel), None
+
+ (qpos, qvel), _ = jax.lax.scan(step, (qpos, qvel), length=UNROLL_LENGTH)
+
+ return qpos, qvel
+
+
+jax_unroll_fn = jax.jit(unroll).lower(jax_qpos, jax_qvel).compile()
+
+# warm up:
+jax.block_until_ready(jax_unroll_fn(jax_qpos, jax_qvel))
+
+beg = time.perf_counter()
+final_qpos, final_qvel = jax_unroll_fn(jax_qpos, jax_qvel)
+jax.block_until_ready((final_qpos, final_qvel))
+end = time.perf_counter()
+
+run_time = end - beg
+
+print(f"Total steps per second: {NWORLDS * UNROLL_LENGTH / run_time:,.0f}")
diff --git a/contrib/xml/apptronik_apollo.xml b/contrib/xml/apptronik_apollo.xml
new file mode 100644
index 00000000..31acc2c4
--- /dev/null
+++ b/contrib/xml/apptronik_apollo.xml
@@ -0,0 +1,397 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/contrib/xml/scene.xml b/contrib/xml/scene.xml
new file mode 100644
index 00000000..d8e44e1f
--- /dev/null
+++ b/contrib/xml/scene.xml
@@ -0,0 +1,35 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mujoco/mjx/__init__.py b/mujoco/mjx/__init__.py
deleted file mode 100644
index 54db1ce0..00000000
--- a/mujoco/mjx/__init__.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright 2025 The Physics-Next Project Developers
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Public API for MJX."""
-
-from ._src.collision_driver import broad_phase
-from ._src.constraint import make_constraint
-from ._src.forward import euler
-from ._src.forward import forward
-from ._src.forward import fwd_actuation
-from ._src.forward import fwd_acceleration
-from ._src.forward import fwd_position
-from ._src.forward import fwd_velocity
-from ._src.forward import step
-from ._src.io import make_data
-from ._src.io import put_data
-from ._src.io import put_model
-from ._src.passive import passive
-from ._src.smooth import com_pos
-from ._src.smooth import com_vel
-from ._src.smooth import crb
-from ._src.smooth import factor_m
-from ._src.smooth import kinematics
-from ._src.smooth import rne
-from ._src.smooth import solve_m
-from ._src.smooth import transmission
-from ._src.solver import solve
-from ._src.support import is_sparse
-from ._src.support import mul_m
-from ._src.support import xfrc_accumulate
-from ._src.test_util import benchmark
-from ._src.types import *
diff --git a/mujoco/mjx/_src/__init__.py b/mujoco/mjx/_src/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/mujoco/mjx/_src/broad_phase_test.py b/mujoco/mjx/_src/broad_phase_test.py
deleted file mode 100644
index 59e2150c..00000000
--- a/mujoco/mjx/_src/broad_phase_test.py
+++ /dev/null
@@ -1,282 +0,0 @@
-# Copyright 2025 The Physics-Next Project Developers
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Tests for broad phase functions."""
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import mujoco
-from mujoco import mjx
-import numpy as np
-import warp as wp
-
-from . import test_util
-
-BoxType = wp.types.matrix(shape=(2, 3), dtype=wp.float32)
-
-
-# Helper function to initialize a box
-def init_box(min_x, min_y, min_z, max_x, max_y, max_z):
- center = wp.vec3((min_x + max_x) / 2, (min_y + max_y) / 2, (min_z + max_z) / 2)
- size = wp.vec3(max_x - min_x, max_y - min_y, max_z - min_z)
- box = wp.types.matrix(shape=(2, 3), dtype=wp.float32)(
- [center.x, center.y, center.z, size.x, size.y, size.z]
- )
- return box
-
-
-def overlap(
- a: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
- b: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
-) -> bool:
- # Extract centers and sizes
- a_center = a[0]
- a_size = a[1]
- b_center = b[0]
- b_size = b[1]
-
- # Calculate min/max from center and size
- a_min = a_center - 0.5 * a_size
- a_max = a_center + 0.5 * a_size
- b_min = b_center - 0.5 * b_size
- b_max = b_center + 0.5 * b_size
-
- return not (
- a_min.x > b_max.x
- or b_min.x > a_max.x
- or a_min.y > b_max.y
- or b_min.y > a_max.y
- or a_min.z > b_max.z
- or b_min.z > a_max.z
- )
-
-
-def transform_aabb(
- aabb: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
- pos: wp.vec3,
- rot: wp.mat33,
-) -> wp.types.matrix(shape=(2, 3), dtype=wp.float32):
- # Extract center and half-extents from AABB
- center = aabb[0]
- half_extents = aabb[1] * 0.5
-
- # Get absolute values of rotation matrix columns
- right = wp.vec3(wp.abs(rot[0, 0]), wp.abs(rot[0, 1]), wp.abs(rot[0, 2]))
- up = wp.vec3(wp.abs(rot[1, 0]), wp.abs(rot[1, 1]), wp.abs(rot[1, 2]))
- forward = wp.vec3(wp.abs(rot[2, 0]), wp.abs(rot[2, 1]), wp.abs(rot[2, 2]))
-
- # Compute world space half-extents
- world_extents = (
- right * half_extents.x + up * half_extents.y + forward * half_extents.z
- )
-
- # Transform center
- new_center = rot @ center + pos
-
- # Return new AABB as matrix with center and full size
- result = BoxType()
- result[0] = wp.vec3(new_center.x, new_center.y, new_center.z)
- result[1] = wp.vec3(
- world_extents.x * 2.0, world_extents.y * 2.0, world_extents.z * 2.0
- )
- return result
-
-
-def find_overlaps_brute_force(worldId: int, num_boxes_per_world: int, boxes, pos, rot):
- """
- Finds overlapping bounding boxes using the brute-force O(n^2) algorithm.
-
- Returns:
- List of tuples [(idx1, idx2)] where idx1 and idx2 are indices of overlapping boxes.
- """
- overlaps = []
-
- for i in range(num_boxes_per_world):
- box_a = boxes[i]
- box_a = transform_aabb(box_a, pos[worldId, i], rot[worldId, i])
-
- for j in range(i + 1, num_boxes_per_world):
- box_b = boxes[j]
- box_b = transform_aabb(box_b, pos[worldId, j], rot[worldId, j])
-
- # Use the overlap function to check for overlap
- if overlap(box_a, box_b):
- overlaps.append((i, j)) # Store indices of overlapping boxes
-
- return overlaps
-
-
-def find_overlaps_brute_force_batched(
- num_worlds: int, num_boxes_per_world: int, boxes, pos, rot
-):
- """
- Finds overlapping bounding boxes using the brute-force O(n^2) algorithm.
-
- Returns:
- List of tuples [(idx1, idx2)] where idx1 and idx2 are indices of overlapping boxes.
- """
-
- overlaps = []
-
- for worldId in range(num_worlds):
- overlaps.append(
- find_overlaps_brute_force(worldId, num_boxes_per_world, boxes, pos, rot)
- )
-
- # Show progress bar for brute force computation
- # from tqdm import tqdm
-
- # for worldId in tqdm(range(num_worlds), desc="Computing overlaps"):
- # overlaps.append(find_overlaps_brute_force(worldId, num_boxes_per_world, boxes))
-
- return overlaps
-
-
-class MultiIndexList:
- def __init__(self):
- self.data = {}
-
- def __setitem__(self, key, value):
- worldId, i = key
- if worldId not in self.data:
- self.data[worldId] = []
- if i >= len(self.data[worldId]):
- self.data[worldId].extend([None] * (i - len(self.data[worldId]) + 1))
- self.data[worldId][i] = value
-
- def __getitem__(self, key):
- worldId, i = key
- return self.data[worldId][i] # Raises KeyError if not found
-
-
-class BroadPhaseTest(parameterized.TestCase):
- def test_broad_phase(self):
- """Tests broad phase."""
- _, mjd, m, d = test_util.fixture("humanoid/humanoid.xml")
-
- # Create some test boxes
- num_worlds = d.nworld
- num_boxes_per_world = m.ngeom
- # print(f"num_worlds: {num_worlds}, num_boxes_per_world: {num_boxes_per_world}")
-
- # Parameters for random box generation
- sample_space_origin = wp.vec3(-10.0, -10.0, -10.0) # Origin of the bounding volume
- sample_space_size = wp.vec3(20.0, 20.0, 20.0) # Size of the bounding volume
- min_edge_length = 0.5 # Minimum edge length of random boxes
- max_edge_length = 5.0 # Maximum edge length of random boxes
-
- boxes_list = []
-
- # Set random seed for reproducibility
- import random
-
- random.seed(11)
-
- # Generate random boxes for each world
- for _ in range(num_boxes_per_world):
- # Generate random position within bounding volume
- pos_x = sample_space_origin.x + random.random() * sample_space_size.x
- pos_y = sample_space_origin.y + random.random() * sample_space_size.y
- pos_z = sample_space_origin.z + random.random() * sample_space_size.z
-
- # Generate random box dimensions between min and max edge lengths
- size_x = min_edge_length + random.random() * (max_edge_length - min_edge_length)
- size_y = min_edge_length + random.random() * (max_edge_length - min_edge_length)
- size_z = min_edge_length + random.random() * (max_edge_length - min_edge_length)
-
- # Create box with random position and size
- boxes_list.append(
- init_box(pos_x, pos_y, pos_z, pos_x + size_x, pos_y + size_y, pos_z + size_z)
- )
-
- # Generate random positions and orientations for each box
- pos = []
- rot = []
- for _ in range(num_worlds * num_boxes_per_world):
- # Random position within bounding volume
- pos_x = sample_space_origin.x + random.random() * sample_space_size.x
- pos_y = sample_space_origin.y + random.random() * sample_space_size.y
- pos_z = sample_space_origin.z + random.random() * sample_space_size.z
- pos.append(wp.vec3(pos_x, pos_y, pos_z))
- # pos.append(wp.vec3(0, 0, 0))
-
- # Random rotation matrix
- rx = random.random() * 6.28318530718 # 2*pi
- ry = random.random() * 6.28318530718
- rz = random.random() * 6.28318530718
- axis = wp.vec3(rx, ry, rz)
- axis = axis / wp.length(axis) # normalize axis
- angle = random.random() * 6.28318530718 # random angle between 0 and 2*pi
- rot.append(wp.quat_to_matrix(wp.quat_from_axis_angle(axis, angle)))
- # rot.append(wp.quat_to_matrix(wp.quat_from_axis_angle(wp.vec3(1, 0, 0), float(0))))
-
- # Convert pos and rot to MultiIndexList format
- pos_multi = MultiIndexList()
- rot_multi = MultiIndexList()
-
- # Populate the MultiIndexLists using pos and rot data
- idx = 0
- for world_idx in range(num_worlds):
- for i in range(num_boxes_per_world):
- pos_multi[world_idx, i] = pos[idx]
- rot_multi[world_idx, i] = rot[idx]
- idx += 1
-
- brute_force_overlaps = find_overlaps_brute_force_batched(
- num_worlds, num_boxes_per_world, boxes_list, pos_multi, rot_multi
- )
-
- # Test the broad phase by setting custom aabb data
- d.geom_aabb = wp.array(
- boxes_list, dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32)
- )
- d.geom_aabb = d.geom_aabb.reshape((num_boxes_per_world))
- d.geom_xpos = wp.array(pos, dtype=wp.vec3)
- d.geom_xpos = d.geom_xpos.reshape((num_worlds, num_boxes_per_world))
- d.geom_xmat = wp.array(rot, dtype=wp.mat33)
- d.geom_xmat = d.geom_xmat.reshape((num_worlds, num_boxes_per_world))
-
- mjx.broad_phase(m, d)
-
- result = d.broadphase_pairs
- result_count = d.result_count
-
- # Get numpy arrays from result and result_count
- result_np = result.numpy()
- result_count_np = result_count.numpy()
-
- # Iterate over each world
- for world_idx in range(num_worlds):
- # Get number of collisions for this world
- num_collisions = result_count_np[world_idx]
- print(f"Number of collisions for world {world_idx}: {num_collisions}")
-
- list = brute_force_overlaps[world_idx]
- assert len(list) == num_collisions, "Number of collisions does not match"
-
- # Print each collision pair
- for i in range(num_collisions):
- pair = result_np[world_idx][i]
-
- # Convert pair to tuple for comparison
- pair_tuple = (int(pair[0]), int(pair[1]))
- assert pair_tuple in list, (
- f"Collision pair {pair_tuple} not found in brute force results"
- )
-
-
-if __name__ == "__main__":
- wp.init()
- absltest.main()
diff --git a/mujoco/mjx/_src/collision_driver.py b/mujoco/mjx/_src/collision_driver.py
deleted file mode 100644
index 4567898d..00000000
--- a/mujoco/mjx/_src/collision_driver.py
+++ /dev/null
@@ -1,388 +0,0 @@
-# Copyright 2025 The Physics-Next Project Developers
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-import warp as wp
-
-from .types import Model
-from .types import Data
-
-BoxType = wp.types.matrix(shape=(2, 3), dtype=wp.float32)
-
-
-# TODO: Verify that this is corect
-@wp.func
-def transform_aabb(
- aabb: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
- pos: wp.vec3,
- rot: wp.mat33,
-) -> wp.types.matrix(shape=(2, 3), dtype=wp.float32):
- # Extract center and extents from AABB
- center = aabb[0]
- extents = aabb[1]
-
- absRot = rot
- absRot[0, 0] = wp.abs(rot[0, 0])
- absRot[0, 1] = wp.abs(rot[0, 1])
- absRot[0, 2] = wp.abs(rot[0, 2])
- absRot[1, 0] = wp.abs(rot[1, 0])
- absRot[1, 1] = wp.abs(rot[1, 1])
- absRot[1, 2] = wp.abs(rot[1, 2])
- absRot[2, 0] = wp.abs(rot[2, 0])
- absRot[2, 1] = wp.abs(rot[2, 1])
- absRot[2, 2] = wp.abs(rot[2, 2])
- world_extents = extents * absRot
-
- # Transform center
- new_center = rot @ center + pos
-
- # Return new AABB as matrix with center and full size
- result = BoxType()
- result[0] = wp.vec3(new_center.x, new_center.y, new_center.z)
- result[1] = wp.vec3(world_extents.x, world_extents.y, world_extents.z)
- return result
-
-
-@wp.func
-def overlap(
- a: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
- b: wp.types.matrix(shape=(2, 3), dtype=wp.float32),
-) -> bool:
- # Extract centers and sizes
- a_center = a[0]
- a_size = a[1]
- b_center = b[0]
- b_size = b[1]
-
- # Calculate min/max from center and size
- a_min = a_center - 0.5 * a_size
- a_max = a_center + 0.5 * a_size
- b_min = b_center - 0.5 * b_size
- b_max = b_center + 0.5 * b_size
-
- return not (
- a_min.x > b_max.x
- or b_min.x > a_max.x
- or a_min.y > b_max.y
- or b_min.y > a_max.y
- or a_min.z > b_max.z
- or b_min.z > a_max.z
- )
-
-
-@wp.kernel
-def broad_phase_project_boxes_onto_sweep_direction_kernel(
- boxes: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1),
- box_translations: wp.array(dtype=wp.vec3, ndim=2),
- box_rotations: wp.array(dtype=wp.mat33, ndim=2),
- data_start: wp.array(dtype=wp.float32, ndim=2),
- data_end: wp.array(dtype=wp.float32, ndim=2),
- data_indexer: wp.array(dtype=wp.int32, ndim=2),
- direction: wp.vec3,
- abs_dir: wp.vec3,
- result_count: wp.array(dtype=wp.int32, ndim=1),
-):
- worldId, i = wp.tid()
-
- box = boxes[i] # box is a vector6
- box = transform_aabb(box, box_translations[worldId, i], box_rotations[worldId, i])
- box_center = box[0]
- box_size = box[1]
- center = wp.dot(direction, box_center)
- d = wp.dot(box_size, abs_dir)
- f = center - d
-
- # Store results in the data arrays
- data_start[worldId, i] = f
- data_end[worldId, i] = center + d
- data_indexer[worldId, i] = i
-
- if i == 0:
- result_count[worldId] = 0 # Initialize result count to 0
-
-
-@wp.kernel
-def reorder_bounding_boxes_kernel(
- boxes: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1),
- box_translations: wp.array(dtype=wp.vec3, ndim=2),
- box_rotations: wp.array(dtype=wp.mat33, ndim=2),
- boxes_sorted: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=2),
- data_indexer: wp.array(dtype=wp.int32, ndim=2),
-):
- worldId, i = wp.tid()
-
- # Get the index from the data indexer
- mapped = data_indexer[worldId, i]
-
- # Get the box from the original boxes array
- box = boxes[mapped]
- box = transform_aabb(
- box, box_translations[worldId, mapped], box_rotations[worldId, mapped]
- )
-
- # Reorder the box into the sorted array
- boxes_sorted[worldId, i] = box
-
-
-@wp.func
-def find_first_greater_than(
- worldId: int,
- starts: wp.array(dtype=wp.float32, ndim=2),
- value: wp.float32,
- low: int,
- high: int,
-) -> int:
- while low < high:
- mid = (low + high) >> 1
- if starts[worldId, mid] > value:
- high = mid
- else:
- low = mid + 1
- return low
-
-
-@wp.kernel
-def broad_phase_sweep_and_prune_prepare_kernel(
- num_boxes_per_world: int,
- data_start: wp.array(dtype=wp.float32, ndim=2),
- data_end: wp.array(dtype=wp.float32, ndim=2),
- indexer: wp.array(dtype=wp.int32, ndim=2),
- cumulative_sum: wp.array(dtype=wp.int32, ndim=2),
-):
- worldId, i = wp.tid() # Get the thread ID
-
- # Get the index of the current bounding box
- idx1 = indexer[worldId, i]
-
- end = data_end[worldId, idx1]
- limit = find_first_greater_than(worldId, data_start, end, i + 1, num_boxes_per_world)
- limit = wp.min(num_boxes_per_world - 1, limit)
-
- # Calculate the range of boxes for the sweep and prune process
- count = limit - i
-
- # Store the cumulative sum for the current box
- cumulative_sum[worldId, i] = count
-
-
-@wp.func
-def find_right_most_index_int(
- starts: wp.array(dtype=wp.int32, ndim=1), value: wp.int32, low: int, high: int
-) -> int:
- while low < high:
- mid = (low + high) >> 1
- if starts[mid] > value:
- high = mid
- else:
- low = mid + 1
- return high
-
-
-@wp.func
-def find_indices(
- id: int, cumulative_sum: wp.array(dtype=wp.int32, ndim=1), length: int
-) -> wp.vec2i:
- # Perform binary search to find the right most index
- i = find_right_most_index_int(cumulative_sum, id, 0, length)
-
- # Get the baseId, and compute the offset and j
- if i > 0:
- base_id = cumulative_sum[i - 1]
- else:
- base_id = 0
- offset = id - base_id
- j = i + offset + 1
-
- return wp.vec2i(i, j)
-
-
-@wp.kernel
-def broad_phase_sweep_and_prune_kernel(
- num_threads: int,
- length: int,
- num_boxes_per_world: int,
- max_num_overlaps_per_world: int,
- cumulative_sum: wp.array(dtype=wp.int32, ndim=1),
- data_indexer: wp.array(dtype=wp.int32, ndim=2),
- data_result: wp.array(dtype=wp.vec2i, ndim=2),
- result_count: wp.array(dtype=wp.int32, ndim=1),
- boxes_sorted: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=2),
-):
- threadId = wp.tid() # Get thread ID
- if length > 0:
- total_num_work_packages = cumulative_sum[length - 1]
- else:
- total_num_work_packages = 0
-
- while threadId < total_num_work_packages:
- # Get indices for current and next box pair
- ij = find_indices(threadId, cumulative_sum, length)
- i = ij.x
- j = ij.y
-
- worldId = i // num_boxes_per_world
- i = i % num_boxes_per_world
-
- # world_id_j = j // num_boxes_per_world
- j = j % num_boxes_per_world
-
- # assert worldId == world_id_j, "Only boxes in the same world can be compared"
- # TODO: Remove print if debugging is done
- # if worldId != world_id_j:
- # print("Only boxes in the same world can be compared")
-
- idx1 = data_indexer[worldId, i]
-
- box1 = boxes_sorted[worldId, i]
-
- idx2 = data_indexer[worldId, j]
-
- # Check if the boxes overlap
- if idx1 != idx2 and overlap(box1, boxes_sorted[worldId, j]):
- pair = wp.vec2i(wp.min(idx1, idx2), wp.max(idx1, idx2))
-
- id = wp.atomic_add(result_count, worldId, 1)
-
- if id < max_num_overlaps_per_world:
- data_result[worldId, id] = pair
-
- threadId += num_threads
-
-
-def broad_phase(m: Model, d: Data) -> Data:
- """Broad phase collision detection."""
-
- # Directional vectors for sweep
- # TODO: Improve picking of direction
- direction = wp.vec3(0.5935, 0.7790, 0.1235)
- direction = wp.normalize(direction)
- abs_dir = wp.vec3(abs(direction.x), abs(direction.y), abs(direction.z))
-
- wp.launch(
- kernel=broad_phase_project_boxes_onto_sweep_direction_kernel,
- dim=(d.nworld, m.ngeom),
- inputs=[
- d.geom_aabb,
- d.geom_xpos,
- d.geom_xmat,
- d.data_start,
- d.data_end,
- d.data_indexer,
- direction,
- abs_dir,
- d.result_count,
- ],
- )
-
- segmented_sort_available = hasattr(wp.utils, "segmented_sort_pairs")
-
- if segmented_sort_available:
- # print("Using segmented sort")
- wp.utils.segmented_sort_pairs(
- d.data_start,
- d.data_indexer,
- m.ngeom * d.nworld,
- d.segment_indices,
- d.nworld,
- )
- else:
- # Sort each world's segment separately
- for world_id in range(d.nworld):
- start_idx = world_id * m.ngeom
-
- # Create temporary arrays for sorting
- temp_data_start = wp.zeros(
- m.ngeom * 2,
- dtype=d.data_start.dtype,
- )
- temp_data_indexer = wp.zeros(
- m.ngeom * 2,
- dtype=d.data_indexer.dtype,
- )
-
- # Copy data to temporary arrays
- wp.copy(
- temp_data_start,
- d.data_start,
- 0,
- start_idx,
- m.ngeom,
- )
- wp.copy(
- temp_data_indexer,
- d.data_indexer,
- 0,
- start_idx,
- m.ngeom,
- )
-
- # Sort the temporary arrays
- wp.utils.radix_sort_pairs(temp_data_start, temp_data_indexer, m.ngeom)
-
- # Copy sorted data back
- wp.copy(
- d.data_start,
- temp_data_start,
- start_idx,
- 0,
- m.ngeom,
- )
- wp.copy(
- d.data_indexer,
- temp_data_indexer,
- start_idx,
- 0,
- m.ngeom,
- )
-
- wp.launch(
- kernel=reorder_bounding_boxes_kernel,
- dim=(d.nworld, m.ngeom),
- inputs=[d.geom_aabb, d.geom_xpos, d.geom_xmat, d.boxes_sorted, d.data_indexer],
- )
-
- wp.launch(
- kernel=broad_phase_sweep_and_prune_prepare_kernel,
- dim=(d.nworld, m.ngeom),
- inputs=[
- m.ngeom,
- d.data_start,
- d.data_end,
- d.data_indexer,
- d.ranges,
- ],
- )
-
- # The scan (scan = cumulative sum, either inclusive or exclusive depending on the last argument) is used for load balancing among the threads
- wp.utils.array_scan(d.ranges.reshape(-1), d.cumulative_sum, True)
-
- # Estimate how many overlap checks need to be done - assumes each box has to be compared to 5 other boxes (and batched over all worlds)
- num_sweep_threads = 5 * d.nworld * m.ngeom
- wp.launch(
- kernel=broad_phase_sweep_and_prune_kernel,
- dim=num_sweep_threads,
- inputs=[
- num_sweep_threads,
- d.nworld * m.ngeom,
- m.ngeom,
- d.max_num_overlaps_per_world,
- d.cumulative_sum,
- d.data_indexer,
- d.broadphase_pairs,
- d.result_count,
- d.boxes_sorted,
- ],
- )
-
- return d
diff --git a/mujoco/mjx/_src/forward.py b/mujoco/mjx/_src/forward.py
deleted file mode 100644
index 875b7f06..00000000
--- a/mujoco/mjx/_src/forward.py
+++ /dev/null
@@ -1,366 +0,0 @@
-# Copyright 2025 The Physics-Next Project Developers
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from typing import Optional
-
-import warp as wp
-import mujoco
-
-from . import constraint
-from . import math
-from . import passive
-from . import smooth
-from . import solver
-
-from .types import array2df, array3df
-from .types import Model
-from .types import Data
-from .types import MJ_MINVAL
-from .types import DisableBit
-from .types import JointType
-from .types import DynType
-from .support import xfrc_accumulate
-
-
-def _advance(
- m: Model,
- d: Data,
- act_dot: wp.array,
- qacc: wp.array,
- qvel: Optional[wp.array] = None,
-) -> Data:
- """Advance state and time given activation derivatives and acceleration."""
-
- # TODO(team): can we assume static timesteps?
-
- @wp.kernel
- def next_activation(
- m: Model,
- d: Data,
- act_dot_in: array2df,
- ):
- worldId, actid = wp.tid()
-
- # get the high/low range for each actuator state
- limited = m.actuator_actlimited[actid]
- range_low = wp.select(limited, -wp.inf, m.actuator_actrange[actid][0])
- range_high = wp.select(limited, wp.inf, m.actuator_actrange[actid][1])
-
- # get the actual actuation - skip if -1 (means stateless actuator)
- act_adr = m.actuator_actadr[actid]
- if act_adr == -1:
- return
-
- acts = d.act[worldId]
- acts_dot = act_dot_in[worldId]
-
- act = acts[act_adr]
- act_dot = acts_dot[act_adr]
-
- # check dynType
- dyn_type = m.actuator_dyntype[actid]
- dyn_prm = m.actuator_dynprm[actid][0]
-
- # advance the actuation
- if dyn_type == wp.static(DynType.FILTEREXACT.value):
- tau = wp.select(dyn_prm < MJ_MINVAL, dyn_prm, MJ_MINVAL)
- act = act + act_dot * tau * (1.0 - wp.exp(-m.opt.timestep / tau))
- else:
- act = act + act_dot * m.opt.timestep
-
- # apply limits
- wp.clamp(act, range_low, range_high)
-
- acts[act_adr] = act
-
- @wp.kernel
- def advance_velocities(m: Model, d: Data, qacc: array2df):
- worldId, tid = wp.tid()
- d.qvel[worldId, tid] = d.qvel[worldId, tid] + qacc[worldId, tid] * m.opt.timestep
-
- @wp.kernel
- def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df):
- worldId, jntid = wp.tid()
-
- jnt_type = m.jnt_type[jntid]
- qpos_adr = m.jnt_qposadr[jntid]
- dof_adr = m.jnt_dofadr[jntid]
- qpos = d.qpos[worldId]
- qvel = qvel_in[worldId]
-
- if jnt_type == wp.static(JointType.FREE.value):
- qpos_pos = wp.vec3(qpos[qpos_adr], qpos[qpos_adr + 1], qpos[qpos_adr + 2])
- qvel_lin = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2])
-
- qpos_new = qpos_pos + m.opt.timestep * qvel_lin
-
- qpos_quat = wp.quat(
- qpos[qpos_adr + 3],
- qpos[qpos_adr + 4],
- qpos[qpos_adr + 5],
- qpos[qpos_adr + 6],
- )
- qvel_ang = wp.vec3(qvel[dof_adr + 3], qvel[dof_adr + 4], qvel[dof_adr + 5])
-
- qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep)
-
- qpos[qpos_adr] = qpos_new[0]
- qpos[qpos_adr + 1] = qpos_new[1]
- qpos[qpos_adr + 2] = qpos_new[2]
- qpos[qpos_adr + 3] = qpos_quat_new[0]
- qpos[qpos_adr + 4] = qpos_quat_new[1]
- qpos[qpos_adr + 5] = qpos_quat_new[2]
- qpos[qpos_adr + 6] = qpos_quat_new[3]
-
- elif jnt_type == wp.static(JointType.BALL.value): # ball joint
- qpos_quat = wp.quat(
- qpos[qpos_adr],
- qpos[qpos_adr + 1],
- qpos[qpos_adr + 2],
- qpos[qpos_adr + 3],
- )
- qvel_ang = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2])
-
- qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep)
-
- qpos[qpos_adr] = qpos_quat_new[0]
- qpos[qpos_adr + 1] = qpos_quat_new[1]
- qpos[qpos_adr + 2] = qpos_quat_new[2]
- qpos[qpos_adr + 3] = qpos_quat_new[3]
-
- else: # if jnt_type in (JointType.HINGE, JointType.SLIDE):
- qpos[qpos_adr] = qpos[qpos_adr] + m.opt.timestep * qvel[dof_adr]
-
- # skip if no stateful actuators.
- if m.na:
- wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot])
-
- wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc])
-
- # advance positions with qvel if given, d.qvel otherwise (semi-implicit)
- if qvel is not None:
- qvel_in = qvel
- else:
- qvel_in = d.qvel
-
- wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in])
-
- d.time = d.time + m.opt.timestep
- return d
-
-
-def euler(m: Model, d: Data) -> Data:
- """Euler integrator, semi-implicit in velocity."""
- # integrate damping implicitly
-
- def add_damping_sum_qfrc(m: Model, d: Data, is_sparse: bool):
- @wp.kernel
- def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data):
- worldId, tid = wp.tid()
-
- dof_Madr = m.dof_Madr[tid]
- d.qM_integration[worldId, 0, dof_Madr] += m.opt.timestep * m.dof_damping[dof_Madr]
-
- d.qfrc_integration[worldId, tid] = (
- d.qfrc_smooth[worldId, tid] + d.qfrc_constraint[worldId, tid]
- )
-
- @wp.kernel
- def add_damping_sum_qfrc_kernel_dense(m: Model, d: Data):
- worldid, i, j = wp.tid()
-
- damping = wp.select(i == j, 0.0, m.opt.timestep * m.dof_damping[i])
- d.qM_integration[worldid, i, j] = d.qM[worldid, i, j] + damping
-
- if i == 0:
- d.qfrc_integration[worldid, j] = (
- d.qfrc_smooth[worldid, j] + d.qfrc_constraint[worldid, j]
- )
-
- if is_sparse:
- wp.copy(d.qM_integration, d.qM)
- wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d])
- else:
- wp.launch(
- add_damping_sum_qfrc_kernel_dense, dim=(d.nworld, m.nv, m.nv), inputs=[m, d]
- )
-
- if not m.opt.disableflags & DisableBit.EULERDAMP.value:
- add_damping_sum_qfrc(m, d, m.opt.is_sparse)
- smooth.factor_i(m, d, d.qM_integration, d.qLD_integration, d.qLDiagInv_integration)
- smooth.solve_LD(
- m,
- d,
- d.qLD_integration,
- d.qLDiagInv_integration,
- d.qacc_integration,
- d.qfrc_integration,
- )
- return _advance(m, d, d.act_dot, d.qacc_integration)
-
- return _advance(m, d, d.act_dot, d.qacc)
-
-
-def fwd_position(m: Model, d: Data):
- """Position-dependent computations."""
-
- smooth.kinematics(m, d)
- smooth.com_pos(m, d)
- # TODO(team): smooth.camlight
- # TODO(team): smooth.tendon
- smooth.crb(m, d)
- smooth.factor_m(m, d)
- # TODO(team): collision_driver.collision
- constraint.make_constraint(m, d)
- smooth.transmission(m, d)
-
-
-def fwd_velocity(m: Model, d: Data):
- """Velocity-dependent computations."""
-
- # TODO(team): tile operations?
- d.actuator_velocity.zero_()
-
- @wp.kernel
- def _actuator_velocity(d: Data):
- worldid, actid, dofid = wp.tid()
- moment = d.actuator_moment[worldid, actid]
- qvel = d.qvel[worldid]
- wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid])
-
- wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d])
-
- smooth.com_vel(m, d)
- passive.passive(m, d)
- smooth.rne(m, d)
-
-
-def fwd_actuation(m: Model, d: Data):
- """Actuation-dependent computations."""
- if not m.nu:
- return
-
- # TODO support stateful actuators
-
- @wp.kernel
- def _force(
- m: Model,
- ctrl: array2df,
- # outputs
- force: array2df,
- ):
- worldid, dofid = wp.tid()
- gain = m.actuator_gainprm[dofid, 0]
- bias = m.actuator_biasprm[dofid, 0]
- # TODO support gain types other than FIXED
- c = ctrl[worldid, dofid]
- if m.actuator_ctrllimited[dofid]:
- r = m.actuator_ctrlrange[dofid]
- c = wp.clamp(c, r[0], r[1])
- f = gain * c + bias
- if m.actuator_forcelimited[dofid]:
- r = m.actuator_forcerange[dofid]
- f = wp.clamp(f, r[0], r[1])
- force[worldid, dofid] = f
-
- wp.launch(
- _force, dim=[d.nworld, m.nu], inputs=[m, d.ctrl], outputs=[d.actuator_force]
- )
-
- @wp.kernel
- def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df):
- worldid, vid = wp.tid()
-
- s = float(0.0)
- for uid in range(m.nu):
- # TODO consider using Tile API or transpose moment for better access pattern
- s += moment[worldid, uid, vid] * force[worldid, uid]
- jntid = m.dof_jntid[vid]
- if m.jnt_actfrclimited[jntid]:
- r = m.jnt_actfrcrange[jntid]
- s = wp.clamp(s, r[0], r[1])
- qfrc[worldid, vid] = s
-
- wp.launch(
- _qfrc,
- dim=(d.nworld, m.nv),
- inputs=[m, d.actuator_moment, d.actuator_force],
- outputs=[d.qfrc_actuator],
- )
-
- # TODO actuator-level gravity compensation, skip if added as passive force
-
- return d
-
-
-def fwd_acceleration(m: Model, d: Data):
- """Add up all non-constraint forces, compute qacc_smooth."""
-
- qfrc_applied = d.qfrc_applied
- qfrc_accumulated = xfrc_accumulate(m, d)
-
- @wp.kernel
- def _qfrc_smooth(
- d: Data,
- qfrc_applied: wp.array(ndim=2, dtype=wp.float32),
- qfrc_accumulated: wp.array(ndim=2, dtype=wp.float32),
- ):
- worldid, dofid = wp.tid()
- d.qfrc_smooth[worldid, dofid] = (
- d.qfrc_passive[worldid, dofid]
- - d.qfrc_bias[worldid, dofid]
- + d.qfrc_actuator[worldid, dofid]
- + qfrc_applied[worldid, dofid]
- + qfrc_accumulated[worldid, dofid]
- )
-
- wp.launch(
- _qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d, qfrc_applied, qfrc_accumulated]
- )
-
- smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth)
-
-
-def forward(m: Model, d: Data):
- """Forward dynamics."""
-
- fwd_position(m, d)
- # TODO(team): sensor.sensor_pos
- fwd_velocity(m, d)
- # TODO(team): sensor.sensor_vel
- fwd_actuation(m, d)
- fwd_acceleration(m, d)
- # TODO(team): sensor.sensor_acc
-
- if d.njmax == 0:
- wp.copy(d.qacc, d.qacc_smooth)
- else:
- solver.solve(m, d)
-
-
-def step(m: Model, d: Data):
- """Advance simulation."""
- forward(m, d)
-
- if m.opt.integrator == mujoco.mjtIntegrator.mjINT_EULER:
- euler(m, d)
- elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_RK4:
- # TODO(team): rungekutta4
- raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.")
- elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST:
- # TODO(team): implicit
- raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.")
- else:
- raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.")
diff --git a/mujoco/mjx/_src/solver.py b/mujoco/mjx/_src/solver.py
deleted file mode 100644
index ee0eea28..00000000
--- a/mujoco/mjx/_src/solver.py
+++ /dev/null
@@ -1,749 +0,0 @@
-import warp as wp
-import mujoco
-from . import smooth
-from . import support
-from . import types
-
-
-@wp.struct
-class Context:
- Jaref: wp.array(dtype=wp.float32, ndim=1)
- Ma: wp.array(dtype=wp.float32, ndim=2)
- grad: wp.array(dtype=wp.float32, ndim=2)
- grad_dot: wp.array(dtype=wp.float32, ndim=1)
- Mgrad: wp.array(dtype=wp.float32, ndim=2)
- search: wp.array(dtype=wp.float32, ndim=2)
- search_dot: wp.array(dtype=wp.float32, ndim=1)
- gauss: wp.array(dtype=wp.float32, ndim=1)
- cost: wp.array(dtype=wp.float32, ndim=1)
- prev_cost: wp.array(dtype=wp.float32, ndim=1)
- solver_niter: wp.array(dtype=wp.int32, ndim=1)
- active: wp.array(dtype=wp.int32, ndim=1)
- gtol: wp.array(dtype=wp.float32, ndim=1)
- mv: wp.array(dtype=wp.float32, ndim=2)
- jv: wp.array(dtype=wp.float32, ndim=1)
- quad: wp.array(dtype=wp.vec3f, ndim=1)
- quad_gauss: wp.array(dtype=wp.vec3f, ndim=1)
- quad_total: wp.array(dtype=wp.vec3f, ndim=1)
- h: wp.array(dtype=wp.float32, ndim=3)
- alpha: wp.array(dtype=wp.float32, ndim=1)
- prev_grad: wp.array(dtype=wp.float32, ndim=2)
- prev_Mgrad: wp.array(dtype=wp.float32, ndim=2)
- beta: wp.array(dtype=wp.float32, ndim=1)
- beta_num: wp.array(dtype=wp.float32, ndim=1)
- beta_den: wp.array(dtype=wp.float32, ndim=1)
- done: wp.array(dtype=wp.int32, ndim=1)
-
-
-def _context(m: types.Model, d: types.Data) -> Context:
- ctx = Context()
- ctx.Jaref = wp.empty(shape=(d.njmax,), dtype=wp.float32)
- ctx.Ma = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32)
- ctx.grad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32)
- ctx.grad_dot = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.Mgrad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32)
- ctx.search = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32)
- ctx.search_dot = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.gauss = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.cost = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.prev_cost = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.solver_niter = wp.empty(shape=(d.nworld,), dtype=wp.int32)
- ctx.active = wp.empty(shape=(d.njmax,), dtype=wp.int32)
- ctx.gtol = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.mv = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32)
- ctx.jv = wp.empty(shape=(d.njmax,), dtype=wp.float32)
- ctx.quad = wp.empty(shape=(d.njmax,), dtype=wp.vec3f)
- ctx.quad_gauss = wp.empty(shape=(d.nworld,), dtype=wp.vec3f)
- ctx.quad_total = wp.empty(shape=(d.nworld,), dtype=wp.vec3f)
- ctx.h = wp.empty(shape=(d.nworld, m.nv, m.nv), dtype=wp.float32)
- ctx.alpha = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.prev_grad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32)
- ctx.prev_Mgrad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32)
- ctx.beta = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.beta_num = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.beta_den = wp.empty(shape=(d.nworld,), dtype=wp.float32)
- ctx.done = wp.empty(shape=(d.nworld,), dtype=wp.int32)
-
- return ctx
-
-
-def _create_context(ctx: Context, m: types.Model, d: types.Data, grad: bool = True):
- # jaref = d.efc_J @ d.qacc - d.efc_aref
- ctx.Jaref.zero_()
-
- @wp.kernel
- def _jaref(ctx: Context, m: types.Model, d: types.Data):
- efcid, dofid = wp.tid()
-
- if efcid >= d.nefc_total[0]:
- return
-
- worldid = d.efc_worldid[efcid]
- wp.atomic_add(
- ctx.Jaref,
- efcid,
- d.efc_J[efcid, dofid] * d.qacc[worldid, dofid] - d.efc_aref[efcid] / float(m.nv),
- )
-
- wp.launch(_jaref, dim=(d.njmax, m.nv), inputs=[ctx, m, d])
-
- # Ma = qM @ qacc
- support.mul_m(m, d, ctx.Ma, d.qacc)
-
- ctx.cost.fill_(wp.inf)
- ctx.solver_niter.zero_()
- ctx.done.zero_()
-
- _update_constraint(m, d, ctx)
- if grad:
- _update_gradient(m, d, ctx)
-
- # search = -Mgrad
- ctx.search_dot.zero_()
-
- @wp.kernel
- def _search(ctx: Context):
- worldid, dofid = wp.tid()
- search = -1.0 * ctx.Mgrad[worldid, dofid]
- ctx.search[worldid, dofid] = search
- wp.atomic_add(ctx.search_dot, worldid, search * search)
-
- wp.launch(_search, dim=(d.nworld, m.nv), inputs=[ctx])
-
-
-@wp.struct
-class LSPoint:
- alpha: wp.array(dtype=wp.float32, ndim=1)
- cost: wp.array(dtype=wp.float32, ndim=1)
- deriv_0: wp.array(dtype=wp.float32, ndim=1)
- deriv_1: wp.array(dtype=wp.float32, ndim=1)
-
-
-def _lspoint(d: types.Data) -> LSPoint:
- ls_pnt = LSPoint()
- ls_pnt.alpha = wp.empty(shape=(d.nworld), dtype=wp.float32)
- ls_pnt.cost = wp.empty(shape=(d.nworld), dtype=wp.float32)
- ls_pnt.deriv_0 = wp.empty(shape=(d.nworld), dtype=wp.float32)
- ls_pnt.deriv_1 = wp.empty(shape=(d.nworld), dtype=wp.float32)
-
- return ls_pnt
-
-
-def _create_lspoint(ls_pnt: LSPoint, m: types.Model, d: types.Data, ctx: Context):
- wp.copy(ctx.quad_total, ctx.quad_gauss)
-
- @wp.kernel
- def _quad(ls_pnt: LSPoint, ctx: Context, d: types.Data):
- efcid = wp.tid()
-
- if efcid >= d.nefc_total[0]:
- return
-
- worldid = d.efc_worldid[efcid]
- x = ctx.Jaref[efcid] + ls_pnt.alpha[worldid] * ctx.jv[efcid]
- # TODO(team): active and conditionally active constraints
- if x < 0.0:
- wp.atomic_add(ctx.quad_total, worldid, ctx.quad[efcid])
-
- wp.launch(_quad, dim=(d.njmax,), inputs=[ls_pnt, ctx, d])
-
- @wp.kernel
- def _cost_deriv01(ls_pnt: LSPoint, ctx: Context):
- worldid = wp.tid()
- alpha = ls_pnt.alpha[worldid]
- alpha_sq = alpha * alpha
- quad_total0 = ctx.quad_total[worldid][0]
- quad_total1 = ctx.quad_total[worldid][1]
- quad_total2 = ctx.quad_total[worldid][2]
-
- ls_pnt.cost[worldid] = alpha_sq * quad_total2 + alpha * quad_total1 + quad_total0
- ls_pnt.deriv_0[worldid] = 2.0 * alpha * quad_total2 + quad_total1
- ls_pnt.deriv_1[worldid] = 2.0 * quad_total2 + float(quad_total2 == 0.0)
-
- wp.launch(_cost_deriv01, dim=(d.nworld,), inputs=[ls_pnt, ctx])
-
-
-@wp.struct
-class LSContext:
- p0: LSPoint
- lo: LSPoint
- lo_next: LSPoint
- hi: LSPoint
- hi_next: LSPoint
- mid: LSPoint
- swap: wp.array(ndim=1, dtype=wp.int32)
- ls_iter: wp.array(ndim=1, dtype=wp.int32)
- done: wp.array(ndim=1, dtype=wp.int32)
-
-
-def _create_lscontext(m: types.Model, d: types.Data, ctx: Context) -> LSContext:
- ls_ctx = LSContext()
-
- ls_ctx.p0 = _lspoint(d)
- ls_ctx.lo = _lspoint(d)
- ls_ctx.lo_next = _lspoint(d)
- ls_ctx.hi = _lspoint(d)
- ls_ctx.hi_next = _lspoint(d)
- ls_ctx.mid = _lspoint(d)
-
- ls_ctx.swap = wp.empty(shape=(d.nworld), dtype=wp.int32)
- ls_ctx.ls_iter = wp.empty(shape=(d.nworld), dtype=wp.int32)
- ls_ctx.done = wp.zeros((d.nworld), dtype=wp.int32)
-
- return ls_ctx
-
-
-def _update_constraint(m: types.Model, d: types.Data, ctx: Context):
- wp.copy(ctx.prev_cost, ctx.cost)
- ctx.cost.zero_()
-
- @wp.kernel
- def _efc_kernel(ctx: Context, d: types.Data):
- efcid = wp.tid()
-
- if efcid >= d.nefc_total[0]:
- return
-
- worldid = d.efc_worldid[efcid]
- Jaref = ctx.Jaref[efcid]
- efc_D = d.efc_D[efcid]
-
- # TODO(team): active and conditionally active constraints
- active = int(Jaref < 0.0)
- ctx.active[efcid] = active
-
- # efc_force = -efc_D * Jaref * active
- d.efc_force[efcid] = -1.0 * efc_D * Jaref * float(active)
-
- # cost = 0.5 * sum(efc_D * Jaref * Jaref * active))
- wp.atomic_add(ctx.cost, worldid, 0.5 * efc_D * Jaref * Jaref * float(active))
-
- wp.launch(_efc_kernel, dim=(d.njmax,), inputs=[ctx, d])
-
- # qfrc_constraint = efc_J.T @ efc_force
- d.qfrc_constraint.zero_()
-
- @wp.kernel
- def _qfrc_constraint(d: types.Data):
- dofid, efcid = wp.tid()
-
- if efcid >= d.nefc_total[0]:
- return
-
- worldid = d.efc_worldid[efcid]
- wp.atomic_add(
- d.qfrc_constraint[worldid],
- dofid,
- d.efc_J[efcid, dofid] * d.efc_force[efcid],
- )
-
- wp.launch(_qfrc_constraint, dim=(m.nv, d.njmax), inputs=[d])
-
- # gauss = 0.5 * (Ma - qfrc_smooth).T @ (qacc - qacc_smooth)
- ctx.gauss.zero_()
-
- @wp.kernel
- def _gauss(ctx: Context, d: types.Data):
- worldid, dofid = wp.tid()
- gauss_cost = (
- 0.5
- * (ctx.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid])
- * (d.qacc[worldid, dofid] - d.qacc_smooth[worldid, dofid])
- )
- wp.atomic_add(ctx.gauss, worldid, gauss_cost)
- wp.atomic_add(ctx.cost, worldid, gauss_cost)
-
- wp.launch(_gauss, dim=(d.nworld, m.nv), inputs=[ctx, d])
-
-
-def _update_gradient(m: types.Model, d: types.Data, ctx: Context):
- # grad = Ma - qfrc_smooth - qfrc_constraint
- ctx.grad_dot.zero_()
-
- @wp.kernel
- def _grad(ctx: Context, d: types.Data):
- worldid, dofid = wp.tid()
- grad = (
- ctx.Ma[worldid, dofid]
- - d.qfrc_smooth[worldid, dofid]
- - d.qfrc_constraint[worldid, dofid]
- )
- ctx.grad[worldid, dofid] = grad
- wp.atomic_add(ctx.grad_dot, worldid, grad * grad)
-
- wp.launch(_grad, dim=(d.nworld, m.nv), inputs=[ctx, d])
-
- if m.opt.solver == 1: # CG
- smooth.solve_m(m, d, ctx.grad, ctx.Mgrad)
- elif m.opt.solver == 2: # Newton
- # TODO(team): sparse version
- # h = qM + (efc_J.T * efc_D * active) @ efc_J
- @wp.kernel
- def _copy_lower_triangle(m: types.Model, d: types.Data, ctx: Context):
- worldid, elementid = wp.tid()
- rowid = m.dof_tri_row[elementid]
- colid = m.dof_tri_col[elementid]
- ctx.h[worldid, rowid, colid] = d.qM[worldid, rowid, colid]
-
- wp.launch(
- _copy_lower_triangle, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d, ctx]
- )
-
- @wp.kernel
- def _JTDAJ(ctx: Context, m: types.Model, d: types.Data):
- efcid, elementid = wp.tid()
- dofi = m.dof_tri_row[elementid]
- dofj = m.dof_tri_col[elementid]
-
- if efcid >= d.nefc_total[0]:
- return
-
- efc_D = d.efc_D[efcid]
- active = ctx.active[efcid]
- if efc_D == 0.0 or active == 0:
- return
-
- worldid = d.efc_worldid[efcid]
- wp.atomic_add(
- ctx.h[worldid, dofi],
- dofj,
- d.efc_J[efcid, dofi] * d.efc_J[efcid, dofj] * efc_D * float(active),
- )
-
- wp.launch(_JTDAJ, dim=(d.njmax, m.dof_tri_row.size), inputs=[ctx, m, d])
-
- TILE = m.nv
-
- @wp.kernel
- def _cholesky(ctx: Context):
- worldid = wp.tid()
- mat_tile = wp.tile_load(ctx.h[worldid], shape=(TILE, TILE))
- fact_tile = wp.tile_cholesky(mat_tile)
- input_tile = wp.tile_load(ctx.grad[worldid], shape=TILE)
- output_tile = wp.tile_cholesky_solve(fact_tile, input_tile)
- wp.tile_store(ctx.Mgrad[worldid], output_tile)
-
- wp.launch_tiled(_cholesky, dim=(d.nworld,), inputs=[ctx], block_dim=32)
-
-
-@wp.func
-def _rescale(m: types.Model, value: float) -> float:
- return value / (m.stat.meaninertia * float(wp.max(1, m.nv)))
-
-
-@wp.func
-def _in_bracket(x: float, y: float) -> bool:
- return (x < y) and (y < 0.0) or (x > y) and (y > 0.0)
-
-
-def _linesearch(m: types.Model, d: types.Data, ctx: Context):
- @wp.kernel
- def _gtol(ctx: Context, m: types.Model):
- worldid = wp.tid()
- smag = (
- wp.math.sqrt(ctx.search_dot[worldid])
- * m.stat.meaninertia
- * float(wp.max(1, m.nv))
- )
- ctx.gtol[worldid] = m.opt.tolerance * m.opt.ls_tolerance * smag
-
- wp.launch(_gtol, dim=(d.nworld,), inputs=[ctx, m])
-
- # mv = qM @ search
- support.mul_m(m, d, ctx.mv, ctx.search)
-
- # jv = efc_J @ search
- ctx.jv.zero_()
-
- @wp.kernel
- def _jv(ctx: Context, d: types.Data):
- efcid, dofid = wp.tid()
-
- if efcid >= d.nefc_total[0]:
- return
-
- worldid = d.efc_worldid[efcid]
- wp.atomic_add(
- ctx.jv,
- efcid,
- d.efc_J[efcid, dofid] * ctx.search[worldid, dofid],
- )
-
- wp.launch(_jv, dim=(d.njmax, m.nv), inputs=[ctx, d])
-
- # prepare quadratics
- # quad_gauss = [gauss, search.T @ Ma - search.T @ qfrc_smooth, 0.5 * search.T @ mv]
- ctx.quad_gauss.zero_()
-
- @wp.kernel
- def _quad_gauss(ctx: Context, m: types.Model, d: types.Data):
- worldid, dofid = wp.tid()
- search = ctx.search[worldid, dofid]
- quad_gauss = wp.vec3(
- ctx.gauss[worldid] / float(m.nv),
- search * (ctx.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid]),
- 0.5 * search * ctx.mv[worldid, dofid],
- )
- wp.atomic_add(ctx.quad_gauss, worldid, quad_gauss)
-
- wp.launch(_quad_gauss, dim=(d.nworld, m.nv), inputs=[ctx, m, d])
-
- # quad = [0.5 * Jaref * Jaref * efc_D, jv * Jaref * efc_D, 0.5 * jv * jv * efc_D]
- @wp.kernel
- def _quad(ctx: Context, d: types.Data):
- efcid = wp.tid()
-
- if efcid >= d.nefc_total[0]:
- return
-
- Jaref = ctx.Jaref[efcid]
- jv = ctx.jv[efcid]
- efc_D = d.efc_D[efcid]
- ctx.quad[efcid][0] = 0.5 * Jaref * Jaref * efc_D
- ctx.quad[efcid][1] = jv * Jaref * efc_D
- ctx.quad[efcid][2] = 0.5 * jv * jv * efc_D
-
- wp.launch(_quad, dim=(d.njmax), inputs=[ctx, d])
-
- # initialize interval
- ls_ctx = _create_lscontext(m, d, ctx)
-
- ls_ctx.p0.alpha.zero_()
- _create_lspoint(ls_ctx.p0, m, d, ctx)
-
- @wp.kernel
- def _lo_alpha(lo: LSPoint, p0: LSPoint, ctx: Context):
- worldid = wp.tid()
- lo.alpha[worldid] = p0.alpha[worldid] - p0.deriv_0[worldid] / p0.deriv_1[worldid]
-
- wp.launch(_lo_alpha, dim=(d.nworld,), inputs=[ls_ctx.lo, ls_ctx.p0, ctx])
-
- _create_lspoint(ls_ctx.lo, m, d, ctx)
-
- @wp.kernel
- def _tree_map(ls_ctx: LSContext):
- worldid = wp.tid()
-
- lesser = float(ls_ctx.lo.deriv_0[worldid] < ls_ctx.p0.deriv_0[worldid])
- not_lesser = 1.0 - lesser
-
- ls_ctx.hi.alpha[worldid] = (
- lesser * ls_ctx.p0.alpha[worldid] + not_lesser * ls_ctx.lo.alpha[worldid]
- )
- ls_ctx.hi.cost[worldid] = (
- lesser * ls_ctx.p0.cost[worldid] + not_lesser * ls_ctx.lo.cost[worldid]
- )
- ls_ctx.hi.deriv_0[worldid] = (
- lesser * ls_ctx.p0.deriv_0[worldid] + not_lesser * ls_ctx.lo.deriv_0[worldid]
- )
- ls_ctx.hi.deriv_1[worldid] = (
- lesser * ls_ctx.p0.deriv_1[worldid] + not_lesser * ls_ctx.lo.deriv_1[worldid]
- )
-
- ls_ctx.lo.alpha[worldid] = (
- lesser * ls_ctx.lo.alpha[worldid] + not_lesser * ls_ctx.p0.alpha[worldid]
- )
- ls_ctx.lo.cost[worldid] = (
- lesser * ls_ctx.lo.cost[worldid] + not_lesser * ls_ctx.p0.cost[worldid]
- )
- ls_ctx.lo.deriv_0[worldid] = (
- lesser * ls_ctx.lo.deriv_0[worldid] + not_lesser * ls_ctx.p0.deriv_0[worldid]
- )
- ls_ctx.lo.deriv_1[worldid] = (
- lesser * ls_ctx.lo.deriv_1[worldid] + not_lesser * ls_ctx.p0.deriv_1[worldid]
- )
-
- wp.launch(_tree_map, dim=(d.nworld,), inputs=[ls_ctx])
-
- ls_ctx.swap.fill_(1)
- ls_ctx.ls_iter.fill_(0)
-
- for i in range(m.opt.ls_iterations):
-
- @wp.kernel
- def _alpha_lo_next_hi_next_mid(ls_ctx: LSContext):
- worldid = wp.tid()
- ls_ctx.lo_next.alpha[worldid] = (
- ls_ctx.lo.alpha[worldid]
- - ls_ctx.lo.deriv_0[worldid] / ls_ctx.lo.deriv_1[worldid]
- )
- ls_ctx.hi_next.alpha[worldid] = (
- ls_ctx.hi.alpha[worldid]
- - ls_ctx.hi.deriv_0[worldid] / ls_ctx.hi.deriv_1[worldid]
- )
- ls_ctx.mid.alpha[worldid] = 0.5 * (
- ls_ctx.lo.alpha[worldid] + ls_ctx.hi.alpha[worldid]
- )
-
- wp.launch(_alpha_lo_next_hi_next_mid, dim=(d.nworld,), inputs=[ls_ctx])
-
- _create_lspoint(ls_ctx.lo_next, m, d, ctx)
- _create_lspoint(ls_ctx.hi_next, m, d, ctx)
- _create_lspoint(ls_ctx.mid, m, d, ctx)
-
- @wp.kernel
- def _swap_lo_hi(ls_ctx: LSContext):
- worldid = wp.tid()
-
- ls_ctx.ls_iter[worldid] += 1
-
- lo_alpha = ls_ctx.lo.alpha[worldid]
- lo_cost = ls_ctx.lo.cost[worldid]
- lo_deriv_0 = ls_ctx.lo.deriv_0[worldid]
- lo_deriv_1 = ls_ctx.lo.deriv_1[worldid]
-
- lo_next_alpha = ls_ctx.lo_next.alpha[worldid]
- lo_next_cost = ls_ctx.lo_next.cost[worldid]
- lo_next_deriv_0 = ls_ctx.lo_next.deriv_0[worldid]
- lo_next_deriv_1 = ls_ctx.lo_next.deriv_1[worldid]
-
- hi_alpha = ls_ctx.hi.alpha[worldid]
- hi_cost = ls_ctx.hi.cost[worldid]
- hi_deriv_0 = ls_ctx.hi.deriv_0[worldid]
- hi_deriv_1 = ls_ctx.hi.deriv_1[worldid]
-
- hi_next_alpha = ls_ctx.hi_next.alpha[worldid]
- hi_next_cost = ls_ctx.hi_next.cost[worldid]
- hi_next_deriv_0 = ls_ctx.hi_next.deriv_0[worldid]
- hi_next_deriv_1 = ls_ctx.hi_next.deriv_1[worldid]
-
- mid_alpha = ls_ctx.mid.alpha[worldid]
- mid_cost = ls_ctx.mid.cost[worldid]
- mid_deriv_0 = ls_ctx.mid.deriv_0[worldid]
- mid_deriv_1 = ls_ctx.mid.deriv_1[worldid]
-
- swap_lo_next = _in_bracket(lo_deriv_0, lo_next_deriv_0)
- lo_alpha = (
- float(swap_lo_next) * lo_next_alpha + (1.0 - float(swap_lo_next)) * lo_alpha
- )
- lo_cost = (
- float(swap_lo_next) * lo_next_cost + (1.0 - float(swap_lo_next)) * lo_cost
- )
- lo_deriv_0 = (
- float(swap_lo_next) * lo_next_deriv_0 + (1.0 - float(swap_lo_next)) * lo_deriv_0
- )
- lo_deriv_1 = (
- float(swap_lo_next) * lo_next_deriv_1 + (1.0 - float(swap_lo_next)) * lo_deriv_1
- )
-
- swap_lo_mid = _in_bracket(lo_deriv_0, mid_deriv_0)
- lo_alpha = float(swap_lo_mid) * mid_alpha + (1.0 - float(swap_lo_mid)) * lo_alpha
- lo_cost = float(swap_lo_mid) * mid_cost + (1.0 - float(swap_lo_mid)) * lo_cost
- lo_deriv_0 = (
- float(swap_lo_mid) * mid_deriv_0 + (1.0 - float(swap_lo_mid)) * lo_deriv_0
- )
- lo_deriv_1 = (
- float(swap_lo_mid) * mid_deriv_1 + (1.0 - float(swap_lo_mid)) * lo_deriv_1
- )
-
- swap_lo_hi_next = _in_bracket(lo_deriv_0, hi_next_deriv_0)
- lo_alpha = (
- float(swap_lo_hi_next) * hi_next_alpha
- + (1.0 - float(swap_lo_hi_next)) * lo_alpha
- )
- lo_cost = (
- float(swap_lo_hi_next) * hi_next_cost + (1.0 - float(swap_lo_hi_next)) * lo_cost
- )
- lo_deriv_0 = (
- float(swap_lo_hi_next) * hi_next_deriv_0
- + (1.0 - float(swap_lo_hi_next)) * lo_deriv_0
- )
- lo_deriv_1 = (
- float(swap_lo_hi_next) * hi_next_deriv_1
- + (1.0 - float(swap_lo_hi_next)) * lo_deriv_1
- )
-
- swap_hi_next = _in_bracket(hi_deriv_0, hi_next_deriv_0)
- hi_alpha = (
- float(swap_hi_next) * hi_next_alpha + (1.0 - float(swap_hi_next)) * hi_alpha
- )
- hi_cost = (
- float(swap_hi_next) * hi_next_cost + (1.0 - float(swap_hi_next)) * hi_cost
- )
- hi_deriv_0 = (
- float(swap_hi_next) * hi_next_deriv_0 + (1.0 - float(swap_hi_next)) * hi_deriv_0
- )
- hi_deriv_1 = (
- float(swap_hi_next) * hi_next_deriv_1 + (1.0 - float(swap_hi_next)) * hi_deriv_1
- )
-
- swap_hi_mid = _in_bracket(hi_deriv_0, mid_deriv_0)
- hi_alpha = float(swap_hi_mid) * mid_alpha + (1.0 - float(swap_hi_mid)) * hi_alpha
- hi_cost = float(swap_hi_mid) * mid_cost + (1.0 - float(swap_hi_mid)) * hi_cost
- hi_deriv_0 = (
- float(swap_hi_mid) * mid_deriv_0 + (1.0 - float(swap_hi_mid)) * hi_deriv_0
- )
- hi_deriv_1 = (
- float(swap_hi_mid) * mid_deriv_1 + (1.0 - float(swap_hi_mid)) * hi_deriv_1
- )
-
- swap_hi_lo_next = _in_bracket(hi_deriv_0, lo_next_deriv_0)
- hi_alpha = (
- float(swap_hi_lo_next) * lo_next_alpha
- + (1.0 - float(swap_hi_lo_next)) * hi_alpha
- )
- hi_cost = (
- float(swap_hi_lo_next) * lo_next_cost + (1.0 - float(swap_hi_lo_next)) * hi_cost
- )
- hi_deriv_0 = (
- float(swap_hi_lo_next) * lo_next_deriv_0
- + (1.0 - float(swap_hi_lo_next)) * hi_deriv_0
- )
- hi_deriv_1 = (
- float(swap_hi_lo_next) * lo_next_deriv_1
- + (1.0 - float(swap_hi_lo_next)) * hi_deriv_1
- )
-
- ls_ctx.lo.alpha[worldid] = lo_alpha
- ls_ctx.lo.cost[worldid] = lo_cost
- ls_ctx.lo.deriv_0[worldid] = lo_deriv_0
- ls_ctx.lo.deriv_1[worldid] = lo_deriv_1
-
- ls_ctx.hi.alpha[worldid] = hi_alpha
- ls_ctx.hi.cost[worldid] = hi_cost
- ls_ctx.hi.deriv_0[worldid] = hi_deriv_0
- ls_ctx.hi.deriv_1[worldid] = hi_deriv_1
-
- swap = swap_lo_next or swap_lo_mid or swap_lo_hi_next
- swap = swap or swap_hi_next or swap_hi_mid or swap_hi_lo_next
- ls_ctx.swap[worldid] = int(swap)
-
- wp.launch(_swap_lo_hi, dim=(d.nworld,), inputs=[ls_ctx])
-
- @wp.kernel
- def _done(ls_ctx: LSContext, ctx: Context, m: types.Model, ls_iter: int):
- worldid = wp.tid()
- done = ls_iter >= m.opt.ls_iterations
- done = done or (1 - ls_ctx.swap[worldid])
- done = done or (
- (ls_ctx.lo.deriv_0[worldid] < 0.0)
- and (ls_ctx.lo.deriv_0[worldid] > -ctx.gtol[worldid])
- )
- done = done or (
- (ls_ctx.hi.deriv_0[worldid] > 0.0)
- and (ls_ctx.hi.deriv_0[worldid] < ctx.gtol[worldid])
- )
- ls_ctx.done[worldid] = int(done)
-
- wp.launch(_done, dim=(d.nworld,), inputs=[ls_ctx, ctx, m, i])
- # TODO(team): return if all done
-
- @wp.kernel
- def _alpha(ctx: Context, ls_ctx: LSContext):
- worldid = wp.tid()
- p0_cost = ls_ctx.p0.cost[worldid]
- lo_cost = ls_ctx.lo.cost[worldid]
- hi_cost = ls_ctx.hi.cost[worldid]
-
- improvement = float((lo_cost < p0_cost) or (hi_cost < p0_cost))
- lo_hi_cost = float(lo_cost < hi_cost)
- ctx.alpha[worldid] = improvement * (
- lo_hi_cost * ls_ctx.lo.alpha[worldid]
- + (1.0 - lo_hi_cost) * ls_ctx.hi.alpha[worldid]
- )
-
- wp.launch(_alpha, dim=(d.nworld,), inputs=[ctx, ls_ctx])
-
- @wp.kernel
- def _qacc_ma(ctx: Context, d: types.Data):
- worldid, dofid = wp.tid()
- alpha = ctx.alpha[worldid]
- d.qacc[worldid, dofid] += alpha * ctx.search[worldid, dofid]
- ctx.Ma[worldid, dofid] += alpha * ctx.mv[worldid, dofid]
-
- wp.launch(_qacc_ma, dim=(d.nworld, m.nv), inputs=[ctx, d])
-
- @wp.kernel
- def _jaref(ctx: Context, d: types.Data):
- efcid = wp.tid()
-
- if efcid >= d.nefc_total[0]:
- return
-
- worldid = d.efc_worldid[efcid]
- ctx.Jaref[efcid] += ctx.alpha[worldid] * ctx.jv[efcid]
-
- wp.launch(_jaref, dim=(d.njmax,), inputs=[ctx, d])
-
-
-def solve(m: types.Model, d: types.Data):
- """Finds forces that satisfy constraints."""
-
- # warmstart
- wp.copy(d.qacc, d.qacc_warmstart)
-
- ctx = _context(m, d)
- _create_context(ctx, m, d, grad=True)
-
- for i in range(m.opt.iterations):
- _linesearch(m, d, ctx)
- wp.copy(ctx.prev_grad, ctx.grad)
- wp.copy(ctx.prev_Mgrad, ctx.Mgrad)
- _update_constraint(m, d, ctx)
- _update_gradient(m, d, ctx)
-
- if m.opt.solver == 2: # Newton
- ctx.search_dot.zero_()
-
- @wp.kernel
- def _search_newton(ctx: Context):
- worldid, dofid = wp.tid()
- search = -1.0 * ctx.Mgrad[worldid, dofid]
- ctx.search[worldid, dofid] = search
- wp.atomic_add(ctx.search_dot, worldid, search * search)
-
- wp.launch(_search_newton, dim=(d.nworld, m.nv), inputs=[ctx])
- else: # polak-ribiere
- ctx.beta_num.zero_()
- ctx.beta_den.zero_()
-
- @wp.kernel
- def _beta_num_den(ctx: Context):
- worldid, dofid = wp.tid()
- prev_Mgrad = ctx.prev_Mgrad[worldid][dofid]
- wp.atomic_add(
- ctx.beta_num,
- worldid,
- ctx.grad[worldid, dofid] * (ctx.Mgrad[worldid, dofid] - prev_Mgrad),
- )
- wp.atomic_add(ctx.beta_den, worldid, ctx.prev_grad[worldid, dofid] * prev_Mgrad)
-
- wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[ctx])
-
- @wp.kernel
- def _beta(ctx: Context):
- worldid = wp.tid()
- ctx.beta[worldid] = wp.max(
- 0.0, ctx.beta_num[worldid] / wp.max(mujoco.mjMINVAL, ctx.beta_den[worldid])
- )
-
- wp.launch(_beta, dim=(d.nworld,), inputs=[ctx])
-
- ctx.search_dot.zero_()
-
- @wp.kernel
- def _search_cg(ctx: Context):
- worldid, dofid = wp.tid()
- search = (
- -1.0 * ctx.Mgrad[worldid, dofid]
- + ctx.beta[worldid] * ctx.search[worldid, dofid]
- )
- ctx.search[worldid, dofid] = search
- wp.atomic_add(ctx.search_dot, worldid, search * search)
-
- wp.launch(_search_cg, dim=(d.nworld, m.nv), inputs=[ctx])
-
- @wp.kernel
- def _done(ctx: Context, m: types.Model, solver_niter: int):
- worldid = wp.tid()
- improvement = _rescale(m, ctx.prev_cost[worldid] - ctx.cost[worldid])
- gradient = _rescale(m, wp.math.sqrt(ctx.grad_dot[worldid]))
- done = solver_niter >= m.opt.iterations
- done = done or (improvement < m.opt.tolerance)
- done = done or (gradient < m.opt.tolerance)
- ctx.done[worldid] = int(done)
-
- wp.launch(_done, dim=(d.nworld,), inputs=[ctx, m, i])
- # TODO(team): return if all done
-
- wp.copy(d.qacc_warmstart, d.qacc)
diff --git a/mujoco/mjx/_src/support.py b/mujoco/mjx/_src/support.py
deleted file mode 100644
index c6888d15..00000000
--- a/mujoco/mjx/_src/support.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Copyright 2025 The Physics-Next Project Developers
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-import mujoco
-import warp as wp
-from .types import Model
-from .types import Data
-from .types import array2df
-
-
-def is_sparse(m: mujoco.MjModel):
- if m.opt.jacobian == mujoco.mjtJacobian.mjJAC_AUTO:
- return m.nv >= 60
- return m.opt.jacobian == mujoco.mjtJacobian.mjJAC_SPARSE
-
-
-def mul_m(
- m: Model,
- d: Data,
- res: wp.array(ndim=2, dtype=wp.float32),
- vec: wp.array(ndim=2, dtype=wp.float32),
-):
- """Multiply vector by inertia matrix."""
-
- if not m.opt.is_sparse:
- # TODO(team): tile_matmul
- res.zero_()
-
- @wp.kernel
- def _mul_m_dense(
- d: Data,
- res: wp.array(ndim=2, dtype=wp.float32),
- vec: wp.array(ndim=2, dtype=wp.float32),
- ):
- worldid, rowid, colid = wp.tid()
- wp.atomic_add(
- res[worldid], rowid, d.qM[worldid, rowid, colid] * vec[worldid, colid]
- )
-
- wp.launch(_mul_m_dense, dim=(d.nworld, m.nv, m.nv), inputs=[d, res, vec])
- else:
-
- @wp.kernel
- def _mul_m_sparse_diag(
- m: Model,
- d: Data,
- res: wp.array(ndim=2, dtype=wp.float32),
- vec: wp.array(ndim=2, dtype=wp.float32),
- ):
- worldid, dofid = wp.tid()
- res[worldid, dofid] = d.qM[worldid, 0, m.dof_Madr[dofid]] * vec[worldid, dofid]
-
- wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec])
-
- @wp.kernel
- def _mul_m_sparse_ij(
- m: Model,
- d: Data,
- res: wp.array(ndim=2, dtype=wp.float32),
- vec: wp.array(ndim=2, dtype=wp.float32),
- ):
- worldid, elementid = wp.tid()
- i = m.qM_i[elementid]
- j = m.qM_j[elementid]
- madr_ij = m.qM_madr_ij[elementid]
-
- qM = d.qM[worldid, 0, madr_ij]
-
- wp.atomic_add(res[worldid], i, qM * vec[worldid, j])
- wp.atomic_add(res[worldid], j, qM * vec[worldid, i])
-
- wp.launch(
- _mul_m_sparse_ij, dim=(d.nworld, m.qM_madr_ij.size), inputs=[m, d, res, vec]
- )
-
-
-@wp.kernel
-def process_level(
- body_tree: wp.array(ndim=1, dtype=int),
- body_parentid: wp.array(ndim=1, dtype=int),
- dof_bodyid: wp.array(ndim=1, dtype=int),
- mask: wp.array2d(dtype=wp.bool),
- beg: int,
-):
- dofid, tid_y = wp.tid()
- j = beg + tid_y
- el = body_tree[j]
- parent_id = body_parentid[el]
- parent_val = mask[dofid, parent_id]
- mask[dofid, el] = parent_val or (dof_bodyid[dofid] == el)
-
-
-@wp.kernel
-def compute_qfrc(
- d: Data,
- m: Model,
- mask: wp.array2d(dtype=wp.bool),
- qfrc_total: array2df,
-):
- worldid, dofid = wp.tid()
- accumul = float(0.0)
- cdof_vec = d.cdof[worldid, dofid]
- rotational_cdof = wp.vec3(cdof_vec[0], cdof_vec[1], cdof_vec[2])
-
- jac = wp.spatial_vector(
- cdof_vec[3], cdof_vec[4], cdof_vec[5], cdof_vec[0], cdof_vec[1], cdof_vec[2]
- )
-
- for bodyid in range(m.nbody):
- if mask[dofid, bodyid]:
- offset = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]]
- cross_term = wp.cross(rotational_cdof, offset)
- accumul += wp.dot(jac, d.xfrc_applied[worldid, bodyid]) + wp.dot(
- cross_term,
- wp.vec3(
- d.xfrc_applied[worldid, bodyid][0],
- d.xfrc_applied[worldid, bodyid][1],
- d.xfrc_applied[worldid, bodyid][2],
- ),
- )
-
- qfrc_total[worldid, dofid] = accumul
-
-
-def xfrc_accumulate(m: Model, d: Data) -> array2df:
- body_treeadr_np = m.body_treeadr.numpy()
- mask = wp.zeros((m.nv, m.nbody), dtype=wp.bool)
-
- for i in range(len(body_treeadr_np)):
- beg = body_treeadr_np[i]
- end = m.nbody if i == len(body_treeadr_np) - 1 else body_treeadr_np[i + 1]
-
- if end > beg:
- wp.launch(
- kernel=process_level,
- dim=[m.nv, (end - beg)],
- inputs=[m.body_tree, m.body_parentid, m.dof_bodyid, mask, beg],
- )
-
- qfrc_total = wp.zeros((d.nworld, m.nv), dtype=float)
-
- wp.launch(kernel=compute_qfrc, dim=(d.nworld, m.nv), inputs=[d, m, mask, qfrc_total])
-
- return qfrc_total
diff --git a/mujoco/mjx/_src/types.py b/mujoco/mjx/_src/types.py
deleted file mode 100644
index dfd9a1b3..00000000
--- a/mujoco/mjx/_src/types.py
+++ /dev/null
@@ -1,352 +0,0 @@
-# Copyright 2025 The Physics-Next Project Developers
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-import warp as wp
-import enum
-import mujoco
-
-MJ_MINVAL = mujoco.mjMINVAL
-MJ_MINIMP = mujoco.mjMINIMP # minimum constraint impedance
-MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance
-MJ_NREF = mujoco.mjNREF
-MJ_NIMP = mujoco.mjNIMP
-
-
-class DisableBit(enum.IntFlag):
- """Disable default feature bitflags.
-
- Members:
- CONSTRAINT: entire constraint solver
- EQUALITY: equality constraints
- FRICTIONLOSS: joint and tendon frictionloss constraints
- LIMIT: joint and tendon limit constraints
- CONTACT: contact constraints
- PASSIVE: passive forces
- GRAVITY: gravitational forces
- CLAMPCTRL: clamp control to specified range
- WARMSTART: warmstart constraint solver
- ACTUATION: apply actuation forces
- REFSAFE: integrator safety: make ref[0]>=2*timestep
- SENSOR: sensors
- """
-
- CONSTRAINT = mujoco.mjtDisableBit.mjDSBL_CONSTRAINT
- EQUALITY = mujoco.mjtDisableBit.mjDSBL_EQUALITY
- FRICTIONLOSS = mujoco.mjtDisableBit.mjDSBL_FRICTIONLOSS
- LIMIT = mujoco.mjtDisableBit.mjDSBL_LIMIT
- CONTACT = mujoco.mjtDisableBit.mjDSBL_CONTACT
- PASSIVE = mujoco.mjtDisableBit.mjDSBL_PASSIVE
- GRAVITY = mujoco.mjtDisableBit.mjDSBL_GRAVITY
- CLAMPCTRL = mujoco.mjtDisableBit.mjDSBL_CLAMPCTRL
- WARMSTART = mujoco.mjtDisableBit.mjDSBL_WARMSTART
- ACTUATION = mujoco.mjtDisableBit.mjDSBL_ACTUATION
- REFSAFE = mujoco.mjtDisableBit.mjDSBL_REFSAFE
- SENSOR = mujoco.mjtDisableBit.mjDSBL_SENSOR
- EULERDAMP = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
- FILTERPARENT = mujoco.mjtDisableBit.mjDSBL_FILTERPARENT
- # unsupported: MIDPHASE
-
-
-class TrnType(enum.IntEnum):
- """Type of actuator transmission.
-
- Members:
- JOINT: force on joint
- JOINTINPARENT: force on joint, expressed in parent frame
- TENDON: force on tendon (unsupported)
- SITE: force on site (unsupported)
- """
-
- JOINT = mujoco.mjtTrn.mjTRN_JOINT
- JOINTINPARENT = mujoco.mjtTrn.mjTRN_JOINTINPARENT
- # unsupported: SITE, TENDON, SLIDERCRANK, BODY
-
-
-class DynType(enum.IntEnum):
- """Type of actuator dynamics.
-
- Members:
- NONE: no internal dynamics; ctrl specifies force
- INTEGRATOR: integrator: da/dt = u
- FILTER: linear filter: da/dt = (u-a) / tau
- FILTEREXACT: linear filter: da/dt = (u-a) / tau, with exact integration
- MUSCLE: piece-wise linear filter with two time constants
- """
-
- NONE = mujoco.mjtDyn.mjDYN_NONE
- INTEGRATOR = mujoco.mjtDyn.mjDYN_INTEGRATOR
- FILTER = mujoco.mjtDyn.mjDYN_FILTER
- FILTEREXACT = mujoco.mjtDyn.mjDYN_FILTEREXACT
- MUSCLE = mujoco.mjtDyn.mjDYN_MUSCLE
- # unsupported: USER
-
-
-class JointType(enum.IntEnum):
- """Type of degree of freedom.
-
- Members:
- FREE: global position and orientation (quat) (7,)
- BALL: orientation (quat) relative to parent (4,)
- SLIDE: sliding distance along body-fixed axis (1,)
- HINGE: rotation angle (rad) around body-fixed axis (1,)
- """
-
- FREE = mujoco.mjtJoint.mjJNT_FREE
- BALL = mujoco.mjtJoint.mjJNT_BALL
- SLIDE = mujoco.mjtJoint.mjJNT_SLIDE
- HINGE = mujoco.mjtJoint.mjJNT_HINGE
-
- def dof_width(self) -> int:
- return {0: 6, 1: 3, 2: 1, 3: 1}[self.value]
-
- def qpos_width(self) -> int:
- return {0: 7, 1: 4, 2: 1, 3: 1}[self.value]
-
-
-class ConeType(enum.IntEnum):
- """Type of friction cone.
-
- Members:
- PYRAMIDAL: pyramidal
- ELLIPTIC: elliptic
- """
-
- PYRAMIDAL = mujoco.mjtCone.mjCONE_PYRAMIDAL
- ELLIPTIC = mujoco.mjtCone.mjCONE_ELLIPTIC
-
-
-class vec5f(wp.types.vector(length=5, dtype=wp.float32)):
- pass
-
-
-class vec10f(wp.types.vector(length=10, dtype=wp.float32)):
- pass
-
-
-vec5 = vec5f
-vec10 = vec10f
-array2df = wp.array2d(dtype=wp.float32)
-array3df = wp.array3d(dtype=wp.float32)
-
-
-@wp.struct
-class Option:
- timestep: float
- tolerance: float
- ls_tolerance: float
- gravity: wp.vec3
- cone: int # mjtCone
- solver: int # mjtSolver
- iterations: int
- ls_iterations: int
- disableflags: int
- integrator: int # mjtIntegrator
- impratio: wp.float32
- is_sparse: bool # warp only
-
-
-@wp.struct
-class Statistic:
- meaninertia: float
-
-
-@wp.struct
-class Model:
- nq: int
- nv: int
- na: int
- nu: int
- nbody: int
- njnt: int
- ngeom: int
- nsite: int
- nmocap: int
- nM: int
- opt: Option
- stat: Statistic
- qpos0: wp.array(dtype=wp.float32, ndim=1)
- qpos_spring: wp.array(dtype=wp.float32, ndim=1)
- body_tree: wp.array(dtype=wp.int32, ndim=1) # warp only
- body_treeadr: wp.array(dtype=wp.int32, ndim=1) # warp only
- qM_i: wp.array(dtype=wp.int32, ndim=1) # warp only
- qM_j: wp.array(dtype=wp.int32, ndim=1) # warp only
- qM_madr_ij: wp.array(dtype=wp.int32, ndim=1) # warp only
- qLD_update_tree: wp.array(dtype=wp.vec3i, ndim=1) # warp only
- qLD_update_treeadr: wp.array(dtype=wp.int32, ndim=1) # warp only
- qLD_tile: wp.array(dtype=wp.int32, ndim=1) # warp only
- qLD_tileadr: wp.array(dtype=wp.int32, ndim=1) # warp only
- qLD_tilesize: wp.array(dtype=wp.int32, ndim=1) # warp only
- body_dofadr: wp.array(dtype=wp.int32, ndim=1)
- body_dofnum: wp.array(dtype=wp.int32, ndim=1)
- body_jntadr: wp.array(dtype=wp.int32, ndim=1)
- body_jntnum: wp.array(dtype=wp.int32, ndim=1)
- body_parentid: wp.array(dtype=wp.int32, ndim=1)
- body_mocapid: wp.array(dtype=wp.int32, ndim=1)
- body_weldid: wp.array(dtype=wp.int32, ndim=1)
- body_pos: wp.array(dtype=wp.vec3, ndim=1)
- body_quat: wp.array(dtype=wp.quat, ndim=1)
- body_ipos: wp.array(dtype=wp.vec3, ndim=1)
- body_iquat: wp.array(dtype=wp.quat, ndim=1)
- body_rootid: wp.array(dtype=wp.int32, ndim=1)
- body_inertia: wp.array(dtype=wp.vec3, ndim=1)
- body_mass: wp.array(dtype=wp.float32, ndim=1)
- body_invweight0: wp.array(dtype=wp.float32, ndim=2)
- jnt_bodyid: wp.array(dtype=wp.int32, ndim=1)
- jnt_limited: wp.array(dtype=wp.int32, ndim=1)
- jnt_limited_slide_hinge_adr: wp.array(dtype=wp.int32, ndim=1) # warp only
- jnt_solref: wp.array(dtype=wp.vec2f, ndim=1)
- jnt_solimp: wp.array(dtype=vec5, ndim=1)
- jnt_type: wp.array(dtype=wp.int32, ndim=1)
- jnt_qposadr: wp.array(dtype=wp.int32, ndim=1)
- jnt_dofadr: wp.array(dtype=wp.int32, ndim=1)
- jnt_axis: wp.array(dtype=wp.vec3, ndim=1)
- jnt_pos: wp.array(dtype=wp.vec3, ndim=1)
- jnt_range: wp.array(dtype=wp.float32, ndim=2)
- jnt_margin: wp.array(dtype=wp.float32, ndim=1)
- jnt_stiffness: wp.array(dtype=wp.float32, ndim=1)
- jnt_actfrclimited: wp.array(dtype=wp.bool, ndim=1)
- jnt_actfrcrange: wp.array(dtype=wp.vec2, ndim=1)
- geom_bodyid: wp.array(dtype=wp.int32, ndim=1)
- geom_pos: wp.array(dtype=wp.vec3, ndim=1)
- geom_quat: wp.array(dtype=wp.quat, ndim=1)
- site_pos: wp.array(dtype=wp.vec3, ndim=1)
- site_quat: wp.array(dtype=wp.quat, ndim=1)
- site_bodyid: wp.array(dtype=wp.int32, ndim=1)
- dof_bodyid: wp.array(dtype=wp.int32, ndim=1)
- dof_jntid: wp.array(dtype=wp.int32, ndim=1)
- dof_parentid: wp.array(dtype=wp.int32, ndim=1)
- dof_Madr: wp.array(dtype=wp.int32, ndim=1)
- dof_armature: wp.array(dtype=wp.float32, ndim=1)
- dof_invweight0: wp.array(dtype=wp.float32, ndim=1)
- dof_damping: wp.array(dtype=wp.float32, ndim=1)
- dof_tri_row: wp.array(dtype=wp.int32, ndim=1) # warp only
- dof_tri_col: wp.array(dtype=wp.int32, ndim=1) # warp only
- actuator_trntype: wp.array(dtype=wp.int32, ndim=1)
- actuator_trnid: wp.array(dtype=wp.int32, ndim=2)
- actuator_ctrllimited: wp.array(dtype=wp.bool, ndim=1)
- actuator_ctrlrange: wp.array(dtype=wp.vec2, ndim=1)
- actuator_forcelimited: wp.array(dtype=wp.bool, ndim=1)
- actuator_forcerange: wp.array(dtype=wp.vec2, ndim=1)
- actuator_gainprm: wp.array(dtype=wp.float32, ndim=2)
- actuator_biasprm: wp.array(dtype=wp.float32, ndim=2)
- actuator_gear: wp.array(dtype=wp.spatial_vector, ndim=1)
- actuator_actlimited: wp.array(dtype=wp.bool, ndim=1)
- actuator_actrange: wp.array(dtype=wp.vec2, ndim=1)
- actuator_actadr: wp.array(dtype=wp.int32, ndim=1)
- actuator_dyntype: wp.array(dtype=wp.int32, ndim=1)
- actuator_dynprm: wp.array(dtype=vec10f, ndim=1)
- opt: Option
-
-
-@wp.struct
-class Contact:
- dist: wp.array(dtype=wp.float32, ndim=1)
- pos: wp.array(dtype=wp.vec3f, ndim=1)
- frame: wp.array(dtype=wp.mat33f, ndim=1)
- includemargin: wp.array(dtype=wp.float32, ndim=1)
- friction: wp.array(dtype=vec5, ndim=1)
- solref: wp.array(dtype=wp.vec2f, ndim=1)
- solreffriction: wp.array(dtype=wp.vec2f, ndim=1)
- solimp: wp.array(dtype=vec5, ndim=1)
- dim: wp.array(dtype=wp.int32, ndim=1)
- geom: wp.array(dtype=wp.vec2i, ndim=1)
- efc_address: wp.array(dtype=wp.int32, ndim=1)
- worldid: wp.array(dtype=wp.int32, ndim=1)
-
-
-@wp.struct
-class Data:
- nworld: int
- ncon_total: wp.array(dtype=wp.int32, ndim=1) # warp only
- nefc_total: wp.array(dtype=wp.int32, ndim=1) # warp only
- nconmax: int
- njmax: int
- time: float
- qpos: wp.array(dtype=wp.float32, ndim=2)
- qvel: wp.array(dtype=wp.float32, ndim=2)
- qacc_warmstart: wp.array(dtype=wp.float32, ndim=2)
- qfrc_applied: wp.array(dtype=wp.float32, ndim=2)
- ncon: int
- nl: int
- nefc: wp.array(dtype=wp.int32, ndim=1)
- ctrl: wp.array(dtype=wp.float32, ndim=2)
- mocap_pos: wp.array(dtype=wp.vec3, ndim=2)
- mocap_quat: wp.array(dtype=wp.quat, ndim=2)
- qacc: wp.array(dtype=wp.float32, ndim=2)
- xanchor: wp.array(dtype=wp.vec3, ndim=2)
- xaxis: wp.array(dtype=wp.vec3, ndim=2)
- xmat: wp.array(dtype=wp.mat33, ndim=2)
- xpos: wp.array(dtype=wp.vec3, ndim=2)
- xquat: wp.array(dtype=wp.quat, ndim=2)
- xipos: wp.array(dtype=wp.vec3, ndim=2)
- ximat: wp.array(dtype=wp.mat33, ndim=2)
- subtree_com: wp.array(dtype=wp.vec3, ndim=2)
- geom_xpos: wp.array(dtype=wp.vec3, ndim=2)
- geom_xmat: wp.array(dtype=wp.mat33, ndim=2)
- site_xpos: wp.array(dtype=wp.vec3, ndim=2)
- site_xmat: wp.array(dtype=wp.mat33, ndim=2)
- cinert: wp.array(dtype=vec10, ndim=2)
- cdof: wp.array(dtype=wp.spatial_vector, ndim=2)
- crb: wp.array(dtype=vec10, ndim=2)
- qM: wp.array(dtype=wp.float32, ndim=3)
- qLD: wp.array(dtype=wp.float32, ndim=3)
- act: wp.array(dtype=wp.float32, ndim=2)
- act_dot: wp.array(dtype=wp.float32, ndim=2)
- qLDiagInv: wp.array(dtype=wp.float32, ndim=2)
- actuator_velocity: wp.array(dtype=wp.float32, ndim=2)
- actuator_force: wp.array(dtype=wp.float32, ndim=2)
- actuator_length: wp.array(dtype=wp.float32, ndim=2)
- actuator_moment: wp.array(dtype=wp.float32, ndim=3)
- cvel: wp.array(dtype=wp.spatial_vector, ndim=2)
- cdof_dot: wp.array(dtype=wp.spatial_vector, ndim=2)
- qfrc_applied: wp.array(dtype=wp.float32, ndim=2)
- qfrc_bias: wp.array(dtype=wp.float32, ndim=2)
- qfrc_constraint: wp.array(dtype=wp.float32, ndim=2)
- qfrc_passive: wp.array(dtype=wp.float32, ndim=2)
- qfrc_spring: wp.array(dtype=wp.float32, ndim=2)
- qfrc_damper: wp.array(dtype=wp.float32, ndim=2)
- qfrc_actuator: wp.array(dtype=wp.float32, ndim=2)
- qfrc_smooth: wp.array(dtype=wp.float32, ndim=2)
- qacc_smooth: wp.array(dtype=wp.float32, ndim=2)
- qfrc_constraint: wp.array(dtype=wp.float32, ndim=2)
- efc_J: wp.array(dtype=wp.float32, ndim=2)
- efc_D: wp.array(dtype=wp.float32, ndim=1)
- efc_pos: wp.array(dtype=wp.float32, ndim=1)
- efc_aref: wp.array(dtype=wp.float32, ndim=1)
- efc_force: wp.array(dtype=wp.float32, ndim=1)
- efc_margin: wp.array(dtype=wp.float32, ndim=1)
- efc_worldid: wp.array(dtype=wp.int32, ndim=1) # warp only
- xfrc_applied: wp.array(dtype=wp.spatial_vector, ndim=2)
- contact: Contact
-
- # temp arrays
- qfrc_integration: wp.array(dtype=wp.float32, ndim=2)
- qacc_integration: wp.array(dtype=wp.float32, ndim=2)
-
- qM_integration: wp.array(dtype=wp.float32, ndim=3)
- qLD_integration: wp.array(dtype=wp.float32, ndim=3)
- qLDiagInv_integration: wp.array(dtype=wp.float32, ndim=2)
-
- # broadphase arrays
- max_num_overlaps_per_world: int
- broadphase_pairs: wp.array(dtype=wp.vec2i, ndim=2)
- result_count: wp.array(dtype=wp.int32, ndim=1)
- boxes_sorted: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=2)
- data_start: wp.array(dtype=wp.float32, ndim=2)
- data_end: wp.array(dtype=wp.float32, ndim=2)
- data_indexer: wp.array(dtype=wp.int32, ndim=2)
- ranges: wp.array(dtype=wp.int32, ndim=2)
- cumulative_sum: wp.array(dtype=wp.int32, ndim=1)
- segment_indices: wp.array(dtype=wp.int32, ndim=1)
- geom_aabb: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1)
diff --git a/mujoco/mjx/testspeed.py b/mujoco/mjx/testspeed.py
deleted file mode 100644
index 6e544b25..00000000
--- a/mujoco/mjx/testspeed.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Copyright 2025 The Physics-Next Project Developers
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Run benchmarks on various devices."""
-
-import inspect
-from typing import Sequence
-
-from absl import app
-from absl import flags
-from etils import epath
-import warp as wp
-
-import mujoco
-from mujoco import mjx
-
-_FUNCTION = flags.DEFINE_enum(
- "function",
- "kinematics",
- [n for n, _ in inspect.getmembers(mjx, inspect.isfunction)],
- "the function to run",
-)
-_MJCF = flags.DEFINE_string(
- "mjcf", None, "path to model `.xml` or `.mjb`", required=True
-)
-_BASE_PATH = flags.DEFINE_string(
- "base_path", None, "base path, defaults to mujoco.mjx resource path"
-)
-_NSTEP = flags.DEFINE_integer("nstep", 1000, "number of steps per rollout")
-_BATCH_SIZE = flags.DEFINE_integer("batch_size", 4096, "number of parallel rollouts")
-_UNROLL = flags.DEFINE_integer("unroll", 1, "loop unroll length")
-_SOLVER = flags.DEFINE_enum("solver", "cg", ["cg", "newton"], "constraint solver")
-_ITERATIONS = flags.DEFINE_integer("iterations", 1, "number of solver iterations")
-_LS_ITERATIONS = flags.DEFINE_integer(
- "ls_iterations", 4, "number of linesearch iterations"
-)
-_IS_SPARSE = flags.DEFINE_bool(
- "is_sparse", True, "if model should create sparse mass matrices"
-)
-_NEFC_TOTAL = flags.DEFINE_integer(
- "nefc_total", 0, "total number of efc for batch of worlds"
-)
-_OUTPUT = flags.DEFINE_enum(
- "output", "text", ["text", "tsv"], "format to print results"
-)
-
-
-def _main(argv: Sequence[str]):
- """Runs testpeed function."""
- wp.init()
-
- path = epath.resource_path("mujoco.mjx") / "test_data"
- path = _BASE_PATH.value or path
- f = epath.Path(path) / _MJCF.value
- if f.suffix == ".mjb":
- m = mujoco.MjModel.from_binary_path(f.as_posix())
- else:
- m = mujoco.MjModel.from_xml_path(f.as_posix())
-
- if _IS_SPARSE.value:
- m.opt.jacobian = mujoco.mjtJacobian.mjJAC_SPARSE
- else:
- m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE
-
- print(
- f"Model nbody: {m.nbody} nv: {m.nv} ngeom: {m.ngeom} is_sparse: {_IS_SPARSE.value}"
- )
- print(f"Rolling out {_NSTEP.value} steps at dt = {m.opt.timestep:.3f}...")
- jit_time, run_time, steps = mjx.benchmark(
- mjx.__dict__[_FUNCTION.value],
- m,
- _NSTEP.value,
- _BATCH_SIZE.value,
- _UNROLL.value,
- _SOLVER.value,
- _ITERATIONS.value,
- _LS_ITERATIONS.value,
- _NEFC_TOTAL.value,
- )
-
- name = argv[0]
- if _OUTPUT.value == "text":
- print(f"""
-Summary for {_BATCH_SIZE.value} parallel rollouts
-
- Total JIT time: {jit_time:.2f} s
- Total simulation time: {run_time:.2f} s
- Total steps per second: {steps / run_time:,.0f}
- Total realtime factor: {steps * m.opt.timestep / run_time:,.2f} x
- Total time per step: {1e6 * run_time / steps:.2f} µs""")
- elif _OUTPUT.value == "tsv":
- name = name.split("/")[-1].replace("testspeed_", "")
- print(f"{name}\tjit: {jit_time:.2f}s\tsteps/second: {steps / run_time:.0f}")
-
-
-def main():
- app.run(_main)
-
-
-if __name__ == "__main__":
- main()
diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py
new file mode 100644
index 00000000..123859c1
--- /dev/null
+++ b/mujoco_warp/__init__.py
@@ -0,0 +1,57 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Public API for MJWarp."""
+
+from ._src.collision_driver import collision as collision
+from ._src.collision_driver import nxn_broadphase as nxn_broadphase
+from ._src.collision_driver import sap_broadphase as sap_broadphase
+from ._src.constraint import make_constraint as make_constraint
+from ._src.forward import euler as euler
+from ._src.forward import forward as forward
+from ._src.forward import fwd_acceleration as fwd_acceleration
+from ._src.forward import fwd_actuation as fwd_actuation
+from ._src.forward import fwd_position as fwd_position
+from ._src.forward import fwd_velocity as fwd_velocity
+from ._src.forward import implicit as implicit
+from ._src.forward import step as step
+from ._src.io import get_data_into as get_data_into
+from ._src.io import make_data as make_data
+from ._src.io import put_data as put_data
+from ._src.io import put_model as put_model
+from ._src.passive import passive as passive
+from ._src.smooth import com_pos as com_pos
+from ._src.smooth import com_vel as com_vel
+from ._src.smooth import crb as crb
+from ._src.smooth import factor_m as factor_m
+from ._src.smooth import kinematics as kinematics
+from ._src.smooth import rne as rne
+from ._src.smooth import solve_m as solve_m
+from ._src.smooth import transmission as transmission
+from ._src.solver import solve as solve
+from ._src.support import is_sparse as is_sparse
+from ._src.support import mul_m as mul_m
+from ._src.support import xfrc_accumulate as xfrc_accumulate
+from ._src.test_util import benchmark as benchmark
+from ._src.types import ConeType as ConeType
+from ._src.types import Contact as Contact
+from ._src.types import Data as Data
+from ._src.types import DisableBit as DisableBit
+from ._src.types import DynType as DynType
+from ._src.types import JointType as JointType
+from ._src.types import Model as Model
+from ._src.types import Option as Option
+from ._src.types import Statistic as Statistic
+from ._src.types import TrnType as TrnType
diff --git a/mujoco_warp/_src/__init__.py b/mujoco_warp/_src/__init__.py
new file mode 100644
index 00000000..2b276beb
--- /dev/null
+++ b/mujoco_warp/_src/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/mujoco_warp/_src/broad_phase_test.py b/mujoco_warp/_src/broad_phase_test.py
new file mode 100644
index 00000000..47a6616f
--- /dev/null
+++ b/mujoco_warp/_src/broad_phase_test.py
@@ -0,0 +1,250 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for broadphase functions."""
+
+import mujoco
+import numpy as np
+import warp as wp
+from absl.testing import absltest
+
+import mujoco_warp as mjwarp
+
+from . import collision_driver
+
+
+def _load_from_string(xml: str, keyframe: int = -1):
+ mjm = mujoco.MjModel.from_xml_string(xml)
+ mjd = mujoco.MjData(mjm)
+ if keyframe > -1:
+ mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe)
+ mujoco.mj_forward(mjm, mjd)
+
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd)
+
+ return mjm, mjd, m, d
+
+
+class BroadphaseTest(absltest.TestCase):
+ def test_sap_broadphase(self):
+ """Tests sap_broadphase."""
+
+ _SAP_MODEL = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+
+ collision_pair = [
+ (0, 1),
+ (0, 2),
+ (0, 3),
+ (0, 4),
+ (0, 6),
+ (1, 2),
+ (1, 4),
+ (1, 5),
+ (1, 6),
+ (1, 7),
+ (2, 3),
+ (2, 4),
+ (2, 5),
+ (2, 6),
+ (2, 7),
+ (3, 4),
+ (4, 6),
+ (5, 7),
+ (6, 7),
+ ]
+
+ _, _, m, d = _load_from_string(_SAP_MODEL)
+
+ mjwarp.sap_broadphase(m, d)
+
+ ncollision = d.ncollision.numpy()[0]
+ np.testing.assert_equal(ncollision, len(collision_pair), "ncollision")
+
+ for i in range(ncollision):
+ pair = d.collision_pair.numpy()[i]
+ if pair[0] > pair[1]:
+ pair_tuple = (int(pair[1]), int(pair[0]))
+ else:
+ pair_tuple = (int(pair[0]), int(pair[1]))
+
+ np.testing.assert_equal(
+ pair_tuple in collision_pair,
+ True,
+ f"geom pair {pair_tuple} not found in brute force results",
+ )
+
+ # TODO(team): test DisableBit.FILTERPARENT
+
+ # TODO(team): test DisableBit.FILTERPARENT
+
+ def test_nxn_broadphase(self):
+ """Tests nxn_broadphase."""
+
+ _NXN_MODEL = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+ # one world and zero collisions
+ mjm, _, m, d0 = _load_from_string(_NXN_MODEL, keyframe=0)
+ collision_driver.nxn_broadphase(m, d0)
+ np.testing.assert_allclose(d0.ncollision.numpy()[0], 0)
+
+ # one world and one collision
+ _, mjd1, _, d1 = _load_from_string(_NXN_MODEL, keyframe=1)
+ collision_driver.nxn_broadphase(m, d1)
+
+ np.testing.assert_allclose(d1.ncollision.numpy()[0], 1)
+ np.testing.assert_allclose(d1.collision_pair.numpy()[0][0], 0)
+ np.testing.assert_allclose(d1.collision_pair.numpy()[0][1], 1)
+
+ # one world and three collisions
+ _, mjd2, _, d2 = _load_from_string(_NXN_MODEL, keyframe=2)
+ collision_driver.nxn_broadphase(m, d2)
+ np.testing.assert_allclose(d2.ncollision.numpy()[0], 3)
+ np.testing.assert_allclose(d2.collision_pair.numpy()[0][0], 0)
+ np.testing.assert_allclose(d2.collision_pair.numpy()[0][1], 1)
+ np.testing.assert_allclose(d2.collision_pair.numpy()[1][0], 0)
+ np.testing.assert_allclose(d2.collision_pair.numpy()[1][1], 2)
+ np.testing.assert_allclose(d2.collision_pair.numpy()[2][0], 1)
+ np.testing.assert_allclose(d2.collision_pair.numpy()[2][1], 2)
+
+ # two worlds and four collisions
+ d3 = mjwarp.make_data(mjm, nworld=2)
+ d3.geom_xpos = wp.array(
+ np.vstack(
+ [np.expand_dims(mjd1.geom_xpos, axis=0), np.expand_dims(mjd2.geom_xpos, axis=0)]
+ ),
+ dtype=wp.vec3,
+ )
+
+ collision_driver.nxn_broadphase(m, d3)
+ np.testing.assert_allclose(d3.ncollision.numpy()[0], 4)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[0][0], 0)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[0][1], 1)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[1][0], 0)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[1][1], 1)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[2][0], 0)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[2][1], 2)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[3][0], 1)
+ np.testing.assert_allclose(d3.collision_pair.numpy()[3][1], 2)
+
+ # one world and zero collisions: contype and conaffinity incompatibility
+ _, _, m4, d4 = _load_from_string(_NXN_MODEL, keyframe=1)
+ m4.geom_contype = wp.array(np.array([0, 0, 0]), dtype=wp.int32)
+ m4.geom_conaffinity = wp.array(np.array([1, 1, 1]), dtype=wp.int32)
+ collision_driver.nxn_broadphase(m4, d4)
+ np.testing.assert_allclose(d4.ncollision.numpy()[0], 0)
+
+ # one world and one collision: geomtype ordering
+ _, _, _, d5 = _load_from_string(_NXN_MODEL, keyframe=3)
+ collision_driver.nxn_broadphase(m, d5)
+ np.testing.assert_allclose(d5.ncollision.numpy()[0], 1)
+ np.testing.assert_allclose(d5.collision_pair.numpy()[0][0], 3)
+ np.testing.assert_allclose(d5.collision_pair.numpy()[0][1], 2)
+
+ # TODO(team): test margin
+ # TODO(team): test DisableBit.FILTERPARENT
+
+
+if __name__ == "__main__":
+ wp.init()
+ absltest.main()
diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py
new file mode 100644
index 00000000..28f300d6
--- /dev/null
+++ b/mujoco_warp/_src/collision_convex.py
@@ -0,0 +1,26 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import warp as wp
+
+from .types import Data
+from .types import Model
+
+# XXX disable backward pass codegen globally for now
+wp.config.enable_backward = False
+
+
+def narrowphase(m: Model, d: Data):
+ pass
diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py
new file mode 100644
index 00000000..4f5f7afe
--- /dev/null
+++ b/mujoco_warp/_src/collision_driver.py
@@ -0,0 +1,551 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import warp as wp
+
+from .types import MJ_MINVAL
+from .types import Data
+from .types import DisableBit
+from .types import Model
+from .types import vec5
+from .warp_util import event_scope
+
+
+@wp.func
+def _geom_filter(m: Model, geom1: int, geom2: int, filterparent: bool) -> bool:
+ bodyid1 = m.geom_bodyid[geom1]
+ bodyid2 = m.geom_bodyid[geom2]
+ contype1 = m.geom_contype[geom1]
+ contype2 = m.geom_contype[geom2]
+ conaffinity1 = m.geom_conaffinity[geom1]
+ conaffinity2 = m.geom_conaffinity[geom2]
+ weldid1 = m.body_weldid[bodyid1]
+ weldid2 = m.body_weldid[bodyid2]
+ weld_parentid1 = m.body_weldid[m.body_parentid[weldid1]]
+ weld_parentid2 = m.body_weldid[m.body_parentid[weldid2]]
+
+ self_collision = weldid1 == weldid2
+ parent_child_collision = (
+ filterparent
+ and (weldid1 != 0)
+ and (weldid2 != 0)
+ and ((weldid1 == weld_parentid2) or (weldid2 == weld_parentid1))
+ )
+ mask = (contype1 & conaffinity2) or (contype2 & conaffinity1)
+
+ return mask and (not self_collision) and (not parent_child_collision)
+
+
+@wp.func
+def _add_geom_pair(m: Model, d: Data, geom1: int, geom2: int, worldid: int):
+ pairid = wp.atomic_add(d.ncollision, 0, 1)
+
+ if pairid >= d.nconmax:
+ return
+
+ type1 = m.geom_type[geom1]
+ type2 = m.geom_type[geom2]
+
+ if type1 > type2:
+ pair = wp.vec2i(geom2, geom1)
+ else:
+ pair = wp.vec2i(geom1, geom2)
+
+ d.collision_pair[pairid] = pair
+ d.collision_worldid[pairid] = worldid
+
+
+@wp.kernel
+def broadphase_project_spheres_onto_sweep_direction_kernel(
+ m: Model,
+ d: Data,
+ direction: wp.vec3,
+):
+ worldId, i = wp.tid()
+
+ c = d.geom_xpos[worldId, i]
+ r = m.geom_rbound[i]
+ if r == 0.0:
+ # current geom is a plane
+ r = 1000000000.0
+ sphere_radius = r + m.geom_margin[i]
+
+ center = wp.dot(direction, c)
+ f = center - sphere_radius
+
+ # Store results in the data arrays
+ d.sap_projection_lower[worldId, i] = f
+ d.sap_projection_upper[worldId, i] = center + sphere_radius
+ d.sap_sort_index[worldId, i] = i
+
+
+# Define constants for plane types
+PLANE_ZERO_OFFSET = -1.0
+PLANE_NEGATIVE_OFFSET = -2.0
+PLANE_POSITIVE_OFFSET = -3.0
+
+
+@wp.func
+def encode_plane(normal: wp.vec3, point_on_plane: wp.vec3, margin: float) -> wp.vec4:
+ normal = wp.normalize(normal)
+ plane_offset = -wp.dot(normal, point_on_plane + normal * margin)
+
+ # Scale factor for the normal
+ scale = wp.abs(plane_offset)
+
+ # Handle special cases
+ if wp.abs(plane_offset) < 1e-6:
+ return wp.vec4(normal.x, normal.y, normal.z, PLANE_ZERO_OFFSET)
+ elif plane_offset < 0.0:
+ return wp.vec4(
+ scale * normal.x, scale * normal.y, scale * normal.z, PLANE_NEGATIVE_OFFSET
+ )
+ else:
+ return wp.vec4(
+ scale * normal.x, scale * normal.y, scale * normal.z, PLANE_POSITIVE_OFFSET
+ )
+
+
+@wp.func
+def decode_plane(encoded: wp.vec4) -> wp.vec4:
+ magnitude = wp.length(encoded)
+ normal = wp.normalize(xyz(encoded))
+
+ if encoded.w == PLANE_ZERO_OFFSET:
+ return wp.vec4(normal.x, normal.y, normal.z, 0.0)
+ elif encoded.w == PLANE_NEGATIVE_OFFSET:
+ return wp.vec4(normal.x, normal.y, normal.z, -magnitude)
+ else:
+ return wp.vec4(normal.x, normal.y, normal.z, magnitude)
+
+
+@wp.kernel
+def reorder_bounding_spheres_kernel(
+ m: Model,
+ d: Data,
+):
+ worldId, i = wp.tid()
+
+ # Get the index from the data indexer
+ mapped = d.sap_sort_index[worldId, i]
+
+ # Get the bounding volume
+ c = d.geom_xpos[worldId, mapped]
+ r = m.geom_rbound[mapped]
+ margin = m.geom_margin[mapped]
+
+ # Reorder the box into the sorted array
+ if r == 0.0:
+ # store the plane equation
+ xmat = d.geom_xmat[worldId, mapped]
+ plane_normal = wp.vec3(xmat[0, 2], xmat[1, 2], xmat[2, 2])
+ d.sap_geom_sort[worldId, i] = encode_plane(
+ plane_normal, c, margin
+ ) # negative w component is used to disginguish planes from spheres
+ else:
+ d.sap_geom_sort[worldId, i] = wp.vec4(c.x, c.y, c.z, r + margin)
+
+
+@wp.func
+def xyz(v: wp.vec4) -> wp.vec3:
+ return wp.vec3(v.x, v.y, v.z)
+
+
+@wp.func
+def signed_distance_point_plane(point: wp.vec3, plane: wp.vec4) -> float:
+ return wp.dot(point, xyz(plane)) + plane.w
+
+
+@wp.func
+def overlap(
+ world_id: int,
+ a: int,
+ b: int,
+ spheres_or_planes: wp.array(dtype=wp.vec4, ndim=2),
+) -> bool:
+ # Extract centers and sizes
+ s_a = spheres_or_planes[world_id, a]
+ s_b = spheres_or_planes[world_id, b]
+
+ if s_a.w < 0.0 and s_b.w < 0.0:
+ # both are planes
+ return False
+ elif s_a.w < 0.0 or s_b.w < 0.0:
+ if s_b.w < 0.0: # swap if required such that s_a is always a plane
+ tmp = s_a
+ s_a = s_b
+ s_b = tmp
+ s_a = decode_plane(s_a)
+ dist = signed_distance_point_plane(xyz(s_b), s_a)
+ return dist <= s_b.w
+ else:
+ # geoms are spheres
+ delta = xyz(s_a) - xyz(s_b)
+ dist_sq = wp.dot(delta, delta)
+ radius_sum = s_a.w + s_b.w
+ return dist_sq <= radius_sum * radius_sum
+
+
+@wp.func
+def find_first_greater_than(
+ worldId: int,
+ starts: wp.array(dtype=wp.float32, ndim=2),
+ value: wp.float32,
+ low: int,
+ high: int,
+) -> int:
+ while low < high:
+ mid = (low + high) >> 1
+ if starts[worldId, mid] > value:
+ high = mid
+ else:
+ low = mid + 1
+ return low
+
+
+@wp.kernel
+def sap_broadphase_prepare_kernel(
+ m: Model,
+ d: Data,
+):
+ worldId, i = wp.tid() # Get the thread ID
+
+ # Get the index of the current bounding box
+ idx1 = d.sap_sort_index[worldId, i]
+
+ end = d.sap_projection_upper[worldId, idx1]
+ limit = find_first_greater_than(worldId, d.sap_projection_lower, end, i + 1, m.ngeom)
+ limit = wp.min(m.ngeom - 1, limit)
+
+ # Calculate the range of boxes for the sweep and prune process
+ count = limit - i
+
+ # Store the cumulative sum for the current box
+ d.sap_range[worldId, i] = count
+
+
+@wp.func
+def find_right_most_index_int(
+ starts: wp.array(dtype=wp.int32, ndim=1), value: wp.int32, low: int, high: int
+) -> int:
+ while low < high:
+ mid = (low + high) >> 1
+ if starts[mid] > value:
+ high = mid
+ else:
+ low = mid + 1
+ return high
+
+
+@wp.func
+def find_indices(
+ id: int, cumulative_sum: wp.array(dtype=wp.int32, ndim=1), length: int
+) -> wp.vec2i:
+ # Perform binary search to find the right most index
+ i = find_right_most_index_int(cumulative_sum, id, 0, length)
+
+ # Get the baseId, and compute the offset and j
+ if i > 0:
+ base_id = cumulative_sum[i - 1]
+ else:
+ base_id = 0
+ offset = id - base_id
+ j = i + offset + 1
+
+ return wp.vec2i(i, j)
+
+
+@wp.kernel
+def sap_broadphase_kernel(m: Model, d: Data, num_threads: int, filter_parent: bool):
+ threadId = wp.tid() # Get thread ID
+ if d.sap_cumulative_sum.shape[0] > 0:
+ total_num_work_packages = d.sap_cumulative_sum[d.sap_cumulative_sum.shape[0] - 1]
+ else:
+ total_num_work_packages = 0
+
+ while threadId < total_num_work_packages:
+ # Get indices for current and next box pair
+ ij = find_indices(threadId, d.sap_cumulative_sum, d.sap_cumulative_sum.shape[0])
+ i = ij.x
+ j = ij.y
+
+ worldId = i // m.ngeom
+ i = i % m.ngeom
+ j = j % m.ngeom
+
+ # geom index
+ idx1 = d.sap_sort_index[worldId, i]
+ idx2 = d.sap_sort_index[worldId, j]
+
+ if not _geom_filter(m, idx1, idx2, filter_parent):
+ threadId += num_threads
+ continue
+
+ # Check if the boxes overlap
+ if overlap(worldId, i, j, d.sap_geom_sort):
+ _add_geom_pair(m, d, idx1, idx2, worldId)
+
+ threadId += num_threads
+
+
+@wp.kernel
+def get_contact_solver_params_kernel(
+ m: Model,
+ d: Data,
+):
+ tid = wp.tid()
+
+ n_contact_pts = d.ncon[0]
+ if tid >= n_contact_pts:
+ return
+
+ geoms = d.contact.geom[tid]
+ g1 = geoms.x
+ g2 = geoms.y
+
+ margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
+ gap = wp.max(m.geom_gap[g1], m.geom_gap[g2])
+ solmix1 = m.geom_solmix[g1]
+ solmix2 = m.geom_solmix[g2]
+ mix = solmix1 / (solmix1 + solmix2)
+ mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix)
+ mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 >= MJ_MINVAL), 0.0, mix)
+ mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix)
+
+ p1 = m.geom_priority[g1]
+ p2 = m.geom_priority[g2]
+ mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0))
+
+ condim1 = m.geom_condim[g1]
+ condim2 = m.geom_condim[g2]
+ condim = wp.where(
+ p1 == p2, wp.max(condim1, condim2), wp.where(p1 > p2, condim1, condim2)
+ )
+ d.contact.dim[tid] = condim
+
+ if m.geom_solref[g1].x > 0.0 and m.geom_solref[g2].x > 0.0:
+ d.contact.solref[tid] = mix * m.geom_solref[g1] + (1.0 - mix) * m.geom_solref[g2]
+ else:
+ d.contact.solref[tid] = wp.min(m.geom_solref[g1], m.geom_solref[g2])
+ d.contact.includemargin[tid] = margin - gap
+ friction_ = wp.max(m.geom_friction[g1], m.geom_friction[g2])
+ friction5 = vec5(friction_[0], friction_[0], friction_[1], friction_[2], friction_[2])
+ d.contact.friction[tid] = friction5
+ d.contact.solimp[tid] = mix * m.geom_solimp[g1] + (1.0 - mix) * m.geom_solimp[g2]
+
+
+def sap_broadphase(m: Model, d: Data):
+ """Broadphase collision detection via sweep-and-prune."""
+
+ # Use random fixed direction vector for now
+ direction = wp.vec3(0.5935, 0.7790, 0.1235)
+ direction = wp.normalize(direction)
+
+ wp.launch(
+ kernel=broadphase_project_spheres_onto_sweep_direction_kernel,
+ dim=(d.nworld, m.ngeom),
+ inputs=[m, d, direction],
+ )
+
+ tile_sort_available = False
+ segmented_sort_available = hasattr(wp.utils, "segmented_sort_pairs")
+
+ if tile_sort_available:
+ segmented_sort_kernel = create_segmented_sort_kernel(m.ngeom)
+ wp.launch_tiled(
+ kernel=segmented_sort_kernel, dim=(d.nworld), inputs=[m, d], block_dim=128
+ )
+ print("tile sort available")
+ elif segmented_sort_available:
+ wp.utils.segmented_sort_pairs(
+ d.sap_projection_lower,
+ d.sap_sort_index,
+ m.ngeom * d.nworld,
+ d.sap_segment_index,
+ )
+ else:
+ # Sort each world's segment separately
+ for world_id in range(d.nworld):
+ start_idx = world_id * m.ngeom
+
+ # Create temporary arrays for sorting
+ temp_box_projections_lower = wp.zeros(
+ m.ngeom * 2,
+ dtype=d.sap_projection_lower.dtype,
+ )
+ temp_box_sorting_indexer = wp.zeros(
+ m.ngeom * 2,
+ dtype=d.sap_sort_index.dtype,
+ )
+
+ # Copy data to temporary arrays
+ wp.copy(
+ temp_box_projections_lower,
+ d.sap_projection_lower,
+ 0,
+ start_idx,
+ m.ngeom,
+ )
+ wp.copy(
+ temp_box_sorting_indexer,
+ d.sap_sort_index,
+ 0,
+ start_idx,
+ m.ngeom,
+ )
+
+ # Sort the temporary arrays
+ wp.utils.radix_sort_pairs(
+ temp_box_projections_lower, temp_box_sorting_indexer, m.ngeom
+ )
+
+ # Copy sorted data back
+ wp.copy(
+ d.sap_projection_lower,
+ temp_box_projections_lower,
+ start_idx,
+ 0,
+ m.ngeom,
+ )
+ wp.copy(
+ d.sap_sort_index,
+ temp_box_sorting_indexer,
+ start_idx,
+ 0,
+ m.ngeom,
+ )
+
+ wp.launch(
+ kernel=reorder_bounding_spheres_kernel,
+ dim=(d.nworld, m.ngeom),
+ inputs=[m, d],
+ )
+
+ wp.launch(
+ kernel=sap_broadphase_prepare_kernel,
+ dim=(d.nworld, m.ngeom),
+ inputs=[m, d],
+ )
+
+ # The scan (scan = cumulative sum, either inclusive or exclusive depending on the last argument) is used for load balancing among the threads
+ wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True)
+
+ # Estimate how many overlap checks need to be done - assumes each box has to be compared to 5 other boxes (and batched over all worlds)
+ num_sweep_threads = 5 * d.nworld * m.ngeom
+ filter_parent = not m.opt.disableflags & DisableBit.FILTERPARENT.value
+ wp.launch(
+ kernel=sap_broadphase_kernel,
+ dim=num_sweep_threads,
+ inputs=[m, d, num_sweep_threads, filter_parent],
+ )
+
+ return d
+
+
+def nxn_broadphase(m: Model, d: Data):
+ """Broadphase collision detective via brute-force search."""
+ filterparent = not (m.opt.disableflags & DisableBit.FILTERPARENT.value)
+
+ @wp.kernel
+ def _nxn_broadphase(m: Model, d: Data):
+ worldid, elementid = wp.tid()
+ geom1 = (
+ m.ngeom
+ - 2
+ - int(
+ wp.sqrt(float(-8 * elementid + 4 * m.ngeom * (m.ngeom - 1) - 7)) / 2.0 - 0.5
+ )
+ )
+ geom2 = (
+ elementid
+ + geom1
+ + 1
+ - m.ngeom * (m.ngeom - 1) // 2
+ + (m.ngeom - geom1) * ((m.ngeom - geom1) - 1) // 2
+ )
+
+ margin1 = m.geom_margin[geom1]
+ margin2 = m.geom_margin[geom2]
+ pos1 = d.geom_xpos[worldid, geom1]
+ pos2 = d.geom_xpos[worldid, geom2]
+ size1 = m.geom_rbound[geom1]
+ size2 = m.geom_rbound[geom2]
+
+ bound = size1 + size2 + wp.max(margin1, margin2)
+ dif = pos2 - pos1
+
+ if size1 != 0.0 and size2 != 0.0:
+ # neither geom is a plane
+ dist_sq = wp.dot(dif, dif)
+ bounds_filter = dist_sq <= bound * bound
+ elif size1 == 0.0:
+ # geom1 is a plane
+ xmat1 = d.geom_xmat[worldid, geom1]
+ dist = wp.dot(dif, wp.vec3(xmat1[0, 2], xmat1[1, 2], xmat1[2, 2]))
+ bounds_filter = dist <= bound
+ else:
+ # geom2 is a plane
+ xmat2 = d.geom_xmat[worldid, geom2]
+ dist = wp.dot(-dif, wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2]))
+ bounds_filter = dist <= bound
+
+ geom_filter = _geom_filter(m, geom1, geom2, filterparent)
+
+ if bounds_filter and geom_filter:
+ _add_geom_pair(m, d, geom1, geom2, worldid)
+
+ wp.launch(
+ _nxn_broadphase, dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d]
+ )
+
+
+def get_contact_solver_params(m: Model, d: Data):
+ wp.launch(
+ get_contact_solver_params_kernel,
+ dim=[d.nconmax],
+ inputs=[m, d],
+ )
+
+ # TODO(team): do we need condim sorting, deepest penetrating contact here?
+
+
+@event_scope
+def collision(m: Model, d: Data):
+ """Collision detection."""
+
+ # AD: based on engine_collision_driver.py in Eric's warp fork/mjx-collisions-dev
+ # which is further based on the CUDA code here:
+ # https://github.com/btaba/mujoco/blob/warp-collisions/mjx/mujoco/mjx/_src/cuda/engine_collision_driver.cu.cc#L458-L583
+
+ d.ncollision.zero_()
+ d.ncon.zero_()
+
+ # TODO(team): determine ngeom to switch from n^2 to sap
+ if m.ngeom <= 100:
+ nxn_broadphase(m, d)
+ else:
+ sap_broadphase(m, d)
+
+ # XXX switch between collision functions and GJK/EPA here
+ if True:
+ from .collision_functions import narrowphase
+ else:
+ from .collision_convex import narrowphase
+
+ # TODO(team): should we limit per-world contact nubmers?
+ # TODO(team): we should reject far-away contacts in the narrowphase instead of constraint
+ # partitioning because we can move some pressure of the atomics
+ narrowphase(m, d)
+ get_contact_solver_params(m, d)
diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py
new file mode 100644
index 00000000..5c1b3a62
--- /dev/null
+++ b/mujoco_warp/_src/collision_driver_test.py
@@ -0,0 +1,121 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests the collision driver."""
+
+import mujoco
+import numpy as np
+from absl.testing import absltest
+from absl.testing import parameterized
+
+import mujoco_warp as mjwarp
+
+
+class PrimitiveTest(parameterized.TestCase):
+ """Tests the primitive contact functions."""
+
+ _MJCFS = {
+ "box_plane": """
+
+
+
+
+
+
+
+
+
+ """,
+ "plane_sphere": """
+
+
+
+
+
+
+
+
+
+ """,
+ "sphere_sphere": """
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ "capsule_capsule": """
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ "plane_capsule": """
+
+
+
+
+
+
+
+
+
+ """,
+ }
+
+ @parameterized.parameters(
+ "box_plane",
+ "plane_sphere",
+ "sphere_sphere",
+ "plane_capsule",
+ "capsule_capsule",
+ )
+ def test_contact(self, name):
+ """Tests contact calculation with different collision functions."""
+ m = mujoco.MjModel.from_xml_string(self._MJCFS[name])
+ d = mujoco.MjData(m)
+ mujoco.mj_forward(m, d)
+ mx = mjwarp.put_model(m)
+ dx = mjwarp.put_data(m, d)
+ mjwarp.collision(mx, dx)
+ mujoco.mj_collision(m, d)
+ self.assertEqual(d.ncon, dx.ncon.numpy()[0])
+ for i in range(d.ncon):
+ actual_dist = dx.contact.dist.numpy()[i]
+ actual_pos = dx.contact.pos.numpy()[i, :]
+ actual_frame = dx.contact.frame.numpy()[i].flatten()
+ np.testing.assert_array_almost_equal(actual_dist, d.contact.dist[i], 4)
+ np.testing.assert_array_almost_equal(actual_pos, d.contact.pos[i], 4)
+ np.testing.assert_array_almost_equal(actual_frame, d.contact.frame[i], 4)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/mujoco_warp/_src/collision_functions.py b/mujoco_warp/_src/collision_functions.py
new file mode 100644
index 00000000..23a713ed
--- /dev/null
+++ b/mujoco_warp/_src/collision_functions.py
@@ -0,0 +1,821 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import math
+from typing import Any
+
+import numpy as np
+import warp as wp
+
+from .math import closest_segment_to_segment_points
+from .math import make_frame
+from .math import normalize_with_norm
+from .types import Data
+from .types import GeomType
+from .types import Model
+
+BOX_BOX_BLOCK_DIM = 32
+
+@wp.struct
+class Geom:
+ pos: wp.vec3
+ rot: wp.mat33
+ normal: wp.vec3
+ size: wp.vec3
+ # TODO(team): mesh fields: vertadr, vertnum
+
+
+@wp.func
+def _geom(
+ gid: int,
+ m: Model,
+ geom_xpos: wp.array(dtype=wp.vec3),
+ geom_xmat: wp.array(dtype=wp.mat33),
+) -> Geom:
+ geom = Geom()
+ geom.pos = geom_xpos[gid]
+ rot = geom_xmat[gid]
+ geom.rot = rot
+ geom.size = m.geom_size[gid]
+ geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane
+
+ return geom
+
+
+@wp.func
+def write_contact(
+ d: Data,
+ dist: float,
+ pos: wp.vec3,
+ frame: wp.mat33,
+ margin: float,
+ geoms: wp.vec2i,
+ worldid: int,
+):
+ active = (dist - margin) < 0
+ if active:
+ index = wp.atomic_add(d.ncon, 0, 1)
+ if index < d.nconmax:
+ d.contact.dist[index] = dist
+ d.contact.pos[index] = pos
+ d.contact.frame[index] = frame
+ d.contact.geom[index] = geoms
+ d.contact.worldid[index] = worldid
+
+
+@wp.func
+def _plane_sphere(
+ plane_normal: wp.vec3, plane_pos: wp.vec3, sphere_pos: wp.vec3, sphere_radius: float
+):
+ dist = wp.dot(sphere_pos - plane_pos, plane_normal) - sphere_radius
+ pos = sphere_pos - plane_normal * (sphere_radius + 0.5 * dist)
+ return dist, pos
+
+
+@wp.func
+def plane_sphere(
+ plane: Geom,
+ sphere: Geom,
+ worldid: int,
+ d: Data,
+ margin: float,
+ geom_indices: wp.vec2i,
+):
+ dist, pos = _plane_sphere(plane.normal, plane.pos, sphere.pos, sphere.size[0])
+
+ write_contact(d, dist, pos, make_frame(plane.normal), margin, geom_indices, worldid)
+
+
+@wp.func
+def _sphere_sphere(
+ pos1: wp.vec3,
+ radius1: float,
+ pos2: wp.vec3,
+ radius2: float,
+ worldid: int,
+ d: Data,
+ margin: float,
+ geom_indices: wp.vec2i,
+):
+ dir = pos2 - pos1
+ dist = wp.length(dir)
+ if dist == 0.0:
+ n = wp.vec3(1.0, 0.0, 0.0)
+ else:
+ n = dir / dist
+ dist = dist - (radius1 + radius2)
+ pos = pos1 + n * (radius1 + 0.5 * dist)
+
+ write_contact(d, dist, pos, make_frame(n), margin, geom_indices, worldid)
+
+
+@wp.func
+def sphere_sphere(
+ sphere1: Geom,
+ sphere2: Geom,
+ worldid: int,
+ d: Data,
+ margin: float,
+ geom_indices: wp.vec2i,
+):
+ _sphere_sphere(
+ sphere1.pos,
+ sphere1.size[0],
+ sphere2.pos,
+ sphere2.size[0],
+ worldid,
+ d,
+ margin,
+ geom_indices,
+ )
+
+
+@wp.func
+def capsule_capsule(
+ cap1: Geom,
+ cap2: Geom,
+ worldid: int,
+ d: Data,
+ margin: float,
+ geom_indices: wp.vec2i,
+):
+ axis1 = wp.vec3(cap1.rot[0, 2], cap1.rot[1, 2], cap1.rot[2, 2])
+ axis2 = wp.vec3(cap2.rot[0, 2], cap2.rot[1, 2], cap2.rot[2, 2])
+ length1 = cap1.size[1]
+ length2 = cap2.size[1]
+ seg1 = axis1 * length1
+ seg2 = axis2 * length2
+
+ pt1, pt2 = closest_segment_to_segment_points(
+ cap1.pos - seg1,
+ cap1.pos + seg1,
+ cap2.pos - seg2,
+ cap2.pos + seg2,
+ )
+
+ _sphere_sphere(pt1, cap1.size[0], pt2, cap2.size[0], worldid, d, margin, geom_indices)
+
+
+@wp.func
+def plane_capsule(
+ plane: Geom,
+ cap: Geom,
+ worldid: int,
+ d: Data,
+ margin: float,
+ geom_indices: wp.vec2i,
+):
+ """Calculates two contacts between a capsule and a plane."""
+ n = plane.normal
+ axis = wp.vec3(cap.rot[0, 2], cap.rot[1, 2], cap.rot[2, 2])
+ # align contact frames with capsule axis
+ b, b_norm = normalize_with_norm(axis - n * wp.dot(n, axis))
+
+ if b_norm < 0.5:
+ if -0.5 < n[1] and n[1] < 0.5:
+ b = wp.vec3(0.0, 1.0, 0.0)
+ else:
+ b = wp.vec3(0.0, 0.0, 1.0)
+
+ c = wp.cross(n, b)
+ frame = wp.mat33(n[0], n[1], n[2], b[0], b[1], b[2], c[0], c[1], c[2])
+ segment = axis * cap.size[1]
+
+ dist1, pos1 = _plane_sphere(n, plane.pos, cap.pos + segment, cap.size[0])
+ write_contact(d, dist1, pos1, frame, margin, geom_indices, worldid)
+
+ dist2, pos2 = _plane_sphere(n, plane.pos, cap.pos - segment, cap.size[0])
+ write_contact(d, dist2, pos2, frame, margin, geom_indices, worldid)
+
+
+HUGE_VAL = 1e6
+TINY_VAL = 1e-6
+
+
+class vec16b(wp.types.vector(length=16, dtype=wp.int8)):
+ pass
+
+
+class mat43f(wp.types.matrix(shape=(4, 3), dtype=wp.float32)):
+ pass
+
+
+class mat83f(wp.types.matrix(shape=(8, 3), dtype=wp.float32)):
+ pass
+
+
+class mat16_3f(wp.types.matrix(shape=(16, 3), dtype=wp.float32)):
+ pass
+
+
+Box = mat83f
+
+
+@wp.func
+def _argmin(a: Any) -> wp.int32:
+ amin = wp.int32(0)
+ vmin = wp.float32(a[0])
+ for i in range(1, len(a)):
+ if a[i] < vmin:
+ amin = i
+ vmin = a[i]
+ return amin
+
+
+@wp.func
+def box_normals(i: int) -> wp.vec3:
+ direction = wp.where(i < 3, -1.0, 1.0)
+ mod = i % 3
+ if mod == 0:
+ return wp.vec3(0.0, direction, 0.0)
+ if mod == 1:
+ return wp.vec3(0.0, 0.0, direction)
+ return wp.vec3(-direction, 0.0, 0.0)
+
+
+@wp.func
+def box(R: wp.mat33, t: wp.vec3, geom_size: wp.vec3) -> Box:
+ """Get a transformed box"""
+ x = geom_size[0]
+ y = geom_size[1]
+ z = geom_size[2]
+ m = Box()
+ for i in range(8):
+ ix = wp.where(i & 4, x, -x)
+ iy = wp.where(i & 2, y, -y)
+ iz = wp.where(i & 1, z, -z)
+ m[i] = R @ wp.vec3(ix, iy, iz) + t
+ return m
+
+
+@wp.func
+def box_face_verts(box: Box, idx: wp.int32) -> mat43f:
+ """Get the quad corresponding to a box face"""
+ if idx == 0:
+ verts = wp.vec4i(0, 4, 5, 1)
+ if idx == 1:
+ verts = wp.vec4i(0, 2, 6, 4)
+ if idx == 2:
+ verts = wp.vec4i(6, 7, 5, 4)
+ if idx == 3:
+ verts = wp.vec4i(2, 3, 7, 6)
+ if idx == 4:
+ verts = wp.vec4i(1, 5, 7, 3)
+ if idx == 5:
+ verts = wp.vec4i(0, 1, 3, 2)
+
+ m = mat43f()
+ for i in range(4):
+ m[i] = box[verts[i]]
+ return m
+
+
+@wp.func
+def get_box_axis(
+ axis_idx: int,
+ R: wp.mat33,
+):
+ """Get the axis at index axis_idx.
+ R: rotation matrix from a to b
+ Axes 0-12 are face normals of boxes a & b
+ Axes 12-21 are edge cross products."""
+ if axis_idx < 6: # a faces
+ axis = R @ wp.vec3(box_normals(axis_idx))
+ is_degenerate = False
+ elif axis_idx < 12: # b faces
+ axis = wp.vec3(box_normals(axis_idx - 6))
+ is_degenerate = False
+ else: # edges cross products
+ assert axis_idx < 21
+ edges = axis_idx - 12
+ axis_a, axis_b = edges / 3, edges % 3
+ edge_a = wp.transpose(R)[axis_a]
+ if axis_b == 0:
+ axis = wp.vec3(0.0, -edge_a[2], edge_a[1])
+ elif axis_b == 1:
+ axis = wp.vec3(edge_a[2], 0.0, -edge_a[0])
+ else:
+ axis = wp.vec3(-edge_a[1], edge_a[0], 0.0)
+ is_degenerate = wp.length_sq(axis) < TINY_VAL
+ return wp.normalize(axis), is_degenerate
+
+
+@wp.func
+def get_box_axis_support(
+ axis: wp.vec3, degenerate_axis: bool, a: Box, b: Box
+):
+ """Get the overlap (or separating distance if negative) along `axis`, and the sign."""
+ axis_d = wp.vec3d(axis)
+ support_a_max, support_b_max = wp.float32(-HUGE_VAL), wp.float32(-HUGE_VAL)
+ support_a_min, support_b_min = wp.float32(HUGE_VAL), wp.float32(HUGE_VAL)
+ for i in range(8):
+ vert_a = wp.vec3d(a[i])
+ vert_b = wp.vec3d(b[i])
+ proj_a = wp.float32(wp.dot(vert_a, axis_d))
+ proj_b = wp.float32(wp.dot(vert_b, axis_d))
+ support_a_max = wp.max(support_a_max, proj_a)
+ support_b_max = wp.max(support_b_max, proj_b)
+ support_a_min = wp.min(support_a_min, proj_a)
+ support_b_min = wp.min(support_b_min, proj_b)
+ dist1 = support_a_max - support_b_min
+ dist2 = support_b_max - support_a_min
+ dist = wp.where(degenerate_axis, HUGE_VAL, wp.min(dist1, dist2))
+ sign = wp.where(dist1 > dist2, -1, 1)
+ return dist, sign
+
+
+@wp.struct
+class AxisSupport:
+ best_dist: wp.float32
+ best_sign: wp.int8
+ best_idx: wp.int8
+
+
+@wp.func
+def reduce_axis_support(a: AxisSupport, b: AxisSupport):
+ return wp.where(a.best_dist > b.best_dist, b, a)
+
+
+@wp.func
+def face_axis_alignment(a: wp.vec3, R: wp.mat33) -> wp.int32:
+ """Find the box faces most aligned with the axis `a`"""
+ max_dot = wp.float32(0.0)
+ max_idx = wp.int32(0)
+ for i in range(6):
+ d = wp.dot(R @ box_normals(i), a)
+ if d > max_dot:
+ max_dot = d
+ max_idx = i
+ return max_idx
+
+
+@wp.kernel(enable_backward=False)
+def box_box_kernel(
+ m: Model,
+ d: Data,
+ num_kernels: int,
+):
+ """Calculates contacts between pairs of boxes."""
+ tid, axis_idx = wp.tid()
+
+ for bp_idx in range(tid, min(d.ncollision[0], d.nconmax), num_kernels):
+ geoms = d.collision_pair[bp_idx]
+
+ ga, gb = geoms[0], geoms[1]
+
+ if m.geom_type[ga] != int(GeomType.BOX.value) or m.geom_type[gb] != int(GeomType.BOX.value):
+ continue
+
+ worldid = d.collision_worldid[bp_idx]
+ # transformations
+ a_pos, b_pos = d.geom_xpos[worldid, ga], d.geom_xpos[worldid, gb]
+ a_mat, b_mat = d.geom_xmat[worldid, ga], d.geom_xmat[worldid, gb]
+ b_mat_inv = wp.transpose(b_mat)
+ trans_atob = b_mat_inv @ (a_pos - b_pos)
+ rot_atob = b_mat_inv @ a_mat
+
+ a_size = m.geom_size[ga]
+ b_size = m.geom_size[gb]
+ a = box(rot_atob, trans_atob, a_size)
+ b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size)
+
+ # box-box implementation
+
+ # Inlined def collision_axis_tiled( a: Box, b: Box, R: wp.mat33, axis_idx: wp.int32,):
+ # Finds the axis of minimum separation.
+ # a: Box a vertices, in frame b
+ # b: Box b vertices, in frame b
+ # R: rotation matrix from a to b
+ # Returns:
+ # best_axis: vec3
+ # best_sign: int32
+ # best_idx: int32
+ R = rot_atob
+
+ # launch tiled with block_dim=21
+ if axis_idx > 20:
+ continue
+
+ axis, degenerate_axis = get_box_axis(axis_idx, R)
+ axis_dist, axis_sign = get_box_axis_support(axis, degenerate_axis, a, b)
+
+ supports = wp.tile(AxisSupport(axis_dist, wp.int8(axis_sign), wp.int8(axis_idx)))
+
+ face_supports = wp.tile_view(supports, offset=(0,), shape=(12,))
+ edge_supports = wp.tile_view(supports, offset=(12,), shape=(9,))
+
+ face_supports_red = wp.tile_reduce(reduce_axis_support, face_supports)
+ edge_supports_red = wp.tile_reduce(reduce_axis_support, edge_supports)
+
+ face = face_supports_red[0]
+ edge = edge_supports_red[0]
+
+ if axis_idx > 0: # single thread
+ continue
+
+ # choose the best separating axis
+ face_axis, _ = get_box_axis(wp.int32(face.best_idx), R)
+ best_axis = wp.vec3(face_axis)
+ best_sign = wp.int32(face.best_sign)
+ best_idx = wp.int32(face.best_idx)
+ best_dist = wp.float32(face.best_dist)
+
+ if edge.best_dist < face.best_dist:
+ edge_axis, _ = get_box_axis(wp.int32(edge.best_idx), R)
+ if wp.abs(wp.dot(face_axis, edge_axis)) < 0.99:
+ best_axis = edge_axis
+ best_sign = wp.int32(edge.best_sign)
+ best_idx = wp.int32(edge.best_idx)
+ best_dist = wp.float32(edge.best_dist)
+ # end inlined collision_axis_tiled
+
+ # if axis_idx != 0:
+ # continue
+ if best_dist < 0:
+ continue
+
+ # get the (reference) face most aligned with the separating axis
+ a_max = face_axis_alignment(best_axis, rot_atob)
+ b_max = face_axis_alignment(best_axis, wp.identity(3, wp.float32))
+
+ sep_axis = wp.float32(best_sign) * best_axis
+
+ if best_sign > 0:
+ b_min = (b_max + 3) % 6
+ dist, pos = _create_contact_manifold(
+ box_face_verts(a, a_max),
+ rot_atob @ box_normals(a_max),
+ box_face_verts(b, b_min),
+ box_normals(b_min),
+ )
+ else:
+ a_min = (a_max + 3) % 6
+ dist, pos = _create_contact_manifold(
+ box_face_verts(b, b_max),
+ box_normals(b_max),
+ box_face_verts(a, a_min),
+ rot_atob @ box_normals(a_min),
+ )
+
+ # For edge contacts, we use the clipped face point, mainly for performance
+ # reasons. For small penetration, the clipped face point is roughly the edge
+ # contact point.
+ if best_idx > 11: # is_edge_contact
+ idx = _argmin(dist)
+ dist = wp.vec4f(dist[idx], 1.0, 1.0, 1.0)
+ for i in range(4):
+ pos[i] = pos[idx]
+
+ margin = wp.max(m.geom_margin[ga], m.geom_margin[gb])
+ for i in range(4):
+ pos_glob = b_mat @ pos[i] + b_pos
+ n_glob = b_mat @ sep_axis
+ write_contact(
+ d, dist[i], pos_glob, make_frame(n_glob), margin, wp.vec2i(ga, gb), worldid
+ )
+
+
+def box_box(
+ m: Model,
+ d: Data,
+):
+ """Calculates contacts between pairs of boxes."""
+ kernel_ratio = 16
+ num_threads = math.ceil(
+ d.nconmax / kernel_ratio
+ ) # parallel threads excluding tile dim
+ wp.launch_tiled(
+ kernel=box_box_kernel,
+ dim=num_threads,
+ inputs=[m, d, num_threads],
+ block_dim=BOX_BOX_BLOCK_DIM,
+ )
+
+
+@wp.func
+def _closest_segment_point_plane(
+ a: wp.vec3, b: wp.vec3, p0: wp.vec3, plane_normal: wp.vec3
+) -> wp.vec3:
+ """Gets the closest point between a line segment and a plane.
+
+ Args:
+ a: first line segment point
+ b: second line segment point
+ p0: point on plane
+ plane_normal: plane normal
+
+ Returns:
+ closest point between the line segment and the plane
+ """
+ # Parametrize a line segment as S(t) = a + t * (b - a), plug it into the plane
+ # equation dot(n, S(t)) - d = 0, then solve for t to get the line-plane
+ # intersection. We then clip t to be in [0, 1] to be on the line segment.
+ n = plane_normal
+ d = wp.dot(p0, n) # shortest distance from origin to plane
+ denom = wp.dot(n, (b - a))
+ t = (d - wp.dot(n, a)) / (denom + wp.where(denom == 0.0, TINY_VAL, 0.0))
+ t = wp.clamp(t, 0.0, 1.0)
+ segment_point = a + t * (b - a)
+
+ return segment_point
+
+
+@wp.func
+def _project_poly_onto_plane(
+ poly: Any,
+ poly_n: wp.vec3,
+ plane_n: wp.vec3,
+ plane_pt: wp.vec3,
+):
+ """Projects poly1 onto the poly2 plane along poly2's normal."""
+ d = wp.dot(plane_pt, plane_n)
+ denom = wp.dot(poly_n, plane_n)
+ qn_scaled = poly_n / (denom + wp.where(denom == 0.0, TINY_VAL, 0.0))
+
+ for i in range(len(poly)):
+ poly[i] = poly[i] + (d - wp.dot(poly[i], plane_n)) * qn_scaled
+ return poly
+
+
+@wp.func
+def _clip_edge_to_quad(
+ subject_poly: mat43f,
+ clipping_poly: mat43f,
+ clipping_normal: wp.vec3,
+):
+ p0 = mat43f()
+ p1 = mat43f()
+ mask = wp.vec4b()
+ for edge_idx in range(4):
+ subject_p0 = subject_poly[(edge_idx + 3) % 4]
+ subject_p1 = subject_poly[edge_idx]
+
+ any_both_in_front = wp.int32(0)
+ clipped0_dist_max = wp.float32(-HUGE_VAL)
+ clipped1_dist_max = wp.float32(-HUGE_VAL)
+ clipped_p0_distmax = wp.vec3(0.0)
+ clipped_p1_distmax = wp.vec3(0.0)
+
+ for clipping_edge_idx in range(4):
+ clipping_p0 = clipping_poly[(clipping_edge_idx + 3) % 4]
+ clipping_p1 = clipping_poly[clipping_edge_idx]
+ edge_normal = wp.cross(clipping_p1 - clipping_p0, clipping_normal)
+
+ p0_in_front = wp.dot(subject_p0 - clipping_p0, edge_normal) > TINY_VAL
+ p1_in_front = wp.dot(subject_p1 - clipping_p0, edge_normal) > TINY_VAL
+ candidate_clipped_p = _closest_segment_point_plane(
+ subject_p0, subject_p1, clipping_p1, edge_normal
+ )
+ clipped_p0 = wp.where(p0_in_front, candidate_clipped_p, subject_p0)
+ clipped_p1 = wp.where(p1_in_front, candidate_clipped_p, subject_p1)
+ clipped_dist_p0 = wp.dot(clipped_p0 - subject_p0, subject_p1 - subject_p0)
+ clipped_dist_p1 = wp.dot(clipped_p1 - subject_p1, subject_p0 - subject_p1)
+ any_both_in_front |= wp.int32(p0_in_front and p1_in_front)
+
+ if clipped_dist_p0 > clipped0_dist_max:
+ clipped0_dist_max = clipped_dist_p0
+ clipped_p0_distmax = clipped_p0
+
+ if clipped_dist_p1 > clipped1_dist_max:
+ clipped1_dist_max = clipped_dist_p1
+ clipped_p1_distmax = clipped_p1
+ new_p0 = wp.where(any_both_in_front, subject_p0, clipped_p0_distmax)
+ new_p1 = wp.where(any_both_in_front, subject_p1, clipped_p1_distmax)
+
+ mask_val = wp.int8(
+ wp.where(
+ wp.dot(subject_p0 - subject_p1, new_p0 - new_p1) < 0,
+ 0,
+ wp.int32(not any_both_in_front),
+ )
+ )
+
+ p0[edge_idx] = new_p0
+ p1[edge_idx] = new_p1
+ mask[edge_idx] = mask_val
+ return p0, p1, mask
+
+
+@wp.func
+def _clip_quad(
+ subject_quad: mat43f,
+ subject_normal: wp.vec3,
+ clipping_quad: mat43f,
+ clipping_normal: wp.vec3,
+):
+ """Clips a subject quad against a clipping quad.
+ Serial implementation.
+ """
+
+ subject_clipped_p0, subject_clipped_p1, subject_mask = _clip_edge_to_quad(
+ subject_quad, clipping_quad, clipping_normal
+ )
+ clipping_proj = _project_poly_onto_plane(
+ clipping_quad, clipping_normal, subject_normal, subject_quad[0]
+ )
+ clipping_clipped_p0, clipping_clipped_p1, clipping_mask = _clip_edge_to_quad(
+ clipping_proj, subject_quad, subject_normal
+ )
+
+ clipped = mat16_3f()
+ mask = vec16b()
+ for i in range(4):
+ clipped[i] = subject_clipped_p0[i]
+ clipped[i + 4] = clipping_clipped_p0[i]
+ clipped[i + 8] = subject_clipped_p1[i]
+ clipped[i + 12] = clipping_clipped_p1[i]
+ mask[i] = subject_mask[i]
+ mask[i + 4] = clipping_mask[i]
+ mask[i + 8] = subject_mask[i]
+ mask[i + 8 + 4] = clipping_mask[i]
+
+ return clipped, mask
+
+
+# TODO(ca): sparse, tiling variant for large point counts
+@wp.func
+def _manifold_points(
+ poly: Any,
+ mask: Any,
+ clipping_norm: wp.vec3,
+) -> wp.vec4b:
+ """Chooses four points on the polygon with approximately maximal area. Return the indices"""
+ n = len(poly)
+
+ a_idx = wp.int32(0)
+ a_mask = wp.int8(mask[0])
+ for i in range(n):
+ if mask[i] >= a_mask:
+ a_idx = i
+ a_mask = mask[i]
+ a = poly[a_idx]
+
+ b_idx = wp.int32(0)
+ b_dist = wp.float32(-HUGE_VAL)
+ for i in range(n):
+ dist = wp.length_sq(poly[i] - a) + wp.where(mask[i], 0.0, -HUGE_VAL)
+ if dist >= b_dist:
+ b_idx = i
+ b_dist = dist
+ b = poly[b_idx]
+
+ ab = wp.cross(clipping_norm, a - b)
+
+ c_idx = wp.int32(0)
+ c_dist = wp.float32(-HUGE_VAL)
+ for i in range(n):
+ ap = a - poly[i]
+ dist = wp.abs(wp.dot(ap, ab)) + wp.where(mask[i], 0.0, -HUGE_VAL)
+ if dist >= c_dist:
+ c_idx = i
+ c_dist = dist
+ c = poly[c_idx]
+
+ ac = wp.cross(clipping_norm, a - c)
+ bc = wp.cross(clipping_norm, b - c)
+
+ d_idx = wp.int32(0)
+ d_dist = wp.float32(-2.0 * HUGE_VAL)
+ for i in range(n):
+ ap = a - poly[i]
+ dist_ap = wp.abs(wp.dot(ap, ac)) + wp.where(mask[i], 0.0, -HUGE_VAL)
+ bp = b - poly[i]
+ dist_bp = wp.abs(wp.dot(bp, bc)) + wp.where(mask[i], 0.0, -HUGE_VAL)
+ if dist_ap + dist_bp >= d_dist:
+ d_idx = i
+ d_dist = dist_ap + dist_bp
+ d = poly[d_idx]
+ return wp.vec4b(wp.int8(a_idx), wp.int8(b_idx), wp.int8(c_idx), wp.int8(d_idx))
+
+
+@wp.func
+def _create_contact_manifold(
+ clipping_quad: mat43f,
+ clipping_normal: wp.vec3,
+ subject_quad: mat43f,
+ subject_normal: wp.vec3,
+):
+ # Clip the subject (incident) face onto the clipping (reference) face.
+ # The incident points are clipped points on the subject polygon.
+ incident, mask = _clip_quad(
+ subject_quad, subject_normal, clipping_quad, clipping_normal
+ )
+
+ clipping_normal_neg = -clipping_normal
+ d = wp.dot(clipping_quad[0], clipping_normal_neg) + TINY_VAL
+
+ for i in range(16):
+ if wp.dot(incident[i], clipping_normal_neg) < d:
+ mask[i] = wp.int8(0)
+
+ ref = _project_poly_onto_plane(
+ incident, clipping_normal, clipping_normal, clipping_quad[0]
+ )
+
+ # Choose four contact points.
+ best = _manifold_points(ref, mask, clipping_normal)
+ contact_pts = mat43f()
+ dist = wp.vec4f()
+
+ for i in range(4):
+ idx = wp.int32(best[i])
+ contact_pt = ref[idx]
+ contact_pts[i] = contact_pt
+ penetration_dir = incident[idx] - contact_pt
+ penetration = wp.dot(penetration_dir, clipping_normal)
+ dist[i] = wp.where(mask[idx], penetration, 1.0)
+
+ return dist, contact_pts
+
+
+@wp.func
+def plane_box(
+ plane: Geom,
+ box: Geom,
+ worldid: int,
+ d: Data,
+ margin: float,
+ geom_indices: wp.vec2i,
+):
+ count = int(0)
+ corner = wp.vec3()
+ dist = wp.dot(box.pos - plane.pos, plane.normal)
+
+ # test all corners, pick bottom 4
+ for i in range(8):
+ # get corner in local coordinates
+ corner.x = wp.where(i & 1, box.size.x, -box.size.x)
+ corner.y = wp.where(i & 2, box.size.y, -box.size.y)
+ corner.z = wp.where(i & 4, box.size.z, -box.size.z)
+
+ # get corner in global coordinates relative to box center
+ corner = box.rot * corner
+
+ # compute distance to plane, skip if too far or pointing up
+ ldist = wp.dot(plane.normal, corner)
+ if dist + ldist > margin or ldist > 0:
+ continue
+
+ cdist = dist + ldist
+ frame = make_frame(plane.normal)
+ pos = corner + box.pos + (plane.normal * cdist / -2.0)
+ write_contact(d, cdist, pos, frame, margin, geom_indices, worldid)
+ count += 1
+ if count >= 4:
+ break
+
+
+@wp.kernel
+def _narrowphase(
+ m: Model,
+ d: Data,
+):
+ tid = wp.tid()
+
+ if tid >= d.ncollision[0]:
+ return
+
+ geoms = d.collision_pair[tid]
+ worldid = d.collision_worldid[tid]
+
+ g1 = geoms[0]
+ g2 = geoms[1]
+ type1 = m.geom_type[g1]
+ type2 = m.geom_type[g2]
+
+ geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
+ geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid])
+
+ margin = wp.max(m.geom_margin[g1], m.geom_margin[g2])
+
+ # TODO(team): static loop unrolling to remove unnecessary branching
+ if type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.SPHERE.value):
+ plane_sphere(geom1, geom2, worldid, d, margin, geoms)
+ elif type1 == int(GeomType.SPHERE.value) and type2 == int(GeomType.SPHERE.value):
+ sphere_sphere(geom1, geom2, worldid, d, margin, geoms)
+ elif type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.CAPSULE.value):
+ plane_capsule(geom1, geom2, worldid, d, margin, geoms)
+ elif type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.BOX.value):
+ plane_box(geom1, geom2, worldid, d, margin, geoms)
+ elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.CAPSULE.value):
+ capsule_capsule(geom1, geom2, worldid, d, margin, geoms)
+
+
+def narrowphase(m: Model, d: Data):
+ # we need to figure out how to keep the overhead of this small - not launching anything
+ # for pair types without collisions, as well as updating the launch dimensions.
+ wp.launch(_narrowphase, dim=d.nconmax, inputs=[m, d])
+ box_box(m, d)
diff --git a/mujoco/mjx/_src/constraint.py b/mujoco_warp/_src/constraint.py
similarity index 86%
rename from mujoco/mjx/_src/constraint.py
rename to mujoco_warp/_src/constraint.py
index 6cfd2c6b..8e783aa1 100644
--- a/mujoco/mjx/_src/constraint.py
+++ b/mujoco_warp/_src/constraint.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,9 @@
# ==============================================================================
import warp as wp
+
from . import types
+from .warp_util import event_scope
@wp.func
@@ -53,22 +55,22 @@ def _update_efc_row(
# See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters
k = 1.0 / (dmax * dmax * timeconst * timeconst * dampratio * dampratio)
b = 2.0 / (dmax * timeconst)
- k = wp.select(solref[0] <= 0, k, -solref[0] / (dmax * dmax))
- b = wp.select(solref[1] <= 0, b, -solref[1] / dmax)
+ k = wp.where(solref[0] <= 0, -solref[0] / (dmax * dmax), k)
+ b = wp.where(solref[1] <= 0, -solref[1] / dmax, b)
imp_x = wp.abs(pos_imp) / width
imp_a = (1.0 / wp.pow(mid, power - 1.0)) * wp.pow(imp_x, power)
imp_b = 1.0 - (1.0 / wp.pow(1.0 - mid, power - 1.0)) * wp.pow(1.0 - imp_x, power)
- imp_y = wp.select(imp_x < mid, imp_b, imp_a)
+ imp_y = wp.where(imp_x < mid, imp_a, imp_b)
imp = dmin + imp_y * (dmax - dmin)
imp = wp.clamp(imp, dmin, dmax)
- imp = wp.select(imp_x > 1.0, imp, dmax)
+ imp = wp.where(imp_x > 1.0, dmax, imp)
# Update constraints
- d.efc_D[efcid] = 1.0 / wp.max(invweight * (1.0 - imp) / imp, types.MJ_MINVAL)
- d.efc_aref[efcid] = -k * imp * pos_aref - b * Jqvel
- d.efc_pos[efcid] = pos_aref + margin
- d.efc_margin[efcid] = margin
+ d.efc.D[efcid] = 1.0 / wp.max(invweight * (1.0 - imp) / imp, types.MJ_MINVAL)
+ d.efc.aref[efcid] = -k * imp * pos_aref - b * Jqvel
+ d.efc.pos[efcid] = pos_aref + margin
+ d.efc.margin[efcid] = margin
@wp.func
@@ -114,13 +116,13 @@ def _efc_limit_slide_hinge(
active = pos < 0
if active:
- efcid = wp.atomic_add(d.nefc_total, 0, 1)
- d.efc_worldid[efcid] = worldid
+ efcid = wp.atomic_add(d.nefc, 0, 1)
+ d.efc.worldid[efcid] = worldid
dofadr = m.jnt_dofadr[jntid]
J = float(dist_min < dist_max) * 2.0 - 1.0
- d.efc_J[efcid, dofadr] = J
+ d.efc.J[efcid, dofadr] = J
Jqvel = J * d.qvel[worldid, dofadr]
_update_efc_row(
@@ -147,7 +149,7 @@ def _efc_contact_pyramidal(
):
conid, dimid = wp.tid()
- if conid >= d.ncon_total[0]:
+ if conid >= d.ncon[0]:
return
if d.contact.dim[conid] != 3:
@@ -157,9 +159,9 @@ def _efc_contact_pyramidal(
active = pos < 0
if active:
- efcid = wp.atomic_add(d.nefc_total, 0, 1)
+ efcid = wp.atomic_add(d.nefc, 0, 1)
worldid = d.contact.worldid[conid]
- d.efc_worldid[efcid] = worldid
+ d.efc.worldid[efcid] = worldid
body1 = m.geom_bodyid[d.contact.geom[conid][0]]
body2 = m.geom_bodyid[d.contact.geom[conid][1]]
@@ -189,7 +191,7 @@ def _efc_contact_pyramidal(
else:
J = diff_0 - diff_i * d.contact.friction[conid][dimid2 - 1]
- d.efc_J[efcid, i] = J
+ d.efc.J[efcid, i] = J
Jqvel += J * d.qvel[worldid, i]
_update_efc_row(
@@ -208,12 +210,13 @@ def _efc_contact_pyramidal(
)
+@event_scope
def make_constraint(m: types.Model, d: types.Data):
"""Creates constraint jacobians and other supporting data."""
if not (m.opt.disableflags & types.DisableBit.CONSTRAINT.value):
- d.nefc_total.zero_()
- d.efc_J.zero_()
+ d.nefc.zero_()
+ d.efc.J.zero_()
refsafe = not m.opt.disableflags & types.DisableBit.REFSAFE.value
@@ -226,7 +229,10 @@ def make_constraint(m: types.Model, d: types.Data):
inputs=[m, d, refsafe],
)
- if m.opt.cone == types.ConeType.PYRAMIDAL.value:
+ if (
+ not (m.opt.disableflags & types.DisableBit.CONTACT.value)
+ and m.opt.cone == types.ConeType.PYRAMIDAL.value
+ ):
wp.launch(
_efc_contact_pyramidal,
dim=(d.nconmax, 4),
diff --git a/mujoco/mjx/_src/constraint_test.py b/mujoco_warp/_src/constraint_test.py
similarity index 71%
rename from mujoco/mjx/_src/constraint_test.py
rename to mujoco_warp/_src/constraint_test.py
index 575bafbe..e36aa552 100644
--- a/mujoco/mjx/_src/constraint_test.py
+++ b/mujoco_warp/_src/constraint_test.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,14 +15,16 @@
"""Tests for constraint functions."""
+import mujoco
+import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
+
+import mujoco_warp as mjwarp
+
from . import test_util
-import mujoco
-from mujoco import mjx
-import numpy as np
-# tolerance for difference between MuJoCo and MJX constraint calculations,
+# tolerance for difference between MuJoCo and MJWarp constraint calculations,
# mostly due to float precision
_TOLERANCE = 5e-5
@@ -43,15 +45,15 @@ def test_constraints(self):
mujoco.mj_resetDataKeyframe(mjm, mjd, key)
mujoco.mj_forward(mjm, mjd)
- m = mjx.put_model(mjm)
- d = mjx.put_data(mjm, mjd)
- mjx.make_constraint(m, d)
-
- _assert_eq(d.efc_J.numpy()[: mjd.nefc, :].reshape(-1), mjd.efc_J, "efc_J")
- _assert_eq(d.efc_D.numpy()[: mjd.nefc], mjd.efc_D, "efc_D")
- _assert_eq(d.efc_aref.numpy()[: mjd.nefc], mjd.efc_aref, "efc_aref")
- _assert_eq(d.efc_pos.numpy()[: mjd.nefc], mjd.efc_pos, "efc_pos")
- _assert_eq(d.efc_margin.numpy()[: mjd.nefc], mjd.efc_margin, "efc_margin")
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd)
+ mjwarp.make_constraint(m, d)
+
+ _assert_eq(d.efc.J.numpy()[: mjd.nefc, :].reshape(-1), mjd.efc_J, "efc_J")
+ _assert_eq(d.efc.D.numpy()[: mjd.nefc], mjd.efc_D, "efc_D")
+ _assert_eq(d.efc.aref.numpy()[: mjd.nefc], mjd.efc_aref, "efc_aref")
+ _assert_eq(d.efc.pos.numpy()[: mjd.nefc], mjd.efc_pos, "efc_pos")
+ _assert_eq(d.efc.margin.numpy()[: mjd.nefc], mjd.efc_margin, "efc_margin")
if __name__ == "__main__":
diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py
new file mode 100644
index 00000000..f18aef0a
--- /dev/null
+++ b/mujoco_warp/_src/forward.py
@@ -0,0 +1,695 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Optional
+
+import mujoco
+import warp as wp
+
+from . import collision_driver
+from . import constraint
+from . import math
+from . import passive
+from . import smooth
+from . import solver
+from .support import xfrc_accumulate
+from .types import MJ_MINVAL
+from .types import BiasType
+from .types import Data
+from .types import DisableBit
+from .types import DynType
+from .types import GainType
+from .types import JointType
+from .types import Model
+from .types import array2df
+from .types import array3df
+from .warp_util import event_scope
+from .warp_util import kernel
+from .warp_util import kernel_copy
+
+
+def _advance(
+ m: Model, d: Data, act_dot: wp.array, qacc: wp.array, qvel: Optional[wp.array] = None
+):
+ """Advance state and time given activation derivatives and acceleration."""
+
+ # TODO(team): can we assume static timesteps?
+
+ @kernel
+ def next_activation(
+ m: Model,
+ d: Data,
+ act_dot_in: array2df,
+ ):
+ worldId, actid = wp.tid()
+
+ # get the high/low range for each actuator state
+ limited = m.actuator_actlimited[actid]
+ range_low = wp.where(limited, m.actuator_actrange[actid][0], -wp.inf)
+ range_high = wp.where(limited, m.actuator_actrange[actid][1], wp.inf)
+
+ # get the actual actuation - skip if -1 (means stateless actuator)
+ act_adr = m.actuator_actadr[actid]
+ if act_adr == -1:
+ return
+
+ acts = d.act[worldId]
+ acts_dot = act_dot_in[worldId]
+
+ act = acts[act_adr]
+ act_dot = acts_dot[act_adr]
+
+ # check dynType
+ dyn_type = m.actuator_dyntype[actid]
+ dyn_prm = m.actuator_dynprm[actid][0]
+
+ # advance the actuation
+ if dyn_type == wp.static(DynType.FILTEREXACT.value):
+ tau = wp.where(dyn_prm < MJ_MINVAL, MJ_MINVAL, dyn_prm)
+ act = act + act_dot * tau * (1.0 - wp.exp(-m.opt.timestep / tau))
+ else:
+ act = act + act_dot * m.opt.timestep
+
+ # apply limits
+ wp.clamp(act, range_low, range_high)
+
+ acts[act_adr] = act
+
+ @kernel
+ def advance_velocities(m: Model, d: Data, qacc: array2df):
+ worldId, tid = wp.tid()
+ d.qvel[worldId, tid] = d.qvel[worldId, tid] + qacc[worldId, tid] * m.opt.timestep
+
+ @kernel
+ def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df):
+ worldId, jntid = wp.tid()
+
+ jnt_type = m.jnt_type[jntid]
+ qpos_adr = m.jnt_qposadr[jntid]
+ dof_adr = m.jnt_dofadr[jntid]
+ qpos = d.qpos[worldId]
+ qvel = qvel_in[worldId]
+
+ if jnt_type == wp.static(JointType.FREE.value):
+ qpos_pos = wp.vec3(qpos[qpos_adr], qpos[qpos_adr + 1], qpos[qpos_adr + 2])
+ qvel_lin = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2])
+
+ qpos_new = qpos_pos + m.opt.timestep * qvel_lin
+
+ qpos_quat = wp.quat(
+ qpos[qpos_adr + 3],
+ qpos[qpos_adr + 4],
+ qpos[qpos_adr + 5],
+ qpos[qpos_adr + 6],
+ )
+ qvel_ang = wp.vec3(qvel[dof_adr + 3], qvel[dof_adr + 4], qvel[dof_adr + 5])
+
+ qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep)
+
+ qpos[qpos_adr] = qpos_new[0]
+ qpos[qpos_adr + 1] = qpos_new[1]
+ qpos[qpos_adr + 2] = qpos_new[2]
+ qpos[qpos_adr + 3] = qpos_quat_new[0]
+ qpos[qpos_adr + 4] = qpos_quat_new[1]
+ qpos[qpos_adr + 5] = qpos_quat_new[2]
+ qpos[qpos_adr + 6] = qpos_quat_new[3]
+
+ elif jnt_type == wp.static(JointType.BALL.value): # ball joint
+ qpos_quat = wp.quat(
+ qpos[qpos_adr],
+ qpos[qpos_adr + 1],
+ qpos[qpos_adr + 2],
+ qpos[qpos_adr + 3],
+ )
+ qvel_ang = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2])
+
+ qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep)
+
+ qpos[qpos_adr] = qpos_quat_new[0]
+ qpos[qpos_adr + 1] = qpos_quat_new[1]
+ qpos[qpos_adr + 2] = qpos_quat_new[2]
+ qpos[qpos_adr + 3] = qpos_quat_new[3]
+
+ else: # if jnt_type in (JointType.HINGE, JointType.SLIDE):
+ qpos[qpos_adr] = qpos[qpos_adr] + m.opt.timestep * qvel[dof_adr]
+
+ # skip if no stateful actuators.
+ if m.na:
+ wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot])
+
+ wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc])
+
+ # advance positions with qvel if given, d.qvel otherwise (semi-implicit)
+ if qvel is not None:
+ qvel_in = qvel
+ else:
+ qvel_in = d.qvel
+
+ wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in])
+
+ d.time = d.time + m.opt.timestep
+
+
+@event_scope
+def euler(m: Model, d: Data):
+ """Euler integrator, semi-implicit in velocity."""
+
+ # integrate damping implicitly
+
+ def eulerdamp_sparse(m: Model, d: Data):
+ @kernel
+ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data):
+ worldId, tid = wp.tid()
+
+ dof_Madr = m.dof_Madr[tid]
+ d.qM_integration[worldId, 0, dof_Madr] += m.opt.timestep * m.dof_damping[tid]
+
+ d.qfrc_integration[worldId, tid] = (
+ d.qfrc_smooth[worldId, tid] + d.qfrc_constraint[worldId, tid]
+ )
+
+ kernel_copy(d.qM_integration, d.qM)
+ wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d])
+ smooth.factor_solve_i(
+ m,
+ d,
+ d.qM_integration,
+ d.qLD_integration,
+ d.qLDiagInv_integration,
+ d.qacc_integration,
+ d.qfrc_integration,
+ )
+
+ def eulerdamp_fused_dense(m: Model, d: Data):
+ def tile_eulerdamp(adr: int, size: int, tilesize: int):
+ @kernel
+ def eulerdamp(
+ m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int
+ ):
+ worldid, nodeid = wp.tid()
+ dofid = m.qLD_tile[leveladr + nodeid]
+ M_tile = wp.tile_load(
+ d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid)
+ )
+ damping_tile = wp.tile_load(damping, shape=(tilesize,), offset=(dofid,))
+ damping_scaled = damping_tile * m.opt.timestep
+ qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled)
+
+ qfrc_smooth_tile = wp.tile_load(
+ d.qfrc_smooth[worldid], shape=(tilesize,), offset=(dofid,)
+ )
+ qfrc_constraint_tile = wp.tile_load(
+ d.qfrc_constraint[worldid], shape=(tilesize,), offset=(dofid,)
+ )
+
+ qfrc_tile = qfrc_smooth_tile + qfrc_constraint_tile
+
+ L_tile = wp.tile_cholesky(qm_integration_tile)
+ qacc_tile = wp.tile_cholesky_solve(L_tile, qfrc_tile)
+ wp.tile_store(d.qacc_integration[worldid], qacc_tile, offset=(dofid))
+
+ wp.launch_tiled(
+ eulerdamp, dim=(d.nworld, size), inputs=[m, d, m.dof_damping, adr], block_dim=32
+ )
+
+ qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy()
+
+ for i in range(len(qLD_tileadr)):
+ beg = qLD_tileadr[i]
+ end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1]
+ tile_eulerdamp(beg, end - beg, int(qLD_tilesize[i]))
+
+ if not m.opt.disableflags & DisableBit.EULERDAMP.value:
+ if m.opt.is_sparse:
+ eulerdamp_sparse(m, d)
+ else:
+ eulerdamp_fused_dense(m, d)
+
+ _advance(m, d, d.act_dot, d.qacc_integration)
+ else:
+ _advance(m, d, d.act_dot, d.qacc)
+
+
+@event_scope
+def implicit(m: Model, d: Data):
+ """Integrates fully implicit in velocity."""
+
+ # optimization comments (AD)
+ # I went from small kernels for every step to a relatively big single
+ # kernel using tile API because it kept improving performance -
+ # 30M to 50M FPS on an A6000.
+ #
+ # The main benefit is reduced global memory roundtrips, but I assume
+ # there is also some benefit to loading data as early as possible.
+ #
+ # I further tried fusing in the cholesky factor/solve but the high
+ # storage requirements led to low occupancy and thus worse performance.
+ #
+ # The actuator_bias_gain_vel kernel could theoretically be fused in as well,
+ # but it's pretty clean straight-line code that loads a lot of data but
+ # only stores one array, so I think the benefit of keeping that one on-chip
+ # is likely not worth it compared to the compromises we're making with tile API.
+ # It would also need a different data layout for the biasprm/gainprm arrays
+ # to be tileable.
+
+ # assumptions
+ assert not m.opt.is_sparse # unsupported
+ # TODO(team): add sparse version
+
+ # compile-time constants
+ passive_enabled = not m.opt.disableflags & DisableBit.PASSIVE.value
+ actuation_enabled = (
+ not m.opt.disableflags & DisableBit.ACTUATION.value
+ ) and m.actuator_affine_bias_gain
+
+ @kernel
+ def actuator_bias_gain_vel(m: Model, d: Data):
+ worldid, actid = wp.tid()
+
+ bias_vel = 0.0
+ gain_vel = 0.0
+
+ actuator_biastype = m.actuator_biastype[actid]
+ actuator_gaintype = m.actuator_gaintype[actid]
+ actuator_dyntype = m.actuator_dyntype[actid]
+
+ if actuator_biastype == wp.static(BiasType.AFFINE.value):
+ bias_vel = m.actuator_biasprm[actid, 2]
+
+ if actuator_gaintype == wp.static(GainType.AFFINE.value):
+ gain_vel = m.actuator_gainprm[actid, 2]
+
+ ctrl = d.ctrl[worldid, actid]
+
+ if actuator_dyntype != wp.static(DynType.NONE.value):
+ ctrl = d.act[worldid, actid]
+
+ d.act_vel_integration[worldid, actid] = bias_vel + gain_vel * ctrl
+
+ def qderiv_actuator_damping_fused(
+ m: Model, d: Data, damping: wp.array(dtype=wp.float32)
+ ):
+ if actuation_enabled:
+ block_dim = 64
+ else:
+ block_dim = 256
+
+ @wp.func
+ def subtract_multiply(x: wp.float32, y: wp.float32):
+ return x - y * wp.static(m.opt.timestep)
+
+ def qderiv_actuator_damping_tiled(
+ adr: int, size: int, tilesize_nv: int, tilesize_nu: int
+ ):
+ @kernel
+ def qderiv_actuator_fused_kernel(
+ m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int
+ ):
+ worldid, nodeid = wp.tid()
+ offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid]
+
+ # skip tree with no actuators.
+ if wp.static(actuation_enabled and tilesize_nu != 0):
+ offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid]
+ actuator_moment_tile = wp.tile_load(
+ d.actuator_moment[worldid],
+ shape=(tilesize_nu, tilesize_nv),
+ offset=(offset_nu, offset_nv),
+ )
+ zeros = wp.tile_zeros(shape=(tilesize_nu, tilesize_nu), dtype=wp.float32)
+ vel_tile = wp.tile_load(
+ d.act_vel_integration[worldid], shape=(tilesize_nu), offset=offset_nu
+ )
+ diag = wp.tile_diag_add(zeros, vel_tile)
+ actuator_moment_T = wp.tile_transpose(actuator_moment_tile)
+ amTVel = wp.tile_matmul(actuator_moment_T, diag)
+ qderiv_tile = wp.tile_matmul(amTVel, actuator_moment_tile)
+ else:
+ qderiv_tile = wp.tile_zeros(
+ shape=(tilesize_nv, tilesize_nv), dtype=wp.float32
+ )
+
+ if wp.static(passive_enabled):
+ dof_damping = wp.tile_load(damping, shape=tilesize_nv, offset=offset_nv)
+ negative = wp.neg(dof_damping)
+ qderiv_tile = wp.tile_diag_add(qderiv_tile, negative)
+
+ # add to qM
+ qM_tile = wp.tile_load(
+ d.qM[worldid], shape=(tilesize_nv, tilesize_nv), offset=(offset_nv, offset_nv)
+ )
+ qderiv_tile = wp.tile_map(subtract_multiply, qM_tile, qderiv_tile)
+ wp.tile_store(
+ d.qM_integration[worldid], qderiv_tile, offset=(offset_nv, offset_nv)
+ )
+
+ # sum qfrc
+ qfrc_smooth_tile = wp.tile_load(
+ d.qfrc_smooth[worldid], shape=tilesize_nv, offset=offset_nv
+ )
+ qfrc_constraint_tile = wp.tile_load(
+ d.qfrc_constraint[worldid], shape=tilesize_nv, offset=offset_nv
+ )
+ qfrc_combined = wp.add(qfrc_smooth_tile, qfrc_constraint_tile)
+ wp.tile_store(d.qfrc_integration[worldid], qfrc_combined, offset=offset_nv)
+
+ wp.launch_tiled(
+ qderiv_actuator_fused_kernel,
+ dim=(d.nworld, size),
+ inputs=[m, d, damping, adr],
+ block_dim=block_dim,
+ )
+
+ qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy()
+ qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy()
+ qderiv_tileadr = m.actuator_moment_tileadr.numpy()
+
+ for i in range(len(qderiv_tileadr)):
+ beg = qderiv_tileadr[i]
+ end = (
+ m.qLD_tile.shape[0] if i == len(qderiv_tileadr) - 1 else qderiv_tileadr[i + 1]
+ )
+ if qderiv_tilesize_nv[i] != 0:
+ qderiv_actuator_damping_tiled(
+ beg, end - beg, int(qderiv_tilesize_nv[i]), int(qderiv_tilesize_nu[i])
+ )
+
+ if passive_enabled or actuation_enabled:
+ if actuation_enabled:
+ wp.launch(
+ actuator_bias_gain_vel,
+ dim=(d.nworld, m.nu),
+ inputs=[m, d],
+ )
+
+ qderiv_actuator_damping_fused(m, d, m.dof_damping)
+
+ smooth._factor_solve_i_dense(
+ m, d, d.qM_integration, d.qacc_integration, d.qfrc_integration
+ )
+
+ _advance(m, d, d.act_dot, d.qacc_integration)
+ else:
+ _advance(m, d, d.act_dot, d.qacc)
+
+
+@event_scope
+def fwd_position(m: Model, d: Data):
+ """Position-dependent computations."""
+
+ smooth.kinematics(m, d)
+ smooth.com_pos(m, d)
+ # TODO(team): smooth.camlight
+ # TODO(team): smooth.tendon
+ smooth.crb(m, d)
+ smooth.factor_m(m, d)
+ collision_driver.collision(m, d)
+ constraint.make_constraint(m, d)
+ smooth.transmission(m, d)
+
+
+@event_scope
+def fwd_velocity(m: Model, d: Data):
+ """Velocity-dependent computations."""
+
+ if m.opt.is_sparse:
+ # TODO(team): sparse version
+ d.actuator_velocity.zero_()
+
+ @kernel
+ def _actuator_velocity(d: Data):
+ worldid, actid, dofid = wp.tid()
+ moment = d.actuator_moment[worldid, actid]
+ qvel = d.qvel[worldid]
+ wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid])
+
+ wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d])
+ else:
+
+ def actuator_velocity(
+ adr: int,
+ size: int,
+ tilesize_nu: int,
+ tilesize_nv: int,
+ ):
+ @kernel
+ def _actuator_velocity(
+ m: Model, d: Data, leveladr: int, velocity: array3df, qvel: array3df
+ ):
+ worldid, nodeid = wp.tid()
+ offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid]
+ offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid]
+ actuator_moment_tile = wp.tile_load(
+ d.actuator_moment[worldid],
+ shape=(tilesize_nu, tilesize_nv),
+ offset=(offset_nu, offset_nv),
+ )
+ qvel_tile = wp.tile_load(
+ qvel[worldid], shape=(tilesize_nv, 1), offset=(offset_nv, 0)
+ )
+ velocity_tile = wp.tile_matmul(actuator_moment_tile, qvel_tile)
+
+ wp.tile_store(velocity[worldid], velocity_tile, offset=(offset_nu, 0))
+
+ wp.launch_tiled(
+ _actuator_velocity,
+ dim=(d.nworld, size),
+ inputs=[
+ m,
+ d,
+ adr,
+ d.actuator_velocity.reshape(d.actuator_velocity.shape + (1,)),
+ d.qvel.reshape(d.qvel.shape + (1,)),
+ ],
+ block_dim=32,
+ )
+
+ actuator_moment_tilesize_nu = m.actuator_moment_tilesize_nu.numpy()
+ actuator_moment_tilesize_nv = m.actuator_moment_tilesize_nv.numpy()
+ actuator_moment_tileadr = m.actuator_moment_tileadr.numpy()
+
+ for i in range(len(actuator_moment_tileadr)):
+ beg = actuator_moment_tileadr[i]
+ end = (
+ m.actuator_moment_tileadr.shape[0]
+ if i == len(actuator_moment_tileadr) - 1
+ else actuator_moment_tileadr[i + 1]
+ )
+ if actuator_moment_tilesize_nu[i] != 0 and actuator_moment_tilesize_nv[i] != 0:
+ actuator_velocity(
+ beg,
+ end - beg,
+ int(actuator_moment_tilesize_nu[i]),
+ int(actuator_moment_tilesize_nv[i]),
+ )
+
+ smooth.com_vel(m, d)
+ passive.passive(m, d)
+ smooth.rne(m, d)
+
+
+@event_scope
+def fwd_actuation(m: Model, d: Data):
+ """Actuation-dependent computations."""
+ if not m.nu or m.opt.disableflags & DisableBit.ACTUATION:
+ d.act_dot.zero_()
+ d.qfrc_actuator.zero_()
+ return
+
+ # TODO support stateful actuators
+
+ @kernel
+ def _force(
+ m: Model,
+ d: Data,
+ # outputs
+ force: array2df,
+ ):
+ worldid, uid = wp.tid()
+
+ actuator_length = d.actuator_length[worldid, uid]
+ actuator_velocity = d.actuator_velocity[worldid, uid]
+
+ gain = m.actuator_gainprm[uid, 0]
+ gain += m.actuator_gainprm[uid, 1] * actuator_length
+ gain += m.actuator_gainprm[uid, 2] * actuator_velocity
+
+ bias = m.actuator_biasprm[uid, 0]
+ bias += m.actuator_biasprm[uid, 1] * actuator_length
+ bias += m.actuator_biasprm[uid, 2] * actuator_velocity
+
+ ctrl = d.ctrl[worldid, uid]
+ disable_clampctrl = m.opt.disableflags & wp.static(DisableBit.CLAMPCTRL.value)
+ if m.actuator_ctrllimited[uid] and not disable_clampctrl:
+ r = m.actuator_ctrlrange[uid]
+ ctrl = wp.clamp(ctrl, r[0], r[1])
+ f = gain * ctrl + bias
+ if m.actuator_forcelimited[uid]:
+ r = m.actuator_forcerange[uid]
+ force[worldid, uid] = f
+
+ @kernel
+ def _qfrc_limited(m: Model, d: Data):
+ worldid, dofid = wp.tid()
+ jntid = m.dof_jntid[dofid]
+ if m.jnt_actfrclimited[jntid]:
+ d.qfrc_actuator[worldid, dofid] = wp.clamp(
+ d.qfrc_actuator[worldid, dofid],
+ m.jnt_actfrcrange[jntid][0],
+ m.jnt_actfrcrange[jntid][1],
+ )
+
+ if m.opt.is_sparse:
+ # TODO(team): sparse version
+ @kernel
+ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df):
+ worldid, vid = wp.tid()
+
+ s = float(0.0)
+ for uid in range(m.nu):
+ # TODO consider using Tile API or transpose moment for better access pattern
+ s += moment[worldid, uid, vid] * force[worldid, uid]
+ jntid = m.dof_jntid[vid]
+ if m.jnt_actfrclimited[jntid]:
+ r = m.jnt_actfrcrange[jntid]
+ s = wp.clamp(s, r[0], r[1])
+ qfrc[worldid, vid] = s
+
+ wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_force])
+
+ if m.opt.is_sparse:
+ # TODO(team): sparse version
+
+ wp.launch(
+ _qfrc,
+ dim=(d.nworld, m.nv),
+ inputs=[m, d.actuator_moment, d.actuator_force],
+ outputs=[d.qfrc_actuator],
+ )
+
+ else:
+
+ def qfrc_actuator(adr: int, size: int, tilesize_nu: int, tilesize_nv: int):
+ @kernel
+ def qfrc_actuator_kernel(
+ m: Model,
+ d: Data,
+ leveladr: int,
+ qfrc_actuator: array3df,
+ actuator_force: array3df,
+ ):
+ worldid, nodeid = wp.tid()
+ offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid]
+ offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid]
+
+ actuator_moment_tile = wp.tile_load(
+ d.actuator_moment[worldid],
+ shape=(tilesize_nu, tilesize_nv),
+ offset=(offset_nu, offset_nv),
+ )
+ actuator_moment_T_tile = wp.tile_transpose(actuator_moment_tile)
+
+ force_tile = wp.tile_load(
+ actuator_force[worldid], shape=(tilesize_nu, 1), offset=(offset_nu, 0)
+ )
+ qfrc_tile = wp.tile_matmul(actuator_moment_T_tile, force_tile)
+ wp.tile_store(qfrc_actuator[worldid], qfrc_tile, offset=(offset_nv, 0))
+
+ wp.launch_tiled(
+ qfrc_actuator_kernel,
+ dim=(d.nworld, size),
+ inputs=[
+ m,
+ d,
+ adr,
+ d.qfrc_actuator.reshape(d.qfrc_actuator.shape + (1,)),
+ d.actuator_force.reshape(d.actuator_force.shape + (1,)),
+ ],
+ block_dim=32,
+ )
+
+ qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy()
+ qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy()
+ qderiv_tileadr = m.actuator_moment_tileadr.numpy()
+
+ for i in range(len(qderiv_tileadr)):
+ beg = qderiv_tileadr[i]
+ end = (
+ m.actuator_moment_tileadr.shape[0]
+ if i == len(qderiv_tileadr) - 1
+ else qderiv_tileadr[i + 1]
+ )
+ if qderiv_tilesize_nu[i] != 0 and qderiv_tilesize_nv[i] != 0:
+ qfrc_actuator(
+ beg, end - beg, int(qderiv_tilesize_nu[i]), int(qderiv_tilesize_nv[i])
+ )
+
+ wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d])
+
+ # TODO actuator-level gravity compensation, skip if added as passive force
+
+
+@event_scope
+def fwd_acceleration(m: Model, d: Data):
+ """Add up all non-constraint forces, compute qacc_smooth."""
+
+ @kernel
+ def _qfrc_smooth(d: Data):
+ worldid, dofid = wp.tid()
+ d.qfrc_smooth[worldid, dofid] = (
+ d.qfrc_passive[worldid, dofid]
+ - d.qfrc_bias[worldid, dofid]
+ + d.qfrc_actuator[worldid, dofid]
+ + d.qfrc_applied[worldid, dofid]
+ )
+
+ wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d])
+ xfrc_accumulate(m, d, d.qfrc_smooth)
+
+ smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth)
+
+
+@event_scope
+def forward(m: Model, d: Data):
+ """Forward dynamics."""
+
+ fwd_position(m, d)
+ # TODO(team): sensor.sensor_pos
+ fwd_velocity(m, d)
+ # TODO(team): sensor.sensor_vel
+ fwd_actuation(m, d)
+ fwd_acceleration(m, d)
+ # TODO(team): sensor.sensor_acc
+
+ if d.njmax == 0:
+ kernel_copy(d.qacc, d.qacc_smooth)
+ else:
+ solver.solve(m, d)
+
+
+@event_scope
+def step(m: Model, d: Data):
+ """Advance simulation."""
+ forward(m, d)
+
+ if m.opt.integrator == mujoco.mjtIntegrator.mjINT_EULER:
+ euler(m, d)
+ elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_RK4:
+ # TODO(team): rungekutta4
+ raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.")
+ elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST:
+ implicit(m, d)
+ else:
+ raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.")
diff --git a/mujoco/mjx/_src/forward_test.py b/mujoco_warp/_src/forward_test.py
similarity index 54%
rename from mujoco/mjx/_src/forward_test.py
rename to mujoco_warp/_src/forward_test.py
index 5f5dce6b..63d85ca0 100644
--- a/mujoco/mjx/_src/forward_test.py
+++ b/mujoco_warp/_src/forward_test.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,19 +15,18 @@
"""Tests for forward dynamics functions."""
-from absl.testing import absltest
-from etils import epath
+import mujoco
import numpy as np
import warp as wp
+from absl.testing import absltest
+from absl.testing import parameterized
+from etils import epath
-wp.config.verify_cuda = True
-
-import mujoco
-from mujoco import mjx
+import mujoco_warp as mjwarp
-from .types import DisableBit
+wp.config.verify_cuda = True
-# tolerance for difference between MuJoCo and MJX smooth calculations - mostly
+# tolerance for difference between MuJoCo and mjwarp smooth calculations - mostly
# due to float precision
_TOLERANCE = 5e-5
@@ -40,25 +39,23 @@ def _assert_eq(a, b, name):
class ForwardTest(absltest.TestCase):
def _load(self, fname: str, is_sparse: bool = True):
- path = epath.resource_path("mujoco.mjx") / "test_data" / fname
+ path = epath.resource_path("mujoco_warp") / "test_data" / fname
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
mjm.opt.jacobian = is_sparse
mjd = mujoco.MjData(mjm)
mujoco.mj_resetDataKeyframe(mjm, mjd, 1) # reset to stand_on_left_leg
mjd.qvel = np.random.uniform(low=-0.01, high=0.01, size=mjd.qvel.shape)
- mjd.ctrl = np.random.normal(scale=10, size=mjd.ctrl.shape)
- mjd.act = np.random.normal(scale=10, size=mjd.act.shape)
+ mjd.ctrl = np.random.normal(scale=1, size=mjd.ctrl.shape)
mujoco.mj_forward(mjm, mjd)
- m = mjx.put_model(mjm)
- d = mjx.put_data(mjm, mjd)
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd)
return mjm, mjd, m, d
def test_fwd_velocity(self):
- """Tests MJX fwd_velocity."""
- _, mjd, m, d = self._load("humanoid/humanoid.xml")
+ _, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False)
d.actuator_velocity.zero_()
- mjx.fwd_velocity(m, d)
+ mjwarp.fwd_velocity(m, d)
_assert_eq(
d.actuator_velocity.numpy()[0], mjd.actuator_velocity, "actuator_velocity"
@@ -66,7 +63,6 @@ def test_fwd_velocity(self):
_assert_eq(d.qfrc_bias.numpy()[0], mjd.qfrc_bias, "qfrc_bias")
def test_fwd_actuation(self):
- """Tests MJX fwd_actuation."""
mjm, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False)
mujoco.mj_fwdActuation(mjm, mjd)
@@ -74,27 +70,30 @@ def test_fwd_actuation(self):
for arr in (d.actuator_force, d.qfrc_actuator):
arr.zero_()
- mjx.fwd_actuation(m, d)
+ mjwarp.fwd_actuation(m, d)
_assert_eq(d.ctrl.numpy()[0], mjd.ctrl, "ctrl")
_assert_eq(d.actuator_force.numpy()[0], mjd.actuator_force, "actuator_force")
_assert_eq(d.qfrc_actuator.numpy()[0], mjd.qfrc_actuator, "qfrc_actuator")
+ # TODO(team): test DisableBit.CLAMPCTRL
+ # TODO(team): test DisableBit.ACTUATION
+ # TODO(team): test actuator gain/bias (e.g. position control)
+
def test_fwd_acceleration(self):
- """Tests MJX fwd_acceleration."""
_, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False)
for arr in (d.qfrc_smooth, d.qacc_smooth):
arr.zero_()
- mjx.factor_m(m, d) # for dense, get tile cholesky factorization
- mjx.fwd_acceleration(m, d)
+ mjwarp.factor_m(m, d) # for dense, get tile cholesky factorization
+ mjwarp.fwd_acceleration(m, d)
_assert_eq(d.qfrc_smooth.numpy()[0], mjd.qfrc_smooth, "qfrc_smooth")
_assert_eq(d.qacc_smooth.numpy()[0], mjd.qacc_smooth, "qacc_smooth")
def test_eulerdamp(self):
- path = epath.resource_path("mujoco.mjx") / "test_data/pendula.xml"
+ path = epath.resource_path("mujoco_warp") / "test_data/pendula.xml"
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
self.assertTrue((mjm.dof_damping > 0).any())
@@ -103,10 +102,10 @@ def test_eulerdamp(self):
mjd.qacc[:] = 1.0
mujoco.mj_forward(mjm, mjd)
- m = mjx.put_model(mjm)
- d = mjx.put_data(mjm, mjd)
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd)
- mjx.euler(m, d)
+ mjwarp.euler(m, d)
mujoco.mj_Euler(mjm, mjd)
_assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos")
@@ -119,33 +118,84 @@ def test_eulerdamp(self):
mjd.qacc[:] = 1.0
mujoco.mj_forward(mjm, mjd)
- m = mjx.put_model(mjm)
- d = mjx.put_data(mjm, mjd)
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd)
- mjx.euler(m, d)
+ mjwarp.euler(m, d)
mujoco.mj_Euler(mjm, mjd)
_assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos")
_assert_eq(d.act.numpy()[0], mjd.act, "act")
def test_disable_eulerdamp(self):
- path = epath.resource_path("mujoco.mjx") / "test_data/pendula.xml"
+ path = epath.resource_path("mujoco_warp") / "test_data/pendula.xml"
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
- mjm.opt.disableflags = mjm.opt.disableflags | DisableBit.EULERDAMP.value
+ mjm.opt.disableflags = mjm.opt.disableflags | mujoco.mjtDisableBit.mjDSBL_EULERDAMP
mjd = mujoco.MjData(mjm)
mujoco.mj_forward(mjm, mjd)
mjd.qvel[:] = 1.0
mjd.qacc[:] = 1.0
- m = mjx.put_model(mjm)
- d = mjx.put_data(mjm, mjd)
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd)
- mjx.euler(m, d)
+ mjwarp.euler(m, d)
np.testing.assert_allclose(d.qvel.numpy()[0], 1 + mjm.opt.timestep)
+class ImplicitIntegratorTest(parameterized.TestCase):
+ def _load(self, fname: str, disableFlags: int):
+ path = epath.resource_path("mujoco_warp") / "test_data" / fname
+ mjm = mujoco.MjModel.from_xml_path(path.as_posix())
+ mjm.opt.jacobian = 0
+ mjm.opt.integrator = mujoco.mjtIntegrator.mjINT_IMPLICITFAST
+ mjm.opt.disableflags = mjm.opt.disableflags | disableFlags
+ mjm.actuator_gainprm[:, 2] = np.random.normal(
+ scale=10, size=mjm.actuator_gainprm[:, 2].shape
+ )
+
+ # change actuators to velocity/damper to cover all codepaths
+ mjm.actuator_gaintype[3] = mujoco.mjtGain.mjGAIN_AFFINE
+ mjm.actuator_gaintype[6] = mujoco.mjtGain.mjGAIN_AFFINE
+ mjm.actuator_biastype[0:3] = mujoco.mjtBias.mjBIAS_AFFINE
+ mjm.actuator_biastype[4:6] = mujoco.mjtBias.mjBIAS_AFFINE
+ mjm.actuator_biasprm[0:3, 2] = -1
+ mjm.actuator_biasprm[4:6, 2] = -1
+ mjm.actuator_ctrlrange[3:7] = 10.0
+ mjm.actuator_gear[:] = 1.0
+
+ mjd = mujoco.MjData(mjm)
+
+ mjd.qvel = np.random.uniform(low=-0.01, high=0.01, size=mjd.qvel.shape)
+ mjd.ctrl = np.random.normal(scale=10, size=mjd.ctrl.shape)
+ mjd.act = np.random.normal(scale=10, size=mjd.act.shape)
+ mujoco.mj_forward(mjm, mjd)
+
+ mjd.ctrl = np.random.normal(scale=10, size=mjd.ctrl.shape)
+ mjd.act = np.random.normal(scale=10, size=mjd.act.shape)
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd)
+ return mjm, mjd, m, d
+
+ @parameterized.parameters(
+ 0,
+ mjwarp.DisableBit.PASSIVE.value,
+ mjwarp.DisableBit.ACTUATION.value,
+ mjwarp.DisableBit.PASSIVE.value & mjwarp.DisableBit.ACTUATION.value,
+ )
+ def test_implicit(self, disableFlags):
+ np.random.seed(0)
+ mjm, mjd, m, d = self._load("pendula.xml", disableFlags)
+
+ mjwarp.implicit(m, d)
+ mujoco.mj_implicit(mjm, mjd)
+
+ _assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos")
+ _assert_eq(d.act.numpy()[0], mjd.act, "act")
+
+
if __name__ == "__main__":
wp.init()
absltest.main()
diff --git a/mujoco/mjx/_src/io.py b/mujoco_warp/_src/io.py
similarity index 59%
rename from mujoco/mjx/_src/io.py
rename to mujoco_warp/_src/io.py
index 268e9fc5..c71c2600 100644
--- a/mujoco/mjx/_src/io.py
+++ b/mujoco_warp/_src/io.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +13,9 @@
# limitations under the License.
# ==============================================================================
-import warp as wp
import mujoco
import numpy as np
+import warp as wp
from . import support
from . import types
@@ -33,6 +33,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
m.nsite = mjm.nsite
m.nmocap = mjm.nmocap
m.nM = mjm.nM
+ m.nexclude = mjm.nexclude
m.opt.timestep = mjm.opt.timestep
m.opt.tolerance = mjm.opt.tolerance
m.opt.ls_tolerance = mjm.opt.ls_tolerance
@@ -53,7 +54,18 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
# dof lower triangle row and column indices
dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv)
- # indices for sparse qM
+ # indices for sparse qM full_m
+ is_, js = [], []
+ for i in range(mjm.nv):
+ j = i
+ while j > -1:
+ is_.append(i)
+ js.append(j)
+ j = mjm.dof_parentid[j]
+ qM_fullm_i = is_
+ qM_fullm_j = js
+
+ # indices for sparse qM mul_m
is_, js, madr_ijs = [], [], []
for i in range(mjm.nv):
madr_ij, j = mjm.dof_Madr[i], i
@@ -64,7 +76,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
break
is_, js, madr_ijs = is_ + [i], js + [j], madr_ijs + [madr_ij]
- qM_i, qM_j, qM_madr_ij = (np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs))
+ qM_mulm_i, qM_mulm_j, qM_madr_ij = (
+ np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs)
+ )
jnt_limited_slide_hinge_adr = np.nonzero(
mjm.jnt_limited
@@ -125,8 +139,52 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
qLD_tileadr = np.cumsum(tile_off)[:-1]
qLD_tilesize = np.array(sorted(tiles.keys()))
- m.qM_i = wp.array(qM_i, dtype=wp.int32, ndim=1)
- m.qM_j = wp.array(qM_j, dtype=wp.int32, ndim=1)
+ # tiles for actuator_moment - needs nu + nv tile size and offset
+ actuator_moment_offset_nv = np.empty(shape=(0,), dtype=int)
+ actuator_moment_offset_nu = np.empty(shape=(0,), dtype=int)
+ actuator_moment_tileadr = np.empty(shape=(0,), dtype=int)
+ actuator_moment_tilesize_nv = np.empty(shape=(0,), dtype=int)
+ actuator_moment_tilesize_nu = np.empty(shape=(0,), dtype=int)
+
+ if not support.is_sparse(mjm):
+ # how many actuators for each tree
+ tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1]
+ tree_id = mjm.dof_treeid[tile_corners]
+ num_trees = int(np.max(tree_id))
+ tree = mjm.body_treeid[mjm.jnt_bodyid[mjm.actuator_trnid[:, 0]]]
+ counts, ids = np.histogram(tree, bins=np.arange(0, num_trees + 2))
+ acts_per_tree = dict(zip([int(i) for i in ids], [int(i) for i in counts]))
+
+ tiles = {}
+ act_beg = 0
+ for i in range(len(tile_corners)):
+ tile_beg = tile_corners[i]
+ tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1]
+ tree = int(tree_id[i])
+ act_num = acts_per_tree[tree]
+ tiles.setdefault((tile_end - tile_beg, act_num), []).append((tile_beg, act_beg))
+ act_beg += act_num
+
+ sorted_keys = sorted(tiles.keys())
+ actuator_moment_offset_nv = [
+ t[0] for key in sorted_keys for t in tiles.get(key, [])
+ ]
+ actuator_moment_offset_nu = [
+ t[1] for key in sorted_keys for t in tiles.get(key, [])
+ ]
+ tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())]
+ actuator_moment_tileadr = np.cumsum(tile_off)[:-1] # offset
+ actuator_moment_tilesize_nv = np.array(
+ [a[0] for a in sorted_keys]
+ ) # for this level
+ actuator_moment_tilesize_nu = np.array(
+ [int(a[1]) for a in sorted_keys]
+ ) # for this level
+
+ m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32, ndim=1)
+ m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32, ndim=1)
+ m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32, ndim=1)
+ m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32, ndim=1)
m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32, ndim=1)
m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1)
m.qLD_update_treeadr = wp.array(
@@ -135,12 +193,28 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1)
m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu")
m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu")
+ m.actuator_moment_offset_nv = wp.array(
+ actuator_moment_offset_nv, dtype=wp.int32, ndim=1
+ )
+ m.actuator_moment_offset_nu = wp.array(
+ actuator_moment_offset_nu, dtype=wp.int32, ndim=1
+ )
+ m.actuator_moment_tileadr = wp.array(
+ actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu"
+ )
+ m.actuator_moment_tilesize_nv = wp.array(
+ actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu"
+ )
+ m.actuator_moment_tilesize_nu = wp.array(
+ actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu"
+ )
m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1)
m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1)
m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1)
m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32, ndim=1)
m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1)
m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1)
+ m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1)
m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1)
m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1)
m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1)
@@ -148,7 +222,18 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1)
m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1)
m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1)
+
+ subtree_mass = np.copy(mjm.body_mass)
+ # TODO(team): should this be [mjm.nbody - 1, 0) ?
+ for i in range(mjm.nbody - 1, -1, -1):
+ subtree_mass[mjm.body_parentid[i]] += subtree_mass[i]
+
+ m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1)
m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2)
+ m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1)
+ m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1)
+ m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1)
+ m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1)
m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1)
m.jnt_limited = wp.array(mjm.jnt_limited, dtype=wp.int32, ndim=1)
m.jnt_limited_slide_hinge_adr = wp.array(
@@ -166,11 +251,30 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1)
m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1)
m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1)
+ m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1)
m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1)
+ m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1)
+ m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1)
+ m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1)
m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1)
m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1)
+ m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1)
+ m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1)
+ m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1)
+ m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1)
+ m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1)
+ m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1)
+ m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1)
+ m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1)
+ m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3)
+ m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1)
+ m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1)
+ m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1)
+ m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1)
+ m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1)
m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1)
m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1)
+ m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1)
m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1)
m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1)
m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1)
@@ -186,7 +290,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1)
m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1)
m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1)
+ m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1)
m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=wp.float32, ndim=2)
+ m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1)
m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=wp.float32, ndim=2)
m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1)
m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1)
@@ -194,17 +300,75 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1)
m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1)
m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1)
+ m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1)
+
+ # short-circuiting here allows us to skip a lot of code in implicit integration
+ m.actuator_affine_bias_gain = bool(
+ np.any(mjm.actuator_biastype == types.BiasType.AFFINE.value)
+ or np.any(mjm.actuator_gaintype == types.GainType.AFFINE.value)
+ )
return m
+def _constraint(nv: int, nworld: int, njmax: int) -> types.Constraint:
+ efc = types.Constraint()
+
+ efc.J = wp.zeros((njmax, nv), dtype=wp.float32)
+ efc.D = wp.zeros((njmax,), dtype=wp.float32)
+ efc.pos = wp.zeros((njmax,), dtype=wp.float32)
+ efc.aref = wp.zeros((njmax,), dtype=wp.float32)
+ efc.force = wp.zeros((njmax,), dtype=wp.float32)
+ efc.margin = wp.zeros((njmax,), dtype=wp.float32)
+ efc.worldid = wp.zeros((njmax,), dtype=wp.int32)
+
+ efc.Jaref = wp.empty(shape=(njmax,), dtype=wp.float32)
+ efc.Ma = wp.empty(shape=(nworld, nv), dtype=wp.float32)
+ efc.grad = wp.empty(shape=(nworld, nv), dtype=wp.float32)
+ efc.grad_dot = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.Mgrad = wp.empty(shape=(nworld, nv), dtype=wp.float32)
+ efc.search = wp.empty(shape=(nworld, nv), dtype=wp.float32)
+ efc.search_dot = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.gauss = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.cost = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.prev_cost = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.solver_niter = wp.empty(shape=(nworld,), dtype=wp.int32)
+ efc.active = wp.empty(shape=(njmax,), dtype=wp.int32)
+ efc.gtol = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.mv = wp.empty(shape=(nworld, nv), dtype=wp.float32)
+ efc.jv = wp.empty(shape=(njmax,), dtype=wp.float32)
+ efc.quad = wp.empty(shape=(njmax,), dtype=wp.vec3f)
+ efc.quad_gauss = wp.empty(shape=(nworld,), dtype=wp.vec3f)
+ efc.h = wp.empty(shape=(nworld, nv, nv), dtype=wp.float32)
+ efc.alpha = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.prev_grad = wp.empty(shape=(nworld, nv), dtype=wp.float32)
+ efc.prev_Mgrad = wp.empty(shape=(nworld, nv), dtype=wp.float32)
+ efc.beta = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.beta_num = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.beta_den = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.done = wp.empty(shape=(nworld,), dtype=bool)
+
+ efc.ls_done = wp.zeros(shape=(nworld,), dtype=bool)
+ efc.p0 = wp.empty(shape=(nworld,), dtype=wp.vec3)
+ efc.lo = wp.empty(shape=(nworld,), dtype=wp.vec3)
+ efc.lo_alpha = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.hi = wp.empty(shape=(nworld,), dtype=wp.vec3)
+ efc.hi_alpha = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.lo_next = wp.empty(shape=(nworld,), dtype=wp.vec3)
+ efc.lo_next_alpha = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.hi_next = wp.empty(shape=(nworld,), dtype=wp.vec3)
+ efc.hi_next_alpha = wp.empty(shape=(nworld,), dtype=wp.float32)
+ efc.mid = wp.empty(shape=(nworld,), dtype=wp.vec3)
+ efc.mid_alpha = wp.empty(shape=(nworld,), dtype=wp.float32)
+
+ return efc
+
+
def make_data(
mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: int = -1
) -> types.Data:
d = types.Data()
d.nworld = nworld
- d.ncon_total = wp.zeros((1,), dtype=wp.int32, ndim=1)
- d.nefc_total = wp.zeros((1,), dtype=wp.int32, ndim=1)
# TODO(team): move to Model?
if nconmax == -1:
@@ -216,8 +380,8 @@ def make_data(
njmax = 512
d.njmax = njmax
- d.ncon = 0
- d.nefc = wp.zeros(nworld, dtype=wp.int32)
+ d.ncon = wp.zeros(1, dtype=wp.int32)
+ d.nefc = wp.zeros(1, dtype=wp.int32, ndim=1)
d.nl = 0
d.time = 0.0
@@ -274,6 +438,7 @@ def make_data(
d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i)
d.contact.efc_address = wp.zeros((nconmax,), dtype=wp.int32)
d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32)
+ d.efc = _constraint(mjm.nv, d.nworld, d.njmax)
d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
@@ -281,44 +446,34 @@ def make_data(
d.qfrc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
- d.efc_J = wp.zeros((njmax, mjm.nv), dtype=wp.float32)
- d.efc_D = wp.zeros((njmax,), dtype=wp.float32)
- d.efc_pos = wp.zeros((njmax,), dtype=wp.float32)
- d.efc_aref = wp.zeros((njmax,), dtype=wp.float32)
- d.efc_force = wp.zeros((njmax,), dtype=wp.float32)
- d.efc_margin = wp.zeros((njmax,), dtype=wp.float32)
- d.efc_worldid = wp.zeros((njmax,), dtype=wp.int32)
+
+ d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
+ d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector)
+
# internal tmp arrays
d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qM_integration = wp.zeros_like(d.qM)
d.qLD_integration = wp.zeros_like(d.qLD)
d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv)
-
- # the result of the broadphase gets stored in this array
- d.max_num_overlaps_per_world = (
- mjm.ngeom * (mjm.ngeom - 1) // 2
- ) # TODO: this is a hack to estimate the maximum number of overlaps per world
- d.broadphase_pairs = wp.zeros((nworld, d.max_num_overlaps_per_world), dtype=wp.vec2i)
- d.result_count = wp.zeros(nworld, dtype=wp.int32)
-
- # internal broadphase tmp arrays
- d.boxes_sorted = wp.zeros(
- (nworld, mjm.ngeom), dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32)
- )
- d.data_start = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32)
- d.data_end = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32)
- d.data_indexer = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32)
- d.ranges = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32)
- d.cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32)
+ d.act_vel_integration = wp.zeros_like(d.ctrl)
+
+ # sweep-and-prune broadphase
+ d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4)
+ d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32)
+ d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32)
+ d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32)
+ d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32)
+ d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32)
segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)]
- d.segment_indices = wp.array(segment_indices_list, dtype=int)
+ d.sap_segment_index = wp.array(segment_indices_list, dtype=int)
- d.geom_aabb = wp.array(
- mjm.geom_aabb, dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1
- )
+ # collision driver
+ d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1)
+ d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1)
+ d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1)
return d
@@ -332,25 +487,23 @@ def put_data(
) -> types.Data:
d = types.Data()
d.nworld = nworld
- d.ncon_total = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1)
- d.nefc_total = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1)
# TODO(team): move to Model?
if nconmax == -1:
# TODO(team): heuristic for nconmax
- nconmax = 512
+ nconmax = max(512, mjd.ncon * nworld)
d.nconmax = nconmax
if njmax == -1:
# TODO(team): heuristic for njmax
- njmax = 512
+ njmax = max(512, mjd.nefc * nworld)
d.njmax = njmax
if nworld * mjd.nefc > njmax:
raise ValueError("nworld * nefc > njmax")
- d.ncon = mjd.ncon
+ d.ncon = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1)
d.nl = mjd.nl
- d.nefc = wp.zeros(1, dtype=wp.int32)
+ d.nefc = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1)
d.time = mjd.time
# TODO(erikfrey): would it be better to tile on the gpu?
@@ -360,8 +513,10 @@ def tile(x):
if support.is_sparse(mjm):
qM = np.expand_dims(mjd.qM, axis=0)
qLD = np.expand_dims(mjd.qLD, axis=0)
- # TODO(taylorhowell): sparse efc_J
efc_J = np.zeros((mjd.nefc, mjm.nv))
+ mujoco.mju_sparse2dense(
+ efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind
+ )
else:
qM = np.zeros((mjm.nv, mjm.nv))
mujoco.mj_fullM(mjm, qM, mjd.qM)
@@ -418,7 +573,8 @@ def tile(x):
d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth), dtype=wp.float32, ndim=2)
d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2)
d.qacc_smooth = wp.array(tile(mjd.qacc_smooth), dtype=wp.float32, ndim=2)
- d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2)
+ d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2)
+ d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2)
nefc = mjd.nefc
efc_worldid = np.zeros(njmax, dtype=int)
@@ -447,17 +603,6 @@ def tile(x):
[np.repeat(mjd.efc_margin, nworld, axis=0), np.zeros(nefc_fill)]
)
- d.efc_J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2)
- d.efc_D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1)
- d.efc_pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1)
- d.efc_aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1)
- d.efc_force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1)
- d.efc_margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1)
- d.efc_worldid = wp.from_numpy(efc_worldid, dtype=wp.int32)
-
- d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2)
- d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2)
-
ncon = mjd.ncon
con_efc_address = np.zeros(nconmax, dtype=int)
con_worldid = np.zeros(nconmax, dtype=int)
@@ -512,33 +657,147 @@ def tile(x):
d.contact.efc_address = wp.array(con_efc_address, dtype=wp.int32, ndim=1)
d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1)
+ d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
+ d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector)
+
+ d.efc = _constraint(mjm.nv, d.nworld, d.njmax)
+ d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2)
+ d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1)
+ d.efc.pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1)
+ d.efc.aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1)
+ d.efc.force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1)
+ d.efc.margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1)
+ d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32)
+
d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2)
+
# internal tmp arrays
d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32)
d.qM_integration = wp.zeros_like(d.qM)
d.qLD_integration = wp.zeros_like(d.qLD)
d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv)
+ d.act_vel_integration = wp.zeros_like(d.ctrl)
+
+ # broadphase sweep and prune
+ d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4)
+ d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32)
+ d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32)
+ d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32)
+ d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32)
+ d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32)
+ segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)]
+ d.sap_segment_index = wp.array(segment_indices_list, dtype=int)
- # the result of the broadphase gets stored in this array
- d.max_num_overlaps_per_world = mjm.ngeom * (mjm.ngeom - 1) // 2
- d.broadphase_pairs = wp.zeros((nworld, d.max_num_overlaps_per_world), dtype=wp.vec2i)
- d.result_count = wp.zeros(nworld, dtype=wp.int32)
+ # collision driver
+ d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1)
+ d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1)
+ d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1)
+
+ return d
- # internal broadphase tmp arrays
- d.boxes_sorted = wp.zeros(
- (nworld, mjm.ngeom), dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32)
- )
- d.data_start = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32)
- d.data_end = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32)
- d.data_indexer = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32)
- d.ranges = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32)
- d.cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32)
- segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)]
- d.segment_indices = wp.array(segment_indices_list, dtype=int)
- d.geom_aabb = wp.array(
- mjm.geom_aabb, dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1
+def get_data_into(
+ result: mujoco.MjData,
+ mjm: mujoco.MjModel,
+ d: types.Data,
+):
+ """Gets Data from a device into an existing mujoco.MjData."""
+ if d.nworld > 1:
+ raise NotImplementedError("only nworld == 1 supported for now")
+
+ ncon = d.ncon.numpy()[0]
+ nefc = d.nefc.numpy()[0]
+
+ if ncon != result.ncon or nefc != result.nefc:
+ mujoco._functions._realloc_con_efc(result, ncon=ncon, nefc=nefc)
+
+ result.time = d.time
+
+ result.qpos[:] = d.qpos.numpy()[0]
+ result.qvel[:] = d.qvel.numpy()[0]
+ result.qacc_warmstart = d.qacc_warmstart.numpy()[0]
+ result.qfrc_applied = d.qfrc_applied.numpy()[0]
+ result.mocap_pos = d.mocap_pos.numpy()[0]
+ result.mocap_quat = d.mocap_quat.numpy()[0]
+ result.qacc = d.qacc.numpy()[0]
+ result.xanchor = d.xanchor.numpy()[0]
+ result.xaxis = d.xaxis.numpy()[0]
+ result.xmat = d.xmat.numpy().reshape((-1, 9))
+ result.xpos = d.xpos.numpy()[0]
+ result.xquat = d.xquat.numpy()[0]
+ result.xipos = d.xipos.numpy()[0]
+ result.ximat = d.ximat.numpy().reshape((-1, 9))
+ result.subtree_com = d.subtree_com.numpy()[0]
+ result.geom_xpos = d.geom_xpos.numpy()[0]
+ result.geom_xmat = d.geom_xmat.numpy().reshape((-1, 9))
+ result.site_xpos = d.site_xpos.numpy()[0]
+ result.site_xmat = d.site_xmat.numpy().reshape((-1, 9))
+ result.cinert = d.cinert.numpy()[0]
+ result.cdof = d.cdof.numpy()[0]
+ result.crb = d.crb.numpy()[0]
+ result.qLDiagInv = d.qLDiagInv.numpy()[0]
+ result.ctrl = d.ctrl.numpy()[0]
+ result.actuator_velocity = d.actuator_velocity.numpy()[0]
+ result.actuator_force = d.actuator_force.numpy()[0]
+ result.actuator_length = d.actuator_length.numpy()[0]
+ mujoco.mju_dense2sparse(
+ result.actuator_moment,
+ d.actuator_moment.numpy()[0],
+ result.moment_rownnz,
+ result.moment_rowadr,
+ result.moment_colind,
)
+ result.cvel = d.cvel.numpy()[0]
+ result.cdof_dot = d.cdof_dot.numpy()[0]
+ result.qfrc_bias = d.qfrc_bias.numpy()[0]
+ result.qfrc_passive = d.qfrc_passive.numpy()[0]
+ result.qfrc_spring = d.qfrc_spring.numpy()[0]
+ result.qfrc_damper = d.qfrc_damper.numpy()[0]
+ result.qfrc_actuator = d.qfrc_actuator.numpy()[0]
+ result.qfrc_smooth = d.qfrc_smooth.numpy()[0]
+ result.qfrc_constraint = d.qfrc_constraint.numpy()[0]
+ result.qacc_smooth = d.qacc_smooth.numpy()[0]
+ result.act = d.act.numpy()[0]
+ result.act_dot = d.act_dot.numpy()[0]
+
+ result.contact.dist[:] = d.contact.dist.numpy()[:ncon]
+ result.contact.pos[:] = d.contact.pos.numpy()[:ncon]
+ result.contact.frame[:] = d.contact.frame.numpy()[:ncon].reshape((-1, 9))
+ result.contact.includemargin[:] = d.contact.includemargin.numpy()[:ncon]
+ result.contact.friction[:] = d.contact.friction.numpy()[:ncon]
+ result.contact.solref[:] = d.contact.solref.numpy()[:ncon]
+ result.contact.solreffriction[:] = d.contact.solreffriction.numpy()[:ncon]
+ result.contact.solimp[:] = d.contact.solimp.numpy()[:ncon]
+ result.contact.dim[:] = d.contact.dim.numpy()[:ncon]
+ result.contact.efc_address[:] = d.contact.efc_address.numpy()[:ncon]
- return d
+ if support.is_sparse(mjm):
+ result.qM[:] = d.qM.numpy()[0, 0]
+ result.qLD[:] = d.qLD.numpy()[0, 0]
+ # TODO(team): set efc_J after fix to _realloc_con_efc lands
+ # efc_J = d.efc_J.numpy()[0, :nefc]
+ # mujoco.mju_dense2sparse(
+ # result.efc_J, efc_J, result.efc_J_rownnz, result.efc_J_rowadr, result.efc_J_colind
+ # )
+ else:
+ qM = d.qM.numpy()
+ qLD = d.qLD.numpy()
+ adr = 0
+ for i in range(mjm.nv):
+ j = i
+ while j >= 0:
+ result.qM[adr] = qM[0, i, j]
+ result.qLD[adr] = qLD[0, i, j]
+ j = mjm.dof_parentid[j]
+ adr += 1
+ # TODO(team): set efc_J after fix to _realloc_con_efc lands
+ # if nefc > 0:
+ # result.efc_J[:nefc * mjm.nv] = d.efc_J.numpy()[:nefc].flatten()
+ result.efc_D[:] = d.efc.D.numpy()[:nefc]
+ result.efc_pos[:] = d.efc.pos.numpy()[:nefc]
+ result.efc_aref[:] = d.efc.aref.numpy()[:nefc]
+ result.efc_force[:] = d.efc.force.numpy()[:nefc]
+ result.efc_margin[:] = d.efc.margin.numpy()[:nefc]
+
+ # TODO: other efc_ fields, anything else missing
diff --git a/mujoco/mjx/_src/math.py b/mujoco_warp/_src/math.py
similarity index 66%
rename from mujoco/mjx/_src/math.py
rename to mujoco_warp/_src/math.py
index 416e1c5c..8bbefb9a 100644
--- a/mujoco/mjx/_src/math.py
+++ b/mujoco_warp/_src/math.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
+from typing import Tuple
+
import warp as wp
from . import types
@@ -148,3 +150,93 @@ def quat_integrate(q: wp.quat, v: wp.vec3, dt: wp.float32) -> wp.quat:
q_res = mul_quat(q, q_res)
return wp.normalize(q_res)
+
+
+@wp.func
+def orthogonals(a: wp.vec3):
+ y = wp.vec3(0.0, 1.0, 0.0)
+ z = wp.vec3(0.0, 0.0, 1.0)
+ b = wp.where((-0.5 < a[1]) and (a[1] < 0.5), y, z)
+ b = b - a * wp.dot(a, b)
+ b = wp.normalize(b)
+ if wp.length(a) == 0.0:
+ b = wp.vec3(0.0, 0.0, 0.0)
+ c = wp.cross(a, b)
+
+ return b, c
+
+
+@wp.func
+def make_frame(a: wp.vec3):
+ a = wp.normalize(a)
+ b, c = orthogonals(a)
+
+ # fmt: off
+ return wp.mat33(
+ a.x, a.y, a.z,
+ b.x, b.y, b.z,
+ c.x, c.y, c.z
+ )
+ # fmt: on
+
+
+@wp.func
+def normalize_with_norm(x: wp.vec3):
+ norm = wp.length(x)
+ if norm == 0.0:
+ return x, 0.0
+ return x / norm, norm
+
+
+@wp.func
+def closest_segment_point(a: wp.vec3, b: wp.vec3, pt: wp.vec3) -> wp.vec3:
+ """Returns the closest point on the a-b line segment to a point pt."""
+ ab = b - a
+ t = wp.dot(pt - a, ab) / (wp.dot(ab, ab) + 1e-6)
+ return a + wp.clamp(t, 0.0, 1.0) * ab
+
+
+@wp.func
+def closest_segment_point_and_dist(
+ a: wp.vec3, b: wp.vec3, pt: wp.vec3
+) -> Tuple[wp.vec3, wp.float32]:
+ """Returns closest point on the line segment and the distance squared."""
+ closest = closest_segment_point(a, b, pt)
+ dist = wp.dot((pt - closest), (pt - closest))
+ return closest, dist
+
+
+@wp.func
+def closest_segment_to_segment_points(
+ a0: wp.vec3, a1: wp.vec3, b0: wp.vec3, b1: wp.vec3
+) -> Tuple[wp.vec3, wp.vec3]:
+ """Returns closest points between two line segments."""
+
+ dir_a, len_a = normalize_with_norm(a1 - a0)
+ dir_b, len_b = normalize_with_norm(b1 - b0)
+
+ half_len_a = len_a * 0.5
+ half_len_b = len_b * 0.5
+ a_mid = a0 + dir_a * half_len_a
+ b_mid = b0 + dir_b * half_len_b
+
+ trans = a_mid - b_mid
+
+ dira_dot_dirb = wp.dot(dir_a, dir_b)
+ dira_dot_trans = wp.dot(dir_a, trans)
+ dirb_dot_trans = wp.dot(dir_b, trans)
+ denom = 1.0 - dira_dot_dirb * dira_dot_dirb
+
+ orig_t_a = (-dira_dot_trans + dira_dot_dirb * dirb_dot_trans) / (denom + 1e-6)
+ orig_t_b = dirb_dot_trans + orig_t_a * dira_dot_dirb
+ t_a = wp.clamp(orig_t_a, -half_len_a, half_len_a)
+ t_b = wp.clamp(orig_t_b, -half_len_b, half_len_b)
+
+ best_a = a_mid + dir_a * t_a
+ best_b = b_mid + dir_b * t_b
+
+ new_a, d1 = closest_segment_point_and_dist(a0, a1, best_b)
+ new_b, d2 = closest_segment_point_and_dist(b0, b1, best_a)
+ if d1 < d2:
+ return new_a, best_b
+ return best_a, new_b
diff --git a/mujoco_warp/_src/math_test.py b/mujoco_warp/_src/math_test.py
new file mode 100644
index 00000000..db691533
--- /dev/null
+++ b/mujoco_warp/_src/math_test.py
@@ -0,0 +1,88 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import warp as wp
+from absl.testing import absltest
+
+from .math import closest_segment_to_segment_points
+
+
+class ClosestSegmentSegmentPointsTest(absltest.TestCase):
+ """Tests for closest segment-to-segment points."""
+
+ def test_closest_segments_points(self):
+ """Test closest points between two segments."""
+ a0 = wp.vec3([0.73432405, 0.12372768, 0.20272314])
+ a1 = wp.vec3([1.10600128, 0.88555209, 0.65209485])
+ b0 = wp.vec3([0.85599262, 0.61736299, 0.9843583])
+ b1 = wp.vec3([1.84270939, 0.92891793, 1.36343326])
+
+ best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1)
+ self.assertSequenceAlmostEqual(best_a, [1.09063, 0.85404, 0.63351], 5)
+ self.assertSequenceAlmostEqual(best_b, [0.99596, 0.66156, 1.03813], 5)
+
+ def test_intersecting_segments(self):
+ """Tests segments that intersect."""
+ a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0])
+ b0, b1 = wp.vec3([-1.0, 0.0, 0.0]), wp.vec3([1.0, 0.0, 0.0])
+
+ best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1)
+ self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 0.0], 5)
+ self.assertSequenceAlmostEqual(best_b, [0.0, 0.0, 0.0], 5)
+
+ def test_intersecting_lines(self):
+ """Tests that intersecting lines get clipped."""
+ a0, a1 = wp.vec3([0.2, 0.2, 0.0]), wp.vec3([1.0, 1.0, 0.0])
+ b0, b1 = wp.vec3([0.2, 0.4, 0.0]), wp.vec3([1.0, 2.0, 0.0])
+
+ best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1)
+ self.assertSequenceAlmostEqual(best_a, [0.3, 0.3, 0.0], 2)
+ self.assertSequenceAlmostEqual(best_b, [0.2, 0.4, 0.0], 2)
+
+ def test_parallel_segments(self):
+ """Tests that parallel segments have closest points at the midpoint."""
+ a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0])
+ b0, b1 = wp.vec3([1.0, 0.0, -1.0]), wp.vec3([1.0, 0.0, 1.0])
+
+ best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1)
+ self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 0.0], 5)
+ self.assertSequenceAlmostEqual(best_b, [1.0, 0.0, 0.0], 5)
+
+ def test_parallel_offset_segments(self):
+ """Tests that offset parallel segments are close at segment endpoints."""
+ a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0])
+ b0, b1 = wp.vec3([1.0, 0.0, 1.0]), wp.vec3([1.0, 0.0, 3.0])
+
+ best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1)
+ self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 1.0], 5)
+ self.assertSequenceAlmostEqual(best_b, [1.0, 0.0, 1.0], 5)
+
+ def test_zero_length_segments(self):
+ """Test that zero length segments don't return NaNs."""
+ a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, -1.0])
+ b0, b1 = wp.vec3([1.0, 0.0, 0.1]), wp.vec3([1.0, 0.0, 0.1])
+
+ best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1)
+ self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, -1.0], 5)
+ self.assertSequenceAlmostEqual(best_b, [1.0, 0.0, 0.1], 5)
+
+ def test_overlapping_segments(self):
+ """Tests that perfectly overlapping segments intersect at the midpoints."""
+ a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0])
+ b0, b1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0])
+
+ best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1)
+ self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 0.0], 5)
+ self.assertSequenceAlmostEqual(best_b, [0.0, 0.0, 0.0], 5)
diff --git a/mujoco/mjx/_src/passive.py b/mujoco_warp/_src/passive.py
similarity index 75%
rename from mujoco/mjx/_src/passive.py
rename to mujoco_warp/_src/passive.py
index 0125331c..c8cbc30b 100644
--- a/mujoco/mjx/_src/passive.py
+++ b/mujoco_warp/_src/passive.py
@@ -1,15 +1,38 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
import warp as wp
from . import math
-from .types import Model
from .types import Data
+from .types import DisableBit
from .types import JointType
+from .types import Model
+from .warp_util import event_scope
+from .warp_util import kernel
+@event_scope
def passive(m: Model, d: Data):
"""Adds all passive forces."""
+ if m.opt.disableflags & DisableBit.PASSIVE:
+ d.qfrc_passive.zero_()
+ # TODO(team): qfrc_gravcomp
+ return
- @wp.kernel
+ @kernel
def _spring(m: Model, d: Data):
worldid, jntid = wp.tid()
stiffness = m.jnt_stiffness[jntid]
@@ -67,7 +90,7 @@ def _spring(m: Model, d: Data):
fdif = d.qpos[worldid, qposid] - m.qpos_spring[qposid]
d.qfrc_spring[worldid, dofid] = -stiffness * fdif
- @wp.kernel
+ @kernel
def _damper_passive(m: Model, d: Data):
worldid, dofid = wp.tid()
damping = m.dof_damping[dofid]
diff --git a/mujoco/mjx/_src/passive_test.py b/mujoco_warp/_src/passive_test.py
similarity index 85%
rename from mujoco/mjx/_src/passive_test.py
rename to mujoco_warp/_src/passive_test.py
index 4a8409bb..7186ab39 100644
--- a/mujoco/mjx/_src/passive_test.py
+++ b/mujoco_warp/_src/passive_test.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,15 +15,15 @@
"""Tests for passive force functions."""
-from absl.testing import absltest
import numpy as np
import warp as wp
+from absl.testing import absltest
-from mujoco import mjx
+import mujoco_warp as mjwarp
from . import test_util
-# tolerance for difference between MuJoCo and MJX smooth calculations - mostly
+# tolerance for difference between MuJoCo and MJWarp passive force calculations - mostly
# due to float precision
_TOLERANCE = 5e-5
@@ -36,18 +36,20 @@ def _assert_eq(a, b, name):
class PassiveTest(absltest.TestCase):
def test_passive(self):
- """Tests MJX passive."""
+ """Tests passive."""
_, mjd, m, d = test_util.fixture("pendula.xml")
for arr in (d.qfrc_spring, d.qfrc_damper, d.qfrc_passive):
arr.zero_()
- mjx.passive(m, d)
+ mjwarp.passive(m, d)
_assert_eq(d.qfrc_spring.numpy()[0], mjd.qfrc_spring, "qfrc_spring")
_assert_eq(d.qfrc_damper.numpy()[0], mjd.qfrc_damper, "qfrc_damper")
_assert_eq(d.qfrc_passive.numpy()[0], mjd.qfrc_passive, "qfrc_passive")
+ # TODO(team): test DisableBit.PASSIVE
+
if __name__ == "__main__":
wp.init()
diff --git a/mujoco/mjx/_src/smooth.py b/mujoco_warp/_src/smooth.py
similarity index 81%
rename from mujoco/mjx/_src/smooth.py
rename to mujoco_warp/_src/smooth.py
index 8f61bb9c..19f70d4c 100644
--- a/mujoco/mjx/_src/smooth.py
+++ b/mujoco_warp/_src/smooth.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,20 +14,26 @@
# ==============================================================================
import warp as wp
-from . import math
-from .types import Model
+from . import math
from .types import Data
+from .types import DisableBit
+from .types import JointType
+from .types import Model
+from .types import TrnType
from .types import array2df
from .types import array3df
from .types import vec10
-from .types import JointType, TrnType
+from .warp_util import event_scope
+from .warp_util import kernel
+from .warp_util import kernel_copy
+@event_scope
def kinematics(m: Model, d: Data):
"""Forward kinematics."""
- @wp.kernel
+ @kernel
def _root(m: Model, d: Data):
worldid = wp.tid()
d.xpos[worldid, 0] = wp.vec3(0.0)
@@ -36,7 +42,7 @@ def _root(m: Model, d: Data):
d.xmat[worldid, 0] = wp.identity(n=3, dtype=wp.float32)
d.ximat[worldid, 0] = wp.identity(n=3, dtype=wp.float32)
- @wp.kernel
+ @kernel
def _level(m: Model, d: Data, leveladr: int):
worldid, nodeid = wp.tid()
bodyid = m.body_tree[leveladr + nodeid]
@@ -52,7 +58,6 @@ def _level(m: Model, d: Data, leveladr: int):
elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value):
# free joint
qadr = m.jnt_qposadr[jntadr]
- # TODO(erikfrey): would it be better to use some kind of wp.copy here?
xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2])
xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6])
d.xanchor[worldid, jntadr] = xpos
@@ -95,8 +100,35 @@ def _level(m: Model, d: Data, leveladr: int):
jntadr += 1
d.xpos[worldid, bodyid] = xpos
- d.xquat[worldid, bodyid] = wp.normalize(xquat)
+ xquat = wp.normalize(xquat)
+ d.xquat[worldid, bodyid] = xquat
d.xmat[worldid, bodyid] = math.quat_to_mat(xquat)
+ d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(m.body_ipos[bodyid], xquat)
+ d.ximat[worldid, bodyid] = math.quat_to_mat(
+ math.mul_quat(xquat, m.body_iquat[bodyid])
+ )
+
+ @kernel
+ def geom_local_to_global(m: Model, d: Data):
+ worldid, geomid = wp.tid()
+ bodyid = m.geom_bodyid[geomid]
+ xpos = d.xpos[worldid, bodyid]
+ xquat = d.xquat[worldid, bodyid]
+ d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(m.geom_pos[geomid], xquat)
+ d.geom_xmat[worldid, geomid] = math.quat_to_mat(
+ math.mul_quat(xquat, m.geom_quat[geomid])
+ )
+
+ @kernel
+ def site_local_to_global(m: Model, d: Data):
+ worldid, siteid = wp.tid()
+ bodyid = m.site_bodyid[siteid]
+ xpos = d.xpos[worldid, bodyid]
+ xquat = d.xquat[worldid, bodyid]
+ d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(m.site_pos[siteid], xquat)
+ d.site_xmat[worldid, siteid] = math.quat_to_mat(
+ math.mul_quat(xquat, m.site_quat[siteid])
+ )
wp.launch(_root, dim=(d.nworld), inputs=[m, d])
@@ -106,35 +138,35 @@ def _level(m: Model, d: Data, leveladr: int):
end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1]
wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg])
+ if m.ngeom:
+ wp.launch(geom_local_to_global, dim=(d.nworld, m.ngeom), inputs=[m, d])
+
+ if m.nsite:
+ wp.launch(site_local_to_global, dim=(d.nworld, m.nsite), inputs=[m, d])
+
+@event_scope
def com_pos(m: Model, d: Data):
"""Map inertias and motion dofs to global frame centered at subtree-CoM."""
- @wp.kernel
- def mass_subtree_acc(m: Model, mass_subtree: wp.array(dtype=float), leveladr: int):
- nodeid = wp.tid()
- bodyid = m.body_tree[leveladr + nodeid]
- pid = m.body_parentid[bodyid]
- wp.atomic_add(mass_subtree, pid, mass_subtree[bodyid])
-
- @wp.kernel
+ @kernel
def subtree_com_init(m: Model, d: Data):
worldid, bodyid = wp.tid()
d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[bodyid]
- @wp.kernel
+ @kernel
def subtree_com_acc(m: Model, d: Data, leveladr: int):
worldid, nodeid = wp.tid()
bodyid = m.body_tree[leveladr + nodeid]
pid = m.body_parentid[bodyid]
wp.atomic_add(d.subtree_com, worldid, pid, d.subtree_com[worldid, bodyid])
- @wp.kernel
- def subtree_div(mass_subtree: wp.array(dtype=float), d: Data):
+ @kernel
+ def subtree_div(m: Model, d: Data):
worldid, bodyid = wp.tid()
- d.subtree_com[worldid, bodyid] /= mass_subtree[bodyid]
+ d.subtree_com[worldid, bodyid] /= m.subtree_mass[bodyid]
- @wp.kernel
+ @kernel
def cinert(m: Model, d: Data):
worldid, bodyid = wp.tid()
mat = d.ximat[worldid, bodyid]
@@ -168,7 +200,7 @@ def cinert(m: Model, d: Data):
d.cinert[worldid, bodyid] = res
- @wp.kernel
+ @kernel
def cdof(m: Model, d: Data):
worldid, jntid = wp.tid()
bodyid = m.jnt_bodyid[jntid]
@@ -199,31 +231,27 @@ def cdof(m: Model, d: Data):
elif jnt_type == wp.static(JointType.HINGE.value): # hinge
res[dofid] = wp.spatial_vector(xaxis, wp.cross(xaxis, offset))
- body_treeadr = m.body_treeadr.numpy()
- mass_subtree = wp.clone(m.body_mass)
- for i in reversed(range(len(body_treeadr))):
- beg = body_treeadr[i]
- end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1]
- wp.launch(mass_subtree_acc, dim=(end - beg,), inputs=[m, mass_subtree, beg])
-
wp.launch(subtree_com_init, dim=(d.nworld, m.nbody), inputs=[m, d])
+ body_treeadr = m.body_treeadr.numpy()
+
for i in reversed(range(len(body_treeadr))):
beg = body_treeadr[i]
end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1]
wp.launch(subtree_com_acc, dim=(d.nworld, end - beg), inputs=[m, d, beg])
- wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[mass_subtree, d])
+ wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[m, d])
wp.launch(cinert, dim=(d.nworld, m.nbody), inputs=[m, d])
wp.launch(cdof, dim=(d.nworld, m.njnt), inputs=[m, d])
+@event_scope
def crb(m: Model, d: Data):
"""Composite rigid body inertia algorithm."""
- wp.copy(d.crb, d.cinert)
+ kernel_copy(d.crb, d.cinert)
- @wp.kernel
+ @kernel
def crb_accumulate(m: Model, d: Data, leveladr: int):
worldid, nodeid = wp.tid()
bodyid = m.body_tree[leveladr + nodeid]
@@ -232,7 +260,7 @@ def crb_accumulate(m: Model, d: Data, leveladr: int):
return
wp.atomic_add(d.crb, worldid, pid, d.crb[worldid, bodyid])
- @wp.kernel
+ @kernel
def qM_sparse(m: Model, d: Data):
worldid, dofid = wp.tid()
madr_ij = m.dof_Madr[dofid]
@@ -250,21 +278,27 @@ def qM_sparse(m: Model, d: Data):
madr_ij += 1
dofid = m.dof_parentid[dofid]
- @wp.kernel
+ @kernel
def qM_dense(m: Model, d: Data):
worldid, dofid = wp.tid()
bodyid = m.dof_bodyid[dofid]
# init M(i,i) with armature inertia
- d.qM[worldid, dofid, dofid] = m.dof_armature[dofid]
+ M = m.dof_armature[dofid]
# precompute buf = crb_body_i * cdof_i
buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid])
+ M += wp.dot(d.cdof[worldid, dofid], buf)
+
+ d.qM[worldid, dofid, dofid] = M
# sparse backward pass over ancestors
dofidi = dofid
+ dofid = m.dof_parentid[dofid]
while dofid >= 0:
- d.qM[worldid, dofidi, dofid] += wp.dot(d.cdof[worldid, dofid], buf)
+ qMij = wp.dot(d.cdof[worldid, dofid], buf)
+ d.qM[worldid, dofidi, dofid] += qMij
+ d.qM[worldid, dofid, dofidi] += qMij
dofid = m.dof_parentid[dofid]
body_treeadr = m.body_treeadr.numpy()
@@ -283,7 +317,7 @@ def qM_dense(m: Model, d: Data):
def _factor_i_sparse(m: Model, d: Data, M: array3df, L: array3df, D: array2df):
"""Sparse L'*D*L factorizaton of inertia-like matrix M, assumed spd."""
- @wp.kernel
+ @kernel
def qLD_acc(m: Model, leveladr: int, L: array3df):
worldid, nodeid = wp.tid()
update = m.qLD_update_tree[leveladr + nodeid]
@@ -297,12 +331,12 @@ def qLD_acc(m: Model, leveladr: int, L: array3df):
# M(k,i) = tmp
L[worldid, 0, Madr_ki] = tmp
- @wp.kernel
+ @kernel
def qLDiag_div(m: Model, L: array3df, D: array2df):
worldid, dofid = wp.tid()
D[worldid, dofid] = 1.0 / L[worldid, 0, m.dof_Madr[dofid]]
- wp.copy(L, M)
+ kernel_copy(L, M)
qLD_update_treeadr = m.qLD_update_treeadr.numpy()
@@ -323,7 +357,7 @@ def _factor_i_dense(m: Model, d: Data, M: wp.array, L: wp.array):
block_dim = 32
def tile_cholesky(adr: int, size: int, tilesize: int):
- @wp.kernel
+ @kernel
def cholesky(m: Model, leveladr: int, M: array3df, L: array3df):
worldid, nodeid = wp.tid()
dofid = m.qLD_tile[leveladr + nodeid]
@@ -355,27 +389,25 @@ def factor_i(m: Model, d: Data, M, L, D=None):
_factor_i_dense(m, d, M, L)
+@event_scope
def factor_m(m: Model, d: Data):
"""Factorizaton of inertia-like matrix M, assumed spd."""
factor_i(m, d, d.qM, d.qLD, d.qLDiagInv)
+@event_scope
def rne(m: Model, d: Data):
"""Computes inverse dynamics using Newton-Euler algorithm."""
- cacc = wp.zeros(shape=(d.nworld, m.nbody), dtype=wp.spatial_vector)
- cfrc = wp.zeros(shape=(d.nworld, m.nbody), dtype=wp.spatial_vector)
-
- @wp.kernel
- def cacc_gravity(m: Model, cacc: wp.array(dtype=wp.spatial_vector, ndim=2)):
+ @kernel
+ def cacc_gravity(m: Model, d: Data):
worldid = wp.tid()
- cacc[worldid, 0] = wp.spatial_vector(wp.vec3(0.0), -m.opt.gravity)
+ d.rne_cacc[worldid, 0] = wp.spatial_vector(wp.vec3(0.0), -m.opt.gravity)
- @wp.kernel
+ @kernel
def cacc_level(
m: Model,
d: Data,
- cacc: wp.array(dtype=wp.spatial_vector, ndim=2),
leveladr: int,
):
worldid, nodeid = wp.tid()
@@ -383,62 +415,64 @@ def cacc_level(
dofnum = m.body_dofnum[bodyid]
pid = m.body_parentid[bodyid]
dofadr = m.body_dofadr[bodyid]
- local_cacc = cacc[worldid, pid]
+ local_cacc = d.rne_cacc[worldid, pid]
for i in range(dofnum):
local_cacc += d.cdof_dot[worldid, dofadr + i] * d.qvel[worldid, dofadr + i]
- cacc[worldid, bodyid] = local_cacc
+ d.rne_cacc[worldid, bodyid] = local_cacc
- @wp.kernel
- def frc_fn(
- d: Data,
- cfrc: wp.array(dtype=wp.spatial_vector, ndim=2),
- cacc: wp.array(dtype=wp.spatial_vector, ndim=2),
- ):
+ @kernel
+ def frc_fn(d: Data):
worldid, bodyid = wp.tid()
- frc = math.inert_vec(d.cinert[worldid, bodyid], cacc[worldid, bodyid])
+ frc = math.inert_vec(d.cinert[worldid, bodyid], d.rne_cacc[worldid, bodyid])
frc += math.motion_cross_force(
d.cvel[worldid, bodyid],
math.inert_vec(d.cinert[worldid, bodyid], d.cvel[worldid, bodyid]),
)
- cfrc[worldid, bodyid] += frc
+ d.rne_cfrc[worldid, bodyid] = frc
- @wp.kernel
- def cfrc_fn(m: Model, cfrc: wp.array(dtype=wp.spatial_vector, ndim=2), leveladr: int):
+ @kernel
+ def cfrc_fn(m: Model, d: Data, leveladr: int):
worldid, nodeid = wp.tid()
bodyid = m.body_tree[leveladr + nodeid]
pid = m.body_parentid[bodyid]
- wp.atomic_add(cfrc[worldid], pid, cfrc[worldid, bodyid])
+ wp.atomic_add(d.rne_cfrc[worldid], pid, d.rne_cfrc[worldid, bodyid])
- @wp.kernel
- def qfrc_bias(m: Model, d: Data, cfrc: wp.array(dtype=wp.spatial_vector, ndim=2)):
+ @kernel
+ def qfrc_bias(m: Model, d: Data):
worldid, dofid = wp.tid()
bodyid = m.dof_bodyid[dofid]
- d.qfrc_bias[worldid, dofid] = wp.dot(d.cdof[worldid, dofid], cfrc[worldid, bodyid])
+ d.qfrc_bias[worldid, dofid] = wp.dot(
+ d.cdof[worldid, dofid], d.rne_cfrc[worldid, bodyid]
+ )
- wp.launch(cacc_gravity, dim=[d.nworld], inputs=[m, cacc])
+ if m.opt.disableflags & DisableBit.GRAVITY:
+ d.rne_cacc.zero_()
+ else:
+ wp.launch(cacc_gravity, dim=[d.nworld], inputs=[m, d])
body_treeadr = m.body_treeadr.numpy()
for i in range(len(body_treeadr)):
beg = body_treeadr[i]
end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1]
- wp.launch(cacc_level, dim=(d.nworld, end - beg), inputs=[m, d, cacc, beg])
+ wp.launch(cacc_level, dim=(d.nworld, end - beg), inputs=[m, d, beg])
- wp.launch(frc_fn, dim=[d.nworld, m.nbody], inputs=[d, cfrc, cacc])
+ wp.launch(frc_fn, dim=[d.nworld, m.nbody], inputs=[d])
for i in reversed(range(len(body_treeadr))):
beg = body_treeadr[i]
end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1]
- wp.launch(cfrc_fn, dim=[d.nworld, end - beg], inputs=[m, cfrc, beg])
+ wp.launch(cfrc_fn, dim=[d.nworld, end - beg], inputs=[m, d, beg])
- wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d, cfrc])
+ wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d])
+@event_scope
def transmission(m: Model, d: Data):
"""Computes actuator/transmission lengths and moments."""
if not m.nu:
return d
- @wp.kernel
+ @kernel
def _transmission(
m: Model,
d: Data,
@@ -502,15 +536,16 @@ def _transmission(
)
+@event_scope
def com_vel(m: Model, d: Data):
"""Computes cvel, cdof_dot."""
- @wp.kernel
+ @kernel
def _root(d: Data):
worldid, elementid = wp.tid()
d.cvel[worldid, 0][elementid] = 0.0
- @wp.kernel
+ @kernel
def _level(m: Model, d: Data, leveladr: int):
worldid, nodeid = wp.tid()
bodyid = m.body_tree[leveladr + nodeid]
@@ -576,26 +611,26 @@ def _solve_LD_sparse(
):
"""Computes sparse backsubstitution: x = inv(L'*D*L)*y"""
- @wp.kernel
+ @kernel
def x_acc_up(m: Model, L: array3df, x: array2df, leveladr: int):
worldid, nodeid = wp.tid()
update = m.qLD_update_tree[leveladr + nodeid]
i, k, Madr_ki = update[0], update[1], update[2]
wp.atomic_sub(x[worldid], i, L[worldid, 0, Madr_ki] * x[worldid, k])
- @wp.kernel
+ @kernel
def qLDiag_mul(D: array2df, x: array2df):
worldid, dofid = wp.tid()
x[worldid, dofid] *= D[worldid, dofid]
- @wp.kernel
+ @kernel
def x_acc_down(m: Model, L: array3df, x: array2df, leveladr: int):
worldid, nodeid = wp.tid()
update = m.qLD_update_tree[leveladr + nodeid]
i, k, Madr_ki = update[0], update[1], update[2]
wp.atomic_sub(x[worldid], k, L[worldid, 0, Madr_ki] * x[worldid, i])
- wp.copy(x, y)
+ kernel_copy(x, y)
qLD_update_treeadr = m.qLD_update_treeadr.numpy()
@@ -623,7 +658,7 @@ def _solve_LD_dense(m: Model, d: Data, L: array3df, x: array2df, y: array2df):
block_dim = 32
def tile_cho_solve(adr: int, size: int, tilesize: int):
- @wp.kernel
+ @kernel
def cho_solve(m: Model, L: array3df, x: array2df, y: array2df, leveladr: int):
worldid, nodeid = wp.tid()
dofid = m.qLD_tile[leveladr + nodeid]
@@ -655,6 +690,45 @@ def solve_LD(m: Model, d: Data, L: array3df, D: array2df, x: array2df, y: array2
_solve_LD_dense(m, d, L, x, y)
+@event_scope
def solve_m(m: Model, d: Data, x: array2df, y: array2df):
"""Computes backsubstitution: x = qLD * y."""
solve_LD(m, d, d.qLD, d.qLDiagInv, x, y)
+
+
+def _factor_solve_i_dense(m: Model, d: Data, M: array3df, x: array2df, y: array2df):
+ # TODO(team): develop heuristic for block dim, or make configurable
+ block_dim = 32
+
+ def tile_cholesky(adr: int, size: int, tilesize: int):
+ @kernel(module="unique")
+ def cholesky(m: Model, leveladr: int, M: array3df, x: array2df, y: array2df):
+ worldid, nodeid = wp.tid()
+ dofid = m.qLD_tile[leveladr + nodeid]
+ M_tile = wp.tile_load(
+ M[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid)
+ )
+ y_slice = wp.tile_load(y[worldid], shape=(tilesize,), offset=(dofid,))
+
+ L_tile = wp.tile_cholesky(M_tile)
+ x_slice = wp.tile_cholesky_solve(L_tile, y_slice)
+ wp.tile_store(x[worldid], x_slice, offset=(dofid,))
+
+ wp.launch_tiled(
+ cholesky, dim=(d.nworld, size), inputs=[m, adr, M, x, y], block_dim=block_dim
+ )
+
+ qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy()
+
+ for i in range(len(qLD_tileadr)):
+ beg = qLD_tileadr[i]
+ end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1]
+ tile_cholesky(beg, end - beg, int(qLD_tilesize[i]))
+
+
+def factor_solve_i(m, d, M, L, D, x, y):
+ if m.opt.is_sparse:
+ _factor_i_sparse(m, d, M, L, D)
+ _solve_LD_sparse(m, d, L, D, x, y)
+ else:
+ _factor_solve_i_dense(m, d, M, x, y)
diff --git a/mujoco/mjx/_src/smooth_test.py b/mujoco_warp/_src/smooth_test.py
similarity index 67%
rename from mujoco/mjx/_src/smooth_test.py
rename to mujoco_warp/_src/smooth_test.py
index e0fe3122..ed40e91d 100644
--- a/mujoco/mjx/_src/smooth_test.py
+++ b/mujoco_warp/_src/smooth_test.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,18 +15,19 @@
"""Tests for smooth dynamics functions."""
-from absl.testing import absltest
-from absl.testing import parameterized
import mujoco
-from mujoco import mjx
import numpy as np
import warp as wp
+from absl.testing import absltest
+from absl.testing import parameterized
-wp.config.verify_cuda = True
+import mujoco_warp as mjwarp
from . import test_util
-# tolerance for difference between MuJoCo and mjWarp smooth calculations - mostly
+wp.config.verify_cuda = True
+
+# tolerance for difference between MuJoCo and MJWarp smooth calculations - mostly
# due to float precision
_TOLERANCE = 5e-5
@@ -45,12 +46,19 @@ def test_kinematics(self):
for arr in (d.xanchor, d.xaxis, d.xquat, d.xpos):
arr.zero_()
- mjx.kinematics(m, d)
+ mjwarp.kinematics(m, d)
_assert_eq(d.xanchor.numpy()[0], mjd.xanchor, "xanchor")
_assert_eq(d.xaxis.numpy()[0], mjd.xaxis, "xaxis")
- _assert_eq(d.xquat.numpy()[0], mjd.xquat, "xquat")
_assert_eq(d.xpos.numpy()[0], mjd.xpos, "xpos")
+ _assert_eq(d.xquat.numpy()[0], mjd.xquat, "xquat")
+ _assert_eq(d.xmat.numpy()[0], mjd.xmat.reshape((-1, 3, 3)), "xmat")
+ _assert_eq(d.xipos.numpy()[0], mjd.xipos, "xipos")
+ _assert_eq(d.ximat.numpy()[0], mjd.ximat.reshape((-1, 3, 3)), "ximat")
+ _assert_eq(d.geom_xpos.numpy()[0], mjd.geom_xpos, "geom_xpos")
+ _assert_eq(d.geom_xmat.numpy()[0], mjd.geom_xmat.reshape((-1, 3, 3)), "geom_xmat")
+ _assert_eq(d.site_xpos.numpy()[0], mjd.site_xpos, "site_xpos")
+ _assert_eq(d.site_xmat.numpy()[0], mjd.site_xmat.reshape((-1, 3, 3)), "site_xmat")
def test_com_pos(self):
"""Tests com_pos."""
@@ -59,51 +67,49 @@ def test_com_pos(self):
for arr in (d.subtree_com, d.cinert, d.cdof):
arr.zero_()
- mjx.com_pos(m, d)
+ mjwarp.com_pos(m, d)
_assert_eq(d.subtree_com.numpy()[0], mjd.subtree_com, "subtree_com")
_assert_eq(d.cinert.numpy()[0], mjd.cinert, "cinert")
_assert_eq(d.cdof.numpy()[0], mjd.cdof, "cdof")
- def test_crb(self):
+ @parameterized.parameters(True, False)
+ def test_crb(self, sparse: bool):
"""Tests crb."""
- _, mjd, m, d = test_util.fixture("pendula.xml")
+ mjm, mjd, m, d = test_util.fixture("pendula.xml", sparse=sparse)
d.crb.zero_()
- mjx.crb(m, d)
+ mjwarp.crb(m, d)
_assert_eq(d.crb.numpy()[0], mjd.crb, "crb")
- _assert_eq(d.qM.numpy()[0, 0], mjd.qM, "qM")
- def test_factor_m_sparse(self):
- """Tests factor_m (sparse)."""
- _, mjd, m, d = test_util.fixture("pendula.xml", sparse=True)
+ if sparse:
+ _assert_eq(d.qM.numpy()[0, 0], mjd.qM, "qM")
+ else:
+ qM = np.zeros((mjm.nv, mjm.nv))
+ mujoco.mj_fullM(mjm, qM, mjd.qM)
+ _assert_eq(d.qM.numpy()[0], qM, "qM")
+ @parameterized.parameters(True, False)
+ def test_factor_m(self, sparse: bool):
+ """Tests factor_m."""
+ _, mjd, m, d = test_util.fixture("pendula.xml", sparse=sparse)
+
+ qLD = d.qLD.numpy()[0].copy()
for arr in (d.qLD, d.qLDiagInv):
arr.zero_()
- mjx.factor_m(m, d)
- _assert_eq(d.qLD.numpy()[0, 0], mjd.qLD, "qLD (sparse)")
- _assert_eq(d.qLDiagInv.numpy()[0], mjd.qLDiagInv, "qLDiagInv")
+ mjwarp.factor_m(m, d)
- def test_factor_m_dense(self):
- """Tests MJX factor_m (dense)."""
- # TODO(team): switch this to pendula.xml and merge with above test
- # after mmacklin's tile_cholesky fixes are in
- _, mjd, m, d = test_util.fixture("humanoid/humanoid.xml", sparse=False)
-
- qLD = d.qLD.numpy()[0].copy()
- d.qLD.zero_()
-
- mjx.factor_m(m, d)
- _assert_eq(d.qLD.numpy()[0], qLD, "qLD (dense)")
+ if sparse:
+ _assert_eq(d.qLD.numpy()[0, 0], mjd.qLD, "qLD (sparse)")
+ _assert_eq(d.qLDiagInv.numpy()[0], mjd.qLDiagInv, "qLDiagInv")
+ else:
+ _assert_eq(d.qLD.numpy()[0], qLD, "qLD (dense)")
@parameterized.parameters(True, False)
def test_solve_m(self, sparse: bool):
"""Tests solve_m."""
- # TODO(team): switch this to pendula.xml and merge with above test
- # after mmacklin's tile_cholesky fixes are in
- fname = "pendula.xml" if sparse else "humanoid/humanoid.xml"
- mjm, mjd, m, d = test_util.fixture(fname, sparse=sparse)
+ mjm, mjd, m, d = test_util.fixture("pendula.xml", sparse=sparse)
qfrc_smooth = np.tile(mjd.qfrc_smooth, (1, 1))
qacc_smooth = np.zeros(
@@ -117,7 +123,7 @@ def test_solve_m(self, sparse: bool):
d.qacc_smooth.zero_()
- mjx.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth)
+ mjwarp.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth)
_assert_eq(d.qacc_smooth.numpy()[0], qacc_smooth[0], "qacc_smooth")
def test_rne(self):
@@ -126,9 +132,11 @@ def test_rne(self):
d.qfrc_bias.zero_()
- mjx.rne(m, d)
+ mjwarp.rne(m, d)
_assert_eq(d.qfrc_bias.numpy()[0], mjd.qfrc_bias, "qfrc_bias")
+ # TODO(team): test DisableBit.GRAVITY
+
def test_com_vel(self):
"""Tests com_vel."""
_, mjd, m, d = test_util.fixture("pendula.xml")
@@ -136,7 +144,7 @@ def test_com_vel(self):
for arr in (d.cvel, d.cdof_dot):
arr.zero_()
- mjx.com_vel(m, d)
+ mjwarp.com_vel(m, d)
_assert_eq(d.cvel.numpy()[0], mjd.cvel, "cvel")
_assert_eq(d.cdof_dot.numpy()[0], mjd.cdof_dot, "cdof_dot")
@@ -156,7 +164,7 @@ def test_transmission(self):
mjd.moment_colind,
)
- mjx._src.smooth.transmission(m, d)
+ mjwarp._src.smooth.transmission(m, d)
_assert_eq(d.actuator_length.numpy()[0], mjd.actuator_length, "actuator_length")
_assert_eq(d.actuator_moment.numpy()[0], actuator_moment, "actuator_moment")
diff --git a/mujoco_warp/_src/solver.py b/mujoco_warp/_src/solver.py
new file mode 100644
index 00000000..9adb2770
--- /dev/null
+++ b/mujoco_warp/_src/solver.py
@@ -0,0 +1,919 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import warp as wp
+
+from . import smooth
+from . import support
+from . import types
+from .warp_util import event_scope
+from .warp_util import kernel
+from .warp_util import kernel_copy
+
+
+def _create_context(m: types.Model, d: types.Data, grad: bool = True):
+ @kernel
+ def _init_context(d: types.Data):
+ worldid = wp.tid()
+ d.efc.cost[worldid] = wp.inf
+ d.efc.solver_niter[worldid] = 0
+ d.efc.done[worldid] = False
+ if grad:
+ d.efc.search_dot[worldid] = 0.0
+
+ @kernel
+ def _jaref(m: types.Model, d: types.Data):
+ efcid, dofid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+ wp.atomic_add(
+ d.efc.Jaref,
+ efcid,
+ d.efc.J[efcid, dofid] * d.qacc[worldid, dofid] - d.efc.aref[efcid] / float(m.nv),
+ )
+
+ @kernel
+ def _search(d: types.Data):
+ worldid, dofid = wp.tid()
+ search = -1.0 * d.efc.Mgrad[worldid, dofid]
+ d.efc.search[worldid, dofid] = search
+ wp.atomic_add(d.efc.search_dot, worldid, search * search)
+
+ wp.launch(_init_context, dim=(d.nworld), inputs=[d])
+
+ # jaref = d.efc_J @ d.qacc - d.efc_aref
+ d.efc.Jaref.zero_()
+
+ wp.launch(_jaref, dim=(d.njmax, m.nv), inputs=[m, d])
+
+ # Ma = qM @ qacc
+ support.mul_m(m, d, d.efc.Ma, d.qacc, d.efc.done)
+
+ _update_constraint(m, d)
+ if grad:
+ _update_gradient(m, d)
+
+ # search = -Mgrad
+ wp.launch(_search, dim=(d.nworld, m.nv), inputs=[d])
+
+
+def _update_constraint(m: types.Model, d: types.Data):
+ @kernel
+ def _init_cost(d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.prev_cost[worldid] = d.efc.cost[worldid]
+ d.efc.cost[worldid] = 0.0
+ d.efc.gauss[worldid] = 0.0
+
+ @kernel
+ def _efc_kernel(d: types.Data):
+ efcid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ Jaref = d.efc.Jaref[efcid]
+ efc_D = d.efc.D[efcid]
+
+ # TODO(team): active and conditionally active constraints
+ active = int(Jaref < 0.0)
+ d.efc.active[efcid] = active
+
+ if active:
+ # efc_force = -efc_D * Jaref * active
+ d.efc.force[efcid] = -1.0 * efc_D * Jaref
+
+ # cost = 0.5 * sum(efc_D * Jaref * Jaref * active))
+ wp.atomic_add(d.efc.cost, worldid, 0.5 * efc_D * Jaref * Jaref)
+ else:
+ d.efc.force[efcid] = 0.0
+
+ @kernel
+ def _zero_qfrc_constraint(d: types.Data):
+ worldid, dofid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.qfrc_constraint[worldid, dofid] = 0.0
+
+ @kernel
+ def _qfrc_constraint(d: types.Data):
+ dofid, efcid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ wp.atomic_add(
+ d.qfrc_constraint[worldid],
+ dofid,
+ d.efc.J[efcid, dofid] * d.efc.force[efcid],
+ )
+
+ @kernel
+ def _gauss(d: types.Data):
+ worldid, dofid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ gauss_cost = (
+ 0.5
+ * (d.efc.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid])
+ * (d.qacc[worldid, dofid] - d.qacc_smooth[worldid, dofid])
+ )
+ wp.atomic_add(d.efc.gauss, worldid, gauss_cost)
+ wp.atomic_add(d.efc.cost, worldid, gauss_cost)
+
+ wp.launch(_init_cost, dim=(d.nworld), inputs=[d])
+
+ wp.launch(_efc_kernel, dim=(d.njmax,), inputs=[d])
+
+ # qfrc_constraint = efc_J.T @ efc_force
+ wp.launch(_zero_qfrc_constraint, dim=(d.nworld, m.nv), inputs=[d])
+
+ wp.launch(_qfrc_constraint, dim=(m.nv, d.njmax), inputs=[d])
+
+ # gauss = 0.5 * (Ma - qfrc_smooth).T @ (qacc - qacc_smooth)
+
+ wp.launch(_gauss, dim=(d.nworld, m.nv), inputs=[d])
+
+
+def _update_gradient(m: types.Model, d: types.Data):
+ TILE = m.nv
+ ITERATIONS = m.opt.iterations
+
+ @kernel
+ def _zero_grad_dot(d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.grad_dot[worldid] = 0.0
+
+ @kernel
+ def _grad(d: types.Data):
+ worldid, dofid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ grad = (
+ d.efc.Ma[worldid, dofid]
+ - d.qfrc_smooth[worldid, dofid]
+ - d.qfrc_constraint[worldid, dofid]
+ )
+ d.efc.grad[worldid, dofid] = grad
+ wp.atomic_add(d.efc.grad_dot, worldid, grad * grad)
+
+ if m.opt.is_sparse:
+
+ @kernel
+ def _zero_h_lower(m: types.Model, d: types.Data):
+ # TODO(team): static m?
+ worldid, elementid = wp.tid()
+
+ if ITERATIONS > 1:
+ if d.efc.done[worldid]:
+ return
+
+ rowid = m.dof_tri_row[elementid]
+ colid = m.dof_tri_col[elementid]
+ d.efc.h[worldid, rowid, colid] = 0.0
+
+ @kernel
+ def _set_h_qM_lower_sparse(m: types.Model, d: types.Data):
+ # TODO(team): static m?
+ worldid, elementid = wp.tid()
+
+ if ITERATIONS > 1:
+ if d.efc.done[worldid]:
+ return
+
+ i = m.qM_fullm_i[elementid]
+ j = m.qM_fullm_j[elementid]
+ d.efc.h[worldid, i, j] = d.qM[worldid, 0, elementid]
+
+ else:
+
+ @kernel
+ def _copy_lower_triangle(m: types.Model, d: types.Data):
+ # TODO(team): static m?
+ worldid, elementid = wp.tid()
+
+ if ITERATIONS > 1:
+ if d.efc.done[worldid]:
+ return
+
+ rowid = m.dof_tri_row[elementid]
+ colid = m.dof_tri_col[elementid]
+ d.efc.h[worldid, rowid, colid] = d.qM[worldid, rowid, colid]
+
+ @kernel
+ def _JTDAJ(m: types.Model, d: types.Data):
+ # TODO(team): static m?
+ efcid, elementid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if ITERATIONS > 1:
+ if d.efc.done[worldid]:
+ return
+
+ dofi = m.dof_tri_row[elementid]
+ dofj = m.dof_tri_col[elementid]
+
+ efc_D = d.efc.D[efcid]
+ active = d.efc.active[efcid]
+ if efc_D == 0.0 or active == 0:
+ return
+
+ # TODO(team): sparse efc_J
+ wp.atomic_add(
+ d.efc.h[worldid, dofi],
+ dofj,
+ d.efc.J[efcid, dofi] * d.efc.J[efcid, dofj] * efc_D,
+ )
+
+ @kernel
+ def _cholesky(d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ mat_tile = wp.tile_load(d.efc.h[worldid], shape=(TILE, TILE))
+ fact_tile = wp.tile_cholesky(mat_tile)
+ input_tile = wp.tile_load(d.efc.grad[worldid], shape=TILE)
+ output_tile = wp.tile_cholesky_solve(fact_tile, input_tile)
+ wp.tile_store(d.efc.Mgrad[worldid], output_tile)
+
+ # grad = Ma - qfrc_smooth - qfrc_constraint
+ wp.launch(_zero_grad_dot, dim=(d.nworld), inputs=[d])
+
+ wp.launch(_grad, dim=(d.nworld, m.nv), inputs=[d])
+
+ if m.opt.solver == types.SolverType.CG:
+ smooth.solve_m(m, d, d.efc.Mgrad, d.efc.grad)
+ elif m.opt.solver == types.SolverType.NEWTON:
+ # h = qM + (efc_J.T * efc_D * active) @ efc_J
+ if m.opt.is_sparse:
+ wp.launch(_zero_h_lower, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d])
+
+ wp.launch(
+ _set_h_qM_lower_sparse, dim=(d.nworld, m.qM_fullm_i.size), inputs=[m, d]
+ )
+ else:
+ wp.launch(_copy_lower_triangle, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d])
+
+ wp.launch(_JTDAJ, dim=(d.njmax, m.dof_tri_row.size), inputs=[m, d])
+
+ wp.launch_tiled(_cholesky, dim=(d.nworld,), inputs=[d], block_dim=32)
+
+
+@wp.func
+def _rescale(m: types.Model, value: float) -> float:
+ return value / (m.stat.meaninertia * float(wp.max(1, m.nv)))
+
+
+@wp.func
+def _in_bracket(x: wp.vec3, y: wp.vec3) -> bool:
+ return (x[1] < y[1] and y[1] < 0.0) or (x[1] > y[1] and y[1] > 0.0)
+
+
+@wp.func
+def _eval_pt(quad: wp.vec3, alpha: wp.float32) -> wp.vec3:
+ return wp.vec3(
+ alpha * alpha * quad[2] + alpha * quad[1] + quad[0],
+ 2.0 * alpha * quad[2] + quad[1],
+ 2.0 * quad[2],
+ )
+
+
+@wp.func
+def _safe_div(x: wp.float32, y: wp.float32) -> wp.float32:
+ return x / wp.where(y != 0.0, y, types.MJ_MINVAL)
+
+
+def _linesearch_iterative(m: types.Model, d: types.Data):
+ ITERATIONS = m.opt.iterations
+
+ @kernel
+ def _gtol(m: types.Model, d: types.Data):
+ # TODO(team): static m?
+ worldid = wp.tid()
+
+ if m.opt.iterations > 1:
+ if d.efc.done[worldid]:
+ return
+
+ snorm = wp.math.sqrt(d.efc.search_dot[worldid])
+ scale = m.stat.meaninertia * wp.float(wp.max(1, m.nv))
+ d.efc.gtol[worldid] = m.opt.tolerance * m.opt.ls_tolerance * snorm * scale
+
+ @kernel
+ def _zero_jv(d: types.Data):
+ efcid = wp.tid()
+
+ if d.efc.done[d.efc.worldid[efcid]]:
+ return
+
+ d.efc.jv[efcid] = 0.0
+
+ @kernel
+ def _jv(d: types.Data):
+ efcid, dofid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ j = d.efc.J[efcid, dofid]
+ search = d.efc.search[d.efc.worldid[efcid], dofid]
+ wp.atomic_add(d.efc.jv, efcid, j * search)
+
+ @kernel
+ def _zero_quad_gauss(d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.quad_gauss[worldid] = wp.vec3(0.0)
+
+ @kernel
+ def _init_quad_gauss(m: types.Model, d: types.Data):
+ # TODO(team): static m?
+ worldid, dofid = wp.tid()
+
+ if ITERATIONS > 1:
+ if d.efc.done[worldid]:
+ return
+
+ search = d.efc.search[worldid, dofid]
+ quad_gauss = wp.vec3()
+ quad_gauss[0] = d.efc.gauss[worldid] / float(m.nv)
+ quad_gauss[1] = search * (d.efc.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid])
+ quad_gauss[2] = 0.5 * search * d.efc.mv[worldid, dofid]
+ wp.atomic_add(d.efc.quad_gauss, worldid, quad_gauss)
+
+ @kernel
+ def _init_quad(d: types.Data):
+ efcid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ Jaref = d.efc.Jaref[efcid]
+ jv = d.efc.jv[efcid]
+ efc_D = d.efc.D[efcid]
+ quad = wp.vec3()
+ quad[0] = 0.5 * Jaref * Jaref * efc_D
+ quad[1] = jv * Jaref * efc_D
+ quad[2] = 0.5 * jv * jv * efc_D
+ d.efc.quad[efcid] = quad
+
+ @kernel
+ def _init_p0_gauss(p0: wp.array(dtype=wp.vec3), d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ quad = d.efc.quad_gauss[worldid]
+ p0[worldid] = wp.vec3(quad[0], quad[1], 2.0 * quad[2])
+
+ @kernel
+ def _init_p0(p0: wp.array(dtype=wp.vec3), d: types.Data):
+ efcid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ # TODO(team): active and conditionally active constraints:
+ if d.efc.Jaref[efcid] >= 0.0:
+ return
+
+ quad = d.efc.quad[efcid]
+ wp.atomic_add(p0, worldid, wp.vec3(quad[0], quad[1], 2.0 * quad[2]))
+
+ @kernel
+ def _init_lo_gauss(
+ p0: wp.array(dtype=wp.vec3),
+ lo: wp.array(dtype=wp.vec3),
+ lo_alpha: wp.array(dtype=wp.float32),
+ d: types.Data,
+ ):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ pp0 = p0[worldid]
+ alpha = -_safe_div(pp0[1], pp0[2])
+ lo[worldid] = _eval_pt(d.efc.quad_gauss[worldid], alpha)
+ lo_alpha[worldid] = alpha
+
+ @kernel
+ def _init_lo(
+ lo: wp.array(dtype=wp.vec3),
+ lo_alpha: wp.array(dtype=wp.float32),
+ d: types.Data,
+ ):
+ efcid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ alpha = lo_alpha[worldid]
+
+ # TODO(team): active and conditionally active constraints
+ if d.efc.Jaref[efcid] + alpha * d.efc.jv[efcid] < 0.0:
+ wp.atomic_add(lo, worldid, _eval_pt(d.efc.quad[efcid], alpha))
+
+ @kernel
+ def _init_bounds(
+ p0: wp.array(dtype=wp.vec3),
+ lo: wp.array(dtype=wp.vec3),
+ lo_alpha: wp.array(dtype=wp.float32),
+ hi: wp.array(dtype=wp.vec3),
+ hi_alpha: wp.array(dtype=wp.float32),
+ d: types.Data,
+ ):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ pp0 = p0[worldid]
+ plo = lo[worldid]
+ plo_alpha = lo_alpha[worldid]
+ lo_less = plo[1] < pp0[1]
+ lo[worldid] = wp.where(lo_less, plo, pp0)
+ lo_alpha[worldid] = wp.where(lo_less, plo_alpha, 0.0)
+ hi[worldid] = wp.where(lo_less, pp0, plo)
+ hi_alpha[worldid] = wp.where(lo_less, 0.0, plo_alpha)
+
+ @kernel
+ def _next_alpha_gauss(
+ done: wp.array(dtype=bool),
+ lo: wp.array(dtype=wp.vec3),
+ lo_alpha: wp.array(dtype=wp.float32),
+ hi: wp.array(dtype=wp.vec3),
+ hi_alpha: wp.array(dtype=wp.float32),
+ lo_next: wp.array(dtype=wp.vec3),
+ lo_next_alpha: wp.array(dtype=wp.float32),
+ hi_next: wp.array(dtype=wp.vec3),
+ hi_next_alpha: wp.array(dtype=wp.float32),
+ mid: wp.array(dtype=wp.vec3),
+ mid_alpha: wp.array(dtype=wp.float32),
+ d: types.Data,
+ ):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ if done[worldid]:
+ return
+
+ quad = d.efc.quad_gauss[worldid]
+
+ plo = lo[worldid]
+ plo_alpha = lo_alpha[worldid]
+ plo_next_alpha = plo_alpha - _safe_div(plo[1], plo[2])
+ lo_next[worldid] = _eval_pt(quad, plo_next_alpha)
+ lo_next_alpha[worldid] = plo_next_alpha
+
+ phi = hi[worldid]
+ phi_alpha = hi_alpha[worldid]
+ phi_next_alpha = phi_alpha - _safe_div(phi[1], phi[2])
+ hi_next[worldid] = _eval_pt(quad, phi_next_alpha)
+ hi_next_alpha[worldid] = phi_next_alpha
+
+ pmid_alpha = 0.5 * (plo_alpha + phi_alpha)
+ mid[worldid] = _eval_pt(quad, pmid_alpha)
+ mid_alpha[worldid] = pmid_alpha
+
+ @kernel
+ def _next_quad(
+ done: wp.array(dtype=bool),
+ lo_next: wp.array(dtype=wp.vec3),
+ lo_next_alpha: wp.array(dtype=wp.float32),
+ hi_next: wp.array(dtype=wp.vec3),
+ hi_next_alpha: wp.array(dtype=wp.float32),
+ mid: wp.array(dtype=wp.vec3),
+ mid_alpha: wp.array(dtype=wp.float32),
+ d: types.Data,
+ ):
+ efcid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ if done[worldid]:
+ return
+
+ quad = d.efc.quad[efcid]
+ jaref = d.efc.Jaref[efcid]
+ jv = d.efc.jv[efcid]
+
+ alpha = lo_next_alpha[worldid]
+ # TODO(team): active and conditionally active constraints
+ if jaref + alpha * jv < 0.0:
+ wp.atomic_add(lo_next, worldid, _eval_pt(quad, alpha))
+
+ alpha = hi_next_alpha[worldid]
+ # TODO(team): active and conditionally active constraints
+ if jaref + alpha * jv < 0.0:
+ wp.atomic_add(hi_next, worldid, _eval_pt(quad, alpha))
+
+ alpha = mid_alpha[worldid]
+ # TODO(team): active and conditionally active constraints
+ if jaref + alpha * jv < 0.0:
+ wp.atomic_add(mid, worldid, _eval_pt(quad, alpha))
+
+ @kernel
+ def _swap(
+ done: wp.array(dtype=bool),
+ p0: wp.array(dtype=wp.vec3),
+ lo: wp.array(dtype=wp.vec3),
+ lo_alpha: wp.array(dtype=wp.float32),
+ hi: wp.array(dtype=wp.vec3),
+ hi_alpha: wp.array(dtype=wp.float32),
+ lo_next: wp.array(dtype=wp.vec3),
+ lo_next_alpha: wp.array(dtype=wp.float32),
+ hi_next: wp.array(dtype=wp.vec3),
+ hi_next_alpha: wp.array(dtype=wp.float32),
+ mid: wp.array(dtype=wp.vec3),
+ mid_alpha: wp.array(dtype=wp.float32),
+ d: types.Data,
+ ):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ if done[worldid]:
+ return
+
+ plo = lo[worldid]
+ plo_alpha = lo_alpha[worldid]
+ phi = hi[worldid]
+ phi_alpha = hi_alpha[worldid]
+ plo_next = lo_next[worldid]
+ plo_next_alpha = lo_next_alpha[worldid]
+ phi_next = hi_next[worldid]
+ phi_next_alpha = hi_next_alpha[worldid]
+ pmid = mid[worldid]
+ pmid_alpha = mid_alpha[worldid]
+
+ # swap lo:
+ swap_lo_lo_next = _in_bracket(plo, plo_next)
+ plo = wp.where(swap_lo_lo_next, plo_next, plo)
+ plo_alpha = wp.where(swap_lo_lo_next, plo_next_alpha, plo_alpha)
+ swap_lo_mid = _in_bracket(plo, pmid)
+ plo = wp.where(swap_lo_mid, pmid, plo)
+ plo_alpha = wp.where(swap_lo_mid, pmid_alpha, plo_alpha)
+ swap_lo_hi_next = _in_bracket(plo, phi_next)
+ plo = wp.where(swap_lo_hi_next, phi_next, plo)
+ plo_alpha = wp.where(swap_lo_hi_next, phi_next_alpha, plo_alpha)
+ lo[worldid] = plo
+ lo_alpha[worldid] = plo_alpha
+ swap_lo = swap_lo_lo_next or swap_lo_mid or swap_lo_hi_next
+
+ # swap hi:
+ swap_hi_hi_next = _in_bracket(phi, phi_next)
+ phi = wp.where(swap_hi_hi_next, phi_next, phi)
+ phi_alpha = wp.where(swap_hi_hi_next, phi_next_alpha, phi_alpha)
+ swap_hi_mid = _in_bracket(phi, pmid)
+ phi = wp.where(swap_hi_mid, pmid, phi)
+ phi_alpha = wp.where(swap_hi_mid, pmid_alpha, phi_alpha)
+ swap_hi_lo_next = _in_bracket(phi, plo_next)
+ phi = wp.where(swap_hi_lo_next, plo_next, phi)
+ phi_alpha = wp.where(swap_hi_lo_next, plo_next_alpha, phi_alpha)
+ hi[worldid] = phi
+ hi_alpha[worldid] = phi_alpha
+ swap_hi = swap_hi_hi_next or swap_hi_mid or swap_hi_lo_next
+
+ # if we did not adjust the interval, we are done
+ # also done if either low or hi slope is nearly flat
+ gtol = d.efc.gtol[worldid]
+ done[worldid] = (
+ (not swap_lo and not swap_hi)
+ or (plo[1] < 0 and plo[1] > -gtol)
+ or (phi[1] > 0 and phi[1] < gtol)
+ )
+
+ # update alpha if we have an improvement
+ pp0 = p0[worldid]
+ alpha = 0.0
+ improved = plo[0] < pp0[0] or phi[0] < pp0[0]
+ plo_better = plo[0] < phi[0]
+ alpha = wp.where(improved and plo_better, plo_alpha, alpha)
+ alpha = wp.where(improved and not plo_better, phi_alpha, alpha)
+ d.efc.alpha[worldid] = alpha
+
+ @kernel
+ def _qacc_ma(d: types.Data):
+ worldid, dofid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ alpha = d.efc.alpha[worldid]
+ d.qacc[worldid, dofid] += alpha * d.efc.search[worldid, dofid]
+ d.efc.Ma[worldid, dofid] += alpha * d.efc.mv[worldid, dofid]
+
+ @kernel
+ def _jaref(d: types.Data):
+ efcid = wp.tid()
+
+ if efcid >= min(d.nefc[0], d.njmax):
+ return
+
+ worldid = d.efc.worldid[efcid]
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.Jaref[efcid] += d.efc.alpha[worldid] * d.efc.jv[efcid]
+
+ wp.launch(_gtol, dim=(d.nworld,), inputs=[m, d])
+
+ # mv = qM @ search
+ support.mul_m(m, d, d.efc.mv, d.efc.search, d.efc.done)
+
+ # jv = efc_J @ search
+ # TODO(team): is there a better way of doing batched matmuls with dynamic array sizes?
+ wp.launch(_zero_jv, dim=(d.njmax), inputs=[d])
+
+ wp.launch(_jv, dim=(d.njmax, m.nv), inputs=[d])
+
+ # prepare quadratics
+ # quad_gauss = [gauss, search.T @ Ma - search.T @ qfrc_smooth, 0.5 * search.T @ mv]
+ wp.launch(_zero_quad_gauss, dim=(d.nworld), inputs=[d])
+
+ wp.launch(_init_quad_gauss, dim=(d.nworld, m.nv), inputs=[m, d])
+
+ # quad = [0.5 * Jaref * Jaref * efc_D, jv * Jaref * efc_D, 0.5 * jv * jv * efc_D]
+
+ wp.launch(_init_quad, dim=(d.njmax), inputs=[d])
+
+ # linesearch points
+ done = d.efc.ls_done
+ done.zero_()
+ p0 = d.efc.p0
+ lo = d.efc.lo
+ lo_alpha = d.efc.lo_alpha
+ hi = d.efc.hi
+ hi_alpha = d.efc.hi_alpha
+ lo_next = d.efc.lo_next
+ lo_next_alpha = d.efc.lo_next_alpha
+ hi_next = d.efc.hi_next
+ hi_next_alpha = d.efc.hi_next_alpha
+ mid = d.efc.mid
+ mid_alpha = d.efc.mid_alpha
+
+ # initialize interval
+
+ wp.launch(_init_p0_gauss, dim=(d.nworld,), inputs=[p0, d])
+
+ wp.launch(_init_p0, dim=(d.njmax,), inputs=[p0, d])
+
+ wp.launch(_init_lo_gauss, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, d])
+
+ wp.launch(_init_lo, dim=(d.njmax,), inputs=[lo, lo_alpha, d])
+
+ # set the lo/hi interval bounds
+
+ wp.launch(_init_bounds, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, hi, hi_alpha, d])
+
+ for _ in range(m.opt.ls_iterations):
+ # note: we always launch ls_iterations kernels, but the kernels may early exit if done is true
+ # this allows us to preserve cudagraph requirements (no dynamic kernel launching) at the expense
+ # of extra launches
+ inputs = [done, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next]
+ inputs += [hi_next_alpha, mid, mid_alpha, d]
+ wp.launch(_next_alpha_gauss, dim=(d.nworld,), inputs=inputs)
+
+ inputs = [done, lo_next, lo_next_alpha, hi_next, hi_next_alpha, mid, mid_alpha]
+ inputs += [d]
+ wp.launch(_next_quad, dim=(d.njmax,), inputs=inputs)
+
+ inputs = [done, p0, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next]
+ inputs += [hi_next_alpha, mid, mid_alpha, d]
+ wp.launch(_swap, dim=(d.nworld,), inputs=inputs)
+
+ wp.launch(_qacc_ma, dim=(d.nworld, m.nv), inputs=[d])
+
+ wp.launch(_jaref, dim=(d.njmax,), inputs=[d])
+
+
+@event_scope
+def solve(m: types.Model, d: types.Data):
+ """Finds forces that satisfy constraints."""
+ ITERATIONS = m.opt.iterations
+
+ @kernel
+ def _zero_search_dot(d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.search_dot[worldid] = 0.0
+
+ @kernel
+ def _search_update(d: types.Data):
+ worldid, dofid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ search = -1.0 * d.efc.Mgrad[worldid, dofid]
+
+ if wp.static(m.opt.solver == types.SolverType.CG):
+ search += d.efc.beta[worldid] * d.efc.search[worldid, dofid]
+
+ d.efc.search[worldid, dofid] = search
+ wp.atomic_add(d.efc.search_dot, worldid, search * search)
+
+ @kernel
+ def _done(m: types.Model, d: types.Data, solver_niter: int):
+ # TODO(team): static m?
+ worldid = wp.tid()
+
+ if ITERATIONS > 1:
+ if d.efc.done[worldid]:
+ return
+
+ improvement = _rescale(m, d.efc.prev_cost[worldid] - d.efc.cost[worldid])
+ gradient = _rescale(m, wp.math.sqrt(d.efc.grad_dot[worldid]))
+ d.efc.done[worldid] = (improvement < m.opt.tolerance) or (
+ gradient < m.opt.tolerance
+ )
+
+ if m.opt.solver == types.SolverType.CG:
+
+ @kernel
+ def _prev_grad_Mgrad(d: types.Data):
+ worldid, dofid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.prev_grad[worldid, dofid] = d.efc.grad[worldid, dofid]
+ d.efc.prev_Mgrad[worldid, dofid] = d.efc.Mgrad[worldid, dofid]
+
+ @kernel
+ def _zero_beta_num_den(d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.beta_num[worldid] = 0.0
+ d.efc.beta_den[worldid] = 0.0
+
+ @kernel
+ def _beta_num_den(d: types.Data):
+ worldid, dofid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ prev_Mgrad = d.efc.prev_Mgrad[worldid][dofid]
+ wp.atomic_add(
+ d.efc.beta_num,
+ worldid,
+ d.efc.grad[worldid, dofid] * (d.efc.Mgrad[worldid, dofid] - prev_Mgrad),
+ )
+ wp.atomic_add(
+ d.efc.beta_den, worldid, d.efc.prev_grad[worldid, dofid] * prev_Mgrad
+ )
+
+ @kernel
+ def _beta(d: types.Data):
+ worldid = wp.tid()
+
+ if wp.static(m.opt.iterations) > 1:
+ if d.efc.done[worldid]:
+ return
+
+ d.efc.beta[worldid] = wp.max(
+ 0.0, d.efc.beta_num[worldid] / wp.max(types.MJ_MINVAL, d.efc.beta_den[worldid])
+ )
+
+ # warmstart
+ kernel_copy(d.qacc, d.qacc_warmstart)
+
+ _create_context(m, d, grad=True)
+
+ for i in range(m.opt.iterations):
+ _linesearch_iterative(m, d)
+
+ if m.opt.solver == types.SolverType.CG:
+ wp.launch(_prev_grad_Mgrad, dim=(d.nworld, m.nv), inputs=[d])
+
+ _update_constraint(m, d)
+ _update_gradient(m, d)
+
+ # polak-ribiere
+ if m.opt.solver == types.SolverType.CG:
+ wp.launch(_zero_beta_num_den, dim=(d.nworld), inputs=[d])
+
+ wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[d])
+
+ wp.launch(_beta, dim=(d.nworld,), inputs=[d])
+
+ wp.launch(_zero_search_dot, dim=(d.nworld), inputs=[d])
+
+ wp.launch(_search_update, dim=(d.nworld, m.nv), inputs=[d])
+
+ wp.launch(_done, dim=(d.nworld,), inputs=[m, d, i])
+
+ kernel_copy(d.qacc_warmstart, d.qacc)
diff --git a/mujoco/mjx/_src/solver_test.py b/mujoco_warp/_src/solver_test.py
similarity index 71%
rename from mujoco/mjx/_src/solver_test.py
rename to mujoco_warp/_src/solver_test.py
index e3402858..dd29a2b6 100644
--- a/mujoco/mjx/_src/solver_test.py
+++ b/mujoco_warp/_src/solver_test.py
@@ -1,16 +1,32 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
"""Tests for solver functions."""
+import mujoco
+import numpy as np
+import warp as wp
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
-import mujoco
-from . import io
-from . import smooth
+
+import mujoco_warp as mjwarp
+
from . import solver
-import numpy as np
-import warp as wp
-# tolerance for difference between MuJoCo and MJX smooth calculations - mostly
+# tolerance for difference between MuJoCo and MJWarp smooth calculations - mostly
# due to float precision
_TOLERANCE = 5e-3
@@ -34,7 +50,7 @@ def _load(
njmax: int = 512,
keyframe: int = 0,
):
- path = epath.resource_path("mujoco.mjx") / "test_data" / fname
+ path = epath.resource_path("mujoco_warp") / "test_data" / fname
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
mjm.opt.jacobian = is_sparse
mjm.opt.iterations = iterations
@@ -45,20 +61,21 @@ def _load(
mjd = mujoco.MjData(mjm)
mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe)
mujoco.mj_step(mjm, mjd)
- m = io.put_model(mjm)
- d = io.put_data(mjm, mjd, nworld=nworld, njmax=njmax)
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd, nworld=nworld, njmax=njmax)
return mjm, mjd, m, d
@parameterized.parameters(
- (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_CG, 25, 5),
- (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4),
+ (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_CG, 25, 5, False),
+ (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4, False),
+ (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4, True),
)
- def test_solve(self, cone, solver_, iterations, ls_iterations):
- """Tests MJX solve."""
+ def test_solve(self, cone, solver_, iterations, ls_iterations, sparse):
+ """Tests solve."""
for keyframe in range(3):
mjm, mjd, m, d = self._load(
"humanoid/humanoid.xml",
- is_sparse=False,
+ is_sparse=sparse,
cone=cone,
solver_=solver_,
iterations=iterations,
@@ -75,42 +92,41 @@ def cost(qacc):
mj_cost = cost(mjd.qacc)
- ctx = solver._context(m, d)
- solver._create_context(ctx, m, d)
+ solver._create_context(m, d)
- mjx_cost = ctx.cost.numpy()[0] - ctx.gauss.numpy()[0]
+ mjwarp_cost = d.efc.cost.numpy()[0] - d.efc.gauss.numpy()[0]
- _assert_eq(mjx_cost, mj_cost, name="cost")
+ _assert_eq(mjwarp_cost, mj_cost, name="cost")
qacc_warmstart = mjd.qacc_warmstart.copy()
mujoco.mj_forward(mjm, mjd)
mjd.qacc_warmstart = qacc_warmstart
- m = io.put_model(mjm)
- d = io.put_data(mjm, mjd, njmax=mjd.nefc)
+ m = mjwarp.put_model(mjm)
+ d = mjwarp.put_data(mjm, mjd, njmax=mjd.nefc)
d.qacc.zero_()
d.qfrc_constraint.zero_()
- d.efc_force.zero_()
+ d.efc.force.zero_()
if solver_ == mujoco.mjtSolver.mjSOL_CG:
- smooth.factor_m(m, d)
- solver.solve(m, d)
+ mjwarp.factor_m(m, d)
+ mjwarp.solve(m, d)
mj_cost = cost(mjd.qacc)
- mjx_cost = cost(d.qacc.numpy()[0])
- self.assertLessEqual(mjx_cost, mj_cost * 1.025)
+ mjwarp_cost = cost(d.qacc.numpy()[0])
+ self.assertLessEqual(mjwarp_cost, mj_cost * 1.025)
if m.opt.solver == mujoco.mjtSolver.mjSOL_NEWTON:
_assert_eq(d.qacc.numpy()[0], mjd.qacc, "qacc")
_assert_eq(d.qfrc_constraint.numpy()[0], mjd.qfrc_constraint, "qfrc_constraint")
- _assert_eq(d.efc_force.numpy()[: mjd.nefc], mjd.efc_force, "efc_force")
+ _assert_eq(d.efc.force.numpy()[: mjd.nefc], mjd.efc_force, "efc_force")
@parameterized.parameters(
(mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_CG, 25, 5),
(mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4),
)
def test_solve_batch(self, cone, solver_, iterations, ls_iterations):
- """Tests MJX solve."""
+ """Tests solve (batch)."""
mjm0, mjd0, _, _ = self._load(
"humanoid/humanoid.xml",
is_sparse=False,
@@ -163,7 +179,7 @@ def test_solve_batch(self, cone, solver_, iterations, ls_iterations):
njmax=2 * nefc_active,
)
- d.nefc_total = wp.array([nefc_active], dtype=wp.int32, ndim=1)
+ d.nefc = wp.array([nefc_active], dtype=wp.int32, ndim=1)
nefc_fill = d.njmax - nefc_active
@@ -226,25 +242,25 @@ def test_solve_batch(self, cone, solver_, iterations, ls_iterations):
d.qM = wp.from_numpy(qM, dtype=wp.float32)
d.qacc_smooth = wp.from_numpy(qacc_smooth, dtype=wp.float32)
d.qfrc_smooth = wp.from_numpy(qfrc_smooth, dtype=wp.float32)
- d.efc_J = wp.from_numpy(efc_J_fill, dtype=wp.float32)
- d.efc_D = wp.from_numpy(efc_D_fill, dtype=wp.float32)
- d.efc_aref = wp.from_numpy(efc_aref_fill, dtype=wp.float32)
- d.efc_worldid = wp.from_numpy(efc_worldid, dtype=wp.int32)
+ d.efc.J = wp.from_numpy(efc_J_fill, dtype=wp.float32)
+ d.efc.D = wp.from_numpy(efc_D_fill, dtype=wp.float32)
+ d.efc.aref = wp.from_numpy(efc_aref_fill, dtype=wp.float32)
+ d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32)
if solver_ == mujoco.mjtSolver.mjSOL_CG:
- m0 = io.put_model(mjm0)
- d0 = io.put_data(mjm0, mjd0)
- smooth.factor_m(m0, d0)
+ m0 = mjwarp.put_model(mjm0)
+ d0 = mjwarp.put_data(mjm0, mjd0)
+ mjwarp.factor_m(m0, d0)
qLD0 = d0.qLD.numpy()
- m1 = io.put_model(mjm1)
- d1 = io.put_data(mjm1, mjd1)
- smooth.factor_m(m1, d1)
+ m1 = mjwarp.put_model(mjm1)
+ d1 = mjwarp.put_data(mjm1, mjd1)
+ mjwarp.factor_m(m1, d1)
qLD1 = d1.qLD.numpy()
- m2 = io.put_model(mjm2)
- d2 = io.put_data(mjm2, mjd2)
- smooth.factor_m(m2, d2)
+ m2 = mjwarp.put_model(mjm2)
+ d2 = mjwarp.put_data(mjm2, mjd2)
+ mjwarp.factor_m(m2, d2)
qLD2 = d2.qLD.numpy()
qLD = np.vstack([qLD0, qLD1, qLD2])
@@ -252,7 +268,7 @@ def test_solve_batch(self, cone, solver_, iterations, ls_iterations):
d.qacc.zero_()
d.qfrc_constraint.zero_()
- d.efc_force.zero_()
+ d.efc.force.zero_()
solver.solve(m, d)
def cost(m, d, qacc):
@@ -263,16 +279,16 @@ def cost(m, d, qacc):
return cost
mj_cost0 = cost(mjm0, mjd0, mjd0.qacc)
- mjx_cost0 = cost(mjm0, mjd0, d.qacc.numpy()[0])
- self.assertLessEqual(mjx_cost0, mj_cost0 * 1.025)
+ mjwarp_cost0 = cost(mjm0, mjd0, d.qacc.numpy()[0])
+ self.assertLessEqual(mjwarp_cost0, mj_cost0 * 1.025)
mj_cost1 = cost(mjm1, mjd1, mjd1.qacc)
- mjx_cost1 = cost(mjm1, mjd1, d.qacc.numpy()[1])
- self.assertLessEqual(mjx_cost1, mj_cost1 * 1.025)
+ mjwarp_cost1 = cost(mjm1, mjd1, d.qacc.numpy()[1])
+ self.assertLessEqual(mjwarp_cost1, mj_cost1 * 1.025)
mj_cost2 = cost(mjm2, mjd2, mjd2.qacc)
- mjx_cost2 = cost(mjm2, mjd2, d.qacc.numpy()[2])
- self.assertLessEqual(mjx_cost2, mj_cost2 * 1.025)
+ mjwarp_cost2 = cost(mjm2, mjd2, d.qacc.numpy()[2])
+ self.assertLessEqual(mjwarp_cost2, mj_cost2 * 1.025)
if m.opt.solver == mujoco.mjtSolver.mjSOL_NEWTON:
_assert_eq(d.qacc.numpy()[0], mjd0.qacc, "qacc0")
@@ -284,17 +300,17 @@ def cost(m, d, qacc):
_assert_eq(d.qfrc_constraint.numpy()[2], mjd2.qfrc_constraint, "qfrc_constraint2")
_assert_eq(
- d.efc_force.numpy()[: mjd0.nefc],
+ d.efc.force.numpy()[: mjd0.nefc],
mjd0.efc_force,
"efc_force0",
)
_assert_eq(
- d.efc_force.numpy()[mjd0.nefc : mjd0.nefc + mjd1.nefc],
+ d.efc.force.numpy()[mjd0.nefc : mjd0.nefc + mjd1.nefc],
mjd1.efc_force,
"efc_force1",
)
_assert_eq(
- d.efc_force.numpy()[mjd0.nefc + mjd1.nefc : mjd0.nefc + mjd1.nefc + mjd2.nefc],
+ d.efc.force.numpy()[mjd0.nefc + mjd1.nefc : mjd0.nefc + mjd1.nefc + mjd2.nefc],
mjd2.efc_force,
"efc_force2",
)
diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py
new file mode 100644
index 00000000..80e4cab9
--- /dev/null
+++ b/mujoco_warp/_src/support.py
@@ -0,0 +1,202 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from typing import Any
+
+import mujoco
+import warp as wp
+
+from .types import Data
+from .types import Model
+from .types import array2df
+from .types import array3df
+from .warp_util import event_scope
+from .warp_util import kernel
+
+
+def is_sparse(m: mujoco.MjModel):
+ if m.opt.jacobian == mujoco.mjtJacobian.mjJAC_AUTO:
+ return m.nv >= 60
+ return m.opt.jacobian == mujoco.mjtJacobian.mjJAC_SPARSE
+
+
+@event_scope
+def mul_m(
+ m: Model,
+ d: Data,
+ res: wp.array(ndim=2, dtype=wp.float32),
+ vec: wp.array(ndim=2, dtype=wp.float32),
+ skip: wp.array(ndim=1, dtype=bool),
+):
+ """Multiply vector by inertia matrix."""
+
+ if not m.opt.is_sparse:
+
+ def tile_mul(adr: int, size: int, tilesize: int):
+ # TODO(team): speed up kernel compile time (14s on 2023 Macbook Pro)
+ @kernel
+ def mul(
+ m: Model,
+ d: Data,
+ leveladr: int,
+ res: array3df,
+ vec: array3df,
+ skip: wp.array(ndim=1, dtype=bool),
+ ):
+ worldid, nodeid = wp.tid()
+
+ if skip[worldid]:
+ return
+
+ dofid = m.qLD_tile[leveladr + nodeid]
+ qM_tile = wp.tile_load(
+ d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid)
+ )
+ vec_tile = wp.tile_load(vec[worldid], shape=(tilesize, 1), offset=(dofid, 0))
+ res_tile = wp.tile_zeros(shape=(tilesize, 1), dtype=wp.float32)
+ wp.tile_matmul(qM_tile, vec_tile, res_tile)
+ wp.tile_store(res[worldid], res_tile, offset=(dofid, 0))
+
+ wp.launch_tiled(
+ mul,
+ dim=(d.nworld, size),
+ inputs=[
+ m,
+ d,
+ adr,
+ res.reshape(res.shape + (1,)),
+ vec.reshape(vec.shape + (1,)),
+ skip,
+ ],
+ # TODO(team): develop heuristic for block dim, or make configurable
+ block_dim=32,
+ )
+
+ qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy()
+
+ for i in range(len(qLD_tileadr)):
+ beg = qLD_tileadr[i]
+ end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1]
+ tile_mul(beg, end - beg, int(qLD_tilesize[i]))
+
+ else:
+
+ @kernel
+ def _mul_m_sparse_diag(
+ m: Model,
+ d: Data,
+ res: wp.array(ndim=2, dtype=wp.float32),
+ vec: wp.array(ndim=2, dtype=wp.float32),
+ skip: wp.array(ndim=1, dtype=bool),
+ ):
+ worldid, dofid = wp.tid()
+
+ if skip[worldid]:
+ return
+
+ res[worldid, dofid] = d.qM[worldid, 0, m.dof_Madr[dofid]] * vec[worldid, dofid]
+
+ @kernel
+ def _mul_m_sparse_ij(
+ m: Model,
+ d: Data,
+ res: wp.array(ndim=2, dtype=wp.float32),
+ vec: wp.array(ndim=2, dtype=wp.float32),
+ skip: wp.array(ndim=1, dtype=bool),
+ ):
+ worldid, elementid = wp.tid()
+
+ if skip[worldid]:
+ return
+
+ i = m.qM_mulm_i[elementid]
+ j = m.qM_mulm_j[elementid]
+ madr_ij = m.qM_madr_ij[elementid]
+
+ qM = d.qM[worldid, 0, madr_ij]
+
+ wp.atomic_add(res[worldid], i, qM * vec[worldid, j])
+ wp.atomic_add(res[worldid], j, qM * vec[worldid, i])
+
+ wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec, skip])
+
+ wp.launch(
+ _mul_m_sparse_ij, dim=(d.nworld, m.qM_madr_ij.size), inputs=[m, d, res, vec, skip]
+ )
+
+
+@event_scope
+def xfrc_accumulate(m: Model, d: Data, qfrc: array2df):
+ @wp.kernel
+ def _accumulate(m: Model, d: Data, qfrc: array2df):
+ worldid, dofid = wp.tid()
+ cdof = d.cdof[worldid, dofid]
+ rotational_cdof = wp.vec3(cdof[0], cdof[1], cdof[2])
+ jac = wp.spatial_vector(cdof[3], cdof[4], cdof[5], cdof[0], cdof[1], cdof[2])
+
+ dof_bodyid = m.dof_bodyid[dofid]
+ accumul = float(0.0)
+
+ for bodyid in range(dof_bodyid, m.nbody):
+ # any body that is in the subtree of dof_bodyid is part of the jacobian
+ parentid = bodyid
+ while parentid != 0 and parentid != dof_bodyid:
+ parentid = m.body_parentid[parentid]
+ if parentid == 0:
+ continue # body is not part of the subtree
+ offset = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]]
+ cross_term = wp.cross(rotational_cdof, offset)
+ accumul += wp.dot(jac, d.xfrc_applied[worldid, bodyid]) + wp.dot(
+ cross_term,
+ wp.vec3(
+ d.xfrc_applied[worldid, bodyid][0],
+ d.xfrc_applied[worldid, bodyid][1],
+ d.xfrc_applied[worldid, bodyid][2],
+ ),
+ )
+
+ qfrc[worldid, dofid] += accumul
+
+ wp.launch(kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc])
+
+
+@wp.func
+def bisection(x: wp.array(dtype=int), v: int, a_: int, b_: int) -> int:
+ # Binary search for the largest index i such that x[i] <= v
+ # x is a sorted array
+ # a and b are the start and end indices within x to search
+ a = int(a_)
+ b = int(b_)
+ c = int(0)
+ while b - a > 1:
+ c = (a + b) // 2
+ if x[c] <= v:
+ a = c
+ else:
+ b = c
+ c = a
+ if c != b and x[b] <= v:
+ c = b
+ return c
+
+
+@wp.func
+def mat33_from_rows(a: wp.vec3, b: wp.vec3, c: wp.vec3):
+ return wp.mat33(a, b, c)
+
+
+@wp.func
+def mat33_from_cols(a: wp.vec3, b: wp.vec3, c: wp.vec3):
+ return wp.mat33(a.x, b.x, c.x, a.y, b.y, c.y, a.z, b.z, c.z)
diff --git a/mujoco/mjx/_src/support_test.py b/mujoco_warp/_src/support_test.py
similarity index 71%
rename from mujoco/mjx/_src/support_test.py
rename to mujoco_warp/_src/support_test.py
index 034f0e6d..799892c8 100644
--- a/mujoco/mjx/_src/support_test.py
+++ b/mujoco_warp/_src/support_test.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,18 +15,19 @@
"""Tests for support functions."""
-from absl.testing import absltest
-from absl.testing import parameterized
import mujoco
-from mujoco import mjx
import numpy as np
import warp as wp
+from absl.testing import absltest
+from absl.testing import parameterized
+
+import mujoco_warp as mjwarp
+
from . import test_util
-from .support import xfrc_accumulate
wp.config.verify_cuda = True
-# tolerance for difference between MuJoCo and mjWarp support calculations - mostly
+# tolerance for difference between MuJoCo and MJWarp support calculations - mostly
# due to float precision
_TOLERANCE = 5e-5
@@ -49,7 +50,8 @@ def test_mul_m(self, sparse):
res = wp.zeros((1, mjm.nv), dtype=wp.float32)
vec = wp.from_numpy(np.expand_dims(mj_vec, axis=0), dtype=wp.float32)
- mjx.mul_m(m, d, res, vec)
+ skip = wp.zeros((d.nworld), dtype=bool)
+ mjwarp.mul_m(m, d, res, vec, skip)
_assert_eq(res.numpy()[0], mj_res, f"mul_m ({'sparse' if sparse else 'dense'})")
@@ -59,23 +61,30 @@ def test_xfrc_accumulated(self):
mjm, mjd, m, d = test_util.fixture("pendula.xml")
xfrc = np.random.randn(*d.xfrc_applied.numpy().shape)
d.xfrc_applied = wp.from_numpy(xfrc, dtype=wp.spatial_vector)
- qfrc = xfrc_accumulate(m, d)
+ qfrc = wp.zeros((1, mjm.nv), dtype=wp.float32)
+ mjwarp.xfrc_accumulate(m, d, qfrc)
qfrc_expected = np.zeros(m.nv)
xfrc = xfrc[0]
- mjd.xfrc_applied[:] = xfrc
for i in range(1, m.nbody):
mujoco.mj_applyFT(
- mjm,
- mjd,
- mjd.xfrc_applied[i, :3],
- mjd.xfrc_applied[i, 3:],
- mjd.xipos[i],
- i,
- qfrc_expected,
+ mjm, mjd, xfrc[i, :3], xfrc[i, 3:], mjd.xipos[i], i, qfrc_expected
)
np.testing.assert_almost_equal(qfrc.numpy()[0], qfrc_expected, 6)
+ def test_make_put_data(self):
+ """Tests that make_put_data and put_data are producing the same shapes for all warp arrays."""
+ mjm, mjd, m, d = test_util.fixture("pendula.xml")
+ md = mjwarp.make_data(mjm)
+
+ # same number of fields
+ self.assertEqual(len(d.__dict__), len(md.__dict__))
+
+ # test shapes for all arrays
+ for attr, val in md.__dict__.items():
+ if isinstance(val, wp.array):
+ self.assertEqual(val.shape, getattr(d, attr).shape)
+
if __name__ == "__main__":
wp.init()
diff --git a/mujoco_warp/_src/test_collision_axis_tiled.py b/mujoco_warp/_src/test_collision_axis_tiled.py
new file mode 100644
index 00000000..ed9182fe
--- /dev/null
+++ b/mujoco_warp/_src/test_collision_axis_tiled.py
@@ -0,0 +1,142 @@
+import itertools
+
+import numpy as np
+import warp as wp
+from absl.testing import absltest
+
+from .collision_functions import collision_axis_tiled, Box
+
+BOX_BLOCK_DIM = 32
+
+@wp.kernel
+def test_collision_axis_tiled_kernel(
+ a: Box,
+ b: Box,
+ R: wp.mat33,
+ best_axis: wp.array(dtype=wp.vec3),
+ best_sign: wp.array(dtype=wp.int32),
+ best_idx: wp.array(dtype=wp.int32),
+):
+ bid, axis_idx = wp.tid()
+ axis_out, sign_out, idx_out = collision_axis_tiled(a, b, R, axis_idx)
+ if axis_idx > 0:
+ return
+ best_axis[bid] = axis_out
+ best_sign[bid] = sign_out
+ best_idx[bid] = idx_out
+
+
+class TestCollisionAxisTiled(absltest.TestCase):
+ """Tests the collision_axis_tiled function."""
+
+ def test_collision_axis_tiled_vf(self):
+ """Tests the collision_axis_tiled function."""
+ vert = np.array(list(itertools.product((-1, 1), (-1, 1), (-1, 1))), dtype=float)
+
+ dims_a = np.array([1.1, 0.8, 1.8])
+ dims_b = np.array([0.8, 0.3, 1.1])
+
+ # shift in yz plane
+ shift_b = np.array([0, 0.2, -0.4])
+
+ separation = 0.13
+
+ s = 0.5 * 2**0.5
+ rx = np.array([1, 0, 0, 0, s, s, 0, -s, s]).reshape((3, 3))
+ ry = np.array([s, 0, s, 0, 1, 0, -s, 0, s]).reshape((3, 3))
+ # Rotate vert 0 towards negative x at origin
+ R_a = ry @ rx.T
+ t_a = (-dims_a * vert[0]) @ R_a.T
+
+ R_b = np.eye(3)
+ t_b = -np.array([dims_b[0] + separation, 0, 0]) + shift_b
+
+ vert_a = (dims_a * vert) @ R_a.T + t_a
+ vert_b = dims_b * vert + t_b
+
+ R_atob = R_b.T @ R_a
+ t_atob = R_b.T @ (t_a - t_b)
+ a = Box(vert_a.ravel())
+ b = Box(vert_b.ravel())
+
+ R = wp.mat33(R_atob)
+ best_axis = wp.empty(1, dtype=wp.vec3)
+ best_sign = wp.empty(1, dtype=wp.int32)
+ best_idx = wp.empty(1, dtype=wp.int32)
+
+ wp.launch_tiled(
+ kernel=test_collision_axis_tiled_kernel,
+ dim=1,
+ inputs=[a, b, R],
+ outputs=[best_axis, best_sign, best_idx],
+ block_dim=BOX_BLOCK_DIM,
+ )
+ expected_axis = np.array([-1, 0, 0])
+ expected_sign = np.sign(best_axis.numpy().dot(expected_axis))
+
+ np.testing.assert_array_equal(
+ np.cross(best_axis.numpy()[0], expected_axis), np.zeros(3)
+ )
+ # np.testing.assert_array_equal(best_idx.numpy()[0], 8)
+ np.testing.assert_array_equal(best_sign.numpy()[0], expected_sign)
+
+ def test_collision_axis_tiled_ee(self):
+ """Tests the collision_axis_tiled function."""
+ # edge on edge
+ vert = np.array(list(itertools.product((-1, 1), (-1, 1), (-1, 1))), dtype=float)
+
+ dims_a = np.array([1.1, 0.8, 1.8])
+ dims_b = np.array([0.8, 0.3, 1.1])
+
+ dims_a = dims_b = np.ones(3)
+
+ # shift
+
+ separation = 0.13
+
+ s = 0.5 * 2**0.5
+ rx = np.array([1, 0, 0, 0, s, s, 0, -s, s]).reshape((3, 3))
+ ry = np.array([s, 0, s, 0, 1, 0, -s, 0, s]).reshape((3, 3))
+
+ # Rotate vert 0 towards negative x at origin
+ R_a = ry @ rx
+ t_a = (-dims_a * 0.5 * (vert[2] + vert[6])) @ R_a.T
+ # t_a = np.array([0, 0, 2**0.5])
+
+ R_b = np.eye(3)
+ t_b = (-dims_b * vert[0]) @ R_b.T
+ t_b = np.array([-1, 0.0, -1])
+ t_b -= np.array([s, 0, s]) * 0.2
+
+ vert_a = (dims_a * vert) @ R_a.T + t_a
+ vert_b = dims_b * vert @ R_b.T + t_b
+
+ R_atob = R_b.T @ R_a
+ t_atob = R_b.T @ (t_a - t_b)
+ a = Box(vert_a.ravel())
+ b = Box(vert_b.ravel())
+
+ R = wp.mat33(R_atob)
+ best_axis = wp.empty(1, dtype=wp.vec3)
+ best_sign = wp.empty(1, dtype=wp.int32)
+ best_idx = wp.empty(1, dtype=wp.int32)
+
+ wp.launch_tiled(
+ kernel=test_collision_axis_tiled_kernel,
+ dim=1,
+ inputs=[a, b, R],
+ outputs=[best_axis, best_sign, best_idx],
+ block_dim=BOX_BLOCK_DIM,
+ )
+ expected_axis = np.array([-1, 0, -1])
+ expected_sign = np.sign(best_axis.numpy().dot(expected_axis))
+
+ np.testing.assert_array_equal(
+ np.cross(best_axis.numpy()[0], expected_axis), np.zeros(3)
+ )
+ # np.testing.assert_array_equal(best_idx.numpy()[0], 8)
+ np.testing.assert_array_equal(best_sign.numpy()[0], expected_sign)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/mujoco/mjx/_src/test_util.py b/mujoco_warp/_src/test_util.py
similarity index 51%
rename from mujoco/mjx/_src/test_util.py
rename to mujoco_warp/_src/test_util.py
index aba4653e..b65de415 100644
--- a/mujoco/mjx/_src/test_util.py
+++ b/mujoco_warp/_src/test_util.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Physics-Next Project Developers
+# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,19 +18,20 @@
import time
from typing import Callable, Tuple
-from etils import epath
+import mujoco
import numpy as np
import warp as wp
-
-import mujoco
+from etils import epath
from . import io
from . import types
+from . import warp_util
def fixture(fname: str, keyframe: int = -1, sparse: bool = True):
- path = epath.resource_path("mujoco.mjx") / "test_data" / fname
+ path = epath.resource_path("mujoco_warp") / "test_data" / fname
mjm = mujoco.MjModel.from_xml_path(path.as_posix())
+ mjm.opt.jacobian = sparse
mjd = mujoco.MjData(mjm)
if keyframe > -1:
mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe)
@@ -38,55 +39,80 @@ def fixture(fname: str, keyframe: int = -1, sparse: bool = True):
mjd.qvel = np.random.uniform(-0.01, 0.01, mjm.nv)
mujoco.mj_step(mjm, mjd, 3) # let dynamics get state significantly non-zero
mujoco.mj_forward(mjm, mjd)
- mjm.opt.jacobian = sparse
m = io.put_model(mjm)
d = io.put_data(mjm, mjd)
return mjm, mjd, m, d
+def _sum(stack1, stack2):
+ ret = {}
+ for k in stack1:
+ times1, sub_stack1 = stack1[k]
+ times2, sub_stack2 = stack2[k]
+ times = [t1 + t2 for t1, t2 in zip(times1, times2)]
+ ret[k] = (times, _sum(sub_stack1, sub_stack2))
+ return ret
+
+
def benchmark(
fn: Callable[[types.Model, types.Data], None],
- m: mujoco.MjModel,
+ mjm: mujoco.MjModel,
+ mjd: mujoco.MjData,
nstep: int = 1000,
batch_size: int = 1024,
- unroll_steps: int = 1,
solver: str = "cg",
iterations: int = 1,
ls_iterations: int = 4,
- nefc_total: int = 0,
-) -> Tuple[float, float, int]:
+ nconmax: int = -1,
+ njmax: int = -1,
+ event_trace: bool = False,
+ measure_alloc: bool = False,
+) -> Tuple[float, float, dict, int, list, list]:
"""Benchmark a model."""
if solver == "cg":
- m.opt.solver = mujoco.mjtSolver.mjSOL_CG
+ mjm.opt.solver = mujoco.mjtSolver.mjSOL_CG
elif solver == "newton":
- m.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
+ mjm.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
- m.opt.iterations = iterations
- m.opt.ls_iterations = ls_iterations
+ mjm.opt.iterations = iterations
+ mjm.opt.ls_iterations = ls_iterations
- mx = io.put_model(m)
- dx = io.make_data(m, nworld=batch_size, njmax=nefc_total)
- dx.nefc_total = wp.array([nefc_total], dtype=wp.int32, ndim=1)
+ m = io.put_model(mjm)
+ d = io.put_data(mjm, mjd, nworld=batch_size, nconmax=nconmax, njmax=njmax)
- wp.clear_kernel_cache()
jit_beg = time.perf_counter()
- fn(mx, dx)
- fn(mx, dx) # double warmup to work around issues with compilation during graph capture
- jit_end = time.perf_counter()
- jit_duration = jit_end - jit_beg
- wp.synchronize()
- # capture the whole smooth.kinematic() function as a CUDA graph
- with wp.ScopedCapture() as capture:
- fn(mx, dx)
- graph = capture.graph
+ fn(m, d)
+ # double warmup to work around issues with compilation during graph capture:
+ fn(m, d)
- run_beg = time.perf_counter()
- for _ in range(nstep):
- wp.capture_launch(graph)
+ jit_end = time.perf_counter()
+ jit_duration = jit_end - jit_beg
wp.synchronize()
- run_end = time.perf_counter()
- run_duration = run_end - run_beg
-
- return jit_duration, run_duration, batch_size * nstep
+ trace = {}
+ ncon = []
+ nefc = []
+
+ with warp_util.EventTracer(enabled=event_trace) as tracer:
+ # capture the whole function as a CUDA graph
+ with wp.ScopedCapture() as capture:
+ fn(m, d)
+ graph = capture.graph
+
+ run_beg = time.perf_counter()
+ for _ in range(nstep):
+ wp.capture_launch(graph)
+ if trace:
+ trace = _sum(trace, tracer.trace())
+ else:
+ trace = tracer.trace()
+ if measure_alloc:
+ wp.synchronize()
+ ncon.append(d.ncon.numpy()[0])
+ nefc.append(d.nefc.numpy()[0])
+ wp.synchronize()
+ run_end = time.perf_counter()
+ run_duration = run_end - run_beg
+
+ return jit_duration, run_duration, trace, batch_size * nstep, ncon, nefc
diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py
new file mode 100644
index 00000000..9ab19a83
--- /dev/null
+++ b/mujoco_warp/_src/types.py
@@ -0,0 +1,757 @@
+# Copyright 2025 The Newton Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+import enum
+
+import mujoco
+import warp as wp
+
+MJ_MINVAL = mujoco.mjMINVAL
+MJ_MINIMP = mujoco.mjMINIMP # minimum constraint impedance
+MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance
+MJ_NREF = mujoco.mjNREF
+MJ_NIMP = mujoco.mjNIMP
+
+
+class DisableBit(enum.IntFlag):
+ """Disable default feature bitflags.
+
+ Members:
+ CONSTRAINT: entire constraint solver
+ LIMIT: joint and tendon limit constraints
+ CONTACT: contact constraints
+ PASSIVE: passive forces
+ GRAVITY: gravitational forces
+ CLAMPCTRL: clamp control to specified range
+ ACTUATION: apply actuation forces
+ REFSAFE: integrator safety: make ref[0]>=2*timestep
+ EULERDAMP: implicit damping for Euler integration
+ FILTERPARENT: disable collisions between parent and child bodies
+ """
+
+ CONSTRAINT = mujoco.mjtDisableBit.mjDSBL_CONSTRAINT
+ LIMIT = mujoco.mjtDisableBit.mjDSBL_LIMIT
+ CONTACT = mujoco.mjtDisableBit.mjDSBL_CONTACT
+ PASSIVE = mujoco.mjtDisableBit.mjDSBL_PASSIVE
+ GRAVITY = mujoco.mjtDisableBit.mjDSBL_GRAVITY
+ CLAMPCTRL = mujoco.mjtDisableBit.mjDSBL_CLAMPCTRL
+ ACTUATION = mujoco.mjtDisableBit.mjDSBL_ACTUATION
+ REFSAFE = mujoco.mjtDisableBit.mjDSBL_REFSAFE
+ EULERDAMP = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
+ FILTERPARENT = mujoco.mjtDisableBit.mjDSBL_FILTERPARENT
+ # unsupported: EQUALITY, FRICTIONLOSS, MIDPHASE, WARMSTART, SENSOR
+
+
+class TrnType(enum.IntEnum):
+ """Type of actuator transmission.
+
+ Members:
+ JOINT: force on joint
+ JOINTINPARENT: force on joint, expressed in parent frame
+ """
+
+ JOINT = mujoco.mjtTrn.mjTRN_JOINT
+ JOINTINPARENT = mujoco.mjtTrn.mjTRN_JOINTINPARENT
+ # unsupported: SITE, TENDON, SLIDERCRANK, BODY
+
+
+class DynType(enum.IntEnum):
+ """Type of actuator dynamics.
+
+ Members:
+ NONE: no internal dynamics; ctrl specifies force
+ FILTEREXACT: linear filter: da/dt = (u-a) / tau, with exact integration
+ """
+
+ NONE = mujoco.mjtDyn.mjDYN_NONE
+ FILTEREXACT = mujoco.mjtDyn.mjDYN_FILTEREXACT
+ # unsupported: INTEGRATOR, FILTER, MUSCLE, USER
+
+
+class GainType(enum.IntEnum):
+ """Type of actuator gain.
+
+ Members:
+ FIXED: fixed gain
+ AFFINE: const + kp*length + kv*velocity
+ """
+
+ FIXED = mujoco.mjtGain.mjGAIN_FIXED
+ AFFINE = mujoco.mjtGain.mjGAIN_AFFINE
+ # unsupported: MUSCLE, USER
+
+
+class BiasType(enum.IntEnum):
+ """Type of actuator bias.
+
+ Members:
+ NONE: no bias
+ AFFINE: const + kp*length + kv*velocity
+ """
+
+ NONE = mujoco.mjtBias.mjBIAS_NONE
+ AFFINE = mujoco.mjtBias.mjBIAS_AFFINE
+ # unsupported: MUSCLE, USER
+
+
+class JointType(enum.IntEnum):
+ """Type of degree of freedom.
+
+ Members:
+ FREE: global position and orientation (quat) (7,)
+ BALL: orientation (quat) relative to parent (4,)
+ SLIDE: sliding distance along body-fixed axis (1,)
+ HINGE: rotation angle (rad) around body-fixed axis (1,)
+ """
+
+ FREE = mujoco.mjtJoint.mjJNT_FREE
+ BALL = mujoco.mjtJoint.mjJNT_BALL
+ SLIDE = mujoco.mjtJoint.mjJNT_SLIDE
+ HINGE = mujoco.mjtJoint.mjJNT_HINGE
+
+ def dof_width(self) -> int:
+ return {0: 6, 1: 3, 2: 1, 3: 1}[self.value]
+
+ def qpos_width(self) -> int:
+ return {0: 7, 1: 4, 2: 1, 3: 1}[self.value]
+
+
+class ConeType(enum.IntEnum):
+ """Type of friction cone.
+
+ Members:
+ PYRAMIDAL: pyramidal
+ """
+
+ PYRAMIDAL = mujoco.mjtCone.mjCONE_PYRAMIDAL
+ # unsupported: ELLIPTIC
+
+
+class GeomType(enum.IntEnum):
+ """Type of geometry.
+
+ Members:
+ PLANE: plane
+ SPHERE: sphere
+ CAPSULE: capsule
+ BOX: box
+ """
+
+ PLANE = mujoco.mjtGeom.mjGEOM_PLANE
+ SPHERE = mujoco.mjtGeom.mjGEOM_SPHERE
+ CAPSULE = mujoco.mjtGeom.mjGEOM_CAPSULE
+ BOX = mujoco.mjtGeom.mjGEOM_BOX
+ # unsupported: HFIELD, ELLIPSOID, CYLINDER, MESH,
+ # NGEOMTYPES, ARROW*, LINE, SKIN, LABEL, NONE
+
+
+class SolverType(enum.IntEnum):
+ """Constraint solver algorithm.
+
+ Members:
+ CG: Conjugate gradient (primal)
+ NEWTON: Newton (primal)
+ """
+
+ CG = mujoco.mjtSolver.mjSOL_CG
+ NEWTON = mujoco.mjtSolver.mjSOL_NEWTON
+
+
+class vec5f(wp.types.vector(length=5, dtype=wp.float32)):
+ pass
+
+
+class vec10f(wp.types.vector(length=10, dtype=wp.float32)):
+ pass
+
+
+vec5 = vec5f
+vec10 = vec10f
+array2df = wp.array2d(dtype=wp.float32)
+array3df = wp.array3d(dtype=wp.float32)
+
+
+@wp.struct
+class Option:
+ """Physics options.
+
+ Attributes:
+ timestep: simulation timestep
+ impratio: ratio of friction-to-normal contact impedance
+ tolerance: main solver tolerance
+ ls_tolerance: CG/Newton linesearch tolerance
+ gravity: gravitational acceleration
+ integrator: integration mode (mjtIntegrator)
+ cone: type of friction cone (mjtCone)
+ solver: solver algorithm (mjtSolver)
+ iterations: number of main solver iterations
+ ls_iterations: maximum number of CG/Newton linesearch iterations
+ disableflags: bit flags for disabling standard features
+ is_sparse: whether to use sparse representations
+ """
+
+ timestep: float
+ impratio: wp.float32
+ tolerance: float
+ ls_tolerance: float
+ gravity: wp.vec3
+ integrator: int
+ cone: int
+ solver: int
+ iterations: int
+ ls_iterations: int
+ disableflags: int
+ is_sparse: bool
+
+
+@wp.struct
+class Statistic:
+ """Model statistics (in qpos0).
+
+ Attributes:
+ meaninertia: mean diagonal inertia
+ """
+
+ meaninertia: float
+
+
+@wp.struct
+class Constraint:
+ """Constraint data.
+
+ Attributes:
+ worldid: world id (njmax,)
+ J: constraint Jacobian (njmax, nv)
+ pos: constraint position (equality, contact) (njmax,)
+ margin: inclusion margin (contact) (njmax,)
+ D: constraint mass (njmax,)
+ aref: reference pseudo-acceleration (njmax,)
+ force: constraint force in constraint space (njmax,)
+ Jaref: Jac*qacc - aref (njmax,)
+ Ma: M*qacc (nworld, nv)
+ grad: gradient of master cost (nworld, nv)
+ grad_dot: dot(grad, grad) (nworld,)
+ Mgrad: M / grad (nworld, nv)
+ search: linesearch vector (nworld, nv)
+ search_dot: dot(search, search) (nworld,)
+ gauss: gauss Cost (nworld,)
+ cost: constraint + Gauss cost (nworld,)
+ prev_cost: cost from previous iter (nworld,)
+ solver_niter: number of solver iterations (nworld,)
+ active: active (quadratic) constraints (njmax,)
+ gtol: linesearch termination tolerance (nworld,)
+ mv: qM @ search (nworld, nv)
+ jv: efc_J @ search (njmax,)
+ quad: quadratic cost coefficients (njmax, 3)
+ quad_gauss: quadratic cost gauss coefficients (nworld, 3)
+ h: cone hessian (nworld, nv, nv)
+ alpha: line search step size (nworld,)
+ prev_grad: previous grad (nworld, nv)
+ prev_Mgrad: previous Mgrad (nworld, nv)
+ beta: polak-ribiere beta (nworld,)
+ beta_num: numerator of beta (nworld,)
+ beta_den: denominator of beta (nworld,)
+ done: solver done (nworld,)
+ ls_done: linesearch done (nworld,)
+ p0: initial point (nworld, 3)
+ lo: low point bounding the line search interval (nworld, 3)
+ lo_alpha: alpha for low point (nworld,)
+ hi: high point bounding the line search interval (nworld, 3)
+ hi_alpha: alpha for high point (nworld,)
+ lo_next: next low point (nworld, 3)
+ lo_next_alpha: alpha for next low point (nworld,)
+ hi_next: next high point (nworld, 3)
+ hi_next_alpha: alpha for next high point (nworld,)
+ mid: loss at mid_alpha (nworld, 3)
+ mid_alpha: midpoint between lo_alpha and hi_alpha (nworld,)
+ """
+
+ worldid: wp.array(dtype=wp.int32, ndim=1)
+ J: wp.array(dtype=wp.float32, ndim=2)
+ pos: wp.array(dtype=wp.float32, ndim=1)
+ margin: wp.array(dtype=wp.float32, ndim=1)
+ D: wp.array(dtype=wp.float32, ndim=1)
+ aref: wp.array(dtype=wp.float32, ndim=1)
+ force: wp.array(dtype=wp.float32, ndim=1)
+ Jaref: wp.array(dtype=wp.float32, ndim=1)
+ Ma: wp.array(dtype=wp.float32, ndim=2)
+ grad: wp.array(dtype=wp.float32, ndim=2)
+ grad_dot: wp.array(dtype=wp.float32, ndim=1)
+ Mgrad: wp.array(dtype=wp.float32, ndim=2)
+ search: wp.array(dtype=wp.float32, ndim=2)
+ search_dot: wp.array(dtype=wp.float32, ndim=1)
+ gauss: wp.array(dtype=wp.float32, ndim=1)
+ cost: wp.array(dtype=wp.float32, ndim=1)
+ prev_cost: wp.array(dtype=wp.float32, ndim=1)
+ solver_niter: wp.array(dtype=wp.int32, ndim=1)
+ active: wp.array(dtype=wp.int32, ndim=1)
+ gtol: wp.array(dtype=wp.float32, ndim=1)
+ mv: wp.array(dtype=wp.float32, ndim=2)
+ jv: wp.array(dtype=wp.float32, ndim=1)
+ quad: wp.array(dtype=wp.vec3f, ndim=1)
+ quad_gauss: wp.array(dtype=wp.vec3f, ndim=1)
+ h: wp.array(dtype=wp.float32, ndim=3)
+ alpha: wp.array(dtype=wp.float32, ndim=1)
+ prev_grad: wp.array(dtype=wp.float32, ndim=2)
+ prev_Mgrad: wp.array(dtype=wp.float32, ndim=2)
+ beta: wp.array(dtype=wp.float32, ndim=1)
+ beta_num: wp.array(dtype=wp.float32, ndim=1)
+ beta_den: wp.array(dtype=wp.float32, ndim=1)
+ done: wp.array(dtype=bool, ndim=1)
+ # linesearch
+ ls_done: wp.array(dtype=bool, ndim=1)
+ p0: wp.array(dtype=wp.vec3, ndim=1)
+ lo: wp.array(dtype=wp.vec3, ndim=1)
+ lo_alpha: wp.array(dtype=wp.float32, ndim=1)
+ hi: wp.array(dtype=wp.vec3, ndim=1)
+ hi_alpha: wp.array(dtype=wp.float32, ndim=1)
+ lo_next: wp.array(dtype=wp.vec3, ndim=1)
+ lo_next_alpha: wp.array(dtype=wp.float32, ndim=1)
+ hi_next: wp.array(dtype=wp.vec3, ndim=1)
+ hi_next_alpha: wp.array(dtype=wp.float32, ndim=1)
+ mid: wp.array(dtype=wp.vec3, ndim=1)
+ mid_alpha: wp.array(dtype=wp.float32, ndim=1)
+
+
+@wp.struct
+class Model:
+ """Model definition and parameters.
+
+ Attributes:
+ nq: number of generalized coordinates = dim ()
+ nv: number of degrees of freedom = dim ()
+ nu: number of actuators/controls = dim ()
+ na: number of activation states = dim ()
+ nbody: number of bodies ()
+ njnt: number of joints ()
+ ngeom: number of geoms ()
+ nsite: number of sites ()
+ nexclude: number of excluded geom pairs ()
+ nmocap: number of mocap bodies ()
+ nM: number of non-zeros in sparse inertia matrix ()
+ opt: physics options
+ stat: model statistics
+ qpos0: qpos values at default pose (nq,)
+ qpos_spring: reference pose for springs (nq,)
+ body_tree: BFS ordering of body ids
+ body_treeadr: starting index of each body tree level
+ actuator_moment_offset_nv: tiling configuration
+ actuator_moment_offset_nu: tiling configuration
+ actuator_moment_tileadr: tiling configuration
+ actuator_moment_tilesize_nv: tiling configuration
+ actuator_moment_tilesize_nu: tiling configuration
+ qM_fullm_i: sparse mass matrix addressing
+ qM_fullm_j: sparse mass matrix addressing
+ qM_mulm_i: sparse mass matrix addressing
+ qM_mulm_j: sparse mass matrix addressing
+ qM_madr_ij: sparse mass matrix addressing
+ qLD_update_tree: dof tree ordering for qLD updates
+ qLD_update_treeadr: index of each dof tree level
+ qLD_tile: tiling configuration
+ qLD_tileadr: tiling configuration
+ qLD_tilesize: tiling configuration
+ body_parentid: id of body's parent (nbody,)
+ body_rootid: id of root above body (nbody,)
+ body_weldid: id of body that this body is welded to (nbody,)
+ body_mocapid: id of mocap data; -1: none (nbody,)
+ body_jntnum: number of joints for this body (nbody,)
+ body_jntadr: start addr of joints; -1: no joints (nbody,)
+ body_dofnum: number of motion degrees of freedom (nbody,)
+ body_dofadr: start addr of dofs; -1: no dofs (nbody,)
+ body_geomnum: number of geoms (nbody,)
+ body_geomadr: start addr of geoms; -1: no geoms (nbody,)
+ body_pos: position offset rel. to parent body (nbody, 3)
+ body_quat: orientation offset rel. to parent body (nbody, 4)
+ body_ipos: local position of center of mass (nbody, 3)
+ body_iquat: local orientation of inertia ellipsoid (nbody, 4)
+ body_mass: mass (nbody,)
+ subtree_mass: mass of subtree (nbody,)
+ body_inertia: diagonal inertia in ipos/iquat frame (nbody, 3)
+ body_invweight0: mean inv inert in qpos0 (trn, rot) (nbody, 2)
+ body_contype: OR over all geom contypes (nbody,)
+ body_conaffinity: OR over all geom conaffinities (nbody,)
+ jnt_type: type of joint (mjtJoint) (njnt,)
+ jnt_qposadr: start addr in 'qpos' for joint's data (njnt,)
+ jnt_dofadr: start addr in 'qvel' for joint's data (njnt,)
+ jnt_bodyid: id of joint's body (njnt,)
+ jnt_limited: does joint have limits (njnt,)
+ jnt_actfrclimited: does joint have actuator force limits (njnt,)
+ jnt_solref: constraint solver reference: limit (njnt, mjNREF)
+ jnt_solimp: constraint solver impedance: limit (njnt, mjNIMP)
+ jnt_pos: local anchor position (njnt, 3)
+ jnt_axis: local joint axis (njnt, 3)
+ jnt_stiffness: stiffness coefficient (njnt,)
+ jnt_range: joint limits (njnt, 2)
+ jnt_actfrcrange: range of total actuator force (njnt, 2)
+ jnt_margin: min distance for limit detection (njnt,)
+ jnt_limited_slide_hinge_adr: limited/slide/hinge jntadr
+ dof_bodyid: id of dof's body (nv,)
+ dof_jntid: id of dof's joint (nv,)
+ dof_parentid: id of dof's parent; -1: none (nv,)
+ dof_Madr: dof address in M-diagonal (nv,)
+ dof_armature: dof armature inertia/mass (nv,)
+ dof_damping: damping coefficient (nv,)
+ dof_invweight0: diag. inverse inertia in qpos0 (nv,)
+ dof_tri_row: np.tril_indices (mjm.nv)[0]
+ dof_tri_col: np.tril_indices (mjm.nv)[1]
+ geom_type: geometric type (mjtGeom) (ngeom,)
+ geom_contype: geom contact type (ngeom,)
+ geom_conaffinity: geom contact affinity (ngeom,)
+ geom_condim: contact dimensionality (1, 3, 4, 6) (ngeom,)
+ geom_bodyid: id of geom's body (ngeom,)
+ geom_dataid: id of geom's mesh/hfield; -1: none (ngeom,)
+ geom_priority: geom contact priority (ngeom,)
+ geom_solmix: mixing coef for solref/imp in geom pair (ngeom,)
+ geom_solref: constraint solver reference: contact (ngeom, mjNREF)
+ geom_solimp: constraint solver impedance: contact (ngeom, mjNIMP)
+ geom_size: geom-specific size parameters (ngeom, 3)
+ geom_aabb: bounding box, (center, size) (ngeom, 6)
+ geom_rbound: radius of bounding sphere (ngeom,)
+ geom_pos: local position offset rel. to body (ngeom, 3)
+ geom_quat: local orientation offset rel. to body (ngeom, 4)
+ geom_friction: friction for (slide, spin, roll) (ngeom, 3)
+ geom_margin: detect contact if dist dict:
+ global _STACK
+
+ if _STACK is None:
+ return {}
+
+ ret = {}
+
+ for k, v in _STACK.items():
+ events, sub_stack = v
+ # push into next level of stack
+ saved_stack, _STACK = _STACK, sub_stack
+ sub_trace = self.trace()
+ # pop!
+ _STACK = saved_stack
+ events = tuple(wp.get_event_elapsed_time(beg, end) for beg, end in events)
+ ret[k] = (events, sub_trace)
+
+ return ret
+
+ def __exit__(self, type, value, traceback):
+ global _STACK
+ _STACK = None
+
+
+def event_scope(fn, name: str = ""):
+ name = name or getattr(fn, "__name__")
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ global _STACK
+ if _STACK is None:
+ return fn(*args, **kwargs)
+ # push into next level of stack
+ saved_stack, _STACK = _STACK, {}
+ beg = wp.Event(enable_timing=True)
+ end = wp.Event(enable_timing=True)
+ wp.record_event(beg)
+ res = fn(*args, **kwargs)
+ wp.record_event(end)
+ # pop back up to current level
+ sub_stack, _STACK = _STACK, saved_stack
+ # append events
+ events, _ = _STACK.get(name, ((), None))
+ _STACK[name] = (events + ((beg, end),), sub_stack)
+ return res
+
+ return wrapper
+
+
+# @kernel decorator to automatically set up modules based on nested
+# function names
+def kernel(
+ f: Optional[Callable] = None,
+ *,
+ enable_backward: Optional[bool] = None,
+ module: Optional[Module] = None,
+):
+ """
+ Decorator to register a Warp kernel from a Python function.
+ The function must be defined with type annotations for all arguments.
+ The function must not return anything.
+
+ Example::
+
+ @kernel
+ def my_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
+ tid = wp.tid()
+ b[tid] = a[tid] + 1.0
+
+
+ @kernel(enable_backward=False)
+ def my_kernel_no_backward(a: wp.array(dtype=float, ndim=2), x: float):
+ # the backward pass will not be generated
+ i, j = wp.tid()
+ a[i, j] = x
+
+
+ @kernel(module="unique")
+ def my_kernel_unique_module(a: wp.array(dtype=float), b: wp.array(dtype=float)):
+ # the kernel will be registered in new unique module created just for this
+ # kernel and its dependent functions and structs
+ tid = wp.tid()
+ b[tid] = a[tid] + 1.0
+
+ Args:
+ f: The function to be registered as a kernel.
+ enable_backward: If False, the backward pass will not be generated.
+ module: The :class:`warp.context.Module` to which the kernel belongs. Alternatively, if a string `"unique"` is provided, the kernel is assigned to a new module named after the kernel name and hash. If None, the module is inferred from the function's module.
+
+ Returns:
+ The registered kernel.
+ """
+ if module is None:
+ # create a module name based on the name of the nested function
+ # get the qualified name, e.g. "main..nested_kernel"
+ qualname = f.__qualname__
+ parts = [part for part in qualname.split(".") if part != ""]
+ outer_functions = parts[:-1]
+ module = get_module(".".join([f.__module__] + outer_functions))
+
+ return wp.kernel(f, enable_backward=enable_backward, module=module)
+
+
+@wp.kernel
+def _copy_2df(dest: types.array2df, src: types.array2df):
+ i, j = wp.tid()
+ dest[i, j] = src[i, j]
+
+
+@wp.kernel
+def _copy_3df(dest: types.array3df, src: types.array3df):
+ i, j, k = wp.tid()
+ dest[i, j, k] = src[i, j, k]
+
+
+@wp.kernel
+def _copy_2dvec10f(
+ dest: wp.array2d(dtype=types.vec10f), src: wp.array2d(dtype=types.vec10f)
+):
+ i, j = wp.tid()
+ dest[i, j] = src[i, j]
+
+
+@wp.kernel
+def _copy_2dvec3f(dest: wp.array2d(dtype=wp.vec3f), src: wp.array2d(dtype=wp.vec3f)):
+ i, j = wp.tid()
+ dest[i, j] = src[i, j]
+
+
+@wp.kernel
+def _copy_2dmat33f(dest: wp.array2d(dtype=wp.mat33f), src: wp.array2d(dtype=wp.mat33f)):
+ i, j = wp.tid()
+ dest[i, j] = src[i, j]
+
+
+@wp.kernel
+def _copy_2dspatialf(
+ dest: wp.array2d(dtype=wp.spatial_vector), src: wp.array2d(dtype=wp.spatial_vector)
+):
+ i, j = wp.tid()
+ dest[i, j] = src[i, j]
+
+
+# TODO(team): remove kernel_copy once wp.copy is supported in cuda subgraphs
+
+
+def kernel_copy(dest: wp.array, src: wp.array):
+ if src.shape != dest.shape:
+ raise ValueError("only same shape copying allowed")
+
+ if src.dtype != dest.dtype:
+ if (src.dtype, dest.dtype) not in (
+ (wp.float32, np.float32),
+ (np.float32, wp.float32),
+ (wp.int32, np.int32),
+ (np.int32, wp.int32),
+ ):
+ raise ValueError(f"only same dtype copying allowed: {src.dtype} != {dest.dtype}")
+
+ if src.ndim == 2 and src.dtype == wp.float32:
+ kernel = _copy_2df
+ elif src.ndim == 3 and src.dtype == wp.float32:
+ kernel = _copy_3df
+ elif src.ndim == 2 and src.dtype == wp.vec3f:
+ kernel = _copy_2dvec3f
+ elif src.ndim == 2 and src.dtype == wp.mat33f:
+ kernel = _copy_2dmat33f
+ elif src.ndim == 2 and src.dtype == types.vec10f:
+ kernel = _copy_2dvec10f
+ elif src.ndim == 2 and src.dtype == wp.spatial_vector:
+ kernel = _copy_2dspatialf
+ else:
+ raise NotImplementedError("copy not supported for these array types")
+
+ wp.launch(kernel=kernel, dim=src.shape, inputs=[dest, src])
diff --git a/mujoco_warp/test_data/collision.xml b/mujoco_warp/test_data/collision.xml
new file mode 100644
index 00000000..5cc041e2
--- /dev/null
+++ b/mujoco_warp/test_data/collision.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mujoco/mjx/test_data/constraints.xml b/mujoco_warp/test_data/constraints.xml
similarity index 100%
rename from mujoco/mjx/test_data/constraints.xml
rename to mujoco_warp/test_data/constraints.xml
diff --git a/mujoco/mjx/test_data/humanoid/humanoid.xml b/mujoco_warp/test_data/humanoid/humanoid.xml
similarity index 96%
rename from mujoco/mjx/test_data/humanoid/humanoid.xml
rename to mujoco_warp/test_data/humanoid/humanoid.xml
index 168310bc..f4152998 100644
--- a/mujoco/mjx/test_data/humanoid/humanoid.xml
+++ b/mujoco_warp/test_data/humanoid/humanoid.xml
@@ -14,7 +14,7 @@
-->
-
@@ -40,7 +40,7 @@
-
+
@@ -186,15 +186,6 @@
-
-
-
-
-
-
-
-
-