diff --git a/.gitignore b/.gitignore index 10f31fb1..60db3b91 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.egg-info/ env/ .history +.vscode # Python byte-compiled / optimized / DLL files __pycache__/ diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 00000000..aff29ad2 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,10 @@ +# This is the list of significant contributors to mjWarp. +# +# This does not necessarily list everyone who has contributed code, +# especially since many employees of one corporation may be contributing. +# To see the full list of contributors, see the revision history in +# source control. + +Google LLC +NVIDIA Corporation + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..e85d9451 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,25 @@ +# How to Contribute + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google/conduct/). diff --git a/README.md b/README.md index b7b4f65f..5345dced 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# mjWarp +# MJWarp -MuJoCo implemented in Warp. +MuJoCo implemented in NVIDIA Warp. # Installing for development @@ -32,47 +32,62 @@ Should print out something like `XX passed in XX.XXs` at the end! Benchmark as follows: ```bash -mjx-testspeed --function=forward --is_sparse=True --mjcf=humanoid/humanoid.xml --batch_size=8192 +mjwarp-testspeed --function=step --mjcf=humanoid/humanoid.xml --batch_size=8192 ``` -Some relevant benchmarks on an NVIDIA GeForce RTX 4090: +To get a full trace of the physics steps (e.g. timings of the subcomponents) run the following: -## forward steps / sec (smooth dynamics only) - -27 dofs per humanoid, 8k batch size. - -| Num Humanoids | MJX | mjWarp dense | mjWarp sparse | -| ----------------| -------| ------------- | -------------- | -| 1 | 7.9M | 15.6M | 13.7M | -| 2 | 2.6M | 7.4M | 7.8M | -| 3 | 2.2M | 4.6M | 5.3M | -| 4 | 1.5M | 3.3M | 4.1M | -| 5 | 1.1M | ❌ | 3.2M | - -# Ideas for what to try next - -## 1. Unroll steps - -In the Pure JAX benchmark, we can tell JAX to unroll some number of FK steps (in the benchmarks above, `unroll=4`). This has a big impact on performance. If we change `unroll` from 4 to 1, pure JAX performance at 8k batch drops from 50M to 33M steps/sec. - -Is there some way that we can improve Warp performance in the same way? If I know ahead of time that I am going to call FK in a loop 1000 times, can I somehow inject unroll primitives? - -## 2. Different levels of parallelism - -The current approach parallelizes over body kinematic tree depth. We could go either direction: remove body parallism (fewer kernel launches), or parallelize over joints instead (more launches, more parallelism). - -## 3. Tiling +```bash +mjwarp-testspeed --function=step --mjcf=humanoid/humanoid.xml --batch_size=8192 --event_trace=True +``` -It looks like a thing! Should we use it? Will it help? +`humanoid.xml` has been carefully optimized for MJX in the following ways: -## 4. Quaternions +* Newton solver iterations are capped at 1, linesearch iterations capped at 4 +* Only foot<>floor collisions are turned on, producing at most 8 contact points +* Adding a damping term in the Euler integrator (which invokes another `factor_m` and `solve_m`) is disabled -Why oh why did Warp make quaternions x,y,z,w? In order to be obstinate I wrote my own quaternion math. Is this slower than using the Warp quaternion primitives? +By comparing MJWarp to MJX on this model, we are comparing MJWarp to the very best that MJX can do. -## 5. `wp.static` +For many (most) MuJoCo models, particularly ones that haven't been carefully tuned, MJX will +do much worse. -Haven't tried this at all - curious to see if it helps. +## physics steps / sec -## 6. Other stuff? +NVIDIA GeForce RTX 4090, 27 dofs, ncon=8, 8k batch size. -Should I be playing with `block_dim`? Is my method for timing OK or did I misunderstand how `wp.synchronize` works? Is there something about allocating that I should be aware of? What am I not thinking of? +``` +Summary for 8192 parallel rollouts + + Total JIT time: 0.82 s + Total simulation time: 2.98 s + Total steps per second: 2,753,173 + Total realtime factor: 13,765.87 x + Total time per step: 363.22 ns + +Event trace: + +step: 361.41 (MJX: 316.58 ns) + forward: 359.15 + fwd_position: 52.58 + kinematics: 19.36 (MJX: 16.45 ns) + com_pos: 7.80 (MJX: 12.37 ns) + crb: 12.44 (MJX: 27.91 ns) + factor_m: 6.40 (MJX: 27.48 ns) + collision: 4.07 (MJX: 1.23 ns) + make_constraint: 6.32 (MJX: 42.39 ns) + transmission: 1.30 (MJX: 3.54 ns) + fwd_velocity: 26.52 + com_vel: 8.44 (MJX: 9.38 ns) + passive: 1.06 (MJX: 3.22 ns) + rne: 10.96 (MJX: 16.75 ns) + fwd_actuation: 2.74 (MJX: 3.93 ns) + fwd_acceleration: 11.90 + xfrc_accumulate: 3.83 (MJX: 6.81 ns) + solve_m: 6.92 (MJX: 8.88 ns) + solve: 264.38 (MJX: 153.29 ns) + mul_m: 5.93 + _linesearch_iterative: 43.15 + mul_m: 3.66 + euler: 1.74 (MJX: 3.78 ns) +``` diff --git a/contrib/README.md b/contrib/README.md new file mode 100644 index 00000000..9f273766 --- /dev/null +++ b/contrib/README.md @@ -0,0 +1,5 @@ +# Contrib + +Contrib is a home for experiments, helper scripts and other tools that are not officially part of MJWarp. + +The contents of this directory are subject to change with no notice. \ No newline at end of file diff --git a/contrib/apptronik_apollo_locomotion.ipynb b/contrib/apptronik_apollo_locomotion.ipynb new file mode 100644 index 00000000..79b16f4f --- /dev/null +++ b/contrib/apptronik_apollo_locomotion.ipynb @@ -0,0 +1,791 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "j2cz906V7d0X" + }, + "source": [ + "# Training Apptronik Apollo using MuJoCo Warp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XpAzX-oCJcSJ" + }, + "outputs": [], + "source": [ + "# you may need some extra deps for this colab:\n", + "# pip install \"jax[cuda12_local]\"\n", + "# pip install playground\n", + "# pip install matplotlib\n", + "# if you want to run this on your local machine you can do like so:\n", + "# pip install jupyter\n", + "# jupyter notebook --NotebookApp.allow_origin='https://colab.research.google.com' --port=8888 --NotebookApp.port_retries=0 --no-browser\n", + "\n", + "import dataclasses\n", + "import datetime\n", + "import functools\n", + "import os\n", + "import time\n", + "from typing import Any, Dict, Optional, Union\n", + "\n", + "import jax\n", + "import mediapy as media\n", + "import mujoco\n", + "import numpy as np\n", + "import warp as wp\n", + "from etils import epath\n", + "from jax import numpy as jp\n", + "from ml_collections import config_dict\n", + "from mujoco import mjx\n", + "from mujoco_playground._src import mjx_env\n", + "from mujoco_playground._src import reward\n", + "from mujoco_playground._src.dm_control_suite import common\n", + "from warp.jax_experimental.ffi import jax_callable\n", + "\n", + "import mujoco_warp as mjwarp\n", + "from mujoco_warp._src.warp_util import kernel_copy\n", + "\n", + "# this ensures JAX embeds Warp kernels into its own computation graph:\n", + "os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_graph_min_graph_size=1\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-BVYPgjk7xcm" + }, + "outputs": [], + "source": [ + "# We'll grab the Apptronik model from MuJoCo Menagerie, then remove some\n", + "# MJX-specific changes that exist in the XML that MJWarp doesn't need\n", + "# (such as explicit contacts, really tight ls_iterations etc.)\n", + "\n", + "mjx_env.ensure_menagerie_exists()\n", + "\n", + "contrib_xml_dir = epath.resource_path(\"mujoco_warp\").parent / \"contrib/xml\"\n", + "apptronik_dir = mjx_env.EXTERNAL_DEPS_PATH / \"mujoco_menagerie/apptronik_apollo/\"\n", + "\n", + "! cp {contrib_xml_dir / 'apptronik_apollo.xml'} {apptronik_dir}\n", + "! cp {contrib_xml_dir / 'scene.xml'} {apptronik_dir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ixAOK05EkQNg" + }, + "outputs": [], + "source": [ + "# MJWarp is not yet fully conncected to JAX. For now we will use the function\n", + "# `mjwarp_step` below, which we wrap in Warp's handy `jax_callable` function\n", + "# that converts Warp kernels to JAX operations.\n", + "#\n", + "# After we build a proper JAX wrapper for MJWarp, this code will disappear.\n", + "\n", + "NWORLD = 8192\n", + "NCONMAX = 81920\n", + "NJMAX = NCONMAX * 4\n", + "\n", + "xml_path = apptronik_dir / \"scene.xml\"\n", + "mjm = mujoco.MjModel.from_xml_path(xml_path.as_posix())\n", + "mjm.opt.iterations = 5\n", + "mjm.opt.ls_iterations = 10\n", + "mjd = mujoco.MjData(mjm)\n", + "mujoco.mj_resetDataKeyframe(mjm, mjd, 0)\n", + "mujoco.mj_forward(mjm, mjd)\n", + "m = mjwarp.put_model(mjm)\n", + "d = mjwarp.put_data(mjm, mjd, nworld=NWORLD, nconmax=NCONMAX, njmax=NJMAX)\n", + "\n", + "\n", + "def mjwarp_step(\n", + " ctrl: wp.array(dtype=wp.float32, ndim=2),\n", + " qpos_in: wp.array(dtype=wp.float32, ndim=2),\n", + " qvel_in: wp.array(dtype=wp.float32, ndim=2),\n", + " qacc_warmstart_in: wp.array(dtype=wp.float32, ndim=2),\n", + " qpos_out: wp.array(dtype=wp.float32, ndim=2),\n", + " qvel_out: wp.array(dtype=wp.float32, ndim=2),\n", + " xpos_out: wp.array(dtype=wp.vec3, ndim=2),\n", + " xmat_out: wp.array(dtype=wp.mat33, ndim=2),\n", + " qacc_warmstart_out: wp.array(dtype=wp.float32, ndim=2),\n", + " subtree_com_out: wp.array(dtype=wp.vec3, ndim=2),\n", + " cvel_out: wp.array(dtype=wp.spatial_vector, ndim=2),\n", + " site_xpos_out: wp.array(dtype=wp.vec3, ndim=2),\n", + "):\n", + " kernel_copy(d.ctrl, ctrl)\n", + " kernel_copy(d.qpos, qpos_in)\n", + " kernel_copy(d.qvel, qvel_in)\n", + " kernel_copy(d.qacc_warmstart, qacc_warmstart_in)\n", + "\n", + " # TODO(team): remove this hard coding substeps\n", + " # ctrl_dt / sim_dt == 4\n", + " for i in range(4):\n", + " mjwarp.step(m, d)\n", + " kernel_copy(qpos_out, d.qpos)\n", + " kernel_copy(qvel_out, d.qvel)\n", + " kernel_copy(xpos_out, d.xpos)\n", + " kernel_copy(xmat_out, d.xmat)\n", + " kernel_copy(qacc_warmstart_out, d.qacc_warmstart)\n", + " kernel_copy(subtree_com_out, d.subtree_com)\n", + " kernel_copy(cvel_out, d.cvel)\n", + " kernel_copy(site_xpos_out, d.site_xpos)\n", + "\n", + "\n", + "jax_mjwarp_step = jax_callable(\n", + " mjwarp_step,\n", + " num_outputs=8,\n", + " output_dims={\n", + " \"qpos_out\": (NWORLD, mjm.nq),\n", + " \"qvel_out\": (NWORLD, mjm.nv),\n", + " \"xpos_out\": (NWORLD, mjm.nbody, 3),\n", + " \"xmat_out\": (NWORLD, mjm.nbody, 3, 3),\n", + " \"qacc_warmstart_out\": (NWORLD, mjm.nv),\n", + " \"subtree_com_out\": (NWORLD, mjm.nbody, 3),\n", + " \"cvel_out\": (NWORLD, mjm.nbody, 6),\n", + " \"site_xpos_out\": (NWORLD, mjm.nsite, 3),\n", + " },\n", + ")\n", + "\n", + "# the functions below allow us to call MJWarp step inside jax vmap:\n", + "\n", + "\n", + "@jax.custom_batching.custom_vmap\n", + "def step(d: mjx.Data):\n", + " return d\n", + "\n", + "\n", + "@step.def_vmap\n", + "def step_vmap_rule(axis_size, in_batched, d: mjx.Data):\n", + " if in_batched[0].ctrl:\n", + " assert d.ctrl.shape[0] == axis_size\n", + " else:\n", + " d = d.replace(ctrl=jp.tile(d.ctrl, (axis_size, 1)))\n", + " params = {f.name: None for f in dataclasses.fields(mjx.Data)}\n", + " params[\"ctrl\"] = True\n", + " params[\"qpos\"] = True\n", + " params[\"qvel\"] = True\n", + " params[\"xpos\"] = True\n", + " params[\"xmat\"] = True\n", + " params[\"qacc_warmstart\"] = True\n", + " params[\"subtree_com\"] = True\n", + " params[\"cvel\"] = True\n", + " params[\"site_xpos\"] = True\n", + " out_batched = mjx.Data(**params)\n", + "\n", + " qpos, qvel, xpos, xmat, qacc_warmstart, subtree_com, cvel, site_xpos = (\n", + " jax_mjwarp_step(d.ctrl, d.qpos, d.qvel, d.qacc_warmstart)\n", + " )\n", + " d = d.replace(\n", + " qpos=qpos,\n", + " qvel=qvel,\n", + " xpos=xpos,\n", + " xmat=xmat,\n", + " qacc_warmstart=qacc_warmstart,\n", + " subtree_com=subtree_com,\n", + " cvel=cvel,\n", + " site_xpos=site_xpos,\n", + " )\n", + " return d, out_batched\n", + "\n", + "\n", + "def init(qpos, ctrl) -> mjx.Data:\n", + " init_params = {f.name: None for f in dataclasses.fields(mjx.Data)}\n", + " init_params[\"qpos\"] = qpos\n", + " init_params[\"ctrl\"] = ctrl\n", + " init_params[\"qvel\"] = jp.zeros(m.nv)\n", + " init_params[\"xpos\"] = jp.zeros((m.nbody, 3))\n", + " init_params[\"xmat\"] = jp.tile(jp.zeros((3, 3)), (m.nbody, 1, 1))\n", + " init_params[\"qacc_warmstart\"] = jp.array(mjd.qacc_warmstart)\n", + " init_params[\"subtree_com\"] = jp.zeros((m.nbody, 3))\n", + " init_params[\"cvel\"] = jp.zeros((m.nbody, 6))\n", + " init_params[\"site_xpos\"] = jp.zeros((m.nsite, 3))\n", + " return mjx.Data(**init_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AaYQk2lu7Bus" + }, + "outputs": [], + "source": [ + "# An environment for training an Apptronik Apollo to walk.\n", + "# This is the same format as environments in MuJoCo Playground.\n", + "\n", + "\n", + "def default_config() -> config_dict.ConfigDict:\n", + " return config_dict.create(\n", + " ctrl_dt=0.02,\n", + " sim_dt=0.005,\n", + " episode_length=1000,\n", + " action_repeat=1,\n", + " action_scale=0.5,\n", + " soft_joint_pos_limit_factor=0.95,\n", + " reward_config=config_dict.create(\n", + " scales=config_dict.create(\n", + " # Tracking related rewards.\n", + " tracking_lin_vel=1.0,\n", + " tracking_ang_vel=0.75,\n", + " # Base related rewards.\n", + " ang_vel_xy=-0.15,\n", + " orientation=-2.0,\n", + " # Energy related rewards.\n", + " action_rate=0.0,\n", + " # Feet related rewards.\n", + " feet_air_time=2.0,\n", + " feet_slip=-0.25,\n", + " feet_phase=1.0,\n", + " # Other rewards.\n", + " termination=-5.0,\n", + " # Pose related rewards.\n", + " joint_deviation_knee=-0.1,\n", + " joint_deviation_hip=-0.25,\n", + " dof_pos_limits=-1.0,\n", + " pose=-0.1,\n", + " ),\n", + " tracking_sigma=0.25,\n", + " max_foot_height=0.15,\n", + " base_height_target=0.5,\n", + " ),\n", + " lin_vel_x=[1.0, 1.0],\n", + " lin_vel_y=[0.0, 0.0],\n", + " ang_vel_yaw=[0.0, 0.0],\n", + " )\n", + "\n", + "\n", + "class Joystick(mjx_env.MjxEnv):\n", + " \"\"\"Track a joystick command.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " config: config_dict.ConfigDict = default_config(),\n", + " config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,\n", + " ):\n", + " super().__init__(config, config_overrides)\n", + " self._post_init()\n", + "\n", + " def _post_init(self) -> None:\n", + " self._init_q = jp.array(mjm.keyframe(\"stand\").qpos)\n", + " self._default_pose = self._init_q[7:]\n", + "\n", + " # Note: First joint is freejoint.\n", + " self._lowers, self._uppers = self.mj_model.jnt_range[1:].T\n", + " c = (self._lowers + self._uppers) / 2\n", + " r = self._uppers - self._lowers\n", + " self._soft_lowers = c - 0.5 * r * self._config.soft_joint_pos_limit_factor\n", + " self._soft_uppers = c + 0.5 * r * self._config.soft_joint_pos_limit_factor\n", + "\n", + " hip_joints = [\"l_hip_ie\", \"l_hip_aa\", \"r_hip_ie\", \"r_hip_aa\"]\n", + " hip_indices = [mjm.joint(j).qposadr - 7 for j in hip_joints]\n", + " self._hip_indices = jp.array(hip_indices)\n", + "\n", + " knee_joints = [\"l_knee_fe\", \"r_knee_fe\"]\n", + " knee_indices = [mjm.joint(j).qposadr - 7 for j in knee_joints]\n", + " self._knee_indices = jp.array(knee_indices)\n", + "\n", + " self._head_body_id = mjm.body(\"neck_pitch_link\").id\n", + " self._torso_id = mjm.body(\"torso_link\").id\n", + "\n", + " feet_sites = [\"l_foot_fr\", \"l_foot_br\", \"l_foot_fl\", \"l_foot_bl\"]\n", + " feet_sites += [\"r_foot_fr\", \"r_foot_br\", \"r_foot_fl\", \"r_foot_bl\"]\n", + " feet_site_ids = [mjm.site(s).id for s in feet_sites]\n", + " self._feet_site_id = jp.array(feet_site_ids)\n", + " self._feet_contact_z = 0.003\n", + "\n", + " self._floor_geom_id = mjm.geom(\"floor\").id\n", + "\n", + " def reset(self, rng: jax.Array) -> mjx_env.State:\n", + " qpos = self._init_q\n", + "\n", + " data = init(qpos=qpos, ctrl=qpos[7:])\n", + "\n", + " # Phase, freq=U(1.0, 1.5)\n", + " rng, key = jax.random.split(rng)\n", + " gait_freq = jax.random.uniform(key, (1,), minval=1.25, maxval=1.5)\n", + " phase_dt = 2 * jp.pi * self.dt * gait_freq\n", + " phase = jp.array([0, jp.pi])\n", + "\n", + " rng, cmd_rng = jax.random.split(rng)\n", + " cmd = self.sample_command(cmd_rng)\n", + "\n", + " info = {\n", + " \"rng\": rng,\n", + " \"step\": 0,\n", + " \"command\": cmd,\n", + " \"last_act\": jp.zeros(mjm.nu),\n", + " \"last_last_act\": jp.zeros(mjm.nu),\n", + " \"motor_targets\": jp.zeros(mjm.nu),\n", + " \"feet_air_time\": jp.zeros(2),\n", + " \"last_contact\": jp.zeros(2, dtype=bool),\n", + " # Phase related.\n", + " \"phase_dt\": phase_dt,\n", + " \"phase\": phase,\n", + " }\n", + "\n", + " metrics = {}\n", + " for k in self._config.reward_config.scales.keys():\n", + " metrics[f\"reward/{k}\"] = jp.zeros(())\n", + "\n", + " contact = data.site_xpos[self._feet_site_id] < self._feet_contact_z\n", + " contact = jp.array([contact[0:4].any(), contact[4:8].any()])\n", + " obs = self._get_obs(data, info, contact)\n", + " reward, done = jp.zeros(2)\n", + " return mjx_env.State(data, obs, reward, done, metrics, info)\n", + "\n", + " def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:\n", + " state.info[\"rng\"], _ = jax.random.split(state.info[\"rng\"], 2)\n", + "\n", + " ctrl = self._default_pose + action * self._config.action_scale\n", + " data = state.data\n", + " data = data.replace(ctrl=ctrl)\n", + " data = step(data)\n", + " state.info[\"motor_targets\"] = ctrl\n", + "\n", + " contact = data.site_xpos[self._feet_site_id, 2] < self._feet_contact_z\n", + " contact = jp.array([contact[0:4].any(), contact[4:8].any()])\n", + " contact_filt = contact | state.info[\"last_contact\"]\n", + " first_contact = (state.info[\"feet_air_time\"] > 0.0) * contact_filt\n", + " state.info[\"feet_air_time\"] += self.dt\n", + "\n", + " obs = self._get_obs(data, state.info, contact)\n", + " done = self._get_termination(data)\n", + "\n", + " rewards = self._get_reward(\n", + " data, action, state.info, state.metrics, done, first_contact, contact\n", + " )\n", + " rewards = {k: v * self._config.reward_config.scales[k] for k, v in rewards.items()}\n", + " reward = sum(rewards.values()) * self.dt\n", + "\n", + " state.info[\"step\"] += 1\n", + " phase_tp1 = state.info[\"phase\"] + state.info[\"phase_dt\"]\n", + " state.info[\"phase\"] = jp.fmod(phase_tp1 + jp.pi, 2 * jp.pi) - jp.pi\n", + " state.info[\"last_last_act\"] = state.info[\"last_act\"]\n", + " state.info[\"last_act\"] = action\n", + " state.info[\"rng\"], cmd_rng = jax.random.split(state.info[\"rng\"])\n", + " state.info[\"command\"] = jp.where(\n", + " state.info[\"step\"] > 500,\n", + " self.sample_command(cmd_rng),\n", + " state.info[\"command\"],\n", + " )\n", + " state.info[\"step\"] = jp.where(\n", + " done | (state.info[\"step\"] > 500),\n", + " 0,\n", + " state.info[\"step\"],\n", + " )\n", + " state.info[\"feet_air_time\"] *= ~contact\n", + " state.info[\"last_contact\"] = contact\n", + " for k, v in rewards.items():\n", + " state.metrics[f\"reward/{k}\"] = v\n", + "\n", + " done = done.astype(reward.dtype)\n", + " state = state.replace(data=data, obs=obs, reward=reward, done=done)\n", + " return state\n", + "\n", + " def _get_termination(self, data: mjx.Data) -> jax.Array:\n", + " fall_termination = data.xpos[self._head_body_id, 2] < 1.0\n", + " return fall_termination\n", + "\n", + " def _get_obs(\n", + " self, data: mjx.Data, info: dict[str, Any], contact: jax.Array\n", + " ) -> mjx_env.Observation:\n", + " cos = jp.cos(info[\"phase\"])\n", + " sin = jp.sin(info[\"phase\"])\n", + " phase = jp.concatenate([cos, sin])\n", + "\n", + " return jp.hstack(\n", + " [\n", + " data.qpos,\n", + " data.qvel,\n", + " data.cvel.ravel(),\n", + " data.xpos.ravel(),\n", + " data.xmat.ravel(),\n", + " phase,\n", + " info[\"command\"],\n", + " info[\"last_act\"],\n", + " info[\"feet_air_time\"],\n", + " ]\n", + " )\n", + "\n", + " def _get_reward(\n", + " self,\n", + " data: mjx.Data,\n", + " action: jax.Array,\n", + " info: dict[str, Any],\n", + " metrics: dict[str, Any],\n", + " done: jax.Array,\n", + " first_contact: jax.Array,\n", + " contact: jax.Array,\n", + " ) -> dict[str, jax.Array]:\n", + " del metrics # Unused.\n", + " return {\n", + " # Tracking rewards.\n", + " \"tracking_lin_vel\": self._reward_tracking_lin_vel(\n", + " info[\"command\"], self._get_global_linvel(data, self._torso_id)\n", + " ),\n", + " \"tracking_ang_vel\": self._reward_tracking_ang_vel(\n", + " info[\"command\"], self._get_global_angvel(data, self._torso_id)\n", + " ),\n", + " # Base-related rewards.\n", + " \"ang_vel_xy\": self._cost_ang_vel_xy(\n", + " self._get_global_angvel(data, self._torso_id)\n", + " ),\n", + " \"orientation\": self._cost_orientation(self._get_z_frame(data, self._torso_id)),\n", + " # Energy related rewards.\n", + " \"action_rate\": self._cost_action_rate(\n", + " action, info[\"last_act\"], info[\"last_last_act\"]\n", + " ),\n", + " # Feet related rewards.\n", + " \"feet_slip\": self._cost_feet_slip(data, contact, info),\n", + " \"feet_air_time\": self._reward_feet_air_time(\n", + " info[\"feet_air_time\"], first_contact, info[\"command\"]\n", + " ),\n", + " \"feet_phase\": self._reward_feet_phase(\n", + " data,\n", + " info[\"phase\"],\n", + " self._config.reward_config.max_foot_height,\n", + " info[\"command\"],\n", + " ),\n", + " # Pose related rewards.\n", + " \"joint_deviation_hip\": self._cost_joint_deviation_hip(\n", + " data.qpos[7:], info[\"command\"]\n", + " ),\n", + " \"joint_deviation_knee\": self._cost_joint_deviation_knee(data.qpos[7:]),\n", + " \"dof_pos_limits\": self._cost_joint_pos_limits(data.qpos[7:]),\n", + " \"pose\": self._cost_pose(data.qpos[7:]),\n", + " # Other rewards.\n", + " \"termination\": self._cost_termination(done),\n", + " }\n", + "\n", + " def _get_global_angvel(self, data: mjx.Data, bodyid: int):\n", + " return data.cvel[bodyid, :3]\n", + "\n", + " def _get_global_linvel(self, data: mjx.Data, bodyid: int):\n", + " offset = data.xpos[bodyid] - data.subtree_com[mjm.body_rootid[bodyid]]\n", + " xang = data.cvel[bodyid, :3]\n", + " xvel = data.cvel[bodyid, 3:] + jp.cross(offset, xang)\n", + " return xvel\n", + "\n", + " def _get_z_frame(self, data: mjx.Data, bodyid: int):\n", + " return data.xmat[bodyid, :, 2]\n", + "\n", + " # Tracking rewards.\n", + "\n", + " def _reward_tracking_lin_vel(\n", + " self,\n", + " commands: jax.Array,\n", + " local_vel: jax.Array,\n", + " ) -> jax.Array:\n", + " lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))\n", + " return jp.exp(-lin_vel_error / self._config.reward_config.tracking_sigma)\n", + "\n", + " def _reward_tracking_ang_vel(\n", + " self,\n", + " commands: jax.Array,\n", + " ang_vel: jax.Array,\n", + " ) -> jax.Array:\n", + " ang_vel_error = jp.square(commands[2] - ang_vel[2])\n", + " return jp.exp(-ang_vel_error / self._config.reward_config.tracking_sigma)\n", + "\n", + " # Base-related rewards.\n", + "\n", + " def _cost_ang_vel_xy(self, global_angvel_torso: jax.Array) -> jax.Array:\n", + " return jp.sum(jp.square(global_angvel_torso[:2]))\n", + "\n", + " def _cost_orientation(self, torso_zaxis: jax.Array) -> jax.Array:\n", + " return jp.sum(jp.square(torso_zaxis - jp.array([0.0, 0.0, 1.0])))\n", + "\n", + " def _cost_base_height(self, base_height: jax.Array) -> jax.Array:\n", + " return jp.square(base_height - self._config.reward_config.base_height_target)\n", + "\n", + " # Energy related rewards.\n", + "\n", + " def _cost_action_rate(\n", + " self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array\n", + " ) -> jax.Array:\n", + " del last_last_act # Unused.\n", + " return jp.sum(jp.square(act - last_act))\n", + "\n", + " # Feet related rewards.\n", + "\n", + " def _cost_feet_slip(\n", + " self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]\n", + " ) -> jax.Array:\n", + " del info # Unused.\n", + " body_vel = self._get_global_linvel(data, self._torso_id)[:2]\n", + " reward = jp.sum(jp.linalg.norm(body_vel, axis=-1) * contact)\n", + " return reward\n", + "\n", + " def _reward_feet_air_time(\n", + " self,\n", + " air_time: jax.Array,\n", + " first_contact: jax.Array,\n", + " commands: jax.Array,\n", + " threshold_min: float = 0.2,\n", + " threshold_max: float = 0.5,\n", + " ) -> jax.Array:\n", + " del commands # Unused.\n", + " air_time = (air_time - threshold_min) * first_contact\n", + " air_time = jp.clip(air_time, max=threshold_max - threshold_min)\n", + " reward = jp.sum(air_time)\n", + " return reward\n", + "\n", + " def get_rz(\n", + " phi: Union[jax.Array, float], swing_height: Union[jax.Array, float] = 0.08\n", + " ) -> jax.Array:\n", + " def cubic_bezier_interpolation(y_start, y_end, x):\n", + " y_diff = y_end - y_start\n", + " bezier = x**3 + 3 * (x**2 * (1 - x))\n", + " return y_start + y_diff * bezier\n", + "\n", + " x = (phi + jp.pi) / (2 * jp.pi)\n", + " stance = cubic_bezier_interpolation(0, swing_height, 2 * x)\n", + " swing = cubic_bezier_interpolation(swing_height, 0, 2 * x - 1)\n", + " return jp.where(x <= 0.5, stance, swing)\n", + "\n", + " def _reward_feet_phase(\n", + " self,\n", + " data: mjx.Data,\n", + " phase: jax.Array,\n", + " foot_height: jax.Array,\n", + " command: jax.Array,\n", + " ) -> jax.Array:\n", + " # Reward for tracking the desired foot height.\n", + " foot_pos = data.site_xpos[self._feet_site_id]\n", + " foot_pos = jp.array(\n", + " [jp.mean(foot_pos[0:4], axis=0), jp.mean(foot_pos[4:8], axis=0)]\n", + " )\n", + " foot_z = foot_pos[..., -1]\n", + " rz = Joystick.get_rz(phase, swing_height=foot_height)\n", + " error = jp.sum(jp.square(foot_z - rz))\n", + " reward = jp.exp(-error / 0.01)\n", + " body_linvel = self._get_global_linvel(data, self._torso_id)[:2]\n", + " body_angvel = self._get_global_angvel(data, self._torso_id)[2]\n", + " linvel_mask = jp.logical_or(\n", + " jp.linalg.norm(body_linvel) > 0.1,\n", + " jp.abs(body_angvel) > 0.1,\n", + " )\n", + " mask = jp.logical_or(linvel_mask, jp.linalg.norm(command) > 0.01)\n", + " reward *= mask\n", + " return reward\n", + "\n", + " # Pose-related rewards.\n", + "\n", + " def _cost_joint_deviation_hip(self, qpos: jax.Array, cmd: jax.Array) -> jax.Array:\n", + " error = qpos[self._hip_indices] - self._default_pose[self._hip_indices]\n", + " # Allow roll deviation when lateral velocity is high.\n", + " weight = jp.where(\n", + " cmd[1] > 0.1,\n", + " jp.array([0.0, 1.0, 0.0, 1.0]),\n", + " jp.array([1.0, 1.0, 1.0, 1.0]),\n", + " )\n", + " cost = jp.sum(jp.abs(error) * weight)\n", + " return cost\n", + "\n", + " def _cost_joint_deviation_knee(self, qpos: jax.Array) -> jax.Array:\n", + " error = qpos[self._knee_indices] - self._default_pose[self._knee_indices]\n", + " return jp.sum(jp.abs(error))\n", + "\n", + " def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array:\n", + " out_of_limits = -jp.clip(qpos - self._soft_lowers, None, 0.0)\n", + " out_of_limits += jp.clip(qpos - self._soft_uppers, 0.0, None)\n", + " return jp.sum(out_of_limits)\n", + "\n", + " def _cost_pose(self, qpos: jax.Array) -> jax.Array:\n", + " return jp.sum(jp.square(qpos - self._default_pose))\n", + "\n", + " # Other rewards.\n", + "\n", + " def _cost_termination(self, done: jax.Array) -> jax.Array:\n", + " return done\n", + "\n", + " def sample_command(self, rng: jax.Array) -> jax.Array:\n", + " rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)\n", + "\n", + " lin_vel_x = jax.random.uniform(\n", + " rng1, minval=self._config.lin_vel_x[0], maxval=self._config.lin_vel_x[1]\n", + " )\n", + " lin_vel_y = jax.random.uniform(\n", + " rng2, minval=self._config.lin_vel_y[0], maxval=self._config.lin_vel_y[1]\n", + " )\n", + " ang_vel_yaw = jax.random.uniform(\n", + " rng3,\n", + " minval=self._config.ang_vel_yaw[0],\n", + " maxval=self._config.ang_vel_yaw[1],\n", + " )\n", + "\n", + " return jp.hstack([lin_vel_x, lin_vel_y, ang_vel_yaw])\n", + "\n", + " @property\n", + " def xml_path(self) -> str:\n", + " return xml_path.as_posix()\n", + "\n", + " @property\n", + " def action_size(self) -> int:\n", + " return mjm.nu\n", + "\n", + " @property\n", + " def mj_model(self) -> mujoco.MjModel:\n", + " return mjm\n", + "\n", + " @property\n", + " def mjx_model(self) -> mjx.Model:\n", + " return None # unused" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aoWmBHXiU-HW" + }, + "outputs": [], + "source": [ + "# Train the environment using Brax PPO\n", + "\n", + "from IPython.display import HTML, clear_output\n", + "from brax.training.agents.ppo import networks as ppo_networks\n", + "from brax.training.agents.ppo import train as ppo\n", + "from matplotlib import pyplot as plt\n", + "from mujoco_playground import wrapper\n", + "from mujoco_playground.config import locomotion_params\n", + "\n", + "ppo_params = config_dict.create(\n", + " num_timesteps=50_000_000,\n", + " num_evals=5,\n", + " reward_scaling=1.0,\n", + " clipping_epsilon=0.2,\n", + " episode_length=1000,\n", + " normalize_observations=True,\n", + " action_repeat=1,\n", + " unroll_length=20,\n", + " num_minibatches=32,\n", + " num_updates_per_batch=4,\n", + " discounting=0.97,\n", + " learning_rate=3e-4,\n", + " entropy_cost=0.005,\n", + " num_envs=NWORLD,\n", + " num_eval_envs=NWORLD,\n", + " batch_size=256,\n", + " max_grad_norm=1.0,\n", + " # network_factory=config_dict.create(\n", + " # policy_hidden_layer_sizes=(512, 256, 128),\n", + " # value_hidden_layer_sizes=(512, 256, 128),\n", + " # ),\n", + ")\n", + "\n", + "x_data, y_data, y_dataerr = [], [], []\n", + "times = [datetime.datetime.now()]\n", + "\n", + "\n", + "def progress(num_steps, metrics):\n", + " times.append(datetime.datetime.now())\n", + " x_data.append(num_steps)\n", + " y_data.append(metrics[\"eval/episode_reward\"])\n", + " y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n", + "\n", + " plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n", + " plt.ylim([0, 30])\n", + " plt.xlabel(\"# environment steps\")\n", + " plt.ylabel(\"reward per episode\")\n", + " plt.title(f\"y={y_data[-1]:.3f}\")\n", + " plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n", + "\n", + " display(plt.gcf())\n", + " clear_output(wait=True)\n", + "\n", + "\n", + "ppo_training_params = dict(ppo_params)\n", + "network_factory = ppo_networks.make_ppo_networks\n", + "if \"network_factory\" in ppo_params:\n", + " del ppo_training_params[\"network_factory\"]\n", + " network_factory = functools.partial(\n", + " ppo_networks.make_ppo_networks, **ppo_params.network_factory\n", + " )\n", + "\n", + "train_fn = functools.partial(\n", + " ppo.train,\n", + " **dict(ppo_training_params),\n", + " network_factory=network_factory,\n", + " progress_fn=progress,\n", + ")\n", + "\n", + "env = Joystick()\n", + "\n", + "make_inference_fn, params, metrics = train_fn(\n", + " environment=env,\n", + " wrap_env_fn=wrapper.wrap_for_brax_training,\n", + ")\n", + "\n", + "print(f\"time to jit: {times[1] - times[0]}\")\n", + "print(f\"time to train: {times[-1] - times[1]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gGRhELoRjtQs" + }, + "outputs": [], + "source": [ + "rng = jax.random.PRNGKey(0)\n", + "\n", + "jit_reset = jax.jit(env.reset)\n", + "\n", + "\n", + "def unroll(state):\n", + " inference_fn = make_inference_fn(params, deterministic=True)\n", + " rng = jax.random.PRNGKey(0)\n", + "\n", + " def single_step(state, _):\n", + " action, _ = inference_fn(state.obs, rng)\n", + " action = jp.tile(action, (NWORLD, 1))\n", + " state = jax.tree.map(lambda x: jp.tile(x, (NWORLD,) + (1,) * len(x.shape)), state)\n", + " state = jax.vmap(env.step)(state, action)\n", + " state = jax.tree.map(lambda x: x[0], state)\n", + "\n", + " return state, state\n", + "\n", + " _, states = jax.lax.scan(single_step, state, length=1000)\n", + "\n", + " return states\n", + "\n", + "\n", + "rollout = jax.jit(unroll)(jit_reset(rng))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_kj9fCiDn6j6" + }, + "outputs": [], + "source": [ + "rollout_arr = [jax.tree.map(lambda x, i=i: x[i], rollout) for i in range(400)]\n", + "frames = env.render(rollout_arr, camera=\"track\", width=640, height=480)\n", + "media.show_video(frames, fps=1.0 / env.dt)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/contrib/jax_unroll.py b/contrib/jax_unroll.py new file mode 100644 index 00000000..ec6db438 --- /dev/null +++ b/contrib/jax_unroll.py @@ -0,0 +1,86 @@ +# before running this, you may need to install JAX, most likely +# backed by your local cuda install: +# +# pip install --upgrade "jax[cuda12_local]" + +import os +import time + +import jax +import mujoco +import numpy as np +import warp as wp +from etils import epath +from jax import numpy as jp +from warp.jax_experimental.ffi import jax_callable + +import mujoco_warp as mjwarp +from mujoco_warp._src.warp_util import kernel_copy + +os.environ["XLA_FLAGS"] = "--xla_gpu_graph_min_graph_size=1" + +NWORLDS = 8192 +UNROLL_LENGTH = 1000 + +wp.clear_kernel_cache() + +path = epath.resource_path("mujoco_warp") / "test_data" / "humanoid/humanoid.xml" +mjm = mujoco.MjModel.from_xml_path(path.as_posix()) +mjm.opt.iterations = 1 +mjm.opt.ls_iterations = 4 +mjd = mujoco.MjData(mjm) +# give the system a little kick to ensure we have non-identity rotations +mjd.qvel = np.random.uniform(-0.01, 0.01, mjm.nv) +mujoco.mj_step(mjm, mjd, 3) # let dynamics get state significantly non-zero +mujoco.mj_forward(mjm, mjd) +m = mjwarp.put_model(mjm) +d = mjwarp.put_data(mjm, mjd, nworld=NWORLDS, nconmax=131012, njmax=131012 * 4) + + +def warp_step( + qpos_in: wp.array(dtype=wp.float32, ndim=2), + qvel_in: wp.array(dtype=wp.float32, ndim=2), + qpos_out: wp.array(dtype=wp.float32, ndim=2), + qvel_out: wp.array(dtype=wp.float32, ndim=2), +): + kernel_copy(d.qpos, qpos_in) + kernel_copy(d.qvel, qvel_in) + mjwarp.step(m, d) + kernel_copy(qpos_out, d.qpos) + kernel_copy(qvel_out, d.qvel) + + +warp_step_fn = jax_callable( + warp_step, + num_outputs=2, + output_dims={"qpos_out": (NWORLDS, mjm.nq), "qvel_out": (NWORLDS, mjm.nv)}, +) + +jax_qpos = jp.tile(jp.array(m.qpos0), (8192, 1)) +jax_qvel = jp.zeros((8192, m.nv)) + + +def unroll(qpos, qvel): + def step(carry, _): + qpos, qvel = carry + qpos, qvel = warp_step_fn(qpos, qvel) + return (qpos, qvel), None + + (qpos, qvel), _ = jax.lax.scan(step, (qpos, qvel), length=UNROLL_LENGTH) + + return qpos, qvel + + +jax_unroll_fn = jax.jit(unroll).lower(jax_qpos, jax_qvel).compile() + +# warm up: +jax.block_until_ready(jax_unroll_fn(jax_qpos, jax_qvel)) + +beg = time.perf_counter() +final_qpos, final_qvel = jax_unroll_fn(jax_qpos, jax_qvel) +jax.block_until_ready((final_qpos, final_qvel)) +end = time.perf_counter() + +run_time = end - beg + +print(f"Total steps per second: {NWORLDS * UNROLL_LENGTH / run_time:,.0f}") diff --git a/contrib/xml/apptronik_apollo.xml b/contrib/xml/apptronik_apollo.xml new file mode 100644 index 00000000..31acc2c4 --- /dev/null +++ b/contrib/xml/apptronik_apollo.xml @@ -0,0 +1,397 @@ + + + + diff --git a/contrib/xml/scene.xml b/contrib/xml/scene.xml new file mode 100644 index 00000000..d8e44e1f --- /dev/null +++ b/contrib/xml/scene.xml @@ -0,0 +1,35 @@ + + + + + diff --git a/mujoco/mjx/__init__.py b/mujoco/mjx/__init__.py deleted file mode 100644 index 54db1ce0..00000000 --- a/mujoco/mjx/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025 The Physics-Next Project Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Public API for MJX.""" - -from ._src.collision_driver import broad_phase -from ._src.constraint import make_constraint -from ._src.forward import euler -from ._src.forward import forward -from ._src.forward import fwd_actuation -from ._src.forward import fwd_acceleration -from ._src.forward import fwd_position -from ._src.forward import fwd_velocity -from ._src.forward import step -from ._src.io import make_data -from ._src.io import put_data -from ._src.io import put_model -from ._src.passive import passive -from ._src.smooth import com_pos -from ._src.smooth import com_vel -from ._src.smooth import crb -from ._src.smooth import factor_m -from ._src.smooth import kinematics -from ._src.smooth import rne -from ._src.smooth import solve_m -from ._src.smooth import transmission -from ._src.solver import solve -from ._src.support import is_sparse -from ._src.support import mul_m -from ._src.support import xfrc_accumulate -from ._src.test_util import benchmark -from ._src.types import * diff --git a/mujoco/mjx/_src/__init__.py b/mujoco/mjx/_src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mujoco/mjx/_src/broad_phase_test.py b/mujoco/mjx/_src/broad_phase_test.py deleted file mode 100644 index 59e2150c..00000000 --- a/mujoco/mjx/_src/broad_phase_test.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright 2025 The Physics-Next Project Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for broad phase functions.""" - -from absl.testing import absltest -from absl.testing import parameterized -import mujoco -from mujoco import mjx -import numpy as np -import warp as wp - -from . import test_util - -BoxType = wp.types.matrix(shape=(2, 3), dtype=wp.float32) - - -# Helper function to initialize a box -def init_box(min_x, min_y, min_z, max_x, max_y, max_z): - center = wp.vec3((min_x + max_x) / 2, (min_y + max_y) / 2, (min_z + max_z) / 2) - size = wp.vec3(max_x - min_x, max_y - min_y, max_z - min_z) - box = wp.types.matrix(shape=(2, 3), dtype=wp.float32)( - [center.x, center.y, center.z, size.x, size.y, size.z] - ) - return box - - -def overlap( - a: wp.types.matrix(shape=(2, 3), dtype=wp.float32), - b: wp.types.matrix(shape=(2, 3), dtype=wp.float32), -) -> bool: - # Extract centers and sizes - a_center = a[0] - a_size = a[1] - b_center = b[0] - b_size = b[1] - - # Calculate min/max from center and size - a_min = a_center - 0.5 * a_size - a_max = a_center + 0.5 * a_size - b_min = b_center - 0.5 * b_size - b_max = b_center + 0.5 * b_size - - return not ( - a_min.x > b_max.x - or b_min.x > a_max.x - or a_min.y > b_max.y - or b_min.y > a_max.y - or a_min.z > b_max.z - or b_min.z > a_max.z - ) - - -def transform_aabb( - aabb: wp.types.matrix(shape=(2, 3), dtype=wp.float32), - pos: wp.vec3, - rot: wp.mat33, -) -> wp.types.matrix(shape=(2, 3), dtype=wp.float32): - # Extract center and half-extents from AABB - center = aabb[0] - half_extents = aabb[1] * 0.5 - - # Get absolute values of rotation matrix columns - right = wp.vec3(wp.abs(rot[0, 0]), wp.abs(rot[0, 1]), wp.abs(rot[0, 2])) - up = wp.vec3(wp.abs(rot[1, 0]), wp.abs(rot[1, 1]), wp.abs(rot[1, 2])) - forward = wp.vec3(wp.abs(rot[2, 0]), wp.abs(rot[2, 1]), wp.abs(rot[2, 2])) - - # Compute world space half-extents - world_extents = ( - right * half_extents.x + up * half_extents.y + forward * half_extents.z - ) - - # Transform center - new_center = rot @ center + pos - - # Return new AABB as matrix with center and full size - result = BoxType() - result[0] = wp.vec3(new_center.x, new_center.y, new_center.z) - result[1] = wp.vec3( - world_extents.x * 2.0, world_extents.y * 2.0, world_extents.z * 2.0 - ) - return result - - -def find_overlaps_brute_force(worldId: int, num_boxes_per_world: int, boxes, pos, rot): - """ - Finds overlapping bounding boxes using the brute-force O(n^2) algorithm. - - Returns: - List of tuples [(idx1, idx2)] where idx1 and idx2 are indices of overlapping boxes. - """ - overlaps = [] - - for i in range(num_boxes_per_world): - box_a = boxes[i] - box_a = transform_aabb(box_a, pos[worldId, i], rot[worldId, i]) - - for j in range(i + 1, num_boxes_per_world): - box_b = boxes[j] - box_b = transform_aabb(box_b, pos[worldId, j], rot[worldId, j]) - - # Use the overlap function to check for overlap - if overlap(box_a, box_b): - overlaps.append((i, j)) # Store indices of overlapping boxes - - return overlaps - - -def find_overlaps_brute_force_batched( - num_worlds: int, num_boxes_per_world: int, boxes, pos, rot -): - """ - Finds overlapping bounding boxes using the brute-force O(n^2) algorithm. - - Returns: - List of tuples [(idx1, idx2)] where idx1 and idx2 are indices of overlapping boxes. - """ - - overlaps = [] - - for worldId in range(num_worlds): - overlaps.append( - find_overlaps_brute_force(worldId, num_boxes_per_world, boxes, pos, rot) - ) - - # Show progress bar for brute force computation - # from tqdm import tqdm - - # for worldId in tqdm(range(num_worlds), desc="Computing overlaps"): - # overlaps.append(find_overlaps_brute_force(worldId, num_boxes_per_world, boxes)) - - return overlaps - - -class MultiIndexList: - def __init__(self): - self.data = {} - - def __setitem__(self, key, value): - worldId, i = key - if worldId not in self.data: - self.data[worldId] = [] - if i >= len(self.data[worldId]): - self.data[worldId].extend([None] * (i - len(self.data[worldId]) + 1)) - self.data[worldId][i] = value - - def __getitem__(self, key): - worldId, i = key - return self.data[worldId][i] # Raises KeyError if not found - - -class BroadPhaseTest(parameterized.TestCase): - def test_broad_phase(self): - """Tests broad phase.""" - _, mjd, m, d = test_util.fixture("humanoid/humanoid.xml") - - # Create some test boxes - num_worlds = d.nworld - num_boxes_per_world = m.ngeom - # print(f"num_worlds: {num_worlds}, num_boxes_per_world: {num_boxes_per_world}") - - # Parameters for random box generation - sample_space_origin = wp.vec3(-10.0, -10.0, -10.0) # Origin of the bounding volume - sample_space_size = wp.vec3(20.0, 20.0, 20.0) # Size of the bounding volume - min_edge_length = 0.5 # Minimum edge length of random boxes - max_edge_length = 5.0 # Maximum edge length of random boxes - - boxes_list = [] - - # Set random seed for reproducibility - import random - - random.seed(11) - - # Generate random boxes for each world - for _ in range(num_boxes_per_world): - # Generate random position within bounding volume - pos_x = sample_space_origin.x + random.random() * sample_space_size.x - pos_y = sample_space_origin.y + random.random() * sample_space_size.y - pos_z = sample_space_origin.z + random.random() * sample_space_size.z - - # Generate random box dimensions between min and max edge lengths - size_x = min_edge_length + random.random() * (max_edge_length - min_edge_length) - size_y = min_edge_length + random.random() * (max_edge_length - min_edge_length) - size_z = min_edge_length + random.random() * (max_edge_length - min_edge_length) - - # Create box with random position and size - boxes_list.append( - init_box(pos_x, pos_y, pos_z, pos_x + size_x, pos_y + size_y, pos_z + size_z) - ) - - # Generate random positions and orientations for each box - pos = [] - rot = [] - for _ in range(num_worlds * num_boxes_per_world): - # Random position within bounding volume - pos_x = sample_space_origin.x + random.random() * sample_space_size.x - pos_y = sample_space_origin.y + random.random() * sample_space_size.y - pos_z = sample_space_origin.z + random.random() * sample_space_size.z - pos.append(wp.vec3(pos_x, pos_y, pos_z)) - # pos.append(wp.vec3(0, 0, 0)) - - # Random rotation matrix - rx = random.random() * 6.28318530718 # 2*pi - ry = random.random() * 6.28318530718 - rz = random.random() * 6.28318530718 - axis = wp.vec3(rx, ry, rz) - axis = axis / wp.length(axis) # normalize axis - angle = random.random() * 6.28318530718 # random angle between 0 and 2*pi - rot.append(wp.quat_to_matrix(wp.quat_from_axis_angle(axis, angle))) - # rot.append(wp.quat_to_matrix(wp.quat_from_axis_angle(wp.vec3(1, 0, 0), float(0)))) - - # Convert pos and rot to MultiIndexList format - pos_multi = MultiIndexList() - rot_multi = MultiIndexList() - - # Populate the MultiIndexLists using pos and rot data - idx = 0 - for world_idx in range(num_worlds): - for i in range(num_boxes_per_world): - pos_multi[world_idx, i] = pos[idx] - rot_multi[world_idx, i] = rot[idx] - idx += 1 - - brute_force_overlaps = find_overlaps_brute_force_batched( - num_worlds, num_boxes_per_world, boxes_list, pos_multi, rot_multi - ) - - # Test the broad phase by setting custom aabb data - d.geom_aabb = wp.array( - boxes_list, dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32) - ) - d.geom_aabb = d.geom_aabb.reshape((num_boxes_per_world)) - d.geom_xpos = wp.array(pos, dtype=wp.vec3) - d.geom_xpos = d.geom_xpos.reshape((num_worlds, num_boxes_per_world)) - d.geom_xmat = wp.array(rot, dtype=wp.mat33) - d.geom_xmat = d.geom_xmat.reshape((num_worlds, num_boxes_per_world)) - - mjx.broad_phase(m, d) - - result = d.broadphase_pairs - result_count = d.result_count - - # Get numpy arrays from result and result_count - result_np = result.numpy() - result_count_np = result_count.numpy() - - # Iterate over each world - for world_idx in range(num_worlds): - # Get number of collisions for this world - num_collisions = result_count_np[world_idx] - print(f"Number of collisions for world {world_idx}: {num_collisions}") - - list = brute_force_overlaps[world_idx] - assert len(list) == num_collisions, "Number of collisions does not match" - - # Print each collision pair - for i in range(num_collisions): - pair = result_np[world_idx][i] - - # Convert pair to tuple for comparison - pair_tuple = (int(pair[0]), int(pair[1])) - assert pair_tuple in list, ( - f"Collision pair {pair_tuple} not found in brute force results" - ) - - -if __name__ == "__main__": - wp.init() - absltest.main() diff --git a/mujoco/mjx/_src/collision_driver.py b/mujoco/mjx/_src/collision_driver.py deleted file mode 100644 index 4567898d..00000000 --- a/mujoco/mjx/_src/collision_driver.py +++ /dev/null @@ -1,388 +0,0 @@ -# Copyright 2025 The Physics-Next Project Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import warp as wp - -from .types import Model -from .types import Data - -BoxType = wp.types.matrix(shape=(2, 3), dtype=wp.float32) - - -# TODO: Verify that this is corect -@wp.func -def transform_aabb( - aabb: wp.types.matrix(shape=(2, 3), dtype=wp.float32), - pos: wp.vec3, - rot: wp.mat33, -) -> wp.types.matrix(shape=(2, 3), dtype=wp.float32): - # Extract center and extents from AABB - center = aabb[0] - extents = aabb[1] - - absRot = rot - absRot[0, 0] = wp.abs(rot[0, 0]) - absRot[0, 1] = wp.abs(rot[0, 1]) - absRot[0, 2] = wp.abs(rot[0, 2]) - absRot[1, 0] = wp.abs(rot[1, 0]) - absRot[1, 1] = wp.abs(rot[1, 1]) - absRot[1, 2] = wp.abs(rot[1, 2]) - absRot[2, 0] = wp.abs(rot[2, 0]) - absRot[2, 1] = wp.abs(rot[2, 1]) - absRot[2, 2] = wp.abs(rot[2, 2]) - world_extents = extents * absRot - - # Transform center - new_center = rot @ center + pos - - # Return new AABB as matrix with center and full size - result = BoxType() - result[0] = wp.vec3(new_center.x, new_center.y, new_center.z) - result[1] = wp.vec3(world_extents.x, world_extents.y, world_extents.z) - return result - - -@wp.func -def overlap( - a: wp.types.matrix(shape=(2, 3), dtype=wp.float32), - b: wp.types.matrix(shape=(2, 3), dtype=wp.float32), -) -> bool: - # Extract centers and sizes - a_center = a[0] - a_size = a[1] - b_center = b[0] - b_size = b[1] - - # Calculate min/max from center and size - a_min = a_center - 0.5 * a_size - a_max = a_center + 0.5 * a_size - b_min = b_center - 0.5 * b_size - b_max = b_center + 0.5 * b_size - - return not ( - a_min.x > b_max.x - or b_min.x > a_max.x - or a_min.y > b_max.y - or b_min.y > a_max.y - or a_min.z > b_max.z - or b_min.z > a_max.z - ) - - -@wp.kernel -def broad_phase_project_boxes_onto_sweep_direction_kernel( - boxes: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1), - box_translations: wp.array(dtype=wp.vec3, ndim=2), - box_rotations: wp.array(dtype=wp.mat33, ndim=2), - data_start: wp.array(dtype=wp.float32, ndim=2), - data_end: wp.array(dtype=wp.float32, ndim=2), - data_indexer: wp.array(dtype=wp.int32, ndim=2), - direction: wp.vec3, - abs_dir: wp.vec3, - result_count: wp.array(dtype=wp.int32, ndim=1), -): - worldId, i = wp.tid() - - box = boxes[i] # box is a vector6 - box = transform_aabb(box, box_translations[worldId, i], box_rotations[worldId, i]) - box_center = box[0] - box_size = box[1] - center = wp.dot(direction, box_center) - d = wp.dot(box_size, abs_dir) - f = center - d - - # Store results in the data arrays - data_start[worldId, i] = f - data_end[worldId, i] = center + d - data_indexer[worldId, i] = i - - if i == 0: - result_count[worldId] = 0 # Initialize result count to 0 - - -@wp.kernel -def reorder_bounding_boxes_kernel( - boxes: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1), - box_translations: wp.array(dtype=wp.vec3, ndim=2), - box_rotations: wp.array(dtype=wp.mat33, ndim=2), - boxes_sorted: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=2), - data_indexer: wp.array(dtype=wp.int32, ndim=2), -): - worldId, i = wp.tid() - - # Get the index from the data indexer - mapped = data_indexer[worldId, i] - - # Get the box from the original boxes array - box = boxes[mapped] - box = transform_aabb( - box, box_translations[worldId, mapped], box_rotations[worldId, mapped] - ) - - # Reorder the box into the sorted array - boxes_sorted[worldId, i] = box - - -@wp.func -def find_first_greater_than( - worldId: int, - starts: wp.array(dtype=wp.float32, ndim=2), - value: wp.float32, - low: int, - high: int, -) -> int: - while low < high: - mid = (low + high) >> 1 - if starts[worldId, mid] > value: - high = mid - else: - low = mid + 1 - return low - - -@wp.kernel -def broad_phase_sweep_and_prune_prepare_kernel( - num_boxes_per_world: int, - data_start: wp.array(dtype=wp.float32, ndim=2), - data_end: wp.array(dtype=wp.float32, ndim=2), - indexer: wp.array(dtype=wp.int32, ndim=2), - cumulative_sum: wp.array(dtype=wp.int32, ndim=2), -): - worldId, i = wp.tid() # Get the thread ID - - # Get the index of the current bounding box - idx1 = indexer[worldId, i] - - end = data_end[worldId, idx1] - limit = find_first_greater_than(worldId, data_start, end, i + 1, num_boxes_per_world) - limit = wp.min(num_boxes_per_world - 1, limit) - - # Calculate the range of boxes for the sweep and prune process - count = limit - i - - # Store the cumulative sum for the current box - cumulative_sum[worldId, i] = count - - -@wp.func -def find_right_most_index_int( - starts: wp.array(dtype=wp.int32, ndim=1), value: wp.int32, low: int, high: int -) -> int: - while low < high: - mid = (low + high) >> 1 - if starts[mid] > value: - high = mid - else: - low = mid + 1 - return high - - -@wp.func -def find_indices( - id: int, cumulative_sum: wp.array(dtype=wp.int32, ndim=1), length: int -) -> wp.vec2i: - # Perform binary search to find the right most index - i = find_right_most_index_int(cumulative_sum, id, 0, length) - - # Get the baseId, and compute the offset and j - if i > 0: - base_id = cumulative_sum[i - 1] - else: - base_id = 0 - offset = id - base_id - j = i + offset + 1 - - return wp.vec2i(i, j) - - -@wp.kernel -def broad_phase_sweep_and_prune_kernel( - num_threads: int, - length: int, - num_boxes_per_world: int, - max_num_overlaps_per_world: int, - cumulative_sum: wp.array(dtype=wp.int32, ndim=1), - data_indexer: wp.array(dtype=wp.int32, ndim=2), - data_result: wp.array(dtype=wp.vec2i, ndim=2), - result_count: wp.array(dtype=wp.int32, ndim=1), - boxes_sorted: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=2), -): - threadId = wp.tid() # Get thread ID - if length > 0: - total_num_work_packages = cumulative_sum[length - 1] - else: - total_num_work_packages = 0 - - while threadId < total_num_work_packages: - # Get indices for current and next box pair - ij = find_indices(threadId, cumulative_sum, length) - i = ij.x - j = ij.y - - worldId = i // num_boxes_per_world - i = i % num_boxes_per_world - - # world_id_j = j // num_boxes_per_world - j = j % num_boxes_per_world - - # assert worldId == world_id_j, "Only boxes in the same world can be compared" - # TODO: Remove print if debugging is done - # if worldId != world_id_j: - # print("Only boxes in the same world can be compared") - - idx1 = data_indexer[worldId, i] - - box1 = boxes_sorted[worldId, i] - - idx2 = data_indexer[worldId, j] - - # Check if the boxes overlap - if idx1 != idx2 and overlap(box1, boxes_sorted[worldId, j]): - pair = wp.vec2i(wp.min(idx1, idx2), wp.max(idx1, idx2)) - - id = wp.atomic_add(result_count, worldId, 1) - - if id < max_num_overlaps_per_world: - data_result[worldId, id] = pair - - threadId += num_threads - - -def broad_phase(m: Model, d: Data) -> Data: - """Broad phase collision detection.""" - - # Directional vectors for sweep - # TODO: Improve picking of direction - direction = wp.vec3(0.5935, 0.7790, 0.1235) - direction = wp.normalize(direction) - abs_dir = wp.vec3(abs(direction.x), abs(direction.y), abs(direction.z)) - - wp.launch( - kernel=broad_phase_project_boxes_onto_sweep_direction_kernel, - dim=(d.nworld, m.ngeom), - inputs=[ - d.geom_aabb, - d.geom_xpos, - d.geom_xmat, - d.data_start, - d.data_end, - d.data_indexer, - direction, - abs_dir, - d.result_count, - ], - ) - - segmented_sort_available = hasattr(wp.utils, "segmented_sort_pairs") - - if segmented_sort_available: - # print("Using segmented sort") - wp.utils.segmented_sort_pairs( - d.data_start, - d.data_indexer, - m.ngeom * d.nworld, - d.segment_indices, - d.nworld, - ) - else: - # Sort each world's segment separately - for world_id in range(d.nworld): - start_idx = world_id * m.ngeom - - # Create temporary arrays for sorting - temp_data_start = wp.zeros( - m.ngeom * 2, - dtype=d.data_start.dtype, - ) - temp_data_indexer = wp.zeros( - m.ngeom * 2, - dtype=d.data_indexer.dtype, - ) - - # Copy data to temporary arrays - wp.copy( - temp_data_start, - d.data_start, - 0, - start_idx, - m.ngeom, - ) - wp.copy( - temp_data_indexer, - d.data_indexer, - 0, - start_idx, - m.ngeom, - ) - - # Sort the temporary arrays - wp.utils.radix_sort_pairs(temp_data_start, temp_data_indexer, m.ngeom) - - # Copy sorted data back - wp.copy( - d.data_start, - temp_data_start, - start_idx, - 0, - m.ngeom, - ) - wp.copy( - d.data_indexer, - temp_data_indexer, - start_idx, - 0, - m.ngeom, - ) - - wp.launch( - kernel=reorder_bounding_boxes_kernel, - dim=(d.nworld, m.ngeom), - inputs=[d.geom_aabb, d.geom_xpos, d.geom_xmat, d.boxes_sorted, d.data_indexer], - ) - - wp.launch( - kernel=broad_phase_sweep_and_prune_prepare_kernel, - dim=(d.nworld, m.ngeom), - inputs=[ - m.ngeom, - d.data_start, - d.data_end, - d.data_indexer, - d.ranges, - ], - ) - - # The scan (scan = cumulative sum, either inclusive or exclusive depending on the last argument) is used for load balancing among the threads - wp.utils.array_scan(d.ranges.reshape(-1), d.cumulative_sum, True) - - # Estimate how many overlap checks need to be done - assumes each box has to be compared to 5 other boxes (and batched over all worlds) - num_sweep_threads = 5 * d.nworld * m.ngeom - wp.launch( - kernel=broad_phase_sweep_and_prune_kernel, - dim=num_sweep_threads, - inputs=[ - num_sweep_threads, - d.nworld * m.ngeom, - m.ngeom, - d.max_num_overlaps_per_world, - d.cumulative_sum, - d.data_indexer, - d.broadphase_pairs, - d.result_count, - d.boxes_sorted, - ], - ) - - return d diff --git a/mujoco/mjx/_src/forward.py b/mujoco/mjx/_src/forward.py deleted file mode 100644 index 875b7f06..00000000 --- a/mujoco/mjx/_src/forward.py +++ /dev/null @@ -1,366 +0,0 @@ -# Copyright 2025 The Physics-Next Project Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Optional - -import warp as wp -import mujoco - -from . import constraint -from . import math -from . import passive -from . import smooth -from . import solver - -from .types import array2df, array3df -from .types import Model -from .types import Data -from .types import MJ_MINVAL -from .types import DisableBit -from .types import JointType -from .types import DynType -from .support import xfrc_accumulate - - -def _advance( - m: Model, - d: Data, - act_dot: wp.array, - qacc: wp.array, - qvel: Optional[wp.array] = None, -) -> Data: - """Advance state and time given activation derivatives and acceleration.""" - - # TODO(team): can we assume static timesteps? - - @wp.kernel - def next_activation( - m: Model, - d: Data, - act_dot_in: array2df, - ): - worldId, actid = wp.tid() - - # get the high/low range for each actuator state - limited = m.actuator_actlimited[actid] - range_low = wp.select(limited, -wp.inf, m.actuator_actrange[actid][0]) - range_high = wp.select(limited, wp.inf, m.actuator_actrange[actid][1]) - - # get the actual actuation - skip if -1 (means stateless actuator) - act_adr = m.actuator_actadr[actid] - if act_adr == -1: - return - - acts = d.act[worldId] - acts_dot = act_dot_in[worldId] - - act = acts[act_adr] - act_dot = acts_dot[act_adr] - - # check dynType - dyn_type = m.actuator_dyntype[actid] - dyn_prm = m.actuator_dynprm[actid][0] - - # advance the actuation - if dyn_type == wp.static(DynType.FILTEREXACT.value): - tau = wp.select(dyn_prm < MJ_MINVAL, dyn_prm, MJ_MINVAL) - act = act + act_dot * tau * (1.0 - wp.exp(-m.opt.timestep / tau)) - else: - act = act + act_dot * m.opt.timestep - - # apply limits - wp.clamp(act, range_low, range_high) - - acts[act_adr] = act - - @wp.kernel - def advance_velocities(m: Model, d: Data, qacc: array2df): - worldId, tid = wp.tid() - d.qvel[worldId, tid] = d.qvel[worldId, tid] + qacc[worldId, tid] * m.opt.timestep - - @wp.kernel - def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df): - worldId, jntid = wp.tid() - - jnt_type = m.jnt_type[jntid] - qpos_adr = m.jnt_qposadr[jntid] - dof_adr = m.jnt_dofadr[jntid] - qpos = d.qpos[worldId] - qvel = qvel_in[worldId] - - if jnt_type == wp.static(JointType.FREE.value): - qpos_pos = wp.vec3(qpos[qpos_adr], qpos[qpos_adr + 1], qpos[qpos_adr + 2]) - qvel_lin = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2]) - - qpos_new = qpos_pos + m.opt.timestep * qvel_lin - - qpos_quat = wp.quat( - qpos[qpos_adr + 3], - qpos[qpos_adr + 4], - qpos[qpos_adr + 5], - qpos[qpos_adr + 6], - ) - qvel_ang = wp.vec3(qvel[dof_adr + 3], qvel[dof_adr + 4], qvel[dof_adr + 5]) - - qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep) - - qpos[qpos_adr] = qpos_new[0] - qpos[qpos_adr + 1] = qpos_new[1] - qpos[qpos_adr + 2] = qpos_new[2] - qpos[qpos_adr + 3] = qpos_quat_new[0] - qpos[qpos_adr + 4] = qpos_quat_new[1] - qpos[qpos_adr + 5] = qpos_quat_new[2] - qpos[qpos_adr + 6] = qpos_quat_new[3] - - elif jnt_type == wp.static(JointType.BALL.value): # ball joint - qpos_quat = wp.quat( - qpos[qpos_adr], - qpos[qpos_adr + 1], - qpos[qpos_adr + 2], - qpos[qpos_adr + 3], - ) - qvel_ang = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2]) - - qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep) - - qpos[qpos_adr] = qpos_quat_new[0] - qpos[qpos_adr + 1] = qpos_quat_new[1] - qpos[qpos_adr + 2] = qpos_quat_new[2] - qpos[qpos_adr + 3] = qpos_quat_new[3] - - else: # if jnt_type in (JointType.HINGE, JointType.SLIDE): - qpos[qpos_adr] = qpos[qpos_adr] + m.opt.timestep * qvel[dof_adr] - - # skip if no stateful actuators. - if m.na: - wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot]) - - wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc]) - - # advance positions with qvel if given, d.qvel otherwise (semi-implicit) - if qvel is not None: - qvel_in = qvel - else: - qvel_in = d.qvel - - wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in]) - - d.time = d.time + m.opt.timestep - return d - - -def euler(m: Model, d: Data) -> Data: - """Euler integrator, semi-implicit in velocity.""" - # integrate damping implicitly - - def add_damping_sum_qfrc(m: Model, d: Data, is_sparse: bool): - @wp.kernel - def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): - worldId, tid = wp.tid() - - dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldId, 0, dof_Madr] += m.opt.timestep * m.dof_damping[dof_Madr] - - d.qfrc_integration[worldId, tid] = ( - d.qfrc_smooth[worldId, tid] + d.qfrc_constraint[worldId, tid] - ) - - @wp.kernel - def add_damping_sum_qfrc_kernel_dense(m: Model, d: Data): - worldid, i, j = wp.tid() - - damping = wp.select(i == j, 0.0, m.opt.timestep * m.dof_damping[i]) - d.qM_integration[worldid, i, j] = d.qM[worldid, i, j] + damping - - if i == 0: - d.qfrc_integration[worldid, j] = ( - d.qfrc_smooth[worldid, j] + d.qfrc_constraint[worldid, j] - ) - - if is_sparse: - wp.copy(d.qM_integration, d.qM) - wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) - else: - wp.launch( - add_damping_sum_qfrc_kernel_dense, dim=(d.nworld, m.nv, m.nv), inputs=[m, d] - ) - - if not m.opt.disableflags & DisableBit.EULERDAMP.value: - add_damping_sum_qfrc(m, d, m.opt.is_sparse) - smooth.factor_i(m, d, d.qM_integration, d.qLD_integration, d.qLDiagInv_integration) - smooth.solve_LD( - m, - d, - d.qLD_integration, - d.qLDiagInv_integration, - d.qacc_integration, - d.qfrc_integration, - ) - return _advance(m, d, d.act_dot, d.qacc_integration) - - return _advance(m, d, d.act_dot, d.qacc) - - -def fwd_position(m: Model, d: Data): - """Position-dependent computations.""" - - smooth.kinematics(m, d) - smooth.com_pos(m, d) - # TODO(team): smooth.camlight - # TODO(team): smooth.tendon - smooth.crb(m, d) - smooth.factor_m(m, d) - # TODO(team): collision_driver.collision - constraint.make_constraint(m, d) - smooth.transmission(m, d) - - -def fwd_velocity(m: Model, d: Data): - """Velocity-dependent computations.""" - - # TODO(team): tile operations? - d.actuator_velocity.zero_() - - @wp.kernel - def _actuator_velocity(d: Data): - worldid, actid, dofid = wp.tid() - moment = d.actuator_moment[worldid, actid] - qvel = d.qvel[worldid] - wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid]) - - wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d]) - - smooth.com_vel(m, d) - passive.passive(m, d) - smooth.rne(m, d) - - -def fwd_actuation(m: Model, d: Data): - """Actuation-dependent computations.""" - if not m.nu: - return - - # TODO support stateful actuators - - @wp.kernel - def _force( - m: Model, - ctrl: array2df, - # outputs - force: array2df, - ): - worldid, dofid = wp.tid() - gain = m.actuator_gainprm[dofid, 0] - bias = m.actuator_biasprm[dofid, 0] - # TODO support gain types other than FIXED - c = ctrl[worldid, dofid] - if m.actuator_ctrllimited[dofid]: - r = m.actuator_ctrlrange[dofid] - c = wp.clamp(c, r[0], r[1]) - f = gain * c + bias - if m.actuator_forcelimited[dofid]: - r = m.actuator_forcerange[dofid] - f = wp.clamp(f, r[0], r[1]) - force[worldid, dofid] = f - - wp.launch( - _force, dim=[d.nworld, m.nu], inputs=[m, d.ctrl], outputs=[d.actuator_force] - ) - - @wp.kernel - def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): - worldid, vid = wp.tid() - - s = float(0.0) - for uid in range(m.nu): - # TODO consider using Tile API or transpose moment for better access pattern - s += moment[worldid, uid, vid] * force[worldid, uid] - jntid = m.dof_jntid[vid] - if m.jnt_actfrclimited[jntid]: - r = m.jnt_actfrcrange[jntid] - s = wp.clamp(s, r[0], r[1]) - qfrc[worldid, vid] = s - - wp.launch( - _qfrc, - dim=(d.nworld, m.nv), - inputs=[m, d.actuator_moment, d.actuator_force], - outputs=[d.qfrc_actuator], - ) - - # TODO actuator-level gravity compensation, skip if added as passive force - - return d - - -def fwd_acceleration(m: Model, d: Data): - """Add up all non-constraint forces, compute qacc_smooth.""" - - qfrc_applied = d.qfrc_applied - qfrc_accumulated = xfrc_accumulate(m, d) - - @wp.kernel - def _qfrc_smooth( - d: Data, - qfrc_applied: wp.array(ndim=2, dtype=wp.float32), - qfrc_accumulated: wp.array(ndim=2, dtype=wp.float32), - ): - worldid, dofid = wp.tid() - d.qfrc_smooth[worldid, dofid] = ( - d.qfrc_passive[worldid, dofid] - - d.qfrc_bias[worldid, dofid] - + d.qfrc_actuator[worldid, dofid] - + qfrc_applied[worldid, dofid] - + qfrc_accumulated[worldid, dofid] - ) - - wp.launch( - _qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d, qfrc_applied, qfrc_accumulated] - ) - - smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) - - -def forward(m: Model, d: Data): - """Forward dynamics.""" - - fwd_position(m, d) - # TODO(team): sensor.sensor_pos - fwd_velocity(m, d) - # TODO(team): sensor.sensor_vel - fwd_actuation(m, d) - fwd_acceleration(m, d) - # TODO(team): sensor.sensor_acc - - if d.njmax == 0: - wp.copy(d.qacc, d.qacc_smooth) - else: - solver.solve(m, d) - - -def step(m: Model, d: Data): - """Advance simulation.""" - forward(m, d) - - if m.opt.integrator == mujoco.mjtIntegrator.mjINT_EULER: - euler(m, d) - elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_RK4: - # TODO(team): rungekutta4 - raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.") - elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST: - # TODO(team): implicit - raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.") - else: - raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.") diff --git a/mujoco/mjx/_src/solver.py b/mujoco/mjx/_src/solver.py deleted file mode 100644 index ee0eea28..00000000 --- a/mujoco/mjx/_src/solver.py +++ /dev/null @@ -1,749 +0,0 @@ -import warp as wp -import mujoco -from . import smooth -from . import support -from . import types - - -@wp.struct -class Context: - Jaref: wp.array(dtype=wp.float32, ndim=1) - Ma: wp.array(dtype=wp.float32, ndim=2) - grad: wp.array(dtype=wp.float32, ndim=2) - grad_dot: wp.array(dtype=wp.float32, ndim=1) - Mgrad: wp.array(dtype=wp.float32, ndim=2) - search: wp.array(dtype=wp.float32, ndim=2) - search_dot: wp.array(dtype=wp.float32, ndim=1) - gauss: wp.array(dtype=wp.float32, ndim=1) - cost: wp.array(dtype=wp.float32, ndim=1) - prev_cost: wp.array(dtype=wp.float32, ndim=1) - solver_niter: wp.array(dtype=wp.int32, ndim=1) - active: wp.array(dtype=wp.int32, ndim=1) - gtol: wp.array(dtype=wp.float32, ndim=1) - mv: wp.array(dtype=wp.float32, ndim=2) - jv: wp.array(dtype=wp.float32, ndim=1) - quad: wp.array(dtype=wp.vec3f, ndim=1) - quad_gauss: wp.array(dtype=wp.vec3f, ndim=1) - quad_total: wp.array(dtype=wp.vec3f, ndim=1) - h: wp.array(dtype=wp.float32, ndim=3) - alpha: wp.array(dtype=wp.float32, ndim=1) - prev_grad: wp.array(dtype=wp.float32, ndim=2) - prev_Mgrad: wp.array(dtype=wp.float32, ndim=2) - beta: wp.array(dtype=wp.float32, ndim=1) - beta_num: wp.array(dtype=wp.float32, ndim=1) - beta_den: wp.array(dtype=wp.float32, ndim=1) - done: wp.array(dtype=wp.int32, ndim=1) - - -def _context(m: types.Model, d: types.Data) -> Context: - ctx = Context() - ctx.Jaref = wp.empty(shape=(d.njmax,), dtype=wp.float32) - ctx.Ma = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32) - ctx.grad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32) - ctx.grad_dot = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.Mgrad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32) - ctx.search = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32) - ctx.search_dot = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.gauss = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.cost = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.prev_cost = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.solver_niter = wp.empty(shape=(d.nworld,), dtype=wp.int32) - ctx.active = wp.empty(shape=(d.njmax,), dtype=wp.int32) - ctx.gtol = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.mv = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32) - ctx.jv = wp.empty(shape=(d.njmax,), dtype=wp.float32) - ctx.quad = wp.empty(shape=(d.njmax,), dtype=wp.vec3f) - ctx.quad_gauss = wp.empty(shape=(d.nworld,), dtype=wp.vec3f) - ctx.quad_total = wp.empty(shape=(d.nworld,), dtype=wp.vec3f) - ctx.h = wp.empty(shape=(d.nworld, m.nv, m.nv), dtype=wp.float32) - ctx.alpha = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.prev_grad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32) - ctx.prev_Mgrad = wp.empty(shape=(d.nworld, m.nv), dtype=wp.float32) - ctx.beta = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.beta_num = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.beta_den = wp.empty(shape=(d.nworld,), dtype=wp.float32) - ctx.done = wp.empty(shape=(d.nworld,), dtype=wp.int32) - - return ctx - - -def _create_context(ctx: Context, m: types.Model, d: types.Data, grad: bool = True): - # jaref = d.efc_J @ d.qacc - d.efc_aref - ctx.Jaref.zero_() - - @wp.kernel - def _jaref(ctx: Context, m: types.Model, d: types.Data): - efcid, dofid = wp.tid() - - if efcid >= d.nefc_total[0]: - return - - worldid = d.efc_worldid[efcid] - wp.atomic_add( - ctx.Jaref, - efcid, - d.efc_J[efcid, dofid] * d.qacc[worldid, dofid] - d.efc_aref[efcid] / float(m.nv), - ) - - wp.launch(_jaref, dim=(d.njmax, m.nv), inputs=[ctx, m, d]) - - # Ma = qM @ qacc - support.mul_m(m, d, ctx.Ma, d.qacc) - - ctx.cost.fill_(wp.inf) - ctx.solver_niter.zero_() - ctx.done.zero_() - - _update_constraint(m, d, ctx) - if grad: - _update_gradient(m, d, ctx) - - # search = -Mgrad - ctx.search_dot.zero_() - - @wp.kernel - def _search(ctx: Context): - worldid, dofid = wp.tid() - search = -1.0 * ctx.Mgrad[worldid, dofid] - ctx.search[worldid, dofid] = search - wp.atomic_add(ctx.search_dot, worldid, search * search) - - wp.launch(_search, dim=(d.nworld, m.nv), inputs=[ctx]) - - -@wp.struct -class LSPoint: - alpha: wp.array(dtype=wp.float32, ndim=1) - cost: wp.array(dtype=wp.float32, ndim=1) - deriv_0: wp.array(dtype=wp.float32, ndim=1) - deriv_1: wp.array(dtype=wp.float32, ndim=1) - - -def _lspoint(d: types.Data) -> LSPoint: - ls_pnt = LSPoint() - ls_pnt.alpha = wp.empty(shape=(d.nworld), dtype=wp.float32) - ls_pnt.cost = wp.empty(shape=(d.nworld), dtype=wp.float32) - ls_pnt.deriv_0 = wp.empty(shape=(d.nworld), dtype=wp.float32) - ls_pnt.deriv_1 = wp.empty(shape=(d.nworld), dtype=wp.float32) - - return ls_pnt - - -def _create_lspoint(ls_pnt: LSPoint, m: types.Model, d: types.Data, ctx: Context): - wp.copy(ctx.quad_total, ctx.quad_gauss) - - @wp.kernel - def _quad(ls_pnt: LSPoint, ctx: Context, d: types.Data): - efcid = wp.tid() - - if efcid >= d.nefc_total[0]: - return - - worldid = d.efc_worldid[efcid] - x = ctx.Jaref[efcid] + ls_pnt.alpha[worldid] * ctx.jv[efcid] - # TODO(team): active and conditionally active constraints - if x < 0.0: - wp.atomic_add(ctx.quad_total, worldid, ctx.quad[efcid]) - - wp.launch(_quad, dim=(d.njmax,), inputs=[ls_pnt, ctx, d]) - - @wp.kernel - def _cost_deriv01(ls_pnt: LSPoint, ctx: Context): - worldid = wp.tid() - alpha = ls_pnt.alpha[worldid] - alpha_sq = alpha * alpha - quad_total0 = ctx.quad_total[worldid][0] - quad_total1 = ctx.quad_total[worldid][1] - quad_total2 = ctx.quad_total[worldid][2] - - ls_pnt.cost[worldid] = alpha_sq * quad_total2 + alpha * quad_total1 + quad_total0 - ls_pnt.deriv_0[worldid] = 2.0 * alpha * quad_total2 + quad_total1 - ls_pnt.deriv_1[worldid] = 2.0 * quad_total2 + float(quad_total2 == 0.0) - - wp.launch(_cost_deriv01, dim=(d.nworld,), inputs=[ls_pnt, ctx]) - - -@wp.struct -class LSContext: - p0: LSPoint - lo: LSPoint - lo_next: LSPoint - hi: LSPoint - hi_next: LSPoint - mid: LSPoint - swap: wp.array(ndim=1, dtype=wp.int32) - ls_iter: wp.array(ndim=1, dtype=wp.int32) - done: wp.array(ndim=1, dtype=wp.int32) - - -def _create_lscontext(m: types.Model, d: types.Data, ctx: Context) -> LSContext: - ls_ctx = LSContext() - - ls_ctx.p0 = _lspoint(d) - ls_ctx.lo = _lspoint(d) - ls_ctx.lo_next = _lspoint(d) - ls_ctx.hi = _lspoint(d) - ls_ctx.hi_next = _lspoint(d) - ls_ctx.mid = _lspoint(d) - - ls_ctx.swap = wp.empty(shape=(d.nworld), dtype=wp.int32) - ls_ctx.ls_iter = wp.empty(shape=(d.nworld), dtype=wp.int32) - ls_ctx.done = wp.zeros((d.nworld), dtype=wp.int32) - - return ls_ctx - - -def _update_constraint(m: types.Model, d: types.Data, ctx: Context): - wp.copy(ctx.prev_cost, ctx.cost) - ctx.cost.zero_() - - @wp.kernel - def _efc_kernel(ctx: Context, d: types.Data): - efcid = wp.tid() - - if efcid >= d.nefc_total[0]: - return - - worldid = d.efc_worldid[efcid] - Jaref = ctx.Jaref[efcid] - efc_D = d.efc_D[efcid] - - # TODO(team): active and conditionally active constraints - active = int(Jaref < 0.0) - ctx.active[efcid] = active - - # efc_force = -efc_D * Jaref * active - d.efc_force[efcid] = -1.0 * efc_D * Jaref * float(active) - - # cost = 0.5 * sum(efc_D * Jaref * Jaref * active)) - wp.atomic_add(ctx.cost, worldid, 0.5 * efc_D * Jaref * Jaref * float(active)) - - wp.launch(_efc_kernel, dim=(d.njmax,), inputs=[ctx, d]) - - # qfrc_constraint = efc_J.T @ efc_force - d.qfrc_constraint.zero_() - - @wp.kernel - def _qfrc_constraint(d: types.Data): - dofid, efcid = wp.tid() - - if efcid >= d.nefc_total[0]: - return - - worldid = d.efc_worldid[efcid] - wp.atomic_add( - d.qfrc_constraint[worldid], - dofid, - d.efc_J[efcid, dofid] * d.efc_force[efcid], - ) - - wp.launch(_qfrc_constraint, dim=(m.nv, d.njmax), inputs=[d]) - - # gauss = 0.5 * (Ma - qfrc_smooth).T @ (qacc - qacc_smooth) - ctx.gauss.zero_() - - @wp.kernel - def _gauss(ctx: Context, d: types.Data): - worldid, dofid = wp.tid() - gauss_cost = ( - 0.5 - * (ctx.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid]) - * (d.qacc[worldid, dofid] - d.qacc_smooth[worldid, dofid]) - ) - wp.atomic_add(ctx.gauss, worldid, gauss_cost) - wp.atomic_add(ctx.cost, worldid, gauss_cost) - - wp.launch(_gauss, dim=(d.nworld, m.nv), inputs=[ctx, d]) - - -def _update_gradient(m: types.Model, d: types.Data, ctx: Context): - # grad = Ma - qfrc_smooth - qfrc_constraint - ctx.grad_dot.zero_() - - @wp.kernel - def _grad(ctx: Context, d: types.Data): - worldid, dofid = wp.tid() - grad = ( - ctx.Ma[worldid, dofid] - - d.qfrc_smooth[worldid, dofid] - - d.qfrc_constraint[worldid, dofid] - ) - ctx.grad[worldid, dofid] = grad - wp.atomic_add(ctx.grad_dot, worldid, grad * grad) - - wp.launch(_grad, dim=(d.nworld, m.nv), inputs=[ctx, d]) - - if m.opt.solver == 1: # CG - smooth.solve_m(m, d, ctx.grad, ctx.Mgrad) - elif m.opt.solver == 2: # Newton - # TODO(team): sparse version - # h = qM + (efc_J.T * efc_D * active) @ efc_J - @wp.kernel - def _copy_lower_triangle(m: types.Model, d: types.Data, ctx: Context): - worldid, elementid = wp.tid() - rowid = m.dof_tri_row[elementid] - colid = m.dof_tri_col[elementid] - ctx.h[worldid, rowid, colid] = d.qM[worldid, rowid, colid] - - wp.launch( - _copy_lower_triangle, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d, ctx] - ) - - @wp.kernel - def _JTDAJ(ctx: Context, m: types.Model, d: types.Data): - efcid, elementid = wp.tid() - dofi = m.dof_tri_row[elementid] - dofj = m.dof_tri_col[elementid] - - if efcid >= d.nefc_total[0]: - return - - efc_D = d.efc_D[efcid] - active = ctx.active[efcid] - if efc_D == 0.0 or active == 0: - return - - worldid = d.efc_worldid[efcid] - wp.atomic_add( - ctx.h[worldid, dofi], - dofj, - d.efc_J[efcid, dofi] * d.efc_J[efcid, dofj] * efc_D * float(active), - ) - - wp.launch(_JTDAJ, dim=(d.njmax, m.dof_tri_row.size), inputs=[ctx, m, d]) - - TILE = m.nv - - @wp.kernel - def _cholesky(ctx: Context): - worldid = wp.tid() - mat_tile = wp.tile_load(ctx.h[worldid], shape=(TILE, TILE)) - fact_tile = wp.tile_cholesky(mat_tile) - input_tile = wp.tile_load(ctx.grad[worldid], shape=TILE) - output_tile = wp.tile_cholesky_solve(fact_tile, input_tile) - wp.tile_store(ctx.Mgrad[worldid], output_tile) - - wp.launch_tiled(_cholesky, dim=(d.nworld,), inputs=[ctx], block_dim=32) - - -@wp.func -def _rescale(m: types.Model, value: float) -> float: - return value / (m.stat.meaninertia * float(wp.max(1, m.nv))) - - -@wp.func -def _in_bracket(x: float, y: float) -> bool: - return (x < y) and (y < 0.0) or (x > y) and (y > 0.0) - - -def _linesearch(m: types.Model, d: types.Data, ctx: Context): - @wp.kernel - def _gtol(ctx: Context, m: types.Model): - worldid = wp.tid() - smag = ( - wp.math.sqrt(ctx.search_dot[worldid]) - * m.stat.meaninertia - * float(wp.max(1, m.nv)) - ) - ctx.gtol[worldid] = m.opt.tolerance * m.opt.ls_tolerance * smag - - wp.launch(_gtol, dim=(d.nworld,), inputs=[ctx, m]) - - # mv = qM @ search - support.mul_m(m, d, ctx.mv, ctx.search) - - # jv = efc_J @ search - ctx.jv.zero_() - - @wp.kernel - def _jv(ctx: Context, d: types.Data): - efcid, dofid = wp.tid() - - if efcid >= d.nefc_total[0]: - return - - worldid = d.efc_worldid[efcid] - wp.atomic_add( - ctx.jv, - efcid, - d.efc_J[efcid, dofid] * ctx.search[worldid, dofid], - ) - - wp.launch(_jv, dim=(d.njmax, m.nv), inputs=[ctx, d]) - - # prepare quadratics - # quad_gauss = [gauss, search.T @ Ma - search.T @ qfrc_smooth, 0.5 * search.T @ mv] - ctx.quad_gauss.zero_() - - @wp.kernel - def _quad_gauss(ctx: Context, m: types.Model, d: types.Data): - worldid, dofid = wp.tid() - search = ctx.search[worldid, dofid] - quad_gauss = wp.vec3( - ctx.gauss[worldid] / float(m.nv), - search * (ctx.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid]), - 0.5 * search * ctx.mv[worldid, dofid], - ) - wp.atomic_add(ctx.quad_gauss, worldid, quad_gauss) - - wp.launch(_quad_gauss, dim=(d.nworld, m.nv), inputs=[ctx, m, d]) - - # quad = [0.5 * Jaref * Jaref * efc_D, jv * Jaref * efc_D, 0.5 * jv * jv * efc_D] - @wp.kernel - def _quad(ctx: Context, d: types.Data): - efcid = wp.tid() - - if efcid >= d.nefc_total[0]: - return - - Jaref = ctx.Jaref[efcid] - jv = ctx.jv[efcid] - efc_D = d.efc_D[efcid] - ctx.quad[efcid][0] = 0.5 * Jaref * Jaref * efc_D - ctx.quad[efcid][1] = jv * Jaref * efc_D - ctx.quad[efcid][2] = 0.5 * jv * jv * efc_D - - wp.launch(_quad, dim=(d.njmax), inputs=[ctx, d]) - - # initialize interval - ls_ctx = _create_lscontext(m, d, ctx) - - ls_ctx.p0.alpha.zero_() - _create_lspoint(ls_ctx.p0, m, d, ctx) - - @wp.kernel - def _lo_alpha(lo: LSPoint, p0: LSPoint, ctx: Context): - worldid = wp.tid() - lo.alpha[worldid] = p0.alpha[worldid] - p0.deriv_0[worldid] / p0.deriv_1[worldid] - - wp.launch(_lo_alpha, dim=(d.nworld,), inputs=[ls_ctx.lo, ls_ctx.p0, ctx]) - - _create_lspoint(ls_ctx.lo, m, d, ctx) - - @wp.kernel - def _tree_map(ls_ctx: LSContext): - worldid = wp.tid() - - lesser = float(ls_ctx.lo.deriv_0[worldid] < ls_ctx.p0.deriv_0[worldid]) - not_lesser = 1.0 - lesser - - ls_ctx.hi.alpha[worldid] = ( - lesser * ls_ctx.p0.alpha[worldid] + not_lesser * ls_ctx.lo.alpha[worldid] - ) - ls_ctx.hi.cost[worldid] = ( - lesser * ls_ctx.p0.cost[worldid] + not_lesser * ls_ctx.lo.cost[worldid] - ) - ls_ctx.hi.deriv_0[worldid] = ( - lesser * ls_ctx.p0.deriv_0[worldid] + not_lesser * ls_ctx.lo.deriv_0[worldid] - ) - ls_ctx.hi.deriv_1[worldid] = ( - lesser * ls_ctx.p0.deriv_1[worldid] + not_lesser * ls_ctx.lo.deriv_1[worldid] - ) - - ls_ctx.lo.alpha[worldid] = ( - lesser * ls_ctx.lo.alpha[worldid] + not_lesser * ls_ctx.p0.alpha[worldid] - ) - ls_ctx.lo.cost[worldid] = ( - lesser * ls_ctx.lo.cost[worldid] + not_lesser * ls_ctx.p0.cost[worldid] - ) - ls_ctx.lo.deriv_0[worldid] = ( - lesser * ls_ctx.lo.deriv_0[worldid] + not_lesser * ls_ctx.p0.deriv_0[worldid] - ) - ls_ctx.lo.deriv_1[worldid] = ( - lesser * ls_ctx.lo.deriv_1[worldid] + not_lesser * ls_ctx.p0.deriv_1[worldid] - ) - - wp.launch(_tree_map, dim=(d.nworld,), inputs=[ls_ctx]) - - ls_ctx.swap.fill_(1) - ls_ctx.ls_iter.fill_(0) - - for i in range(m.opt.ls_iterations): - - @wp.kernel - def _alpha_lo_next_hi_next_mid(ls_ctx: LSContext): - worldid = wp.tid() - ls_ctx.lo_next.alpha[worldid] = ( - ls_ctx.lo.alpha[worldid] - - ls_ctx.lo.deriv_0[worldid] / ls_ctx.lo.deriv_1[worldid] - ) - ls_ctx.hi_next.alpha[worldid] = ( - ls_ctx.hi.alpha[worldid] - - ls_ctx.hi.deriv_0[worldid] / ls_ctx.hi.deriv_1[worldid] - ) - ls_ctx.mid.alpha[worldid] = 0.5 * ( - ls_ctx.lo.alpha[worldid] + ls_ctx.hi.alpha[worldid] - ) - - wp.launch(_alpha_lo_next_hi_next_mid, dim=(d.nworld,), inputs=[ls_ctx]) - - _create_lspoint(ls_ctx.lo_next, m, d, ctx) - _create_lspoint(ls_ctx.hi_next, m, d, ctx) - _create_lspoint(ls_ctx.mid, m, d, ctx) - - @wp.kernel - def _swap_lo_hi(ls_ctx: LSContext): - worldid = wp.tid() - - ls_ctx.ls_iter[worldid] += 1 - - lo_alpha = ls_ctx.lo.alpha[worldid] - lo_cost = ls_ctx.lo.cost[worldid] - lo_deriv_0 = ls_ctx.lo.deriv_0[worldid] - lo_deriv_1 = ls_ctx.lo.deriv_1[worldid] - - lo_next_alpha = ls_ctx.lo_next.alpha[worldid] - lo_next_cost = ls_ctx.lo_next.cost[worldid] - lo_next_deriv_0 = ls_ctx.lo_next.deriv_0[worldid] - lo_next_deriv_1 = ls_ctx.lo_next.deriv_1[worldid] - - hi_alpha = ls_ctx.hi.alpha[worldid] - hi_cost = ls_ctx.hi.cost[worldid] - hi_deriv_0 = ls_ctx.hi.deriv_0[worldid] - hi_deriv_1 = ls_ctx.hi.deriv_1[worldid] - - hi_next_alpha = ls_ctx.hi_next.alpha[worldid] - hi_next_cost = ls_ctx.hi_next.cost[worldid] - hi_next_deriv_0 = ls_ctx.hi_next.deriv_0[worldid] - hi_next_deriv_1 = ls_ctx.hi_next.deriv_1[worldid] - - mid_alpha = ls_ctx.mid.alpha[worldid] - mid_cost = ls_ctx.mid.cost[worldid] - mid_deriv_0 = ls_ctx.mid.deriv_0[worldid] - mid_deriv_1 = ls_ctx.mid.deriv_1[worldid] - - swap_lo_next = _in_bracket(lo_deriv_0, lo_next_deriv_0) - lo_alpha = ( - float(swap_lo_next) * lo_next_alpha + (1.0 - float(swap_lo_next)) * lo_alpha - ) - lo_cost = ( - float(swap_lo_next) * lo_next_cost + (1.0 - float(swap_lo_next)) * lo_cost - ) - lo_deriv_0 = ( - float(swap_lo_next) * lo_next_deriv_0 + (1.0 - float(swap_lo_next)) * lo_deriv_0 - ) - lo_deriv_1 = ( - float(swap_lo_next) * lo_next_deriv_1 + (1.0 - float(swap_lo_next)) * lo_deriv_1 - ) - - swap_lo_mid = _in_bracket(lo_deriv_0, mid_deriv_0) - lo_alpha = float(swap_lo_mid) * mid_alpha + (1.0 - float(swap_lo_mid)) * lo_alpha - lo_cost = float(swap_lo_mid) * mid_cost + (1.0 - float(swap_lo_mid)) * lo_cost - lo_deriv_0 = ( - float(swap_lo_mid) * mid_deriv_0 + (1.0 - float(swap_lo_mid)) * lo_deriv_0 - ) - lo_deriv_1 = ( - float(swap_lo_mid) * mid_deriv_1 + (1.0 - float(swap_lo_mid)) * lo_deriv_1 - ) - - swap_lo_hi_next = _in_bracket(lo_deriv_0, hi_next_deriv_0) - lo_alpha = ( - float(swap_lo_hi_next) * hi_next_alpha - + (1.0 - float(swap_lo_hi_next)) * lo_alpha - ) - lo_cost = ( - float(swap_lo_hi_next) * hi_next_cost + (1.0 - float(swap_lo_hi_next)) * lo_cost - ) - lo_deriv_0 = ( - float(swap_lo_hi_next) * hi_next_deriv_0 - + (1.0 - float(swap_lo_hi_next)) * lo_deriv_0 - ) - lo_deriv_1 = ( - float(swap_lo_hi_next) * hi_next_deriv_1 - + (1.0 - float(swap_lo_hi_next)) * lo_deriv_1 - ) - - swap_hi_next = _in_bracket(hi_deriv_0, hi_next_deriv_0) - hi_alpha = ( - float(swap_hi_next) * hi_next_alpha + (1.0 - float(swap_hi_next)) * hi_alpha - ) - hi_cost = ( - float(swap_hi_next) * hi_next_cost + (1.0 - float(swap_hi_next)) * hi_cost - ) - hi_deriv_0 = ( - float(swap_hi_next) * hi_next_deriv_0 + (1.0 - float(swap_hi_next)) * hi_deriv_0 - ) - hi_deriv_1 = ( - float(swap_hi_next) * hi_next_deriv_1 + (1.0 - float(swap_hi_next)) * hi_deriv_1 - ) - - swap_hi_mid = _in_bracket(hi_deriv_0, mid_deriv_0) - hi_alpha = float(swap_hi_mid) * mid_alpha + (1.0 - float(swap_hi_mid)) * hi_alpha - hi_cost = float(swap_hi_mid) * mid_cost + (1.0 - float(swap_hi_mid)) * hi_cost - hi_deriv_0 = ( - float(swap_hi_mid) * mid_deriv_0 + (1.0 - float(swap_hi_mid)) * hi_deriv_0 - ) - hi_deriv_1 = ( - float(swap_hi_mid) * mid_deriv_1 + (1.0 - float(swap_hi_mid)) * hi_deriv_1 - ) - - swap_hi_lo_next = _in_bracket(hi_deriv_0, lo_next_deriv_0) - hi_alpha = ( - float(swap_hi_lo_next) * lo_next_alpha - + (1.0 - float(swap_hi_lo_next)) * hi_alpha - ) - hi_cost = ( - float(swap_hi_lo_next) * lo_next_cost + (1.0 - float(swap_hi_lo_next)) * hi_cost - ) - hi_deriv_0 = ( - float(swap_hi_lo_next) * lo_next_deriv_0 - + (1.0 - float(swap_hi_lo_next)) * hi_deriv_0 - ) - hi_deriv_1 = ( - float(swap_hi_lo_next) * lo_next_deriv_1 - + (1.0 - float(swap_hi_lo_next)) * hi_deriv_1 - ) - - ls_ctx.lo.alpha[worldid] = lo_alpha - ls_ctx.lo.cost[worldid] = lo_cost - ls_ctx.lo.deriv_0[worldid] = lo_deriv_0 - ls_ctx.lo.deriv_1[worldid] = lo_deriv_1 - - ls_ctx.hi.alpha[worldid] = hi_alpha - ls_ctx.hi.cost[worldid] = hi_cost - ls_ctx.hi.deriv_0[worldid] = hi_deriv_0 - ls_ctx.hi.deriv_1[worldid] = hi_deriv_1 - - swap = swap_lo_next or swap_lo_mid or swap_lo_hi_next - swap = swap or swap_hi_next or swap_hi_mid or swap_hi_lo_next - ls_ctx.swap[worldid] = int(swap) - - wp.launch(_swap_lo_hi, dim=(d.nworld,), inputs=[ls_ctx]) - - @wp.kernel - def _done(ls_ctx: LSContext, ctx: Context, m: types.Model, ls_iter: int): - worldid = wp.tid() - done = ls_iter >= m.opt.ls_iterations - done = done or (1 - ls_ctx.swap[worldid]) - done = done or ( - (ls_ctx.lo.deriv_0[worldid] < 0.0) - and (ls_ctx.lo.deriv_0[worldid] > -ctx.gtol[worldid]) - ) - done = done or ( - (ls_ctx.hi.deriv_0[worldid] > 0.0) - and (ls_ctx.hi.deriv_0[worldid] < ctx.gtol[worldid]) - ) - ls_ctx.done[worldid] = int(done) - - wp.launch(_done, dim=(d.nworld,), inputs=[ls_ctx, ctx, m, i]) - # TODO(team): return if all done - - @wp.kernel - def _alpha(ctx: Context, ls_ctx: LSContext): - worldid = wp.tid() - p0_cost = ls_ctx.p0.cost[worldid] - lo_cost = ls_ctx.lo.cost[worldid] - hi_cost = ls_ctx.hi.cost[worldid] - - improvement = float((lo_cost < p0_cost) or (hi_cost < p0_cost)) - lo_hi_cost = float(lo_cost < hi_cost) - ctx.alpha[worldid] = improvement * ( - lo_hi_cost * ls_ctx.lo.alpha[worldid] - + (1.0 - lo_hi_cost) * ls_ctx.hi.alpha[worldid] - ) - - wp.launch(_alpha, dim=(d.nworld,), inputs=[ctx, ls_ctx]) - - @wp.kernel - def _qacc_ma(ctx: Context, d: types.Data): - worldid, dofid = wp.tid() - alpha = ctx.alpha[worldid] - d.qacc[worldid, dofid] += alpha * ctx.search[worldid, dofid] - ctx.Ma[worldid, dofid] += alpha * ctx.mv[worldid, dofid] - - wp.launch(_qacc_ma, dim=(d.nworld, m.nv), inputs=[ctx, d]) - - @wp.kernel - def _jaref(ctx: Context, d: types.Data): - efcid = wp.tid() - - if efcid >= d.nefc_total[0]: - return - - worldid = d.efc_worldid[efcid] - ctx.Jaref[efcid] += ctx.alpha[worldid] * ctx.jv[efcid] - - wp.launch(_jaref, dim=(d.njmax,), inputs=[ctx, d]) - - -def solve(m: types.Model, d: types.Data): - """Finds forces that satisfy constraints.""" - - # warmstart - wp.copy(d.qacc, d.qacc_warmstart) - - ctx = _context(m, d) - _create_context(ctx, m, d, grad=True) - - for i in range(m.opt.iterations): - _linesearch(m, d, ctx) - wp.copy(ctx.prev_grad, ctx.grad) - wp.copy(ctx.prev_Mgrad, ctx.Mgrad) - _update_constraint(m, d, ctx) - _update_gradient(m, d, ctx) - - if m.opt.solver == 2: # Newton - ctx.search_dot.zero_() - - @wp.kernel - def _search_newton(ctx: Context): - worldid, dofid = wp.tid() - search = -1.0 * ctx.Mgrad[worldid, dofid] - ctx.search[worldid, dofid] = search - wp.atomic_add(ctx.search_dot, worldid, search * search) - - wp.launch(_search_newton, dim=(d.nworld, m.nv), inputs=[ctx]) - else: # polak-ribiere - ctx.beta_num.zero_() - ctx.beta_den.zero_() - - @wp.kernel - def _beta_num_den(ctx: Context): - worldid, dofid = wp.tid() - prev_Mgrad = ctx.prev_Mgrad[worldid][dofid] - wp.atomic_add( - ctx.beta_num, - worldid, - ctx.grad[worldid, dofid] * (ctx.Mgrad[worldid, dofid] - prev_Mgrad), - ) - wp.atomic_add(ctx.beta_den, worldid, ctx.prev_grad[worldid, dofid] * prev_Mgrad) - - wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[ctx]) - - @wp.kernel - def _beta(ctx: Context): - worldid = wp.tid() - ctx.beta[worldid] = wp.max( - 0.0, ctx.beta_num[worldid] / wp.max(mujoco.mjMINVAL, ctx.beta_den[worldid]) - ) - - wp.launch(_beta, dim=(d.nworld,), inputs=[ctx]) - - ctx.search_dot.zero_() - - @wp.kernel - def _search_cg(ctx: Context): - worldid, dofid = wp.tid() - search = ( - -1.0 * ctx.Mgrad[worldid, dofid] - + ctx.beta[worldid] * ctx.search[worldid, dofid] - ) - ctx.search[worldid, dofid] = search - wp.atomic_add(ctx.search_dot, worldid, search * search) - - wp.launch(_search_cg, dim=(d.nworld, m.nv), inputs=[ctx]) - - @wp.kernel - def _done(ctx: Context, m: types.Model, solver_niter: int): - worldid = wp.tid() - improvement = _rescale(m, ctx.prev_cost[worldid] - ctx.cost[worldid]) - gradient = _rescale(m, wp.math.sqrt(ctx.grad_dot[worldid])) - done = solver_niter >= m.opt.iterations - done = done or (improvement < m.opt.tolerance) - done = done or (gradient < m.opt.tolerance) - ctx.done[worldid] = int(done) - - wp.launch(_done, dim=(d.nworld,), inputs=[ctx, m, i]) - # TODO(team): return if all done - - wp.copy(d.qacc_warmstart, d.qacc) diff --git a/mujoco/mjx/_src/support.py b/mujoco/mjx/_src/support.py deleted file mode 100644 index c6888d15..00000000 --- a/mujoco/mjx/_src/support.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2025 The Physics-Next Project Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import mujoco -import warp as wp -from .types import Model -from .types import Data -from .types import array2df - - -def is_sparse(m: mujoco.MjModel): - if m.opt.jacobian == mujoco.mjtJacobian.mjJAC_AUTO: - return m.nv >= 60 - return m.opt.jacobian == mujoco.mjtJacobian.mjJAC_SPARSE - - -def mul_m( - m: Model, - d: Data, - res: wp.array(ndim=2, dtype=wp.float32), - vec: wp.array(ndim=2, dtype=wp.float32), -): - """Multiply vector by inertia matrix.""" - - if not m.opt.is_sparse: - # TODO(team): tile_matmul - res.zero_() - - @wp.kernel - def _mul_m_dense( - d: Data, - res: wp.array(ndim=2, dtype=wp.float32), - vec: wp.array(ndim=2, dtype=wp.float32), - ): - worldid, rowid, colid = wp.tid() - wp.atomic_add( - res[worldid], rowid, d.qM[worldid, rowid, colid] * vec[worldid, colid] - ) - - wp.launch(_mul_m_dense, dim=(d.nworld, m.nv, m.nv), inputs=[d, res, vec]) - else: - - @wp.kernel - def _mul_m_sparse_diag( - m: Model, - d: Data, - res: wp.array(ndim=2, dtype=wp.float32), - vec: wp.array(ndim=2, dtype=wp.float32), - ): - worldid, dofid = wp.tid() - res[worldid, dofid] = d.qM[worldid, 0, m.dof_Madr[dofid]] * vec[worldid, dofid] - - wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec]) - - @wp.kernel - def _mul_m_sparse_ij( - m: Model, - d: Data, - res: wp.array(ndim=2, dtype=wp.float32), - vec: wp.array(ndim=2, dtype=wp.float32), - ): - worldid, elementid = wp.tid() - i = m.qM_i[elementid] - j = m.qM_j[elementid] - madr_ij = m.qM_madr_ij[elementid] - - qM = d.qM[worldid, 0, madr_ij] - - wp.atomic_add(res[worldid], i, qM * vec[worldid, j]) - wp.atomic_add(res[worldid], j, qM * vec[worldid, i]) - - wp.launch( - _mul_m_sparse_ij, dim=(d.nworld, m.qM_madr_ij.size), inputs=[m, d, res, vec] - ) - - -@wp.kernel -def process_level( - body_tree: wp.array(ndim=1, dtype=int), - body_parentid: wp.array(ndim=1, dtype=int), - dof_bodyid: wp.array(ndim=1, dtype=int), - mask: wp.array2d(dtype=wp.bool), - beg: int, -): - dofid, tid_y = wp.tid() - j = beg + tid_y - el = body_tree[j] - parent_id = body_parentid[el] - parent_val = mask[dofid, parent_id] - mask[dofid, el] = parent_val or (dof_bodyid[dofid] == el) - - -@wp.kernel -def compute_qfrc( - d: Data, - m: Model, - mask: wp.array2d(dtype=wp.bool), - qfrc_total: array2df, -): - worldid, dofid = wp.tid() - accumul = float(0.0) - cdof_vec = d.cdof[worldid, dofid] - rotational_cdof = wp.vec3(cdof_vec[0], cdof_vec[1], cdof_vec[2]) - - jac = wp.spatial_vector( - cdof_vec[3], cdof_vec[4], cdof_vec[5], cdof_vec[0], cdof_vec[1], cdof_vec[2] - ) - - for bodyid in range(m.nbody): - if mask[dofid, bodyid]: - offset = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] - cross_term = wp.cross(rotational_cdof, offset) - accumul += wp.dot(jac, d.xfrc_applied[worldid, bodyid]) + wp.dot( - cross_term, - wp.vec3( - d.xfrc_applied[worldid, bodyid][0], - d.xfrc_applied[worldid, bodyid][1], - d.xfrc_applied[worldid, bodyid][2], - ), - ) - - qfrc_total[worldid, dofid] = accumul - - -def xfrc_accumulate(m: Model, d: Data) -> array2df: - body_treeadr_np = m.body_treeadr.numpy() - mask = wp.zeros((m.nv, m.nbody), dtype=wp.bool) - - for i in range(len(body_treeadr_np)): - beg = body_treeadr_np[i] - end = m.nbody if i == len(body_treeadr_np) - 1 else body_treeadr_np[i + 1] - - if end > beg: - wp.launch( - kernel=process_level, - dim=[m.nv, (end - beg)], - inputs=[m.body_tree, m.body_parentid, m.dof_bodyid, mask, beg], - ) - - qfrc_total = wp.zeros((d.nworld, m.nv), dtype=float) - - wp.launch(kernel=compute_qfrc, dim=(d.nworld, m.nv), inputs=[d, m, mask, qfrc_total]) - - return qfrc_total diff --git a/mujoco/mjx/_src/types.py b/mujoco/mjx/_src/types.py deleted file mode 100644 index dfd9a1b3..00000000 --- a/mujoco/mjx/_src/types.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright 2025 The Physics-Next Project Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import warp as wp -import enum -import mujoco - -MJ_MINVAL = mujoco.mjMINVAL -MJ_MINIMP = mujoco.mjMINIMP # minimum constraint impedance -MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance -MJ_NREF = mujoco.mjNREF -MJ_NIMP = mujoco.mjNIMP - - -class DisableBit(enum.IntFlag): - """Disable default feature bitflags. - - Members: - CONSTRAINT: entire constraint solver - EQUALITY: equality constraints - FRICTIONLOSS: joint and tendon frictionloss constraints - LIMIT: joint and tendon limit constraints - CONTACT: contact constraints - PASSIVE: passive forces - GRAVITY: gravitational forces - CLAMPCTRL: clamp control to specified range - WARMSTART: warmstart constraint solver - ACTUATION: apply actuation forces - REFSAFE: integrator safety: make ref[0]>=2*timestep - SENSOR: sensors - """ - - CONSTRAINT = mujoco.mjtDisableBit.mjDSBL_CONSTRAINT - EQUALITY = mujoco.mjtDisableBit.mjDSBL_EQUALITY - FRICTIONLOSS = mujoco.mjtDisableBit.mjDSBL_FRICTIONLOSS - LIMIT = mujoco.mjtDisableBit.mjDSBL_LIMIT - CONTACT = mujoco.mjtDisableBit.mjDSBL_CONTACT - PASSIVE = mujoco.mjtDisableBit.mjDSBL_PASSIVE - GRAVITY = mujoco.mjtDisableBit.mjDSBL_GRAVITY - CLAMPCTRL = mujoco.mjtDisableBit.mjDSBL_CLAMPCTRL - WARMSTART = mujoco.mjtDisableBit.mjDSBL_WARMSTART - ACTUATION = mujoco.mjtDisableBit.mjDSBL_ACTUATION - REFSAFE = mujoco.mjtDisableBit.mjDSBL_REFSAFE - SENSOR = mujoco.mjtDisableBit.mjDSBL_SENSOR - EULERDAMP = mujoco.mjtDisableBit.mjDSBL_EULERDAMP - FILTERPARENT = mujoco.mjtDisableBit.mjDSBL_FILTERPARENT - # unsupported: MIDPHASE - - -class TrnType(enum.IntEnum): - """Type of actuator transmission. - - Members: - JOINT: force on joint - JOINTINPARENT: force on joint, expressed in parent frame - TENDON: force on tendon (unsupported) - SITE: force on site (unsupported) - """ - - JOINT = mujoco.mjtTrn.mjTRN_JOINT - JOINTINPARENT = mujoco.mjtTrn.mjTRN_JOINTINPARENT - # unsupported: SITE, TENDON, SLIDERCRANK, BODY - - -class DynType(enum.IntEnum): - """Type of actuator dynamics. - - Members: - NONE: no internal dynamics; ctrl specifies force - INTEGRATOR: integrator: da/dt = u - FILTER: linear filter: da/dt = (u-a) / tau - FILTEREXACT: linear filter: da/dt = (u-a) / tau, with exact integration - MUSCLE: piece-wise linear filter with two time constants - """ - - NONE = mujoco.mjtDyn.mjDYN_NONE - INTEGRATOR = mujoco.mjtDyn.mjDYN_INTEGRATOR - FILTER = mujoco.mjtDyn.mjDYN_FILTER - FILTEREXACT = mujoco.mjtDyn.mjDYN_FILTEREXACT - MUSCLE = mujoco.mjtDyn.mjDYN_MUSCLE - # unsupported: USER - - -class JointType(enum.IntEnum): - """Type of degree of freedom. - - Members: - FREE: global position and orientation (quat) (7,) - BALL: orientation (quat) relative to parent (4,) - SLIDE: sliding distance along body-fixed axis (1,) - HINGE: rotation angle (rad) around body-fixed axis (1,) - """ - - FREE = mujoco.mjtJoint.mjJNT_FREE - BALL = mujoco.mjtJoint.mjJNT_BALL - SLIDE = mujoco.mjtJoint.mjJNT_SLIDE - HINGE = mujoco.mjtJoint.mjJNT_HINGE - - def dof_width(self) -> int: - return {0: 6, 1: 3, 2: 1, 3: 1}[self.value] - - def qpos_width(self) -> int: - return {0: 7, 1: 4, 2: 1, 3: 1}[self.value] - - -class ConeType(enum.IntEnum): - """Type of friction cone. - - Members: - PYRAMIDAL: pyramidal - ELLIPTIC: elliptic - """ - - PYRAMIDAL = mujoco.mjtCone.mjCONE_PYRAMIDAL - ELLIPTIC = mujoco.mjtCone.mjCONE_ELLIPTIC - - -class vec5f(wp.types.vector(length=5, dtype=wp.float32)): - pass - - -class vec10f(wp.types.vector(length=10, dtype=wp.float32)): - pass - - -vec5 = vec5f -vec10 = vec10f -array2df = wp.array2d(dtype=wp.float32) -array3df = wp.array3d(dtype=wp.float32) - - -@wp.struct -class Option: - timestep: float - tolerance: float - ls_tolerance: float - gravity: wp.vec3 - cone: int # mjtCone - solver: int # mjtSolver - iterations: int - ls_iterations: int - disableflags: int - integrator: int # mjtIntegrator - impratio: wp.float32 - is_sparse: bool # warp only - - -@wp.struct -class Statistic: - meaninertia: float - - -@wp.struct -class Model: - nq: int - nv: int - na: int - nu: int - nbody: int - njnt: int - ngeom: int - nsite: int - nmocap: int - nM: int - opt: Option - stat: Statistic - qpos0: wp.array(dtype=wp.float32, ndim=1) - qpos_spring: wp.array(dtype=wp.float32, ndim=1) - body_tree: wp.array(dtype=wp.int32, ndim=1) # warp only - body_treeadr: wp.array(dtype=wp.int32, ndim=1) # warp only - qM_i: wp.array(dtype=wp.int32, ndim=1) # warp only - qM_j: wp.array(dtype=wp.int32, ndim=1) # warp only - qM_madr_ij: wp.array(dtype=wp.int32, ndim=1) # warp only - qLD_update_tree: wp.array(dtype=wp.vec3i, ndim=1) # warp only - qLD_update_treeadr: wp.array(dtype=wp.int32, ndim=1) # warp only - qLD_tile: wp.array(dtype=wp.int32, ndim=1) # warp only - qLD_tileadr: wp.array(dtype=wp.int32, ndim=1) # warp only - qLD_tilesize: wp.array(dtype=wp.int32, ndim=1) # warp only - body_dofadr: wp.array(dtype=wp.int32, ndim=1) - body_dofnum: wp.array(dtype=wp.int32, ndim=1) - body_jntadr: wp.array(dtype=wp.int32, ndim=1) - body_jntnum: wp.array(dtype=wp.int32, ndim=1) - body_parentid: wp.array(dtype=wp.int32, ndim=1) - body_mocapid: wp.array(dtype=wp.int32, ndim=1) - body_weldid: wp.array(dtype=wp.int32, ndim=1) - body_pos: wp.array(dtype=wp.vec3, ndim=1) - body_quat: wp.array(dtype=wp.quat, ndim=1) - body_ipos: wp.array(dtype=wp.vec3, ndim=1) - body_iquat: wp.array(dtype=wp.quat, ndim=1) - body_rootid: wp.array(dtype=wp.int32, ndim=1) - body_inertia: wp.array(dtype=wp.vec3, ndim=1) - body_mass: wp.array(dtype=wp.float32, ndim=1) - body_invweight0: wp.array(dtype=wp.float32, ndim=2) - jnt_bodyid: wp.array(dtype=wp.int32, ndim=1) - jnt_limited: wp.array(dtype=wp.int32, ndim=1) - jnt_limited_slide_hinge_adr: wp.array(dtype=wp.int32, ndim=1) # warp only - jnt_solref: wp.array(dtype=wp.vec2f, ndim=1) - jnt_solimp: wp.array(dtype=vec5, ndim=1) - jnt_type: wp.array(dtype=wp.int32, ndim=1) - jnt_qposadr: wp.array(dtype=wp.int32, ndim=1) - jnt_dofadr: wp.array(dtype=wp.int32, ndim=1) - jnt_axis: wp.array(dtype=wp.vec3, ndim=1) - jnt_pos: wp.array(dtype=wp.vec3, ndim=1) - jnt_range: wp.array(dtype=wp.float32, ndim=2) - jnt_margin: wp.array(dtype=wp.float32, ndim=1) - jnt_stiffness: wp.array(dtype=wp.float32, ndim=1) - jnt_actfrclimited: wp.array(dtype=wp.bool, ndim=1) - jnt_actfrcrange: wp.array(dtype=wp.vec2, ndim=1) - geom_bodyid: wp.array(dtype=wp.int32, ndim=1) - geom_pos: wp.array(dtype=wp.vec3, ndim=1) - geom_quat: wp.array(dtype=wp.quat, ndim=1) - site_pos: wp.array(dtype=wp.vec3, ndim=1) - site_quat: wp.array(dtype=wp.quat, ndim=1) - site_bodyid: wp.array(dtype=wp.int32, ndim=1) - dof_bodyid: wp.array(dtype=wp.int32, ndim=1) - dof_jntid: wp.array(dtype=wp.int32, ndim=1) - dof_parentid: wp.array(dtype=wp.int32, ndim=1) - dof_Madr: wp.array(dtype=wp.int32, ndim=1) - dof_armature: wp.array(dtype=wp.float32, ndim=1) - dof_invweight0: wp.array(dtype=wp.float32, ndim=1) - dof_damping: wp.array(dtype=wp.float32, ndim=1) - dof_tri_row: wp.array(dtype=wp.int32, ndim=1) # warp only - dof_tri_col: wp.array(dtype=wp.int32, ndim=1) # warp only - actuator_trntype: wp.array(dtype=wp.int32, ndim=1) - actuator_trnid: wp.array(dtype=wp.int32, ndim=2) - actuator_ctrllimited: wp.array(dtype=wp.bool, ndim=1) - actuator_ctrlrange: wp.array(dtype=wp.vec2, ndim=1) - actuator_forcelimited: wp.array(dtype=wp.bool, ndim=1) - actuator_forcerange: wp.array(dtype=wp.vec2, ndim=1) - actuator_gainprm: wp.array(dtype=wp.float32, ndim=2) - actuator_biasprm: wp.array(dtype=wp.float32, ndim=2) - actuator_gear: wp.array(dtype=wp.spatial_vector, ndim=1) - actuator_actlimited: wp.array(dtype=wp.bool, ndim=1) - actuator_actrange: wp.array(dtype=wp.vec2, ndim=1) - actuator_actadr: wp.array(dtype=wp.int32, ndim=1) - actuator_dyntype: wp.array(dtype=wp.int32, ndim=1) - actuator_dynprm: wp.array(dtype=vec10f, ndim=1) - opt: Option - - -@wp.struct -class Contact: - dist: wp.array(dtype=wp.float32, ndim=1) - pos: wp.array(dtype=wp.vec3f, ndim=1) - frame: wp.array(dtype=wp.mat33f, ndim=1) - includemargin: wp.array(dtype=wp.float32, ndim=1) - friction: wp.array(dtype=vec5, ndim=1) - solref: wp.array(dtype=wp.vec2f, ndim=1) - solreffriction: wp.array(dtype=wp.vec2f, ndim=1) - solimp: wp.array(dtype=vec5, ndim=1) - dim: wp.array(dtype=wp.int32, ndim=1) - geom: wp.array(dtype=wp.vec2i, ndim=1) - efc_address: wp.array(dtype=wp.int32, ndim=1) - worldid: wp.array(dtype=wp.int32, ndim=1) - - -@wp.struct -class Data: - nworld: int - ncon_total: wp.array(dtype=wp.int32, ndim=1) # warp only - nefc_total: wp.array(dtype=wp.int32, ndim=1) # warp only - nconmax: int - njmax: int - time: float - qpos: wp.array(dtype=wp.float32, ndim=2) - qvel: wp.array(dtype=wp.float32, ndim=2) - qacc_warmstart: wp.array(dtype=wp.float32, ndim=2) - qfrc_applied: wp.array(dtype=wp.float32, ndim=2) - ncon: int - nl: int - nefc: wp.array(dtype=wp.int32, ndim=1) - ctrl: wp.array(dtype=wp.float32, ndim=2) - mocap_pos: wp.array(dtype=wp.vec3, ndim=2) - mocap_quat: wp.array(dtype=wp.quat, ndim=2) - qacc: wp.array(dtype=wp.float32, ndim=2) - xanchor: wp.array(dtype=wp.vec3, ndim=2) - xaxis: wp.array(dtype=wp.vec3, ndim=2) - xmat: wp.array(dtype=wp.mat33, ndim=2) - xpos: wp.array(dtype=wp.vec3, ndim=2) - xquat: wp.array(dtype=wp.quat, ndim=2) - xipos: wp.array(dtype=wp.vec3, ndim=2) - ximat: wp.array(dtype=wp.mat33, ndim=2) - subtree_com: wp.array(dtype=wp.vec3, ndim=2) - geom_xpos: wp.array(dtype=wp.vec3, ndim=2) - geom_xmat: wp.array(dtype=wp.mat33, ndim=2) - site_xpos: wp.array(dtype=wp.vec3, ndim=2) - site_xmat: wp.array(dtype=wp.mat33, ndim=2) - cinert: wp.array(dtype=vec10, ndim=2) - cdof: wp.array(dtype=wp.spatial_vector, ndim=2) - crb: wp.array(dtype=vec10, ndim=2) - qM: wp.array(dtype=wp.float32, ndim=3) - qLD: wp.array(dtype=wp.float32, ndim=3) - act: wp.array(dtype=wp.float32, ndim=2) - act_dot: wp.array(dtype=wp.float32, ndim=2) - qLDiagInv: wp.array(dtype=wp.float32, ndim=2) - actuator_velocity: wp.array(dtype=wp.float32, ndim=2) - actuator_force: wp.array(dtype=wp.float32, ndim=2) - actuator_length: wp.array(dtype=wp.float32, ndim=2) - actuator_moment: wp.array(dtype=wp.float32, ndim=3) - cvel: wp.array(dtype=wp.spatial_vector, ndim=2) - cdof_dot: wp.array(dtype=wp.spatial_vector, ndim=2) - qfrc_applied: wp.array(dtype=wp.float32, ndim=2) - qfrc_bias: wp.array(dtype=wp.float32, ndim=2) - qfrc_constraint: wp.array(dtype=wp.float32, ndim=2) - qfrc_passive: wp.array(dtype=wp.float32, ndim=2) - qfrc_spring: wp.array(dtype=wp.float32, ndim=2) - qfrc_damper: wp.array(dtype=wp.float32, ndim=2) - qfrc_actuator: wp.array(dtype=wp.float32, ndim=2) - qfrc_smooth: wp.array(dtype=wp.float32, ndim=2) - qacc_smooth: wp.array(dtype=wp.float32, ndim=2) - qfrc_constraint: wp.array(dtype=wp.float32, ndim=2) - efc_J: wp.array(dtype=wp.float32, ndim=2) - efc_D: wp.array(dtype=wp.float32, ndim=1) - efc_pos: wp.array(dtype=wp.float32, ndim=1) - efc_aref: wp.array(dtype=wp.float32, ndim=1) - efc_force: wp.array(dtype=wp.float32, ndim=1) - efc_margin: wp.array(dtype=wp.float32, ndim=1) - efc_worldid: wp.array(dtype=wp.int32, ndim=1) # warp only - xfrc_applied: wp.array(dtype=wp.spatial_vector, ndim=2) - contact: Contact - - # temp arrays - qfrc_integration: wp.array(dtype=wp.float32, ndim=2) - qacc_integration: wp.array(dtype=wp.float32, ndim=2) - - qM_integration: wp.array(dtype=wp.float32, ndim=3) - qLD_integration: wp.array(dtype=wp.float32, ndim=3) - qLDiagInv_integration: wp.array(dtype=wp.float32, ndim=2) - - # broadphase arrays - max_num_overlaps_per_world: int - broadphase_pairs: wp.array(dtype=wp.vec2i, ndim=2) - result_count: wp.array(dtype=wp.int32, ndim=1) - boxes_sorted: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=2) - data_start: wp.array(dtype=wp.float32, ndim=2) - data_end: wp.array(dtype=wp.float32, ndim=2) - data_indexer: wp.array(dtype=wp.int32, ndim=2) - ranges: wp.array(dtype=wp.int32, ndim=2) - cumulative_sum: wp.array(dtype=wp.int32, ndim=1) - segment_indices: wp.array(dtype=wp.int32, ndim=1) - geom_aabb: wp.array(dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1) diff --git a/mujoco/mjx/testspeed.py b/mujoco/mjx/testspeed.py deleted file mode 100644 index 6e544b25..00000000 --- a/mujoco/mjx/testspeed.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2025 The Physics-Next Project Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Run benchmarks on various devices.""" - -import inspect -from typing import Sequence - -from absl import app -from absl import flags -from etils import epath -import warp as wp - -import mujoco -from mujoco import mjx - -_FUNCTION = flags.DEFINE_enum( - "function", - "kinematics", - [n for n, _ in inspect.getmembers(mjx, inspect.isfunction)], - "the function to run", -) -_MJCF = flags.DEFINE_string( - "mjcf", None, "path to model `.xml` or `.mjb`", required=True -) -_BASE_PATH = flags.DEFINE_string( - "base_path", None, "base path, defaults to mujoco.mjx resource path" -) -_NSTEP = flags.DEFINE_integer("nstep", 1000, "number of steps per rollout") -_BATCH_SIZE = flags.DEFINE_integer("batch_size", 4096, "number of parallel rollouts") -_UNROLL = flags.DEFINE_integer("unroll", 1, "loop unroll length") -_SOLVER = flags.DEFINE_enum("solver", "cg", ["cg", "newton"], "constraint solver") -_ITERATIONS = flags.DEFINE_integer("iterations", 1, "number of solver iterations") -_LS_ITERATIONS = flags.DEFINE_integer( - "ls_iterations", 4, "number of linesearch iterations" -) -_IS_SPARSE = flags.DEFINE_bool( - "is_sparse", True, "if model should create sparse mass matrices" -) -_NEFC_TOTAL = flags.DEFINE_integer( - "nefc_total", 0, "total number of efc for batch of worlds" -) -_OUTPUT = flags.DEFINE_enum( - "output", "text", ["text", "tsv"], "format to print results" -) - - -def _main(argv: Sequence[str]): - """Runs testpeed function.""" - wp.init() - - path = epath.resource_path("mujoco.mjx") / "test_data" - path = _BASE_PATH.value or path - f = epath.Path(path) / _MJCF.value - if f.suffix == ".mjb": - m = mujoco.MjModel.from_binary_path(f.as_posix()) - else: - m = mujoco.MjModel.from_xml_path(f.as_posix()) - - if _IS_SPARSE.value: - m.opt.jacobian = mujoco.mjtJacobian.mjJAC_SPARSE - else: - m.opt.jacobian = mujoco.mjtJacobian.mjJAC_DENSE - - print( - f"Model nbody: {m.nbody} nv: {m.nv} ngeom: {m.ngeom} is_sparse: {_IS_SPARSE.value}" - ) - print(f"Rolling out {_NSTEP.value} steps at dt = {m.opt.timestep:.3f}...") - jit_time, run_time, steps = mjx.benchmark( - mjx.__dict__[_FUNCTION.value], - m, - _NSTEP.value, - _BATCH_SIZE.value, - _UNROLL.value, - _SOLVER.value, - _ITERATIONS.value, - _LS_ITERATIONS.value, - _NEFC_TOTAL.value, - ) - - name = argv[0] - if _OUTPUT.value == "text": - print(f""" -Summary for {_BATCH_SIZE.value} parallel rollouts - - Total JIT time: {jit_time:.2f} s - Total simulation time: {run_time:.2f} s - Total steps per second: {steps / run_time:,.0f} - Total realtime factor: {steps * m.opt.timestep / run_time:,.2f} x - Total time per step: {1e6 * run_time / steps:.2f} µs""") - elif _OUTPUT.value == "tsv": - name = name.split("/")[-1].replace("testspeed_", "") - print(f"{name}\tjit: {jit_time:.2f}s\tsteps/second: {steps / run_time:.0f}") - - -def main(): - app.run(_main) - - -if __name__ == "__main__": - main() diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py new file mode 100644 index 00000000..123859c1 --- /dev/null +++ b/mujoco_warp/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Public API for MJWarp.""" + +from ._src.collision_driver import collision as collision +from ._src.collision_driver import nxn_broadphase as nxn_broadphase +from ._src.collision_driver import sap_broadphase as sap_broadphase +from ._src.constraint import make_constraint as make_constraint +from ._src.forward import euler as euler +from ._src.forward import forward as forward +from ._src.forward import fwd_acceleration as fwd_acceleration +from ._src.forward import fwd_actuation as fwd_actuation +from ._src.forward import fwd_position as fwd_position +from ._src.forward import fwd_velocity as fwd_velocity +from ._src.forward import implicit as implicit +from ._src.forward import step as step +from ._src.io import get_data_into as get_data_into +from ._src.io import make_data as make_data +from ._src.io import put_data as put_data +from ._src.io import put_model as put_model +from ._src.passive import passive as passive +from ._src.smooth import com_pos as com_pos +from ._src.smooth import com_vel as com_vel +from ._src.smooth import crb as crb +from ._src.smooth import factor_m as factor_m +from ._src.smooth import kinematics as kinematics +from ._src.smooth import rne as rne +from ._src.smooth import solve_m as solve_m +from ._src.smooth import transmission as transmission +from ._src.solver import solve as solve +from ._src.support import is_sparse as is_sparse +from ._src.support import mul_m as mul_m +from ._src.support import xfrc_accumulate as xfrc_accumulate +from ._src.test_util import benchmark as benchmark +from ._src.types import ConeType as ConeType +from ._src.types import Contact as Contact +from ._src.types import Data as Data +from ._src.types import DisableBit as DisableBit +from ._src.types import DynType as DynType +from ._src.types import JointType as JointType +from ._src.types import Model as Model +from ._src.types import Option as Option +from ._src.types import Statistic as Statistic +from ._src.types import TrnType as TrnType diff --git a/mujoco_warp/_src/__init__.py b/mujoco_warp/_src/__init__.py new file mode 100644 index 00000000..2b276beb --- /dev/null +++ b/mujoco_warp/_src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/mujoco_warp/_src/broad_phase_test.py b/mujoco_warp/_src/broad_phase_test.py new file mode 100644 index 00000000..47a6616f --- /dev/null +++ b/mujoco_warp/_src/broad_phase_test.py @@ -0,0 +1,250 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for broadphase functions.""" + +import mujoco +import numpy as np +import warp as wp +from absl.testing import absltest + +import mujoco_warp as mjwarp + +from . import collision_driver + + +def _load_from_string(xml: str, keyframe: int = -1): + mjm = mujoco.MjModel.from_xml_string(xml) + mjd = mujoco.MjData(mjm) + if keyframe > -1: + mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe) + mujoco.mj_forward(mjm, mjd) + + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd) + + return mjm, mjd, m, d + + +class BroadphaseTest(absltest.TestCase): + def test_sap_broadphase(self): + """Tests sap_broadphase.""" + + _SAP_MODEL = """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + + collision_pair = [ + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (0, 6), + (1, 2), + (1, 4), + (1, 5), + (1, 6), + (1, 7), + (2, 3), + (2, 4), + (2, 5), + (2, 6), + (2, 7), + (3, 4), + (4, 6), + (5, 7), + (6, 7), + ] + + _, _, m, d = _load_from_string(_SAP_MODEL) + + mjwarp.sap_broadphase(m, d) + + ncollision = d.ncollision.numpy()[0] + np.testing.assert_equal(ncollision, len(collision_pair), "ncollision") + + for i in range(ncollision): + pair = d.collision_pair.numpy()[i] + if pair[0] > pair[1]: + pair_tuple = (int(pair[1]), int(pair[0])) + else: + pair_tuple = (int(pair[0]), int(pair[1])) + + np.testing.assert_equal( + pair_tuple in collision_pair, + True, + f"geom pair {pair_tuple} not found in brute force results", + ) + + # TODO(team): test DisableBit.FILTERPARENT + + # TODO(team): test DisableBit.FILTERPARENT + + def test_nxn_broadphase(self): + """Tests nxn_broadphase.""" + + _NXN_MODEL = """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + # one world and zero collisions + mjm, _, m, d0 = _load_from_string(_NXN_MODEL, keyframe=0) + collision_driver.nxn_broadphase(m, d0) + np.testing.assert_allclose(d0.ncollision.numpy()[0], 0) + + # one world and one collision + _, mjd1, _, d1 = _load_from_string(_NXN_MODEL, keyframe=1) + collision_driver.nxn_broadphase(m, d1) + + np.testing.assert_allclose(d1.ncollision.numpy()[0], 1) + np.testing.assert_allclose(d1.collision_pair.numpy()[0][0], 0) + np.testing.assert_allclose(d1.collision_pair.numpy()[0][1], 1) + + # one world and three collisions + _, mjd2, _, d2 = _load_from_string(_NXN_MODEL, keyframe=2) + collision_driver.nxn_broadphase(m, d2) + np.testing.assert_allclose(d2.ncollision.numpy()[0], 3) + np.testing.assert_allclose(d2.collision_pair.numpy()[0][0], 0) + np.testing.assert_allclose(d2.collision_pair.numpy()[0][1], 1) + np.testing.assert_allclose(d2.collision_pair.numpy()[1][0], 0) + np.testing.assert_allclose(d2.collision_pair.numpy()[1][1], 2) + np.testing.assert_allclose(d2.collision_pair.numpy()[2][0], 1) + np.testing.assert_allclose(d2.collision_pair.numpy()[2][1], 2) + + # two worlds and four collisions + d3 = mjwarp.make_data(mjm, nworld=2) + d3.geom_xpos = wp.array( + np.vstack( + [np.expand_dims(mjd1.geom_xpos, axis=0), np.expand_dims(mjd2.geom_xpos, axis=0)] + ), + dtype=wp.vec3, + ) + + collision_driver.nxn_broadphase(m, d3) + np.testing.assert_allclose(d3.ncollision.numpy()[0], 4) + np.testing.assert_allclose(d3.collision_pair.numpy()[0][0], 0) + np.testing.assert_allclose(d3.collision_pair.numpy()[0][1], 1) + np.testing.assert_allclose(d3.collision_pair.numpy()[1][0], 0) + np.testing.assert_allclose(d3.collision_pair.numpy()[1][1], 1) + np.testing.assert_allclose(d3.collision_pair.numpy()[2][0], 0) + np.testing.assert_allclose(d3.collision_pair.numpy()[2][1], 2) + np.testing.assert_allclose(d3.collision_pair.numpy()[3][0], 1) + np.testing.assert_allclose(d3.collision_pair.numpy()[3][1], 2) + + # one world and zero collisions: contype and conaffinity incompatibility + _, _, m4, d4 = _load_from_string(_NXN_MODEL, keyframe=1) + m4.geom_contype = wp.array(np.array([0, 0, 0]), dtype=wp.int32) + m4.geom_conaffinity = wp.array(np.array([1, 1, 1]), dtype=wp.int32) + collision_driver.nxn_broadphase(m4, d4) + np.testing.assert_allclose(d4.ncollision.numpy()[0], 0) + + # one world and one collision: geomtype ordering + _, _, _, d5 = _load_from_string(_NXN_MODEL, keyframe=3) + collision_driver.nxn_broadphase(m, d5) + np.testing.assert_allclose(d5.ncollision.numpy()[0], 1) + np.testing.assert_allclose(d5.collision_pair.numpy()[0][0], 3) + np.testing.assert_allclose(d5.collision_pair.numpy()[0][1], 2) + + # TODO(team): test margin + # TODO(team): test DisableBit.FILTERPARENT + + +if __name__ == "__main__": + wp.init() + absltest.main() diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py new file mode 100644 index 00000000..28f300d6 --- /dev/null +++ b/mujoco_warp/_src/collision_convex.py @@ -0,0 +1,26 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp + +from .types import Data +from .types import Model + +# XXX disable backward pass codegen globally for now +wp.config.enable_backward = False + + +def narrowphase(m: Model, d: Data): + pass diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py new file mode 100644 index 00000000..4f5f7afe --- /dev/null +++ b/mujoco_warp/_src/collision_driver.py @@ -0,0 +1,551 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp + +from .types import MJ_MINVAL +from .types import Data +from .types import DisableBit +from .types import Model +from .types import vec5 +from .warp_util import event_scope + + +@wp.func +def _geom_filter(m: Model, geom1: int, geom2: int, filterparent: bool) -> bool: + bodyid1 = m.geom_bodyid[geom1] + bodyid2 = m.geom_bodyid[geom2] + contype1 = m.geom_contype[geom1] + contype2 = m.geom_contype[geom2] + conaffinity1 = m.geom_conaffinity[geom1] + conaffinity2 = m.geom_conaffinity[geom2] + weldid1 = m.body_weldid[bodyid1] + weldid2 = m.body_weldid[bodyid2] + weld_parentid1 = m.body_weldid[m.body_parentid[weldid1]] + weld_parentid2 = m.body_weldid[m.body_parentid[weldid2]] + + self_collision = weldid1 == weldid2 + parent_child_collision = ( + filterparent + and (weldid1 != 0) + and (weldid2 != 0) + and ((weldid1 == weld_parentid2) or (weldid2 == weld_parentid1)) + ) + mask = (contype1 & conaffinity2) or (contype2 & conaffinity1) + + return mask and (not self_collision) and (not parent_child_collision) + + +@wp.func +def _add_geom_pair(m: Model, d: Data, geom1: int, geom2: int, worldid: int): + pairid = wp.atomic_add(d.ncollision, 0, 1) + + if pairid >= d.nconmax: + return + + type1 = m.geom_type[geom1] + type2 = m.geom_type[geom2] + + if type1 > type2: + pair = wp.vec2i(geom2, geom1) + else: + pair = wp.vec2i(geom1, geom2) + + d.collision_pair[pairid] = pair + d.collision_worldid[pairid] = worldid + + +@wp.kernel +def broadphase_project_spheres_onto_sweep_direction_kernel( + m: Model, + d: Data, + direction: wp.vec3, +): + worldId, i = wp.tid() + + c = d.geom_xpos[worldId, i] + r = m.geom_rbound[i] + if r == 0.0: + # current geom is a plane + r = 1000000000.0 + sphere_radius = r + m.geom_margin[i] + + center = wp.dot(direction, c) + f = center - sphere_radius + + # Store results in the data arrays + d.sap_projection_lower[worldId, i] = f + d.sap_projection_upper[worldId, i] = center + sphere_radius + d.sap_sort_index[worldId, i] = i + + +# Define constants for plane types +PLANE_ZERO_OFFSET = -1.0 +PLANE_NEGATIVE_OFFSET = -2.0 +PLANE_POSITIVE_OFFSET = -3.0 + + +@wp.func +def encode_plane(normal: wp.vec3, point_on_plane: wp.vec3, margin: float) -> wp.vec4: + normal = wp.normalize(normal) + plane_offset = -wp.dot(normal, point_on_plane + normal * margin) + + # Scale factor for the normal + scale = wp.abs(plane_offset) + + # Handle special cases + if wp.abs(plane_offset) < 1e-6: + return wp.vec4(normal.x, normal.y, normal.z, PLANE_ZERO_OFFSET) + elif plane_offset < 0.0: + return wp.vec4( + scale * normal.x, scale * normal.y, scale * normal.z, PLANE_NEGATIVE_OFFSET + ) + else: + return wp.vec4( + scale * normal.x, scale * normal.y, scale * normal.z, PLANE_POSITIVE_OFFSET + ) + + +@wp.func +def decode_plane(encoded: wp.vec4) -> wp.vec4: + magnitude = wp.length(encoded) + normal = wp.normalize(xyz(encoded)) + + if encoded.w == PLANE_ZERO_OFFSET: + return wp.vec4(normal.x, normal.y, normal.z, 0.0) + elif encoded.w == PLANE_NEGATIVE_OFFSET: + return wp.vec4(normal.x, normal.y, normal.z, -magnitude) + else: + return wp.vec4(normal.x, normal.y, normal.z, magnitude) + + +@wp.kernel +def reorder_bounding_spheres_kernel( + m: Model, + d: Data, +): + worldId, i = wp.tid() + + # Get the index from the data indexer + mapped = d.sap_sort_index[worldId, i] + + # Get the bounding volume + c = d.geom_xpos[worldId, mapped] + r = m.geom_rbound[mapped] + margin = m.geom_margin[mapped] + + # Reorder the box into the sorted array + if r == 0.0: + # store the plane equation + xmat = d.geom_xmat[worldId, mapped] + plane_normal = wp.vec3(xmat[0, 2], xmat[1, 2], xmat[2, 2]) + d.sap_geom_sort[worldId, i] = encode_plane( + plane_normal, c, margin + ) # negative w component is used to disginguish planes from spheres + else: + d.sap_geom_sort[worldId, i] = wp.vec4(c.x, c.y, c.z, r + margin) + + +@wp.func +def xyz(v: wp.vec4) -> wp.vec3: + return wp.vec3(v.x, v.y, v.z) + + +@wp.func +def signed_distance_point_plane(point: wp.vec3, plane: wp.vec4) -> float: + return wp.dot(point, xyz(plane)) + plane.w + + +@wp.func +def overlap( + world_id: int, + a: int, + b: int, + spheres_or_planes: wp.array(dtype=wp.vec4, ndim=2), +) -> bool: + # Extract centers and sizes + s_a = spheres_or_planes[world_id, a] + s_b = spheres_or_planes[world_id, b] + + if s_a.w < 0.0 and s_b.w < 0.0: + # both are planes + return False + elif s_a.w < 0.0 or s_b.w < 0.0: + if s_b.w < 0.0: # swap if required such that s_a is always a plane + tmp = s_a + s_a = s_b + s_b = tmp + s_a = decode_plane(s_a) + dist = signed_distance_point_plane(xyz(s_b), s_a) + return dist <= s_b.w + else: + # geoms are spheres + delta = xyz(s_a) - xyz(s_b) + dist_sq = wp.dot(delta, delta) + radius_sum = s_a.w + s_b.w + return dist_sq <= radius_sum * radius_sum + + +@wp.func +def find_first_greater_than( + worldId: int, + starts: wp.array(dtype=wp.float32, ndim=2), + value: wp.float32, + low: int, + high: int, +) -> int: + while low < high: + mid = (low + high) >> 1 + if starts[worldId, mid] > value: + high = mid + else: + low = mid + 1 + return low + + +@wp.kernel +def sap_broadphase_prepare_kernel( + m: Model, + d: Data, +): + worldId, i = wp.tid() # Get the thread ID + + # Get the index of the current bounding box + idx1 = d.sap_sort_index[worldId, i] + + end = d.sap_projection_upper[worldId, idx1] + limit = find_first_greater_than(worldId, d.sap_projection_lower, end, i + 1, m.ngeom) + limit = wp.min(m.ngeom - 1, limit) + + # Calculate the range of boxes for the sweep and prune process + count = limit - i + + # Store the cumulative sum for the current box + d.sap_range[worldId, i] = count + + +@wp.func +def find_right_most_index_int( + starts: wp.array(dtype=wp.int32, ndim=1), value: wp.int32, low: int, high: int +) -> int: + while low < high: + mid = (low + high) >> 1 + if starts[mid] > value: + high = mid + else: + low = mid + 1 + return high + + +@wp.func +def find_indices( + id: int, cumulative_sum: wp.array(dtype=wp.int32, ndim=1), length: int +) -> wp.vec2i: + # Perform binary search to find the right most index + i = find_right_most_index_int(cumulative_sum, id, 0, length) + + # Get the baseId, and compute the offset and j + if i > 0: + base_id = cumulative_sum[i - 1] + else: + base_id = 0 + offset = id - base_id + j = i + offset + 1 + + return wp.vec2i(i, j) + + +@wp.kernel +def sap_broadphase_kernel(m: Model, d: Data, num_threads: int, filter_parent: bool): + threadId = wp.tid() # Get thread ID + if d.sap_cumulative_sum.shape[0] > 0: + total_num_work_packages = d.sap_cumulative_sum[d.sap_cumulative_sum.shape[0] - 1] + else: + total_num_work_packages = 0 + + while threadId < total_num_work_packages: + # Get indices for current and next box pair + ij = find_indices(threadId, d.sap_cumulative_sum, d.sap_cumulative_sum.shape[0]) + i = ij.x + j = ij.y + + worldId = i // m.ngeom + i = i % m.ngeom + j = j % m.ngeom + + # geom index + idx1 = d.sap_sort_index[worldId, i] + idx2 = d.sap_sort_index[worldId, j] + + if not _geom_filter(m, idx1, idx2, filter_parent): + threadId += num_threads + continue + + # Check if the boxes overlap + if overlap(worldId, i, j, d.sap_geom_sort): + _add_geom_pair(m, d, idx1, idx2, worldId) + + threadId += num_threads + + +@wp.kernel +def get_contact_solver_params_kernel( + m: Model, + d: Data, +): + tid = wp.tid() + + n_contact_pts = d.ncon[0] + if tid >= n_contact_pts: + return + + geoms = d.contact.geom[tid] + g1 = geoms.x + g2 = geoms.y + + margin = wp.max(m.geom_margin[g1], m.geom_margin[g2]) + gap = wp.max(m.geom_gap[g1], m.geom_gap[g2]) + solmix1 = m.geom_solmix[g1] + solmix2 = m.geom_solmix[g2] + mix = solmix1 / (solmix1 + solmix2) + mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix) + mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 >= MJ_MINVAL), 0.0, mix) + mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix) + + p1 = m.geom_priority[g1] + p2 = m.geom_priority[g2] + mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0)) + + condim1 = m.geom_condim[g1] + condim2 = m.geom_condim[g2] + condim = wp.where( + p1 == p2, wp.max(condim1, condim2), wp.where(p1 > p2, condim1, condim2) + ) + d.contact.dim[tid] = condim + + if m.geom_solref[g1].x > 0.0 and m.geom_solref[g2].x > 0.0: + d.contact.solref[tid] = mix * m.geom_solref[g1] + (1.0 - mix) * m.geom_solref[g2] + else: + d.contact.solref[tid] = wp.min(m.geom_solref[g1], m.geom_solref[g2]) + d.contact.includemargin[tid] = margin - gap + friction_ = wp.max(m.geom_friction[g1], m.geom_friction[g2]) + friction5 = vec5(friction_[0], friction_[0], friction_[1], friction_[2], friction_[2]) + d.contact.friction[tid] = friction5 + d.contact.solimp[tid] = mix * m.geom_solimp[g1] + (1.0 - mix) * m.geom_solimp[g2] + + +def sap_broadphase(m: Model, d: Data): + """Broadphase collision detection via sweep-and-prune.""" + + # Use random fixed direction vector for now + direction = wp.vec3(0.5935, 0.7790, 0.1235) + direction = wp.normalize(direction) + + wp.launch( + kernel=broadphase_project_spheres_onto_sweep_direction_kernel, + dim=(d.nworld, m.ngeom), + inputs=[m, d, direction], + ) + + tile_sort_available = False + segmented_sort_available = hasattr(wp.utils, "segmented_sort_pairs") + + if tile_sort_available: + segmented_sort_kernel = create_segmented_sort_kernel(m.ngeom) + wp.launch_tiled( + kernel=segmented_sort_kernel, dim=(d.nworld), inputs=[m, d], block_dim=128 + ) + print("tile sort available") + elif segmented_sort_available: + wp.utils.segmented_sort_pairs( + d.sap_projection_lower, + d.sap_sort_index, + m.ngeom * d.nworld, + d.sap_segment_index, + ) + else: + # Sort each world's segment separately + for world_id in range(d.nworld): + start_idx = world_id * m.ngeom + + # Create temporary arrays for sorting + temp_box_projections_lower = wp.zeros( + m.ngeom * 2, + dtype=d.sap_projection_lower.dtype, + ) + temp_box_sorting_indexer = wp.zeros( + m.ngeom * 2, + dtype=d.sap_sort_index.dtype, + ) + + # Copy data to temporary arrays + wp.copy( + temp_box_projections_lower, + d.sap_projection_lower, + 0, + start_idx, + m.ngeom, + ) + wp.copy( + temp_box_sorting_indexer, + d.sap_sort_index, + 0, + start_idx, + m.ngeom, + ) + + # Sort the temporary arrays + wp.utils.radix_sort_pairs( + temp_box_projections_lower, temp_box_sorting_indexer, m.ngeom + ) + + # Copy sorted data back + wp.copy( + d.sap_projection_lower, + temp_box_projections_lower, + start_idx, + 0, + m.ngeom, + ) + wp.copy( + d.sap_sort_index, + temp_box_sorting_indexer, + start_idx, + 0, + m.ngeom, + ) + + wp.launch( + kernel=reorder_bounding_spheres_kernel, + dim=(d.nworld, m.ngeom), + inputs=[m, d], + ) + + wp.launch( + kernel=sap_broadphase_prepare_kernel, + dim=(d.nworld, m.ngeom), + inputs=[m, d], + ) + + # The scan (scan = cumulative sum, either inclusive or exclusive depending on the last argument) is used for load balancing among the threads + wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True) + + # Estimate how many overlap checks need to be done - assumes each box has to be compared to 5 other boxes (and batched over all worlds) + num_sweep_threads = 5 * d.nworld * m.ngeom + filter_parent = not m.opt.disableflags & DisableBit.FILTERPARENT.value + wp.launch( + kernel=sap_broadphase_kernel, + dim=num_sweep_threads, + inputs=[m, d, num_sweep_threads, filter_parent], + ) + + return d + + +def nxn_broadphase(m: Model, d: Data): + """Broadphase collision detective via brute-force search.""" + filterparent = not (m.opt.disableflags & DisableBit.FILTERPARENT.value) + + @wp.kernel + def _nxn_broadphase(m: Model, d: Data): + worldid, elementid = wp.tid() + geom1 = ( + m.ngeom + - 2 + - int( + wp.sqrt(float(-8 * elementid + 4 * m.ngeom * (m.ngeom - 1) - 7)) / 2.0 - 0.5 + ) + ) + geom2 = ( + elementid + + geom1 + + 1 + - m.ngeom * (m.ngeom - 1) // 2 + + (m.ngeom - geom1) * ((m.ngeom - geom1) - 1) // 2 + ) + + margin1 = m.geom_margin[geom1] + margin2 = m.geom_margin[geom2] + pos1 = d.geom_xpos[worldid, geom1] + pos2 = d.geom_xpos[worldid, geom2] + size1 = m.geom_rbound[geom1] + size2 = m.geom_rbound[geom2] + + bound = size1 + size2 + wp.max(margin1, margin2) + dif = pos2 - pos1 + + if size1 != 0.0 and size2 != 0.0: + # neither geom is a plane + dist_sq = wp.dot(dif, dif) + bounds_filter = dist_sq <= bound * bound + elif size1 == 0.0: + # geom1 is a plane + xmat1 = d.geom_xmat[worldid, geom1] + dist = wp.dot(dif, wp.vec3(xmat1[0, 2], xmat1[1, 2], xmat1[2, 2])) + bounds_filter = dist <= bound + else: + # geom2 is a plane + xmat2 = d.geom_xmat[worldid, geom2] + dist = wp.dot(-dif, wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2])) + bounds_filter = dist <= bound + + geom_filter = _geom_filter(m, geom1, geom2, filterparent) + + if bounds_filter and geom_filter: + _add_geom_pair(m, d, geom1, geom2, worldid) + + wp.launch( + _nxn_broadphase, dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d] + ) + + +def get_contact_solver_params(m: Model, d: Data): + wp.launch( + get_contact_solver_params_kernel, + dim=[d.nconmax], + inputs=[m, d], + ) + + # TODO(team): do we need condim sorting, deepest penetrating contact here? + + +@event_scope +def collision(m: Model, d: Data): + """Collision detection.""" + + # AD: based on engine_collision_driver.py in Eric's warp fork/mjx-collisions-dev + # which is further based on the CUDA code here: + # https://github.com/btaba/mujoco/blob/warp-collisions/mjx/mujoco/mjx/_src/cuda/engine_collision_driver.cu.cc#L458-L583 + + d.ncollision.zero_() + d.ncon.zero_() + + # TODO(team): determine ngeom to switch from n^2 to sap + if m.ngeom <= 100: + nxn_broadphase(m, d) + else: + sap_broadphase(m, d) + + # XXX switch between collision functions and GJK/EPA here + if True: + from .collision_functions import narrowphase + else: + from .collision_convex import narrowphase + + # TODO(team): should we limit per-world contact nubmers? + # TODO(team): we should reject far-away contacts in the narrowphase instead of constraint + # partitioning because we can move some pressure of the atomics + narrowphase(m, d) + get_contact_solver_params(m, d) diff --git a/mujoco_warp/_src/collision_driver_test.py b/mujoco_warp/_src/collision_driver_test.py new file mode 100644 index 00000000..5c1b3a62 --- /dev/null +++ b/mujoco_warp/_src/collision_driver_test.py @@ -0,0 +1,121 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests the collision driver.""" + +import mujoco +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import mujoco_warp as mjwarp + + +class PrimitiveTest(parameterized.TestCase): + """Tests the primitive contact functions.""" + + _MJCFS = { + "box_plane": """ + + + + + + + + + + """, + "plane_sphere": """ + + + + + + + + + + """, + "sphere_sphere": """ + + + + + + + + + + + + + """, + "capsule_capsule": """ + + + + + + + + + + + + + """, + "plane_capsule": """ + + + + + + + + + + """, + } + + @parameterized.parameters( + "box_plane", + "plane_sphere", + "sphere_sphere", + "plane_capsule", + "capsule_capsule", + ) + def test_contact(self, name): + """Tests contact calculation with different collision functions.""" + m = mujoco.MjModel.from_xml_string(self._MJCFS[name]) + d = mujoco.MjData(m) + mujoco.mj_forward(m, d) + mx = mjwarp.put_model(m) + dx = mjwarp.put_data(m, d) + mjwarp.collision(mx, dx) + mujoco.mj_collision(m, d) + self.assertEqual(d.ncon, dx.ncon.numpy()[0]) + for i in range(d.ncon): + actual_dist = dx.contact.dist.numpy()[i] + actual_pos = dx.contact.pos.numpy()[i, :] + actual_frame = dx.contact.frame.numpy()[i].flatten() + np.testing.assert_array_almost_equal(actual_dist, d.contact.dist[i], 4) + np.testing.assert_array_almost_equal(actual_pos, d.contact.pos[i], 4) + np.testing.assert_array_almost_equal(actual_frame, d.contact.frame[i], 4) + + +if __name__ == "__main__": + absltest.main() diff --git a/mujoco_warp/_src/collision_functions.py b/mujoco_warp/_src/collision_functions.py new file mode 100644 index 00000000..23a713ed --- /dev/null +++ b/mujoco_warp/_src/collision_functions.py @@ -0,0 +1,821 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import math +from typing import Any + +import numpy as np +import warp as wp + +from .math import closest_segment_to_segment_points +from .math import make_frame +from .math import normalize_with_norm +from .types import Data +from .types import GeomType +from .types import Model + +BOX_BOX_BLOCK_DIM = 32 + +@wp.struct +class Geom: + pos: wp.vec3 + rot: wp.mat33 + normal: wp.vec3 + size: wp.vec3 + # TODO(team): mesh fields: vertadr, vertnum + + +@wp.func +def _geom( + gid: int, + m: Model, + geom_xpos: wp.array(dtype=wp.vec3), + geom_xmat: wp.array(dtype=wp.mat33), +) -> Geom: + geom = Geom() + geom.pos = geom_xpos[gid] + rot = geom_xmat[gid] + geom.rot = rot + geom.size = m.geom_size[gid] + geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane + + return geom + + +@wp.func +def write_contact( + d: Data, + dist: float, + pos: wp.vec3, + frame: wp.mat33, + margin: float, + geoms: wp.vec2i, + worldid: int, +): + active = (dist - margin) < 0 + if active: + index = wp.atomic_add(d.ncon, 0, 1) + if index < d.nconmax: + d.contact.dist[index] = dist + d.contact.pos[index] = pos + d.contact.frame[index] = frame + d.contact.geom[index] = geoms + d.contact.worldid[index] = worldid + + +@wp.func +def _plane_sphere( + plane_normal: wp.vec3, plane_pos: wp.vec3, sphere_pos: wp.vec3, sphere_radius: float +): + dist = wp.dot(sphere_pos - plane_pos, plane_normal) - sphere_radius + pos = sphere_pos - plane_normal * (sphere_radius + 0.5 * dist) + return dist, pos + + +@wp.func +def plane_sphere( + plane: Geom, + sphere: Geom, + worldid: int, + d: Data, + margin: float, + geom_indices: wp.vec2i, +): + dist, pos = _plane_sphere(plane.normal, plane.pos, sphere.pos, sphere.size[0]) + + write_contact(d, dist, pos, make_frame(plane.normal), margin, geom_indices, worldid) + + +@wp.func +def _sphere_sphere( + pos1: wp.vec3, + radius1: float, + pos2: wp.vec3, + radius2: float, + worldid: int, + d: Data, + margin: float, + geom_indices: wp.vec2i, +): + dir = pos2 - pos1 + dist = wp.length(dir) + if dist == 0.0: + n = wp.vec3(1.0, 0.0, 0.0) + else: + n = dir / dist + dist = dist - (radius1 + radius2) + pos = pos1 + n * (radius1 + 0.5 * dist) + + write_contact(d, dist, pos, make_frame(n), margin, geom_indices, worldid) + + +@wp.func +def sphere_sphere( + sphere1: Geom, + sphere2: Geom, + worldid: int, + d: Data, + margin: float, + geom_indices: wp.vec2i, +): + _sphere_sphere( + sphere1.pos, + sphere1.size[0], + sphere2.pos, + sphere2.size[0], + worldid, + d, + margin, + geom_indices, + ) + + +@wp.func +def capsule_capsule( + cap1: Geom, + cap2: Geom, + worldid: int, + d: Data, + margin: float, + geom_indices: wp.vec2i, +): + axis1 = wp.vec3(cap1.rot[0, 2], cap1.rot[1, 2], cap1.rot[2, 2]) + axis2 = wp.vec3(cap2.rot[0, 2], cap2.rot[1, 2], cap2.rot[2, 2]) + length1 = cap1.size[1] + length2 = cap2.size[1] + seg1 = axis1 * length1 + seg2 = axis2 * length2 + + pt1, pt2 = closest_segment_to_segment_points( + cap1.pos - seg1, + cap1.pos + seg1, + cap2.pos - seg2, + cap2.pos + seg2, + ) + + _sphere_sphere(pt1, cap1.size[0], pt2, cap2.size[0], worldid, d, margin, geom_indices) + + +@wp.func +def plane_capsule( + plane: Geom, + cap: Geom, + worldid: int, + d: Data, + margin: float, + geom_indices: wp.vec2i, +): + """Calculates two contacts between a capsule and a plane.""" + n = plane.normal + axis = wp.vec3(cap.rot[0, 2], cap.rot[1, 2], cap.rot[2, 2]) + # align contact frames with capsule axis + b, b_norm = normalize_with_norm(axis - n * wp.dot(n, axis)) + + if b_norm < 0.5: + if -0.5 < n[1] and n[1] < 0.5: + b = wp.vec3(0.0, 1.0, 0.0) + else: + b = wp.vec3(0.0, 0.0, 1.0) + + c = wp.cross(n, b) + frame = wp.mat33(n[0], n[1], n[2], b[0], b[1], b[2], c[0], c[1], c[2]) + segment = axis * cap.size[1] + + dist1, pos1 = _plane_sphere(n, plane.pos, cap.pos + segment, cap.size[0]) + write_contact(d, dist1, pos1, frame, margin, geom_indices, worldid) + + dist2, pos2 = _plane_sphere(n, plane.pos, cap.pos - segment, cap.size[0]) + write_contact(d, dist2, pos2, frame, margin, geom_indices, worldid) + + +HUGE_VAL = 1e6 +TINY_VAL = 1e-6 + + +class vec16b(wp.types.vector(length=16, dtype=wp.int8)): + pass + + +class mat43f(wp.types.matrix(shape=(4, 3), dtype=wp.float32)): + pass + + +class mat83f(wp.types.matrix(shape=(8, 3), dtype=wp.float32)): + pass + + +class mat16_3f(wp.types.matrix(shape=(16, 3), dtype=wp.float32)): + pass + + +Box = mat83f + + +@wp.func +def _argmin(a: Any) -> wp.int32: + amin = wp.int32(0) + vmin = wp.float32(a[0]) + for i in range(1, len(a)): + if a[i] < vmin: + amin = i + vmin = a[i] + return amin + + +@wp.func +def box_normals(i: int) -> wp.vec3: + direction = wp.where(i < 3, -1.0, 1.0) + mod = i % 3 + if mod == 0: + return wp.vec3(0.0, direction, 0.0) + if mod == 1: + return wp.vec3(0.0, 0.0, direction) + return wp.vec3(-direction, 0.0, 0.0) + + +@wp.func +def box(R: wp.mat33, t: wp.vec3, geom_size: wp.vec3) -> Box: + """Get a transformed box""" + x = geom_size[0] + y = geom_size[1] + z = geom_size[2] + m = Box() + for i in range(8): + ix = wp.where(i & 4, x, -x) + iy = wp.where(i & 2, y, -y) + iz = wp.where(i & 1, z, -z) + m[i] = R @ wp.vec3(ix, iy, iz) + t + return m + + +@wp.func +def box_face_verts(box: Box, idx: wp.int32) -> mat43f: + """Get the quad corresponding to a box face""" + if idx == 0: + verts = wp.vec4i(0, 4, 5, 1) + if idx == 1: + verts = wp.vec4i(0, 2, 6, 4) + if idx == 2: + verts = wp.vec4i(6, 7, 5, 4) + if idx == 3: + verts = wp.vec4i(2, 3, 7, 6) + if idx == 4: + verts = wp.vec4i(1, 5, 7, 3) + if idx == 5: + verts = wp.vec4i(0, 1, 3, 2) + + m = mat43f() + for i in range(4): + m[i] = box[verts[i]] + return m + + +@wp.func +def get_box_axis( + axis_idx: int, + R: wp.mat33, +): + """Get the axis at index axis_idx. + R: rotation matrix from a to b + Axes 0-12 are face normals of boxes a & b + Axes 12-21 are edge cross products.""" + if axis_idx < 6: # a faces + axis = R @ wp.vec3(box_normals(axis_idx)) + is_degenerate = False + elif axis_idx < 12: # b faces + axis = wp.vec3(box_normals(axis_idx - 6)) + is_degenerate = False + else: # edges cross products + assert axis_idx < 21 + edges = axis_idx - 12 + axis_a, axis_b = edges / 3, edges % 3 + edge_a = wp.transpose(R)[axis_a] + if axis_b == 0: + axis = wp.vec3(0.0, -edge_a[2], edge_a[1]) + elif axis_b == 1: + axis = wp.vec3(edge_a[2], 0.0, -edge_a[0]) + else: + axis = wp.vec3(-edge_a[1], edge_a[0], 0.0) + is_degenerate = wp.length_sq(axis) < TINY_VAL + return wp.normalize(axis), is_degenerate + + +@wp.func +def get_box_axis_support( + axis: wp.vec3, degenerate_axis: bool, a: Box, b: Box +): + """Get the overlap (or separating distance if negative) along `axis`, and the sign.""" + axis_d = wp.vec3d(axis) + support_a_max, support_b_max = wp.float32(-HUGE_VAL), wp.float32(-HUGE_VAL) + support_a_min, support_b_min = wp.float32(HUGE_VAL), wp.float32(HUGE_VAL) + for i in range(8): + vert_a = wp.vec3d(a[i]) + vert_b = wp.vec3d(b[i]) + proj_a = wp.float32(wp.dot(vert_a, axis_d)) + proj_b = wp.float32(wp.dot(vert_b, axis_d)) + support_a_max = wp.max(support_a_max, proj_a) + support_b_max = wp.max(support_b_max, proj_b) + support_a_min = wp.min(support_a_min, proj_a) + support_b_min = wp.min(support_b_min, proj_b) + dist1 = support_a_max - support_b_min + dist2 = support_b_max - support_a_min + dist = wp.where(degenerate_axis, HUGE_VAL, wp.min(dist1, dist2)) + sign = wp.where(dist1 > dist2, -1, 1) + return dist, sign + + +@wp.struct +class AxisSupport: + best_dist: wp.float32 + best_sign: wp.int8 + best_idx: wp.int8 + + +@wp.func +def reduce_axis_support(a: AxisSupport, b: AxisSupport): + return wp.where(a.best_dist > b.best_dist, b, a) + + +@wp.func +def face_axis_alignment(a: wp.vec3, R: wp.mat33) -> wp.int32: + """Find the box faces most aligned with the axis `a`""" + max_dot = wp.float32(0.0) + max_idx = wp.int32(0) + for i in range(6): + d = wp.dot(R @ box_normals(i), a) + if d > max_dot: + max_dot = d + max_idx = i + return max_idx + + +@wp.kernel(enable_backward=False) +def box_box_kernel( + m: Model, + d: Data, + num_kernels: int, +): + """Calculates contacts between pairs of boxes.""" + tid, axis_idx = wp.tid() + + for bp_idx in range(tid, min(d.ncollision[0], d.nconmax), num_kernels): + geoms = d.collision_pair[bp_idx] + + ga, gb = geoms[0], geoms[1] + + if m.geom_type[ga] != int(GeomType.BOX.value) or m.geom_type[gb] != int(GeomType.BOX.value): + continue + + worldid = d.collision_worldid[bp_idx] + # transformations + a_pos, b_pos = d.geom_xpos[worldid, ga], d.geom_xpos[worldid, gb] + a_mat, b_mat = d.geom_xmat[worldid, ga], d.geom_xmat[worldid, gb] + b_mat_inv = wp.transpose(b_mat) + trans_atob = b_mat_inv @ (a_pos - b_pos) + rot_atob = b_mat_inv @ a_mat + + a_size = m.geom_size[ga] + b_size = m.geom_size[gb] + a = box(rot_atob, trans_atob, a_size) + b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size) + + # box-box implementation + + # Inlined def collision_axis_tiled( a: Box, b: Box, R: wp.mat33, axis_idx: wp.int32,): + # Finds the axis of minimum separation. + # a: Box a vertices, in frame b + # b: Box b vertices, in frame b + # R: rotation matrix from a to b + # Returns: + # best_axis: vec3 + # best_sign: int32 + # best_idx: int32 + R = rot_atob + + # launch tiled with block_dim=21 + if axis_idx > 20: + continue + + axis, degenerate_axis = get_box_axis(axis_idx, R) + axis_dist, axis_sign = get_box_axis_support(axis, degenerate_axis, a, b) + + supports = wp.tile(AxisSupport(axis_dist, wp.int8(axis_sign), wp.int8(axis_idx))) + + face_supports = wp.tile_view(supports, offset=(0,), shape=(12,)) + edge_supports = wp.tile_view(supports, offset=(12,), shape=(9,)) + + face_supports_red = wp.tile_reduce(reduce_axis_support, face_supports) + edge_supports_red = wp.tile_reduce(reduce_axis_support, edge_supports) + + face = face_supports_red[0] + edge = edge_supports_red[0] + + if axis_idx > 0: # single thread + continue + + # choose the best separating axis + face_axis, _ = get_box_axis(wp.int32(face.best_idx), R) + best_axis = wp.vec3(face_axis) + best_sign = wp.int32(face.best_sign) + best_idx = wp.int32(face.best_idx) + best_dist = wp.float32(face.best_dist) + + if edge.best_dist < face.best_dist: + edge_axis, _ = get_box_axis(wp.int32(edge.best_idx), R) + if wp.abs(wp.dot(face_axis, edge_axis)) < 0.99: + best_axis = edge_axis + best_sign = wp.int32(edge.best_sign) + best_idx = wp.int32(edge.best_idx) + best_dist = wp.float32(edge.best_dist) + # end inlined collision_axis_tiled + + # if axis_idx != 0: + # continue + if best_dist < 0: + continue + + # get the (reference) face most aligned with the separating axis + a_max = face_axis_alignment(best_axis, rot_atob) + b_max = face_axis_alignment(best_axis, wp.identity(3, wp.float32)) + + sep_axis = wp.float32(best_sign) * best_axis + + if best_sign > 0: + b_min = (b_max + 3) % 6 + dist, pos = _create_contact_manifold( + box_face_verts(a, a_max), + rot_atob @ box_normals(a_max), + box_face_verts(b, b_min), + box_normals(b_min), + ) + else: + a_min = (a_max + 3) % 6 + dist, pos = _create_contact_manifold( + box_face_verts(b, b_max), + box_normals(b_max), + box_face_verts(a, a_min), + rot_atob @ box_normals(a_min), + ) + + # For edge contacts, we use the clipped face point, mainly for performance + # reasons. For small penetration, the clipped face point is roughly the edge + # contact point. + if best_idx > 11: # is_edge_contact + idx = _argmin(dist) + dist = wp.vec4f(dist[idx], 1.0, 1.0, 1.0) + for i in range(4): + pos[i] = pos[idx] + + margin = wp.max(m.geom_margin[ga], m.geom_margin[gb]) + for i in range(4): + pos_glob = b_mat @ pos[i] + b_pos + n_glob = b_mat @ sep_axis + write_contact( + d, dist[i], pos_glob, make_frame(n_glob), margin, wp.vec2i(ga, gb), worldid + ) + + +def box_box( + m: Model, + d: Data, +): + """Calculates contacts between pairs of boxes.""" + kernel_ratio = 16 + num_threads = math.ceil( + d.nconmax / kernel_ratio + ) # parallel threads excluding tile dim + wp.launch_tiled( + kernel=box_box_kernel, + dim=num_threads, + inputs=[m, d, num_threads], + block_dim=BOX_BOX_BLOCK_DIM, + ) + + +@wp.func +def _closest_segment_point_plane( + a: wp.vec3, b: wp.vec3, p0: wp.vec3, plane_normal: wp.vec3 +) -> wp.vec3: + """Gets the closest point between a line segment and a plane. + + Args: + a: first line segment point + b: second line segment point + p0: point on plane + plane_normal: plane normal + + Returns: + closest point between the line segment and the plane + """ + # Parametrize a line segment as S(t) = a + t * (b - a), plug it into the plane + # equation dot(n, S(t)) - d = 0, then solve for t to get the line-plane + # intersection. We then clip t to be in [0, 1] to be on the line segment. + n = plane_normal + d = wp.dot(p0, n) # shortest distance from origin to plane + denom = wp.dot(n, (b - a)) + t = (d - wp.dot(n, a)) / (denom + wp.where(denom == 0.0, TINY_VAL, 0.0)) + t = wp.clamp(t, 0.0, 1.0) + segment_point = a + t * (b - a) + + return segment_point + + +@wp.func +def _project_poly_onto_plane( + poly: Any, + poly_n: wp.vec3, + plane_n: wp.vec3, + plane_pt: wp.vec3, +): + """Projects poly1 onto the poly2 plane along poly2's normal.""" + d = wp.dot(plane_pt, plane_n) + denom = wp.dot(poly_n, plane_n) + qn_scaled = poly_n / (denom + wp.where(denom == 0.0, TINY_VAL, 0.0)) + + for i in range(len(poly)): + poly[i] = poly[i] + (d - wp.dot(poly[i], plane_n)) * qn_scaled + return poly + + +@wp.func +def _clip_edge_to_quad( + subject_poly: mat43f, + clipping_poly: mat43f, + clipping_normal: wp.vec3, +): + p0 = mat43f() + p1 = mat43f() + mask = wp.vec4b() + for edge_idx in range(4): + subject_p0 = subject_poly[(edge_idx + 3) % 4] + subject_p1 = subject_poly[edge_idx] + + any_both_in_front = wp.int32(0) + clipped0_dist_max = wp.float32(-HUGE_VAL) + clipped1_dist_max = wp.float32(-HUGE_VAL) + clipped_p0_distmax = wp.vec3(0.0) + clipped_p1_distmax = wp.vec3(0.0) + + for clipping_edge_idx in range(4): + clipping_p0 = clipping_poly[(clipping_edge_idx + 3) % 4] + clipping_p1 = clipping_poly[clipping_edge_idx] + edge_normal = wp.cross(clipping_p1 - clipping_p0, clipping_normal) + + p0_in_front = wp.dot(subject_p0 - clipping_p0, edge_normal) > TINY_VAL + p1_in_front = wp.dot(subject_p1 - clipping_p0, edge_normal) > TINY_VAL + candidate_clipped_p = _closest_segment_point_plane( + subject_p0, subject_p1, clipping_p1, edge_normal + ) + clipped_p0 = wp.where(p0_in_front, candidate_clipped_p, subject_p0) + clipped_p1 = wp.where(p1_in_front, candidate_clipped_p, subject_p1) + clipped_dist_p0 = wp.dot(clipped_p0 - subject_p0, subject_p1 - subject_p0) + clipped_dist_p1 = wp.dot(clipped_p1 - subject_p1, subject_p0 - subject_p1) + any_both_in_front |= wp.int32(p0_in_front and p1_in_front) + + if clipped_dist_p0 > clipped0_dist_max: + clipped0_dist_max = clipped_dist_p0 + clipped_p0_distmax = clipped_p0 + + if clipped_dist_p1 > clipped1_dist_max: + clipped1_dist_max = clipped_dist_p1 + clipped_p1_distmax = clipped_p1 + new_p0 = wp.where(any_both_in_front, subject_p0, clipped_p0_distmax) + new_p1 = wp.where(any_both_in_front, subject_p1, clipped_p1_distmax) + + mask_val = wp.int8( + wp.where( + wp.dot(subject_p0 - subject_p1, new_p0 - new_p1) < 0, + 0, + wp.int32(not any_both_in_front), + ) + ) + + p0[edge_idx] = new_p0 + p1[edge_idx] = new_p1 + mask[edge_idx] = mask_val + return p0, p1, mask + + +@wp.func +def _clip_quad( + subject_quad: mat43f, + subject_normal: wp.vec3, + clipping_quad: mat43f, + clipping_normal: wp.vec3, +): + """Clips a subject quad against a clipping quad. + Serial implementation. + """ + + subject_clipped_p0, subject_clipped_p1, subject_mask = _clip_edge_to_quad( + subject_quad, clipping_quad, clipping_normal + ) + clipping_proj = _project_poly_onto_plane( + clipping_quad, clipping_normal, subject_normal, subject_quad[0] + ) + clipping_clipped_p0, clipping_clipped_p1, clipping_mask = _clip_edge_to_quad( + clipping_proj, subject_quad, subject_normal + ) + + clipped = mat16_3f() + mask = vec16b() + for i in range(4): + clipped[i] = subject_clipped_p0[i] + clipped[i + 4] = clipping_clipped_p0[i] + clipped[i + 8] = subject_clipped_p1[i] + clipped[i + 12] = clipping_clipped_p1[i] + mask[i] = subject_mask[i] + mask[i + 4] = clipping_mask[i] + mask[i + 8] = subject_mask[i] + mask[i + 8 + 4] = clipping_mask[i] + + return clipped, mask + + +# TODO(ca): sparse, tiling variant for large point counts +@wp.func +def _manifold_points( + poly: Any, + mask: Any, + clipping_norm: wp.vec3, +) -> wp.vec4b: + """Chooses four points on the polygon with approximately maximal area. Return the indices""" + n = len(poly) + + a_idx = wp.int32(0) + a_mask = wp.int8(mask[0]) + for i in range(n): + if mask[i] >= a_mask: + a_idx = i + a_mask = mask[i] + a = poly[a_idx] + + b_idx = wp.int32(0) + b_dist = wp.float32(-HUGE_VAL) + for i in range(n): + dist = wp.length_sq(poly[i] - a) + wp.where(mask[i], 0.0, -HUGE_VAL) + if dist >= b_dist: + b_idx = i + b_dist = dist + b = poly[b_idx] + + ab = wp.cross(clipping_norm, a - b) + + c_idx = wp.int32(0) + c_dist = wp.float32(-HUGE_VAL) + for i in range(n): + ap = a - poly[i] + dist = wp.abs(wp.dot(ap, ab)) + wp.where(mask[i], 0.0, -HUGE_VAL) + if dist >= c_dist: + c_idx = i + c_dist = dist + c = poly[c_idx] + + ac = wp.cross(clipping_norm, a - c) + bc = wp.cross(clipping_norm, b - c) + + d_idx = wp.int32(0) + d_dist = wp.float32(-2.0 * HUGE_VAL) + for i in range(n): + ap = a - poly[i] + dist_ap = wp.abs(wp.dot(ap, ac)) + wp.where(mask[i], 0.0, -HUGE_VAL) + bp = b - poly[i] + dist_bp = wp.abs(wp.dot(bp, bc)) + wp.where(mask[i], 0.0, -HUGE_VAL) + if dist_ap + dist_bp >= d_dist: + d_idx = i + d_dist = dist_ap + dist_bp + d = poly[d_idx] + return wp.vec4b(wp.int8(a_idx), wp.int8(b_idx), wp.int8(c_idx), wp.int8(d_idx)) + + +@wp.func +def _create_contact_manifold( + clipping_quad: mat43f, + clipping_normal: wp.vec3, + subject_quad: mat43f, + subject_normal: wp.vec3, +): + # Clip the subject (incident) face onto the clipping (reference) face. + # The incident points are clipped points on the subject polygon. + incident, mask = _clip_quad( + subject_quad, subject_normal, clipping_quad, clipping_normal + ) + + clipping_normal_neg = -clipping_normal + d = wp.dot(clipping_quad[0], clipping_normal_neg) + TINY_VAL + + for i in range(16): + if wp.dot(incident[i], clipping_normal_neg) < d: + mask[i] = wp.int8(0) + + ref = _project_poly_onto_plane( + incident, clipping_normal, clipping_normal, clipping_quad[0] + ) + + # Choose four contact points. + best = _manifold_points(ref, mask, clipping_normal) + contact_pts = mat43f() + dist = wp.vec4f() + + for i in range(4): + idx = wp.int32(best[i]) + contact_pt = ref[idx] + contact_pts[i] = contact_pt + penetration_dir = incident[idx] - contact_pt + penetration = wp.dot(penetration_dir, clipping_normal) + dist[i] = wp.where(mask[idx], penetration, 1.0) + + return dist, contact_pts + + +@wp.func +def plane_box( + plane: Geom, + box: Geom, + worldid: int, + d: Data, + margin: float, + geom_indices: wp.vec2i, +): + count = int(0) + corner = wp.vec3() + dist = wp.dot(box.pos - plane.pos, plane.normal) + + # test all corners, pick bottom 4 + for i in range(8): + # get corner in local coordinates + corner.x = wp.where(i & 1, box.size.x, -box.size.x) + corner.y = wp.where(i & 2, box.size.y, -box.size.y) + corner.z = wp.where(i & 4, box.size.z, -box.size.z) + + # get corner in global coordinates relative to box center + corner = box.rot * corner + + # compute distance to plane, skip if too far or pointing up + ldist = wp.dot(plane.normal, corner) + if dist + ldist > margin or ldist > 0: + continue + + cdist = dist + ldist + frame = make_frame(plane.normal) + pos = corner + box.pos + (plane.normal * cdist / -2.0) + write_contact(d, cdist, pos, frame, margin, geom_indices, worldid) + count += 1 + if count >= 4: + break + + +@wp.kernel +def _narrowphase( + m: Model, + d: Data, +): + tid = wp.tid() + + if tid >= d.ncollision[0]: + return + + geoms = d.collision_pair[tid] + worldid = d.collision_worldid[tid] + + g1 = geoms[0] + g2 = geoms[1] + type1 = m.geom_type[g1] + type2 = m.geom_type[g2] + + geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) + geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) + + margin = wp.max(m.geom_margin[g1], m.geom_margin[g2]) + + # TODO(team): static loop unrolling to remove unnecessary branching + if type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.SPHERE.value): + plane_sphere(geom1, geom2, worldid, d, margin, geoms) + elif type1 == int(GeomType.SPHERE.value) and type2 == int(GeomType.SPHERE.value): + sphere_sphere(geom1, geom2, worldid, d, margin, geoms) + elif type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.CAPSULE.value): + plane_capsule(geom1, geom2, worldid, d, margin, geoms) + elif type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.BOX.value): + plane_box(geom1, geom2, worldid, d, margin, geoms) + elif type1 == int(GeomType.CAPSULE.value) and type2 == int(GeomType.CAPSULE.value): + capsule_capsule(geom1, geom2, worldid, d, margin, geoms) + + +def narrowphase(m: Model, d: Data): + # we need to figure out how to keep the overhead of this small - not launching anything + # for pair types without collisions, as well as updating the launch dimensions. + wp.launch(_narrowphase, dim=d.nconmax, inputs=[m, d]) + box_box(m, d) diff --git a/mujoco/mjx/_src/constraint.py b/mujoco_warp/_src/constraint.py similarity index 86% rename from mujoco/mjx/_src/constraint.py rename to mujoco_warp/_src/constraint.py index 6cfd2c6b..8e783aa1 100644 --- a/mujoco/mjx/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,9 @@ # ============================================================================== import warp as wp + from . import types +from .warp_util import event_scope @wp.func @@ -53,22 +55,22 @@ def _update_efc_row( # See https://mujoco.readthedocs.io/en/latest/modeling.html#solver-parameters k = 1.0 / (dmax * dmax * timeconst * timeconst * dampratio * dampratio) b = 2.0 / (dmax * timeconst) - k = wp.select(solref[0] <= 0, k, -solref[0] / (dmax * dmax)) - b = wp.select(solref[1] <= 0, b, -solref[1] / dmax) + k = wp.where(solref[0] <= 0, -solref[0] / (dmax * dmax), k) + b = wp.where(solref[1] <= 0, -solref[1] / dmax, b) imp_x = wp.abs(pos_imp) / width imp_a = (1.0 / wp.pow(mid, power - 1.0)) * wp.pow(imp_x, power) imp_b = 1.0 - (1.0 / wp.pow(1.0 - mid, power - 1.0)) * wp.pow(1.0 - imp_x, power) - imp_y = wp.select(imp_x < mid, imp_b, imp_a) + imp_y = wp.where(imp_x < mid, imp_a, imp_b) imp = dmin + imp_y * (dmax - dmin) imp = wp.clamp(imp, dmin, dmax) - imp = wp.select(imp_x > 1.0, imp, dmax) + imp = wp.where(imp_x > 1.0, dmax, imp) # Update constraints - d.efc_D[efcid] = 1.0 / wp.max(invweight * (1.0 - imp) / imp, types.MJ_MINVAL) - d.efc_aref[efcid] = -k * imp * pos_aref - b * Jqvel - d.efc_pos[efcid] = pos_aref + margin - d.efc_margin[efcid] = margin + d.efc.D[efcid] = 1.0 / wp.max(invweight * (1.0 - imp) / imp, types.MJ_MINVAL) + d.efc.aref[efcid] = -k * imp * pos_aref - b * Jqvel + d.efc.pos[efcid] = pos_aref + margin + d.efc.margin[efcid] = margin @wp.func @@ -114,13 +116,13 @@ def _efc_limit_slide_hinge( active = pos < 0 if active: - efcid = wp.atomic_add(d.nefc_total, 0, 1) - d.efc_worldid[efcid] = worldid + efcid = wp.atomic_add(d.nefc, 0, 1) + d.efc.worldid[efcid] = worldid dofadr = m.jnt_dofadr[jntid] J = float(dist_min < dist_max) * 2.0 - 1.0 - d.efc_J[efcid, dofadr] = J + d.efc.J[efcid, dofadr] = J Jqvel = J * d.qvel[worldid, dofadr] _update_efc_row( @@ -147,7 +149,7 @@ def _efc_contact_pyramidal( ): conid, dimid = wp.tid() - if conid >= d.ncon_total[0]: + if conid >= d.ncon[0]: return if d.contact.dim[conid] != 3: @@ -157,9 +159,9 @@ def _efc_contact_pyramidal( active = pos < 0 if active: - efcid = wp.atomic_add(d.nefc_total, 0, 1) + efcid = wp.atomic_add(d.nefc, 0, 1) worldid = d.contact.worldid[conid] - d.efc_worldid[efcid] = worldid + d.efc.worldid[efcid] = worldid body1 = m.geom_bodyid[d.contact.geom[conid][0]] body2 = m.geom_bodyid[d.contact.geom[conid][1]] @@ -189,7 +191,7 @@ def _efc_contact_pyramidal( else: J = diff_0 - diff_i * d.contact.friction[conid][dimid2 - 1] - d.efc_J[efcid, i] = J + d.efc.J[efcid, i] = J Jqvel += J * d.qvel[worldid, i] _update_efc_row( @@ -208,12 +210,13 @@ def _efc_contact_pyramidal( ) +@event_scope def make_constraint(m: types.Model, d: types.Data): """Creates constraint jacobians and other supporting data.""" if not (m.opt.disableflags & types.DisableBit.CONSTRAINT.value): - d.nefc_total.zero_() - d.efc_J.zero_() + d.nefc.zero_() + d.efc.J.zero_() refsafe = not m.opt.disableflags & types.DisableBit.REFSAFE.value @@ -226,7 +229,10 @@ def make_constraint(m: types.Model, d: types.Data): inputs=[m, d, refsafe], ) - if m.opt.cone == types.ConeType.PYRAMIDAL.value: + if ( + not (m.opt.disableflags & types.DisableBit.CONTACT.value) + and m.opt.cone == types.ConeType.PYRAMIDAL.value + ): wp.launch( _efc_contact_pyramidal, dim=(d.nconmax, 4), diff --git a/mujoco/mjx/_src/constraint_test.py b/mujoco_warp/_src/constraint_test.py similarity index 71% rename from mujoco/mjx/_src/constraint_test.py rename to mujoco_warp/_src/constraint_test.py index 575bafbe..e36aa552 100644 --- a/mujoco/mjx/_src/constraint_test.py +++ b/mujoco_warp/_src/constraint_test.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,14 +15,16 @@ """Tests for constraint functions.""" +import mujoco +import numpy as np from absl.testing import absltest from absl.testing import parameterized + +import mujoco_warp as mjwarp + from . import test_util -import mujoco -from mujoco import mjx -import numpy as np -# tolerance for difference between MuJoCo and MJX constraint calculations, +# tolerance for difference between MuJoCo and MJWarp constraint calculations, # mostly due to float precision _TOLERANCE = 5e-5 @@ -43,15 +45,15 @@ def test_constraints(self): mujoco.mj_resetDataKeyframe(mjm, mjd, key) mujoco.mj_forward(mjm, mjd) - m = mjx.put_model(mjm) - d = mjx.put_data(mjm, mjd) - mjx.make_constraint(m, d) - - _assert_eq(d.efc_J.numpy()[: mjd.nefc, :].reshape(-1), mjd.efc_J, "efc_J") - _assert_eq(d.efc_D.numpy()[: mjd.nefc], mjd.efc_D, "efc_D") - _assert_eq(d.efc_aref.numpy()[: mjd.nefc], mjd.efc_aref, "efc_aref") - _assert_eq(d.efc_pos.numpy()[: mjd.nefc], mjd.efc_pos, "efc_pos") - _assert_eq(d.efc_margin.numpy()[: mjd.nefc], mjd.efc_margin, "efc_margin") + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd) + mjwarp.make_constraint(m, d) + + _assert_eq(d.efc.J.numpy()[: mjd.nefc, :].reshape(-1), mjd.efc_J, "efc_J") + _assert_eq(d.efc.D.numpy()[: mjd.nefc], mjd.efc_D, "efc_D") + _assert_eq(d.efc.aref.numpy()[: mjd.nefc], mjd.efc_aref, "efc_aref") + _assert_eq(d.efc.pos.numpy()[: mjd.nefc], mjd.efc_pos, "efc_pos") + _assert_eq(d.efc.margin.numpy()[: mjd.nefc], mjd.efc_margin, "efc_margin") if __name__ == "__main__": diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py new file mode 100644 index 00000000..f18aef0a --- /dev/null +++ b/mujoco_warp/_src/forward.py @@ -0,0 +1,695 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Optional + +import mujoco +import warp as wp + +from . import collision_driver +from . import constraint +from . import math +from . import passive +from . import smooth +from . import solver +from .support import xfrc_accumulate +from .types import MJ_MINVAL +from .types import BiasType +from .types import Data +from .types import DisableBit +from .types import DynType +from .types import GainType +from .types import JointType +from .types import Model +from .types import array2df +from .types import array3df +from .warp_util import event_scope +from .warp_util import kernel +from .warp_util import kernel_copy + + +def _advance( + m: Model, d: Data, act_dot: wp.array, qacc: wp.array, qvel: Optional[wp.array] = None +): + """Advance state and time given activation derivatives and acceleration.""" + + # TODO(team): can we assume static timesteps? + + @kernel + def next_activation( + m: Model, + d: Data, + act_dot_in: array2df, + ): + worldId, actid = wp.tid() + + # get the high/low range for each actuator state + limited = m.actuator_actlimited[actid] + range_low = wp.where(limited, m.actuator_actrange[actid][0], -wp.inf) + range_high = wp.where(limited, m.actuator_actrange[actid][1], wp.inf) + + # get the actual actuation - skip if -1 (means stateless actuator) + act_adr = m.actuator_actadr[actid] + if act_adr == -1: + return + + acts = d.act[worldId] + acts_dot = act_dot_in[worldId] + + act = acts[act_adr] + act_dot = acts_dot[act_adr] + + # check dynType + dyn_type = m.actuator_dyntype[actid] + dyn_prm = m.actuator_dynprm[actid][0] + + # advance the actuation + if dyn_type == wp.static(DynType.FILTEREXACT.value): + tau = wp.where(dyn_prm < MJ_MINVAL, MJ_MINVAL, dyn_prm) + act = act + act_dot * tau * (1.0 - wp.exp(-m.opt.timestep / tau)) + else: + act = act + act_dot * m.opt.timestep + + # apply limits + wp.clamp(act, range_low, range_high) + + acts[act_adr] = act + + @kernel + def advance_velocities(m: Model, d: Data, qacc: array2df): + worldId, tid = wp.tid() + d.qvel[worldId, tid] = d.qvel[worldId, tid] + qacc[worldId, tid] * m.opt.timestep + + @kernel + def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df): + worldId, jntid = wp.tid() + + jnt_type = m.jnt_type[jntid] + qpos_adr = m.jnt_qposadr[jntid] + dof_adr = m.jnt_dofadr[jntid] + qpos = d.qpos[worldId] + qvel = qvel_in[worldId] + + if jnt_type == wp.static(JointType.FREE.value): + qpos_pos = wp.vec3(qpos[qpos_adr], qpos[qpos_adr + 1], qpos[qpos_adr + 2]) + qvel_lin = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2]) + + qpos_new = qpos_pos + m.opt.timestep * qvel_lin + + qpos_quat = wp.quat( + qpos[qpos_adr + 3], + qpos[qpos_adr + 4], + qpos[qpos_adr + 5], + qpos[qpos_adr + 6], + ) + qvel_ang = wp.vec3(qvel[dof_adr + 3], qvel[dof_adr + 4], qvel[dof_adr + 5]) + + qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep) + + qpos[qpos_adr] = qpos_new[0] + qpos[qpos_adr + 1] = qpos_new[1] + qpos[qpos_adr + 2] = qpos_new[2] + qpos[qpos_adr + 3] = qpos_quat_new[0] + qpos[qpos_adr + 4] = qpos_quat_new[1] + qpos[qpos_adr + 5] = qpos_quat_new[2] + qpos[qpos_adr + 6] = qpos_quat_new[3] + + elif jnt_type == wp.static(JointType.BALL.value): # ball joint + qpos_quat = wp.quat( + qpos[qpos_adr], + qpos[qpos_adr + 1], + qpos[qpos_adr + 2], + qpos[qpos_adr + 3], + ) + qvel_ang = wp.vec3(qvel[dof_adr], qvel[dof_adr + 1], qvel[dof_adr + 2]) + + qpos_quat_new = math.quat_integrate(qpos_quat, qvel_ang, m.opt.timestep) + + qpos[qpos_adr] = qpos_quat_new[0] + qpos[qpos_adr + 1] = qpos_quat_new[1] + qpos[qpos_adr + 2] = qpos_quat_new[2] + qpos[qpos_adr + 3] = qpos_quat_new[3] + + else: # if jnt_type in (JointType.HINGE, JointType.SLIDE): + qpos[qpos_adr] = qpos[qpos_adr] + m.opt.timestep * qvel[dof_adr] + + # skip if no stateful actuators. + if m.na: + wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot]) + + wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc]) + + # advance positions with qvel if given, d.qvel otherwise (semi-implicit) + if qvel is not None: + qvel_in = qvel + else: + qvel_in = d.qvel + + wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in]) + + d.time = d.time + m.opt.timestep + + +@event_scope +def euler(m: Model, d: Data): + """Euler integrator, semi-implicit in velocity.""" + + # integrate damping implicitly + + def eulerdamp_sparse(m: Model, d: Data): + @kernel + def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): + worldId, tid = wp.tid() + + dof_Madr = m.dof_Madr[tid] + d.qM_integration[worldId, 0, dof_Madr] += m.opt.timestep * m.dof_damping[tid] + + d.qfrc_integration[worldId, tid] = ( + d.qfrc_smooth[worldId, tid] + d.qfrc_constraint[worldId, tid] + ) + + kernel_copy(d.qM_integration, d.qM) + wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) + smooth.factor_solve_i( + m, + d, + d.qM_integration, + d.qLD_integration, + d.qLDiagInv_integration, + d.qacc_integration, + d.qfrc_integration, + ) + + def eulerdamp_fused_dense(m: Model, d: Data): + def tile_eulerdamp(adr: int, size: int, tilesize: int): + @kernel + def eulerdamp( + m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + ): + worldid, nodeid = wp.tid() + dofid = m.qLD_tile[leveladr + nodeid] + M_tile = wp.tile_load( + d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) + ) + damping_tile = wp.tile_load(damping, shape=(tilesize,), offset=(dofid,)) + damping_scaled = damping_tile * m.opt.timestep + qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) + + qfrc_smooth_tile = wp.tile_load( + d.qfrc_smooth[worldid], shape=(tilesize,), offset=(dofid,) + ) + qfrc_constraint_tile = wp.tile_load( + d.qfrc_constraint[worldid], shape=(tilesize,), offset=(dofid,) + ) + + qfrc_tile = qfrc_smooth_tile + qfrc_constraint_tile + + L_tile = wp.tile_cholesky(qm_integration_tile) + qacc_tile = wp.tile_cholesky_solve(L_tile, qfrc_tile) + wp.tile_store(d.qacc_integration[worldid], qacc_tile, offset=(dofid)) + + wp.launch_tiled( + eulerdamp, dim=(d.nworld, size), inputs=[m, d, m.dof_damping, adr], block_dim=32 + ) + + qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() + + for i in range(len(qLD_tileadr)): + beg = qLD_tileadr[i] + end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1] + tile_eulerdamp(beg, end - beg, int(qLD_tilesize[i])) + + if not m.opt.disableflags & DisableBit.EULERDAMP.value: + if m.opt.is_sparse: + eulerdamp_sparse(m, d) + else: + eulerdamp_fused_dense(m, d) + + _advance(m, d, d.act_dot, d.qacc_integration) + else: + _advance(m, d, d.act_dot, d.qacc) + + +@event_scope +def implicit(m: Model, d: Data): + """Integrates fully implicit in velocity.""" + + # optimization comments (AD) + # I went from small kernels for every step to a relatively big single + # kernel using tile API because it kept improving performance - + # 30M to 50M FPS on an A6000. + # + # The main benefit is reduced global memory roundtrips, but I assume + # there is also some benefit to loading data as early as possible. + # + # I further tried fusing in the cholesky factor/solve but the high + # storage requirements led to low occupancy and thus worse performance. + # + # The actuator_bias_gain_vel kernel could theoretically be fused in as well, + # but it's pretty clean straight-line code that loads a lot of data but + # only stores one array, so I think the benefit of keeping that one on-chip + # is likely not worth it compared to the compromises we're making with tile API. + # It would also need a different data layout for the biasprm/gainprm arrays + # to be tileable. + + # assumptions + assert not m.opt.is_sparse # unsupported + # TODO(team): add sparse version + + # compile-time constants + passive_enabled = not m.opt.disableflags & DisableBit.PASSIVE.value + actuation_enabled = ( + not m.opt.disableflags & DisableBit.ACTUATION.value + ) and m.actuator_affine_bias_gain + + @kernel + def actuator_bias_gain_vel(m: Model, d: Data): + worldid, actid = wp.tid() + + bias_vel = 0.0 + gain_vel = 0.0 + + actuator_biastype = m.actuator_biastype[actid] + actuator_gaintype = m.actuator_gaintype[actid] + actuator_dyntype = m.actuator_dyntype[actid] + + if actuator_biastype == wp.static(BiasType.AFFINE.value): + bias_vel = m.actuator_biasprm[actid, 2] + + if actuator_gaintype == wp.static(GainType.AFFINE.value): + gain_vel = m.actuator_gainprm[actid, 2] + + ctrl = d.ctrl[worldid, actid] + + if actuator_dyntype != wp.static(DynType.NONE.value): + ctrl = d.act[worldid, actid] + + d.act_vel_integration[worldid, actid] = bias_vel + gain_vel * ctrl + + def qderiv_actuator_damping_fused( + m: Model, d: Data, damping: wp.array(dtype=wp.float32) + ): + if actuation_enabled: + block_dim = 64 + else: + block_dim = 256 + + @wp.func + def subtract_multiply(x: wp.float32, y: wp.float32): + return x - y * wp.static(m.opt.timestep) + + def qderiv_actuator_damping_tiled( + adr: int, size: int, tilesize_nv: int, tilesize_nu: int + ): + @kernel + def qderiv_actuator_fused_kernel( + m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + ): + worldid, nodeid = wp.tid() + offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] + + # skip tree with no actuators. + if wp.static(actuation_enabled and tilesize_nu != 0): + offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] + actuator_moment_tile = wp.tile_load( + d.actuator_moment[worldid], + shape=(tilesize_nu, tilesize_nv), + offset=(offset_nu, offset_nv), + ) + zeros = wp.tile_zeros(shape=(tilesize_nu, tilesize_nu), dtype=wp.float32) + vel_tile = wp.tile_load( + d.act_vel_integration[worldid], shape=(tilesize_nu), offset=offset_nu + ) + diag = wp.tile_diag_add(zeros, vel_tile) + actuator_moment_T = wp.tile_transpose(actuator_moment_tile) + amTVel = wp.tile_matmul(actuator_moment_T, diag) + qderiv_tile = wp.tile_matmul(amTVel, actuator_moment_tile) + else: + qderiv_tile = wp.tile_zeros( + shape=(tilesize_nv, tilesize_nv), dtype=wp.float32 + ) + + if wp.static(passive_enabled): + dof_damping = wp.tile_load(damping, shape=tilesize_nv, offset=offset_nv) + negative = wp.neg(dof_damping) + qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) + + # add to qM + qM_tile = wp.tile_load( + d.qM[worldid], shape=(tilesize_nv, tilesize_nv), offset=(offset_nv, offset_nv) + ) + qderiv_tile = wp.tile_map(subtract_multiply, qM_tile, qderiv_tile) + wp.tile_store( + d.qM_integration[worldid], qderiv_tile, offset=(offset_nv, offset_nv) + ) + + # sum qfrc + qfrc_smooth_tile = wp.tile_load( + d.qfrc_smooth[worldid], shape=tilesize_nv, offset=offset_nv + ) + qfrc_constraint_tile = wp.tile_load( + d.qfrc_constraint[worldid], shape=tilesize_nv, offset=offset_nv + ) + qfrc_combined = wp.add(qfrc_smooth_tile, qfrc_constraint_tile) + wp.tile_store(d.qfrc_integration[worldid], qfrc_combined, offset=offset_nv) + + wp.launch_tiled( + qderiv_actuator_fused_kernel, + dim=(d.nworld, size), + inputs=[m, d, damping, adr], + block_dim=block_dim, + ) + + qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() + qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() + qderiv_tileadr = m.actuator_moment_tileadr.numpy() + + for i in range(len(qderiv_tileadr)): + beg = qderiv_tileadr[i] + end = ( + m.qLD_tile.shape[0] if i == len(qderiv_tileadr) - 1 else qderiv_tileadr[i + 1] + ) + if qderiv_tilesize_nv[i] != 0: + qderiv_actuator_damping_tiled( + beg, end - beg, int(qderiv_tilesize_nv[i]), int(qderiv_tilesize_nu[i]) + ) + + if passive_enabled or actuation_enabled: + if actuation_enabled: + wp.launch( + actuator_bias_gain_vel, + dim=(d.nworld, m.nu), + inputs=[m, d], + ) + + qderiv_actuator_damping_fused(m, d, m.dof_damping) + + smooth._factor_solve_i_dense( + m, d, d.qM_integration, d.qacc_integration, d.qfrc_integration + ) + + _advance(m, d, d.act_dot, d.qacc_integration) + else: + _advance(m, d, d.act_dot, d.qacc) + + +@event_scope +def fwd_position(m: Model, d: Data): + """Position-dependent computations.""" + + smooth.kinematics(m, d) + smooth.com_pos(m, d) + # TODO(team): smooth.camlight + # TODO(team): smooth.tendon + smooth.crb(m, d) + smooth.factor_m(m, d) + collision_driver.collision(m, d) + constraint.make_constraint(m, d) + smooth.transmission(m, d) + + +@event_scope +def fwd_velocity(m: Model, d: Data): + """Velocity-dependent computations.""" + + if m.opt.is_sparse: + # TODO(team): sparse version + d.actuator_velocity.zero_() + + @kernel + def _actuator_velocity(d: Data): + worldid, actid, dofid = wp.tid() + moment = d.actuator_moment[worldid, actid] + qvel = d.qvel[worldid] + wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid]) + + wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d]) + else: + + def actuator_velocity( + adr: int, + size: int, + tilesize_nu: int, + tilesize_nv: int, + ): + @kernel + def _actuator_velocity( + m: Model, d: Data, leveladr: int, velocity: array3df, qvel: array3df + ): + worldid, nodeid = wp.tid() + offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] + offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] + actuator_moment_tile = wp.tile_load( + d.actuator_moment[worldid], + shape=(tilesize_nu, tilesize_nv), + offset=(offset_nu, offset_nv), + ) + qvel_tile = wp.tile_load( + qvel[worldid], shape=(tilesize_nv, 1), offset=(offset_nv, 0) + ) + velocity_tile = wp.tile_matmul(actuator_moment_tile, qvel_tile) + + wp.tile_store(velocity[worldid], velocity_tile, offset=(offset_nu, 0)) + + wp.launch_tiled( + _actuator_velocity, + dim=(d.nworld, size), + inputs=[ + m, + d, + adr, + d.actuator_velocity.reshape(d.actuator_velocity.shape + (1,)), + d.qvel.reshape(d.qvel.shape + (1,)), + ], + block_dim=32, + ) + + actuator_moment_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() + actuator_moment_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() + actuator_moment_tileadr = m.actuator_moment_tileadr.numpy() + + for i in range(len(actuator_moment_tileadr)): + beg = actuator_moment_tileadr[i] + end = ( + m.actuator_moment_tileadr.shape[0] + if i == len(actuator_moment_tileadr) - 1 + else actuator_moment_tileadr[i + 1] + ) + if actuator_moment_tilesize_nu[i] != 0 and actuator_moment_tilesize_nv[i] != 0: + actuator_velocity( + beg, + end - beg, + int(actuator_moment_tilesize_nu[i]), + int(actuator_moment_tilesize_nv[i]), + ) + + smooth.com_vel(m, d) + passive.passive(m, d) + smooth.rne(m, d) + + +@event_scope +def fwd_actuation(m: Model, d: Data): + """Actuation-dependent computations.""" + if not m.nu or m.opt.disableflags & DisableBit.ACTUATION: + d.act_dot.zero_() + d.qfrc_actuator.zero_() + return + + # TODO support stateful actuators + + @kernel + def _force( + m: Model, + d: Data, + # outputs + force: array2df, + ): + worldid, uid = wp.tid() + + actuator_length = d.actuator_length[worldid, uid] + actuator_velocity = d.actuator_velocity[worldid, uid] + + gain = m.actuator_gainprm[uid, 0] + gain += m.actuator_gainprm[uid, 1] * actuator_length + gain += m.actuator_gainprm[uid, 2] * actuator_velocity + + bias = m.actuator_biasprm[uid, 0] + bias += m.actuator_biasprm[uid, 1] * actuator_length + bias += m.actuator_biasprm[uid, 2] * actuator_velocity + + ctrl = d.ctrl[worldid, uid] + disable_clampctrl = m.opt.disableflags & wp.static(DisableBit.CLAMPCTRL.value) + if m.actuator_ctrllimited[uid] and not disable_clampctrl: + r = m.actuator_ctrlrange[uid] + ctrl = wp.clamp(ctrl, r[0], r[1]) + f = gain * ctrl + bias + if m.actuator_forcelimited[uid]: + r = m.actuator_forcerange[uid] + force[worldid, uid] = f + + @kernel + def _qfrc_limited(m: Model, d: Data): + worldid, dofid = wp.tid() + jntid = m.dof_jntid[dofid] + if m.jnt_actfrclimited[jntid]: + d.qfrc_actuator[worldid, dofid] = wp.clamp( + d.qfrc_actuator[worldid, dofid], + m.jnt_actfrcrange[jntid][0], + m.jnt_actfrcrange[jntid][1], + ) + + if m.opt.is_sparse: + # TODO(team): sparse version + @kernel + def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): + worldid, vid = wp.tid() + + s = float(0.0) + for uid in range(m.nu): + # TODO consider using Tile API or transpose moment for better access pattern + s += moment[worldid, uid, vid] * force[worldid, uid] + jntid = m.dof_jntid[vid] + if m.jnt_actfrclimited[jntid]: + r = m.jnt_actfrcrange[jntid] + s = wp.clamp(s, r[0], r[1]) + qfrc[worldid, vid] = s + + wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_force]) + + if m.opt.is_sparse: + # TODO(team): sparse version + + wp.launch( + _qfrc, + dim=(d.nworld, m.nv), + inputs=[m, d.actuator_moment, d.actuator_force], + outputs=[d.qfrc_actuator], + ) + + else: + + def qfrc_actuator(adr: int, size: int, tilesize_nu: int, tilesize_nv: int): + @kernel + def qfrc_actuator_kernel( + m: Model, + d: Data, + leveladr: int, + qfrc_actuator: array3df, + actuator_force: array3df, + ): + worldid, nodeid = wp.tid() + offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] + offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] + + actuator_moment_tile = wp.tile_load( + d.actuator_moment[worldid], + shape=(tilesize_nu, tilesize_nv), + offset=(offset_nu, offset_nv), + ) + actuator_moment_T_tile = wp.tile_transpose(actuator_moment_tile) + + force_tile = wp.tile_load( + actuator_force[worldid], shape=(tilesize_nu, 1), offset=(offset_nu, 0) + ) + qfrc_tile = wp.tile_matmul(actuator_moment_T_tile, force_tile) + wp.tile_store(qfrc_actuator[worldid], qfrc_tile, offset=(offset_nv, 0)) + + wp.launch_tiled( + qfrc_actuator_kernel, + dim=(d.nworld, size), + inputs=[ + m, + d, + adr, + d.qfrc_actuator.reshape(d.qfrc_actuator.shape + (1,)), + d.actuator_force.reshape(d.actuator_force.shape + (1,)), + ], + block_dim=32, + ) + + qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() + qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() + qderiv_tileadr = m.actuator_moment_tileadr.numpy() + + for i in range(len(qderiv_tileadr)): + beg = qderiv_tileadr[i] + end = ( + m.actuator_moment_tileadr.shape[0] + if i == len(qderiv_tileadr) - 1 + else qderiv_tileadr[i + 1] + ) + if qderiv_tilesize_nu[i] != 0 and qderiv_tilesize_nv[i] != 0: + qfrc_actuator( + beg, end - beg, int(qderiv_tilesize_nu[i]), int(qderiv_tilesize_nv[i]) + ) + + wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d]) + + # TODO actuator-level gravity compensation, skip if added as passive force + + +@event_scope +def fwd_acceleration(m: Model, d: Data): + """Add up all non-constraint forces, compute qacc_smooth.""" + + @kernel + def _qfrc_smooth(d: Data): + worldid, dofid = wp.tid() + d.qfrc_smooth[worldid, dofid] = ( + d.qfrc_passive[worldid, dofid] + - d.qfrc_bias[worldid, dofid] + + d.qfrc_actuator[worldid, dofid] + + d.qfrc_applied[worldid, dofid] + ) + + wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d]) + xfrc_accumulate(m, d, d.qfrc_smooth) + + smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) + + +@event_scope +def forward(m: Model, d: Data): + """Forward dynamics.""" + + fwd_position(m, d) + # TODO(team): sensor.sensor_pos + fwd_velocity(m, d) + # TODO(team): sensor.sensor_vel + fwd_actuation(m, d) + fwd_acceleration(m, d) + # TODO(team): sensor.sensor_acc + + if d.njmax == 0: + kernel_copy(d.qacc, d.qacc_smooth) + else: + solver.solve(m, d) + + +@event_scope +def step(m: Model, d: Data): + """Advance simulation.""" + forward(m, d) + + if m.opt.integrator == mujoco.mjtIntegrator.mjINT_EULER: + euler(m, d) + elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_RK4: + # TODO(team): rungekutta4 + raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.") + elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST: + implicit(m, d) + else: + raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.") diff --git a/mujoco/mjx/_src/forward_test.py b/mujoco_warp/_src/forward_test.py similarity index 54% rename from mujoco/mjx/_src/forward_test.py rename to mujoco_warp/_src/forward_test.py index 5f5dce6b..63d85ca0 100644 --- a/mujoco/mjx/_src/forward_test.py +++ b/mujoco_warp/_src/forward_test.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,19 +15,18 @@ """Tests for forward dynamics functions.""" -from absl.testing import absltest -from etils import epath +import mujoco import numpy as np import warp as wp +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath -wp.config.verify_cuda = True - -import mujoco -from mujoco import mjx +import mujoco_warp as mjwarp -from .types import DisableBit +wp.config.verify_cuda = True -# tolerance for difference between MuJoCo and MJX smooth calculations - mostly +# tolerance for difference between MuJoCo and mjwarp smooth calculations - mostly # due to float precision _TOLERANCE = 5e-5 @@ -40,25 +39,23 @@ def _assert_eq(a, b, name): class ForwardTest(absltest.TestCase): def _load(self, fname: str, is_sparse: bool = True): - path = epath.resource_path("mujoco.mjx") / "test_data" / fname + path = epath.resource_path("mujoco_warp") / "test_data" / fname mjm = mujoco.MjModel.from_xml_path(path.as_posix()) mjm.opt.jacobian = is_sparse mjd = mujoco.MjData(mjm) mujoco.mj_resetDataKeyframe(mjm, mjd, 1) # reset to stand_on_left_leg mjd.qvel = np.random.uniform(low=-0.01, high=0.01, size=mjd.qvel.shape) - mjd.ctrl = np.random.normal(scale=10, size=mjd.ctrl.shape) - mjd.act = np.random.normal(scale=10, size=mjd.act.shape) + mjd.ctrl = np.random.normal(scale=1, size=mjd.ctrl.shape) mujoco.mj_forward(mjm, mjd) - m = mjx.put_model(mjm) - d = mjx.put_data(mjm, mjd) + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd) return mjm, mjd, m, d def test_fwd_velocity(self): - """Tests MJX fwd_velocity.""" - _, mjd, m, d = self._load("humanoid/humanoid.xml") + _, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False) d.actuator_velocity.zero_() - mjx.fwd_velocity(m, d) + mjwarp.fwd_velocity(m, d) _assert_eq( d.actuator_velocity.numpy()[0], mjd.actuator_velocity, "actuator_velocity" @@ -66,7 +63,6 @@ def test_fwd_velocity(self): _assert_eq(d.qfrc_bias.numpy()[0], mjd.qfrc_bias, "qfrc_bias") def test_fwd_actuation(self): - """Tests MJX fwd_actuation.""" mjm, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False) mujoco.mj_fwdActuation(mjm, mjd) @@ -74,27 +70,30 @@ def test_fwd_actuation(self): for arr in (d.actuator_force, d.qfrc_actuator): arr.zero_() - mjx.fwd_actuation(m, d) + mjwarp.fwd_actuation(m, d) _assert_eq(d.ctrl.numpy()[0], mjd.ctrl, "ctrl") _assert_eq(d.actuator_force.numpy()[0], mjd.actuator_force, "actuator_force") _assert_eq(d.qfrc_actuator.numpy()[0], mjd.qfrc_actuator, "qfrc_actuator") + # TODO(team): test DisableBit.CLAMPCTRL + # TODO(team): test DisableBit.ACTUATION + # TODO(team): test actuator gain/bias (e.g. position control) + def test_fwd_acceleration(self): - """Tests MJX fwd_acceleration.""" _, mjd, m, d = self._load("humanoid/humanoid.xml", is_sparse=False) for arr in (d.qfrc_smooth, d.qacc_smooth): arr.zero_() - mjx.factor_m(m, d) # for dense, get tile cholesky factorization - mjx.fwd_acceleration(m, d) + mjwarp.factor_m(m, d) # for dense, get tile cholesky factorization + mjwarp.fwd_acceleration(m, d) _assert_eq(d.qfrc_smooth.numpy()[0], mjd.qfrc_smooth, "qfrc_smooth") _assert_eq(d.qacc_smooth.numpy()[0], mjd.qacc_smooth, "qacc_smooth") def test_eulerdamp(self): - path = epath.resource_path("mujoco.mjx") / "test_data/pendula.xml" + path = epath.resource_path("mujoco_warp") / "test_data/pendula.xml" mjm = mujoco.MjModel.from_xml_path(path.as_posix()) self.assertTrue((mjm.dof_damping > 0).any()) @@ -103,10 +102,10 @@ def test_eulerdamp(self): mjd.qacc[:] = 1.0 mujoco.mj_forward(mjm, mjd) - m = mjx.put_model(mjm) - d = mjx.put_data(mjm, mjd) + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd) - mjx.euler(m, d) + mjwarp.euler(m, d) mujoco.mj_Euler(mjm, mjd) _assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos") @@ -119,33 +118,84 @@ def test_eulerdamp(self): mjd.qacc[:] = 1.0 mujoco.mj_forward(mjm, mjd) - m = mjx.put_model(mjm) - d = mjx.put_data(mjm, mjd) + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd) - mjx.euler(m, d) + mjwarp.euler(m, d) mujoco.mj_Euler(mjm, mjd) _assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos") _assert_eq(d.act.numpy()[0], mjd.act, "act") def test_disable_eulerdamp(self): - path = epath.resource_path("mujoco.mjx") / "test_data/pendula.xml" + path = epath.resource_path("mujoco_warp") / "test_data/pendula.xml" mjm = mujoco.MjModel.from_xml_path(path.as_posix()) - mjm.opt.disableflags = mjm.opt.disableflags | DisableBit.EULERDAMP.value + mjm.opt.disableflags = mjm.opt.disableflags | mujoco.mjtDisableBit.mjDSBL_EULERDAMP mjd = mujoco.MjData(mjm) mujoco.mj_forward(mjm, mjd) mjd.qvel[:] = 1.0 mjd.qacc[:] = 1.0 - m = mjx.put_model(mjm) - d = mjx.put_data(mjm, mjd) + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd) - mjx.euler(m, d) + mjwarp.euler(m, d) np.testing.assert_allclose(d.qvel.numpy()[0], 1 + mjm.opt.timestep) +class ImplicitIntegratorTest(parameterized.TestCase): + def _load(self, fname: str, disableFlags: int): + path = epath.resource_path("mujoco_warp") / "test_data" / fname + mjm = mujoco.MjModel.from_xml_path(path.as_posix()) + mjm.opt.jacobian = 0 + mjm.opt.integrator = mujoco.mjtIntegrator.mjINT_IMPLICITFAST + mjm.opt.disableflags = mjm.opt.disableflags | disableFlags + mjm.actuator_gainprm[:, 2] = np.random.normal( + scale=10, size=mjm.actuator_gainprm[:, 2].shape + ) + + # change actuators to velocity/damper to cover all codepaths + mjm.actuator_gaintype[3] = mujoco.mjtGain.mjGAIN_AFFINE + mjm.actuator_gaintype[6] = mujoco.mjtGain.mjGAIN_AFFINE + mjm.actuator_biastype[0:3] = mujoco.mjtBias.mjBIAS_AFFINE + mjm.actuator_biastype[4:6] = mujoco.mjtBias.mjBIAS_AFFINE + mjm.actuator_biasprm[0:3, 2] = -1 + mjm.actuator_biasprm[4:6, 2] = -1 + mjm.actuator_ctrlrange[3:7] = 10.0 + mjm.actuator_gear[:] = 1.0 + + mjd = mujoco.MjData(mjm) + + mjd.qvel = np.random.uniform(low=-0.01, high=0.01, size=mjd.qvel.shape) + mjd.ctrl = np.random.normal(scale=10, size=mjd.ctrl.shape) + mjd.act = np.random.normal(scale=10, size=mjd.act.shape) + mujoco.mj_forward(mjm, mjd) + + mjd.ctrl = np.random.normal(scale=10, size=mjd.ctrl.shape) + mjd.act = np.random.normal(scale=10, size=mjd.act.shape) + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd) + return mjm, mjd, m, d + + @parameterized.parameters( + 0, + mjwarp.DisableBit.PASSIVE.value, + mjwarp.DisableBit.ACTUATION.value, + mjwarp.DisableBit.PASSIVE.value & mjwarp.DisableBit.ACTUATION.value, + ) + def test_implicit(self, disableFlags): + np.random.seed(0) + mjm, mjd, m, d = self._load("pendula.xml", disableFlags) + + mjwarp.implicit(m, d) + mujoco.mj_implicit(mjm, mjd) + + _assert_eq(d.qpos.numpy()[0], mjd.qpos, "qpos") + _assert_eq(d.act.numpy()[0], mjd.act, "act") + + if __name__ == "__main__": wp.init() absltest.main() diff --git a/mujoco/mjx/_src/io.py b/mujoco_warp/_src/io.py similarity index 59% rename from mujoco/mjx/_src/io.py rename to mujoco_warp/_src/io.py index 268e9fc5..c71c2600 100644 --- a/mujoco/mjx/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +13,9 @@ # limitations under the License. # ============================================================================== -import warp as wp import mujoco import numpy as np +import warp as wp from . import support from . import types @@ -33,6 +33,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.nsite = mjm.nsite m.nmocap = mjm.nmocap m.nM = mjm.nM + m.nexclude = mjm.nexclude m.opt.timestep = mjm.opt.timestep m.opt.tolerance = mjm.opt.tolerance m.opt.ls_tolerance = mjm.opt.ls_tolerance @@ -53,7 +54,18 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: # dof lower triangle row and column indices dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv) - # indices for sparse qM + # indices for sparse qM full_m + is_, js = [], [] + for i in range(mjm.nv): + j = i + while j > -1: + is_.append(i) + js.append(j) + j = mjm.dof_parentid[j] + qM_fullm_i = is_ + qM_fullm_j = js + + # indices for sparse qM mul_m is_, js, madr_ijs = [], [], [] for i in range(mjm.nv): madr_ij, j = mjm.dof_Madr[i], i @@ -64,7 +76,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: break is_, js, madr_ijs = is_ + [i], js + [j], madr_ijs + [madr_ij] - qM_i, qM_j, qM_madr_ij = (np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs)) + qM_mulm_i, qM_mulm_j, qM_madr_ij = ( + np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs) + ) jnt_limited_slide_hinge_adr = np.nonzero( mjm.jnt_limited @@ -125,8 +139,52 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: qLD_tileadr = np.cumsum(tile_off)[:-1] qLD_tilesize = np.array(sorted(tiles.keys())) - m.qM_i = wp.array(qM_i, dtype=wp.int32, ndim=1) - m.qM_j = wp.array(qM_j, dtype=wp.int32, ndim=1) + # tiles for actuator_moment - needs nu + nv tile size and offset + actuator_moment_offset_nv = np.empty(shape=(0,), dtype=int) + actuator_moment_offset_nu = np.empty(shape=(0,), dtype=int) + actuator_moment_tileadr = np.empty(shape=(0,), dtype=int) + actuator_moment_tilesize_nv = np.empty(shape=(0,), dtype=int) + actuator_moment_tilesize_nu = np.empty(shape=(0,), dtype=int) + + if not support.is_sparse(mjm): + # how many actuators for each tree + tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] + tree_id = mjm.dof_treeid[tile_corners] + num_trees = int(np.max(tree_id)) + tree = mjm.body_treeid[mjm.jnt_bodyid[mjm.actuator_trnid[:, 0]]] + counts, ids = np.histogram(tree, bins=np.arange(0, num_trees + 2)) + acts_per_tree = dict(zip([int(i) for i in ids], [int(i) for i in counts])) + + tiles = {} + act_beg = 0 + for i in range(len(tile_corners)): + tile_beg = tile_corners[i] + tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] + tree = int(tree_id[i]) + act_num = acts_per_tree[tree] + tiles.setdefault((tile_end - tile_beg, act_num), []).append((tile_beg, act_beg)) + act_beg += act_num + + sorted_keys = sorted(tiles.keys()) + actuator_moment_offset_nv = [ + t[0] for key in sorted_keys for t in tiles.get(key, []) + ] + actuator_moment_offset_nu = [ + t[1] for key in sorted_keys for t in tiles.get(key, []) + ] + tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] + actuator_moment_tileadr = np.cumsum(tile_off)[:-1] # offset + actuator_moment_tilesize_nv = np.array( + [a[0] for a in sorted_keys] + ) # for this level + actuator_moment_tilesize_nu = np.array( + [int(a[1]) for a in sorted_keys] + ) # for this level + + m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32, ndim=1) + m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32, ndim=1) + m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32, ndim=1) + m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32, ndim=1) m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32, ndim=1) m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1) m.qLD_update_treeadr = wp.array( @@ -135,12 +193,28 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1) m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu") m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu") + m.actuator_moment_offset_nv = wp.array( + actuator_moment_offset_nv, dtype=wp.int32, ndim=1 + ) + m.actuator_moment_offset_nu = wp.array( + actuator_moment_offset_nu, dtype=wp.int32, ndim=1 + ) + m.actuator_moment_tileadr = wp.array( + actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu" + ) + m.actuator_moment_tilesize_nv = wp.array( + actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu" + ) + m.actuator_moment_tilesize_nu = wp.array( + actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu" + ) m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1) m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1) m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1) m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32, ndim=1) m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1) m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) + m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1) m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) @@ -148,7 +222,18 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1) m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1) + + subtree_mass = np.copy(mjm.body_mass) + # TODO(team): should this be [mjm.nbody - 1, 0) ? + for i in range(mjm.nbody - 1, -1, -1): + subtree_mass[mjm.body_parentid[i]] += subtree_mass[i] + + m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1) m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2) + m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) + m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) + m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) + m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) m.jnt_limited = wp.array(mjm.jnt_limited, dtype=wp.int32, ndim=1) m.jnt_limited_slide_hinge_adr = wp.array( @@ -166,11 +251,30 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1) m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) + m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1) + m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1) + m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1) + m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1) m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1) + m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) + m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1) + m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1) + m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1) + m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1) + m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1) + m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1) + m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1) + m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) + m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) + m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) + m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) + m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) + m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1) m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1) m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1) + m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1) m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1) m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) @@ -186,7 +290,9 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1) m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1) + m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=wp.float32, ndim=2) + m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=wp.float32, ndim=2) m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1) m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) @@ -194,17 +300,75 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1) + m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) + + # short-circuiting here allows us to skip a lot of code in implicit integration + m.actuator_affine_bias_gain = bool( + np.any(mjm.actuator_biastype == types.BiasType.AFFINE.value) + or np.any(mjm.actuator_gaintype == types.GainType.AFFINE.value) + ) return m +def _constraint(nv: int, nworld: int, njmax: int) -> types.Constraint: + efc = types.Constraint() + + efc.J = wp.zeros((njmax, nv), dtype=wp.float32) + efc.D = wp.zeros((njmax,), dtype=wp.float32) + efc.pos = wp.zeros((njmax,), dtype=wp.float32) + efc.aref = wp.zeros((njmax,), dtype=wp.float32) + efc.force = wp.zeros((njmax,), dtype=wp.float32) + efc.margin = wp.zeros((njmax,), dtype=wp.float32) + efc.worldid = wp.zeros((njmax,), dtype=wp.int32) + + efc.Jaref = wp.empty(shape=(njmax,), dtype=wp.float32) + efc.Ma = wp.empty(shape=(nworld, nv), dtype=wp.float32) + efc.grad = wp.empty(shape=(nworld, nv), dtype=wp.float32) + efc.grad_dot = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.Mgrad = wp.empty(shape=(nworld, nv), dtype=wp.float32) + efc.search = wp.empty(shape=(nworld, nv), dtype=wp.float32) + efc.search_dot = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.gauss = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.cost = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.prev_cost = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.solver_niter = wp.empty(shape=(nworld,), dtype=wp.int32) + efc.active = wp.empty(shape=(njmax,), dtype=wp.int32) + efc.gtol = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.mv = wp.empty(shape=(nworld, nv), dtype=wp.float32) + efc.jv = wp.empty(shape=(njmax,), dtype=wp.float32) + efc.quad = wp.empty(shape=(njmax,), dtype=wp.vec3f) + efc.quad_gauss = wp.empty(shape=(nworld,), dtype=wp.vec3f) + efc.h = wp.empty(shape=(nworld, nv, nv), dtype=wp.float32) + efc.alpha = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.prev_grad = wp.empty(shape=(nworld, nv), dtype=wp.float32) + efc.prev_Mgrad = wp.empty(shape=(nworld, nv), dtype=wp.float32) + efc.beta = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.beta_num = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.beta_den = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.done = wp.empty(shape=(nworld,), dtype=bool) + + efc.ls_done = wp.zeros(shape=(nworld,), dtype=bool) + efc.p0 = wp.empty(shape=(nworld,), dtype=wp.vec3) + efc.lo = wp.empty(shape=(nworld,), dtype=wp.vec3) + efc.lo_alpha = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.hi = wp.empty(shape=(nworld,), dtype=wp.vec3) + efc.hi_alpha = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.lo_next = wp.empty(shape=(nworld,), dtype=wp.vec3) + efc.lo_next_alpha = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.hi_next = wp.empty(shape=(nworld,), dtype=wp.vec3) + efc.hi_next_alpha = wp.empty(shape=(nworld,), dtype=wp.float32) + efc.mid = wp.empty(shape=(nworld,), dtype=wp.vec3) + efc.mid_alpha = wp.empty(shape=(nworld,), dtype=wp.float32) + + return efc + + def make_data( mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: int = -1 ) -> types.Data: d = types.Data() d.nworld = nworld - d.ncon_total = wp.zeros((1,), dtype=wp.int32, ndim=1) - d.nefc_total = wp.zeros((1,), dtype=wp.int32, ndim=1) # TODO(team): move to Model? if nconmax == -1: @@ -216,8 +380,8 @@ def make_data( njmax = 512 d.njmax = njmax - d.ncon = 0 - d.nefc = wp.zeros(nworld, dtype=wp.int32) + d.ncon = wp.zeros(1, dtype=wp.int32) + d.nefc = wp.zeros(1, dtype=wp.int32, ndim=1) d.nl = 0 d.time = 0.0 @@ -274,6 +438,7 @@ def make_data( d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i) d.contact.efc_address = wp.zeros((nconmax,), dtype=wp.int32) d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32) + d.efc = _constraint(mjm.nv, d.nworld, d.njmax) d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32) @@ -281,44 +446,34 @@ def make_data( d.qfrc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.efc_J = wp.zeros((njmax, mjm.nv), dtype=wp.float32) - d.efc_D = wp.zeros((njmax,), dtype=wp.float32) - d.efc_pos = wp.zeros((njmax,), dtype=wp.float32) - d.efc_aref = wp.zeros((njmax,), dtype=wp.float32) - d.efc_force = wp.zeros((njmax,), dtype=wp.float32) - d.efc_margin = wp.zeros((njmax,), dtype=wp.float32) - d.efc_worldid = wp.zeros((njmax,), dtype=wp.int32) + + d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) + d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) + # internal tmp arrays d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qM_integration = wp.zeros_like(d.qM) d.qLD_integration = wp.zeros_like(d.qLD) d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) - - # the result of the broadphase gets stored in this array - d.max_num_overlaps_per_world = ( - mjm.ngeom * (mjm.ngeom - 1) // 2 - ) # TODO: this is a hack to estimate the maximum number of overlaps per world - d.broadphase_pairs = wp.zeros((nworld, d.max_num_overlaps_per_world), dtype=wp.vec2i) - d.result_count = wp.zeros(nworld, dtype=wp.int32) - - # internal broadphase tmp arrays - d.boxes_sorted = wp.zeros( - (nworld, mjm.ngeom), dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32) - ) - d.data_start = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) - d.data_end = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) - d.data_indexer = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) - d.ranges = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) - d.cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) + d.act_vel_integration = wp.zeros_like(d.ctrl) + + # sweep-and-prune broadphase + d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4) + d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) + d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) + d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) + d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) + d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] - d.segment_indices = wp.array(segment_indices_list, dtype=int) + d.sap_segment_index = wp.array(segment_indices_list, dtype=int) - d.geom_aabb = wp.array( - mjm.geom_aabb, dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1 - ) + # collision driver + d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) + d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) return d @@ -332,25 +487,23 @@ def put_data( ) -> types.Data: d = types.Data() d.nworld = nworld - d.ncon_total = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1) - d.nefc_total = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1) # TODO(team): move to Model? if nconmax == -1: # TODO(team): heuristic for nconmax - nconmax = 512 + nconmax = max(512, mjd.ncon * nworld) d.nconmax = nconmax if njmax == -1: # TODO(team): heuristic for njmax - njmax = 512 + njmax = max(512, mjd.nefc * nworld) d.njmax = njmax if nworld * mjd.nefc > njmax: raise ValueError("nworld * nefc > njmax") - d.ncon = mjd.ncon + d.ncon = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1) d.nl = mjd.nl - d.nefc = wp.zeros(1, dtype=wp.int32) + d.nefc = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1) d.time = mjd.time # TODO(erikfrey): would it be better to tile on the gpu? @@ -360,8 +513,10 @@ def tile(x): if support.is_sparse(mjm): qM = np.expand_dims(mjd.qM, axis=0) qLD = np.expand_dims(mjd.qLD, axis=0) - # TODO(taylorhowell): sparse efc_J efc_J = np.zeros((mjd.nefc, mjm.nv)) + mujoco.mju_sparse2dense( + efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind + ) else: qM = np.zeros((mjm.nv, mjm.nv)) mujoco.mj_fullM(mjm, qM, mjd.qM) @@ -418,7 +573,8 @@ def tile(x): d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth), dtype=wp.float32, ndim=2) d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2) d.qacc_smooth = wp.array(tile(mjd.qacc_smooth), dtype=wp.float32, ndim=2) - d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2) + d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2) + d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2) nefc = mjd.nefc efc_worldid = np.zeros(njmax, dtype=int) @@ -447,17 +603,6 @@ def tile(x): [np.repeat(mjd.efc_margin, nworld, axis=0), np.zeros(nefc_fill)] ) - d.efc_J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2) - d.efc_D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1) - d.efc_pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1) - d.efc_aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1) - d.efc_force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1) - d.efc_margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1) - d.efc_worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) - - d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2) - d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2) - ncon = mjd.ncon con_efc_address = np.zeros(nconmax, dtype=int) con_worldid = np.zeros(nconmax, dtype=int) @@ -512,33 +657,147 @@ def tile(x): d.contact.efc_address = wp.array(con_efc_address, dtype=wp.int32, ndim=1) d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1) + d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) + d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) + + d.efc = _constraint(mjm.nv, d.nworld, d.njmax) + d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2) + d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1) + d.efc.pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1) + d.efc.aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1) + d.efc.force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1) + d.efc.margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1) + d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) + d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2) + # internal tmp arrays d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qM_integration = wp.zeros_like(d.qM) d.qLD_integration = wp.zeros_like(d.qLD) d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) + d.act_vel_integration = wp.zeros_like(d.ctrl) + + # broadphase sweep and prune + d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4) + d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) + d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) + d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) + d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) + d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) + segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] + d.sap_segment_index = wp.array(segment_indices_list, dtype=int) - # the result of the broadphase gets stored in this array - d.max_num_overlaps_per_world = mjm.ngeom * (mjm.ngeom - 1) // 2 - d.broadphase_pairs = wp.zeros((nworld, d.max_num_overlaps_per_world), dtype=wp.vec2i) - d.result_count = wp.zeros(nworld, dtype=wp.int32) + # collision driver + d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) + d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) + + return d - # internal broadphase tmp arrays - d.boxes_sorted = wp.zeros( - (nworld, mjm.ngeom), dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32) - ) - d.data_start = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) - d.data_end = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) - d.data_indexer = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) - d.ranges = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) - d.cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) - segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] - d.segment_indices = wp.array(segment_indices_list, dtype=int) - d.geom_aabb = wp.array( - mjm.geom_aabb, dtype=wp.types.matrix(shape=(2, 3), dtype=wp.float32), ndim=1 +def get_data_into( + result: mujoco.MjData, + mjm: mujoco.MjModel, + d: types.Data, +): + """Gets Data from a device into an existing mujoco.MjData.""" + if d.nworld > 1: + raise NotImplementedError("only nworld == 1 supported for now") + + ncon = d.ncon.numpy()[0] + nefc = d.nefc.numpy()[0] + + if ncon != result.ncon or nefc != result.nefc: + mujoco._functions._realloc_con_efc(result, ncon=ncon, nefc=nefc) + + result.time = d.time + + result.qpos[:] = d.qpos.numpy()[0] + result.qvel[:] = d.qvel.numpy()[0] + result.qacc_warmstart = d.qacc_warmstart.numpy()[0] + result.qfrc_applied = d.qfrc_applied.numpy()[0] + result.mocap_pos = d.mocap_pos.numpy()[0] + result.mocap_quat = d.mocap_quat.numpy()[0] + result.qacc = d.qacc.numpy()[0] + result.xanchor = d.xanchor.numpy()[0] + result.xaxis = d.xaxis.numpy()[0] + result.xmat = d.xmat.numpy().reshape((-1, 9)) + result.xpos = d.xpos.numpy()[0] + result.xquat = d.xquat.numpy()[0] + result.xipos = d.xipos.numpy()[0] + result.ximat = d.ximat.numpy().reshape((-1, 9)) + result.subtree_com = d.subtree_com.numpy()[0] + result.geom_xpos = d.geom_xpos.numpy()[0] + result.geom_xmat = d.geom_xmat.numpy().reshape((-1, 9)) + result.site_xpos = d.site_xpos.numpy()[0] + result.site_xmat = d.site_xmat.numpy().reshape((-1, 9)) + result.cinert = d.cinert.numpy()[0] + result.cdof = d.cdof.numpy()[0] + result.crb = d.crb.numpy()[0] + result.qLDiagInv = d.qLDiagInv.numpy()[0] + result.ctrl = d.ctrl.numpy()[0] + result.actuator_velocity = d.actuator_velocity.numpy()[0] + result.actuator_force = d.actuator_force.numpy()[0] + result.actuator_length = d.actuator_length.numpy()[0] + mujoco.mju_dense2sparse( + result.actuator_moment, + d.actuator_moment.numpy()[0], + result.moment_rownnz, + result.moment_rowadr, + result.moment_colind, ) + result.cvel = d.cvel.numpy()[0] + result.cdof_dot = d.cdof_dot.numpy()[0] + result.qfrc_bias = d.qfrc_bias.numpy()[0] + result.qfrc_passive = d.qfrc_passive.numpy()[0] + result.qfrc_spring = d.qfrc_spring.numpy()[0] + result.qfrc_damper = d.qfrc_damper.numpy()[0] + result.qfrc_actuator = d.qfrc_actuator.numpy()[0] + result.qfrc_smooth = d.qfrc_smooth.numpy()[0] + result.qfrc_constraint = d.qfrc_constraint.numpy()[0] + result.qacc_smooth = d.qacc_smooth.numpy()[0] + result.act = d.act.numpy()[0] + result.act_dot = d.act_dot.numpy()[0] + + result.contact.dist[:] = d.contact.dist.numpy()[:ncon] + result.contact.pos[:] = d.contact.pos.numpy()[:ncon] + result.contact.frame[:] = d.contact.frame.numpy()[:ncon].reshape((-1, 9)) + result.contact.includemargin[:] = d.contact.includemargin.numpy()[:ncon] + result.contact.friction[:] = d.contact.friction.numpy()[:ncon] + result.contact.solref[:] = d.contact.solref.numpy()[:ncon] + result.contact.solreffriction[:] = d.contact.solreffriction.numpy()[:ncon] + result.contact.solimp[:] = d.contact.solimp.numpy()[:ncon] + result.contact.dim[:] = d.contact.dim.numpy()[:ncon] + result.contact.efc_address[:] = d.contact.efc_address.numpy()[:ncon] - return d + if support.is_sparse(mjm): + result.qM[:] = d.qM.numpy()[0, 0] + result.qLD[:] = d.qLD.numpy()[0, 0] + # TODO(team): set efc_J after fix to _realloc_con_efc lands + # efc_J = d.efc_J.numpy()[0, :nefc] + # mujoco.mju_dense2sparse( + # result.efc_J, efc_J, result.efc_J_rownnz, result.efc_J_rowadr, result.efc_J_colind + # ) + else: + qM = d.qM.numpy() + qLD = d.qLD.numpy() + adr = 0 + for i in range(mjm.nv): + j = i + while j >= 0: + result.qM[adr] = qM[0, i, j] + result.qLD[adr] = qLD[0, i, j] + j = mjm.dof_parentid[j] + adr += 1 + # TODO(team): set efc_J after fix to _realloc_con_efc lands + # if nefc > 0: + # result.efc_J[:nefc * mjm.nv] = d.efc_J.numpy()[:nefc].flatten() + result.efc_D[:] = d.efc.D.numpy()[:nefc] + result.efc_pos[:] = d.efc.pos.numpy()[:nefc] + result.efc_aref[:] = d.efc.aref.numpy()[:nefc] + result.efc_force[:] = d.efc.force.numpy()[:nefc] + result.efc_margin[:] = d.efc.margin.numpy()[:nefc] + + # TODO: other efc_ fields, anything else missing diff --git a/mujoco/mjx/_src/math.py b/mujoco_warp/_src/math.py similarity index 66% rename from mujoco/mjx/_src/math.py rename to mujoco_warp/_src/math.py index 416e1c5c..8bbefb9a 100644 --- a/mujoco/mjx/_src/math.py +++ b/mujoco_warp/_src/math.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +from typing import Tuple + import warp as wp from . import types @@ -148,3 +150,93 @@ def quat_integrate(q: wp.quat, v: wp.vec3, dt: wp.float32) -> wp.quat: q_res = mul_quat(q, q_res) return wp.normalize(q_res) + + +@wp.func +def orthogonals(a: wp.vec3): + y = wp.vec3(0.0, 1.0, 0.0) + z = wp.vec3(0.0, 0.0, 1.0) + b = wp.where((-0.5 < a[1]) and (a[1] < 0.5), y, z) + b = b - a * wp.dot(a, b) + b = wp.normalize(b) + if wp.length(a) == 0.0: + b = wp.vec3(0.0, 0.0, 0.0) + c = wp.cross(a, b) + + return b, c + + +@wp.func +def make_frame(a: wp.vec3): + a = wp.normalize(a) + b, c = orthogonals(a) + + # fmt: off + return wp.mat33( + a.x, a.y, a.z, + b.x, b.y, b.z, + c.x, c.y, c.z + ) + # fmt: on + + +@wp.func +def normalize_with_norm(x: wp.vec3): + norm = wp.length(x) + if norm == 0.0: + return x, 0.0 + return x / norm, norm + + +@wp.func +def closest_segment_point(a: wp.vec3, b: wp.vec3, pt: wp.vec3) -> wp.vec3: + """Returns the closest point on the a-b line segment to a point pt.""" + ab = b - a + t = wp.dot(pt - a, ab) / (wp.dot(ab, ab) + 1e-6) + return a + wp.clamp(t, 0.0, 1.0) * ab + + +@wp.func +def closest_segment_point_and_dist( + a: wp.vec3, b: wp.vec3, pt: wp.vec3 +) -> Tuple[wp.vec3, wp.float32]: + """Returns closest point on the line segment and the distance squared.""" + closest = closest_segment_point(a, b, pt) + dist = wp.dot((pt - closest), (pt - closest)) + return closest, dist + + +@wp.func +def closest_segment_to_segment_points( + a0: wp.vec3, a1: wp.vec3, b0: wp.vec3, b1: wp.vec3 +) -> Tuple[wp.vec3, wp.vec3]: + """Returns closest points between two line segments.""" + + dir_a, len_a = normalize_with_norm(a1 - a0) + dir_b, len_b = normalize_with_norm(b1 - b0) + + half_len_a = len_a * 0.5 + half_len_b = len_b * 0.5 + a_mid = a0 + dir_a * half_len_a + b_mid = b0 + dir_b * half_len_b + + trans = a_mid - b_mid + + dira_dot_dirb = wp.dot(dir_a, dir_b) + dira_dot_trans = wp.dot(dir_a, trans) + dirb_dot_trans = wp.dot(dir_b, trans) + denom = 1.0 - dira_dot_dirb * dira_dot_dirb + + orig_t_a = (-dira_dot_trans + dira_dot_dirb * dirb_dot_trans) / (denom + 1e-6) + orig_t_b = dirb_dot_trans + orig_t_a * dira_dot_dirb + t_a = wp.clamp(orig_t_a, -half_len_a, half_len_a) + t_b = wp.clamp(orig_t_b, -half_len_b, half_len_b) + + best_a = a_mid + dir_a * t_a + best_b = b_mid + dir_b * t_b + + new_a, d1 = closest_segment_point_and_dist(a0, a1, best_b) + new_b, d2 = closest_segment_point_and_dist(b0, b1, best_a) + if d1 < d2: + return new_a, best_b + return best_a, new_b diff --git a/mujoco_warp/_src/math_test.py b/mujoco_warp/_src/math_test.py new file mode 100644 index 00000000..db691533 --- /dev/null +++ b/mujoco_warp/_src/math_test.py @@ -0,0 +1,88 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp +from absl.testing import absltest + +from .math import closest_segment_to_segment_points + + +class ClosestSegmentSegmentPointsTest(absltest.TestCase): + """Tests for closest segment-to-segment points.""" + + def test_closest_segments_points(self): + """Test closest points between two segments.""" + a0 = wp.vec3([0.73432405, 0.12372768, 0.20272314]) + a1 = wp.vec3([1.10600128, 0.88555209, 0.65209485]) + b0 = wp.vec3([0.85599262, 0.61736299, 0.9843583]) + b1 = wp.vec3([1.84270939, 0.92891793, 1.36343326]) + + best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1) + self.assertSequenceAlmostEqual(best_a, [1.09063, 0.85404, 0.63351], 5) + self.assertSequenceAlmostEqual(best_b, [0.99596, 0.66156, 1.03813], 5) + + def test_intersecting_segments(self): + """Tests segments that intersect.""" + a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0]) + b0, b1 = wp.vec3([-1.0, 0.0, 0.0]), wp.vec3([1.0, 0.0, 0.0]) + + best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1) + self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 0.0], 5) + self.assertSequenceAlmostEqual(best_b, [0.0, 0.0, 0.0], 5) + + def test_intersecting_lines(self): + """Tests that intersecting lines get clipped.""" + a0, a1 = wp.vec3([0.2, 0.2, 0.0]), wp.vec3([1.0, 1.0, 0.0]) + b0, b1 = wp.vec3([0.2, 0.4, 0.0]), wp.vec3([1.0, 2.0, 0.0]) + + best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1) + self.assertSequenceAlmostEqual(best_a, [0.3, 0.3, 0.0], 2) + self.assertSequenceAlmostEqual(best_b, [0.2, 0.4, 0.0], 2) + + def test_parallel_segments(self): + """Tests that parallel segments have closest points at the midpoint.""" + a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0]) + b0, b1 = wp.vec3([1.0, 0.0, -1.0]), wp.vec3([1.0, 0.0, 1.0]) + + best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1) + self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 0.0], 5) + self.assertSequenceAlmostEqual(best_b, [1.0, 0.0, 0.0], 5) + + def test_parallel_offset_segments(self): + """Tests that offset parallel segments are close at segment endpoints.""" + a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0]) + b0, b1 = wp.vec3([1.0, 0.0, 1.0]), wp.vec3([1.0, 0.0, 3.0]) + + best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1) + self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 1.0], 5) + self.assertSequenceAlmostEqual(best_b, [1.0, 0.0, 1.0], 5) + + def test_zero_length_segments(self): + """Test that zero length segments don't return NaNs.""" + a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, -1.0]) + b0, b1 = wp.vec3([1.0, 0.0, 0.1]), wp.vec3([1.0, 0.0, 0.1]) + + best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1) + self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, -1.0], 5) + self.assertSequenceAlmostEqual(best_b, [1.0, 0.0, 0.1], 5) + + def test_overlapping_segments(self): + """Tests that perfectly overlapping segments intersect at the midpoints.""" + a0, a1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0]) + b0, b1 = wp.vec3([0.0, 0.0, -1.0]), wp.vec3([0.0, 0.0, 1.0]) + + best_a, best_b = closest_segment_to_segment_points(a0, a1, b0, b1) + self.assertSequenceAlmostEqual(best_a, [0.0, 0.0, 0.0], 5) + self.assertSequenceAlmostEqual(best_b, [0.0, 0.0, 0.0], 5) diff --git a/mujoco/mjx/_src/passive.py b/mujoco_warp/_src/passive.py similarity index 75% rename from mujoco/mjx/_src/passive.py rename to mujoco_warp/_src/passive.py index 0125331c..c8cbc30b 100644 --- a/mujoco/mjx/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -1,15 +1,38 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import warp as wp from . import math -from .types import Model from .types import Data +from .types import DisableBit from .types import JointType +from .types import Model +from .warp_util import event_scope +from .warp_util import kernel +@event_scope def passive(m: Model, d: Data): """Adds all passive forces.""" + if m.opt.disableflags & DisableBit.PASSIVE: + d.qfrc_passive.zero_() + # TODO(team): qfrc_gravcomp + return - @wp.kernel + @kernel def _spring(m: Model, d: Data): worldid, jntid = wp.tid() stiffness = m.jnt_stiffness[jntid] @@ -67,7 +90,7 @@ def _spring(m: Model, d: Data): fdif = d.qpos[worldid, qposid] - m.qpos_spring[qposid] d.qfrc_spring[worldid, dofid] = -stiffness * fdif - @wp.kernel + @kernel def _damper_passive(m: Model, d: Data): worldid, dofid = wp.tid() damping = m.dof_damping[dofid] diff --git a/mujoco/mjx/_src/passive_test.py b/mujoco_warp/_src/passive_test.py similarity index 85% rename from mujoco/mjx/_src/passive_test.py rename to mujoco_warp/_src/passive_test.py index 4a8409bb..7186ab39 100644 --- a/mujoco/mjx/_src/passive_test.py +++ b/mujoco_warp/_src/passive_test.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,15 +15,15 @@ """Tests for passive force functions.""" -from absl.testing import absltest import numpy as np import warp as wp +from absl.testing import absltest -from mujoco import mjx +import mujoco_warp as mjwarp from . import test_util -# tolerance for difference between MuJoCo and MJX smooth calculations - mostly +# tolerance for difference between MuJoCo and MJWarp passive force calculations - mostly # due to float precision _TOLERANCE = 5e-5 @@ -36,18 +36,20 @@ def _assert_eq(a, b, name): class PassiveTest(absltest.TestCase): def test_passive(self): - """Tests MJX passive.""" + """Tests passive.""" _, mjd, m, d = test_util.fixture("pendula.xml") for arr in (d.qfrc_spring, d.qfrc_damper, d.qfrc_passive): arr.zero_() - mjx.passive(m, d) + mjwarp.passive(m, d) _assert_eq(d.qfrc_spring.numpy()[0], mjd.qfrc_spring, "qfrc_spring") _assert_eq(d.qfrc_damper.numpy()[0], mjd.qfrc_damper, "qfrc_damper") _assert_eq(d.qfrc_passive.numpy()[0], mjd.qfrc_passive, "qfrc_passive") + # TODO(team): test DisableBit.PASSIVE + if __name__ == "__main__": wp.init() diff --git a/mujoco/mjx/_src/smooth.py b/mujoco_warp/_src/smooth.py similarity index 81% rename from mujoco/mjx/_src/smooth.py rename to mujoco_warp/_src/smooth.py index 8f61bb9c..19f70d4c 100644 --- a/mujoco/mjx/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,20 +14,26 @@ # ============================================================================== import warp as wp -from . import math -from .types import Model +from . import math from .types import Data +from .types import DisableBit +from .types import JointType +from .types import Model +from .types import TrnType from .types import array2df from .types import array3df from .types import vec10 -from .types import JointType, TrnType +from .warp_util import event_scope +from .warp_util import kernel +from .warp_util import kernel_copy +@event_scope def kinematics(m: Model, d: Data): """Forward kinematics.""" - @wp.kernel + @kernel def _root(m: Model, d: Data): worldid = wp.tid() d.xpos[worldid, 0] = wp.vec3(0.0) @@ -36,7 +42,7 @@ def _root(m: Model, d: Data): d.xmat[worldid, 0] = wp.identity(n=3, dtype=wp.float32) d.ximat[worldid, 0] = wp.identity(n=3, dtype=wp.float32) - @wp.kernel + @kernel def _level(m: Model, d: Data, leveladr: int): worldid, nodeid = wp.tid() bodyid = m.body_tree[leveladr + nodeid] @@ -52,7 +58,6 @@ def _level(m: Model, d: Data, leveladr: int): elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint qadr = m.jnt_qposadr[jntadr] - # TODO(erikfrey): would it be better to use some kind of wp.copy here? xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) d.xanchor[worldid, jntadr] = xpos @@ -95,8 +100,35 @@ def _level(m: Model, d: Data, leveladr: int): jntadr += 1 d.xpos[worldid, bodyid] = xpos - d.xquat[worldid, bodyid] = wp.normalize(xquat) + xquat = wp.normalize(xquat) + d.xquat[worldid, bodyid] = xquat d.xmat[worldid, bodyid] = math.quat_to_mat(xquat) + d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(m.body_ipos[bodyid], xquat) + d.ximat[worldid, bodyid] = math.quat_to_mat( + math.mul_quat(xquat, m.body_iquat[bodyid]) + ) + + @kernel + def geom_local_to_global(m: Model, d: Data): + worldid, geomid = wp.tid() + bodyid = m.geom_bodyid[geomid] + xpos = d.xpos[worldid, bodyid] + xquat = d.xquat[worldid, bodyid] + d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(m.geom_pos[geomid], xquat) + d.geom_xmat[worldid, geomid] = math.quat_to_mat( + math.mul_quat(xquat, m.geom_quat[geomid]) + ) + + @kernel + def site_local_to_global(m: Model, d: Data): + worldid, siteid = wp.tid() + bodyid = m.site_bodyid[siteid] + xpos = d.xpos[worldid, bodyid] + xquat = d.xquat[worldid, bodyid] + d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(m.site_pos[siteid], xquat) + d.site_xmat[worldid, siteid] = math.quat_to_mat( + math.mul_quat(xquat, m.site_quat[siteid]) + ) wp.launch(_root, dim=(d.nworld), inputs=[m, d]) @@ -106,35 +138,35 @@ def _level(m: Model, d: Data, leveladr: int): end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + if m.ngeom: + wp.launch(geom_local_to_global, dim=(d.nworld, m.ngeom), inputs=[m, d]) + + if m.nsite: + wp.launch(site_local_to_global, dim=(d.nworld, m.nsite), inputs=[m, d]) + +@event_scope def com_pos(m: Model, d: Data): """Map inertias and motion dofs to global frame centered at subtree-CoM.""" - @wp.kernel - def mass_subtree_acc(m: Model, mass_subtree: wp.array(dtype=float), leveladr: int): - nodeid = wp.tid() - bodyid = m.body_tree[leveladr + nodeid] - pid = m.body_parentid[bodyid] - wp.atomic_add(mass_subtree, pid, mass_subtree[bodyid]) - - @wp.kernel + @kernel def subtree_com_init(m: Model, d: Data): worldid, bodyid = wp.tid() d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[bodyid] - @wp.kernel + @kernel def subtree_com_acc(m: Model, d: Data, leveladr: int): worldid, nodeid = wp.tid() bodyid = m.body_tree[leveladr + nodeid] pid = m.body_parentid[bodyid] wp.atomic_add(d.subtree_com, worldid, pid, d.subtree_com[worldid, bodyid]) - @wp.kernel - def subtree_div(mass_subtree: wp.array(dtype=float), d: Data): + @kernel + def subtree_div(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] /= mass_subtree[bodyid] + d.subtree_com[worldid, bodyid] /= m.subtree_mass[bodyid] - @wp.kernel + @kernel def cinert(m: Model, d: Data): worldid, bodyid = wp.tid() mat = d.ximat[worldid, bodyid] @@ -168,7 +200,7 @@ def cinert(m: Model, d: Data): d.cinert[worldid, bodyid] = res - @wp.kernel + @kernel def cdof(m: Model, d: Data): worldid, jntid = wp.tid() bodyid = m.jnt_bodyid[jntid] @@ -199,31 +231,27 @@ def cdof(m: Model, d: Data): elif jnt_type == wp.static(JointType.HINGE.value): # hinge res[dofid] = wp.spatial_vector(xaxis, wp.cross(xaxis, offset)) - body_treeadr = m.body_treeadr.numpy() - mass_subtree = wp.clone(m.body_mass) - for i in reversed(range(len(body_treeadr))): - beg = body_treeadr[i] - end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(mass_subtree_acc, dim=(end - beg,), inputs=[m, mass_subtree, beg]) - wp.launch(subtree_com_init, dim=(d.nworld, m.nbody), inputs=[m, d]) + body_treeadr = m.body_treeadr.numpy() + for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] wp.launch(subtree_com_acc, dim=(d.nworld, end - beg), inputs=[m, d, beg]) - wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[mass_subtree, d]) + wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[m, d]) wp.launch(cinert, dim=(d.nworld, m.nbody), inputs=[m, d]) wp.launch(cdof, dim=(d.nworld, m.njnt), inputs=[m, d]) +@event_scope def crb(m: Model, d: Data): """Composite rigid body inertia algorithm.""" - wp.copy(d.crb, d.cinert) + kernel_copy(d.crb, d.cinert) - @wp.kernel + @kernel def crb_accumulate(m: Model, d: Data, leveladr: int): worldid, nodeid = wp.tid() bodyid = m.body_tree[leveladr + nodeid] @@ -232,7 +260,7 @@ def crb_accumulate(m: Model, d: Data, leveladr: int): return wp.atomic_add(d.crb, worldid, pid, d.crb[worldid, bodyid]) - @wp.kernel + @kernel def qM_sparse(m: Model, d: Data): worldid, dofid = wp.tid() madr_ij = m.dof_Madr[dofid] @@ -250,21 +278,27 @@ def qM_sparse(m: Model, d: Data): madr_ij += 1 dofid = m.dof_parentid[dofid] - @wp.kernel + @kernel def qM_dense(m: Model, d: Data): worldid, dofid = wp.tid() bodyid = m.dof_bodyid[dofid] # init M(i,i) with armature inertia - d.qM[worldid, dofid, dofid] = m.dof_armature[dofid] + M = m.dof_armature[dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) + M += wp.dot(d.cdof[worldid, dofid], buf) + + d.qM[worldid, dofid, dofid] = M # sparse backward pass over ancestors dofidi = dofid + dofid = m.dof_parentid[dofid] while dofid >= 0: - d.qM[worldid, dofidi, dofid] += wp.dot(d.cdof[worldid, dofid], buf) + qMij = wp.dot(d.cdof[worldid, dofid], buf) + d.qM[worldid, dofidi, dofid] += qMij + d.qM[worldid, dofid, dofidi] += qMij dofid = m.dof_parentid[dofid] body_treeadr = m.body_treeadr.numpy() @@ -283,7 +317,7 @@ def qM_dense(m: Model, d: Data): def _factor_i_sparse(m: Model, d: Data, M: array3df, L: array3df, D: array2df): """Sparse L'*D*L factorizaton of inertia-like matrix M, assumed spd.""" - @wp.kernel + @kernel def qLD_acc(m: Model, leveladr: int, L: array3df): worldid, nodeid = wp.tid() update = m.qLD_update_tree[leveladr + nodeid] @@ -297,12 +331,12 @@ def qLD_acc(m: Model, leveladr: int, L: array3df): # M(k,i) = tmp L[worldid, 0, Madr_ki] = tmp - @wp.kernel + @kernel def qLDiag_div(m: Model, L: array3df, D: array2df): worldid, dofid = wp.tid() D[worldid, dofid] = 1.0 / L[worldid, 0, m.dof_Madr[dofid]] - wp.copy(L, M) + kernel_copy(L, M) qLD_update_treeadr = m.qLD_update_treeadr.numpy() @@ -323,7 +357,7 @@ def _factor_i_dense(m: Model, d: Data, M: wp.array, L: wp.array): block_dim = 32 def tile_cholesky(adr: int, size: int, tilesize: int): - @wp.kernel + @kernel def cholesky(m: Model, leveladr: int, M: array3df, L: array3df): worldid, nodeid = wp.tid() dofid = m.qLD_tile[leveladr + nodeid] @@ -355,27 +389,25 @@ def factor_i(m: Model, d: Data, M, L, D=None): _factor_i_dense(m, d, M, L) +@event_scope def factor_m(m: Model, d: Data): """Factorizaton of inertia-like matrix M, assumed spd.""" factor_i(m, d, d.qM, d.qLD, d.qLDiagInv) +@event_scope def rne(m: Model, d: Data): """Computes inverse dynamics using Newton-Euler algorithm.""" - cacc = wp.zeros(shape=(d.nworld, m.nbody), dtype=wp.spatial_vector) - cfrc = wp.zeros(shape=(d.nworld, m.nbody), dtype=wp.spatial_vector) - - @wp.kernel - def cacc_gravity(m: Model, cacc: wp.array(dtype=wp.spatial_vector, ndim=2)): + @kernel + def cacc_gravity(m: Model, d: Data): worldid = wp.tid() - cacc[worldid, 0] = wp.spatial_vector(wp.vec3(0.0), -m.opt.gravity) + d.rne_cacc[worldid, 0] = wp.spatial_vector(wp.vec3(0.0), -m.opt.gravity) - @wp.kernel + @kernel def cacc_level( m: Model, d: Data, - cacc: wp.array(dtype=wp.spatial_vector, ndim=2), leveladr: int, ): worldid, nodeid = wp.tid() @@ -383,62 +415,64 @@ def cacc_level( dofnum = m.body_dofnum[bodyid] pid = m.body_parentid[bodyid] dofadr = m.body_dofadr[bodyid] - local_cacc = cacc[worldid, pid] + local_cacc = d.rne_cacc[worldid, pid] for i in range(dofnum): local_cacc += d.cdof_dot[worldid, dofadr + i] * d.qvel[worldid, dofadr + i] - cacc[worldid, bodyid] = local_cacc + d.rne_cacc[worldid, bodyid] = local_cacc - @wp.kernel - def frc_fn( - d: Data, - cfrc: wp.array(dtype=wp.spatial_vector, ndim=2), - cacc: wp.array(dtype=wp.spatial_vector, ndim=2), - ): + @kernel + def frc_fn(d: Data): worldid, bodyid = wp.tid() - frc = math.inert_vec(d.cinert[worldid, bodyid], cacc[worldid, bodyid]) + frc = math.inert_vec(d.cinert[worldid, bodyid], d.rne_cacc[worldid, bodyid]) frc += math.motion_cross_force( d.cvel[worldid, bodyid], math.inert_vec(d.cinert[worldid, bodyid], d.cvel[worldid, bodyid]), ) - cfrc[worldid, bodyid] += frc + d.rne_cfrc[worldid, bodyid] = frc - @wp.kernel - def cfrc_fn(m: Model, cfrc: wp.array(dtype=wp.spatial_vector, ndim=2), leveladr: int): + @kernel + def cfrc_fn(m: Model, d: Data, leveladr: int): worldid, nodeid = wp.tid() bodyid = m.body_tree[leveladr + nodeid] pid = m.body_parentid[bodyid] - wp.atomic_add(cfrc[worldid], pid, cfrc[worldid, bodyid]) + wp.atomic_add(d.rne_cfrc[worldid], pid, d.rne_cfrc[worldid, bodyid]) - @wp.kernel - def qfrc_bias(m: Model, d: Data, cfrc: wp.array(dtype=wp.spatial_vector, ndim=2)): + @kernel + def qfrc_bias(m: Model, d: Data): worldid, dofid = wp.tid() bodyid = m.dof_bodyid[dofid] - d.qfrc_bias[worldid, dofid] = wp.dot(d.cdof[worldid, dofid], cfrc[worldid, bodyid]) + d.qfrc_bias[worldid, dofid] = wp.dot( + d.cdof[worldid, dofid], d.rne_cfrc[worldid, bodyid] + ) - wp.launch(cacc_gravity, dim=[d.nworld], inputs=[m, cacc]) + if m.opt.disableflags & DisableBit.GRAVITY: + d.rne_cacc.zero_() + else: + wp.launch(cacc_gravity, dim=[d.nworld], inputs=[m, d]) body_treeadr = m.body_treeadr.numpy() for i in range(len(body_treeadr)): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(cacc_level, dim=(d.nworld, end - beg), inputs=[m, d, cacc, beg]) + wp.launch(cacc_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) - wp.launch(frc_fn, dim=[d.nworld, m.nbody], inputs=[d, cfrc, cacc]) + wp.launch(frc_fn, dim=[d.nworld, m.nbody], inputs=[d]) for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(cfrc_fn, dim=[d.nworld, end - beg], inputs=[m, cfrc, beg]) + wp.launch(cfrc_fn, dim=[d.nworld, end - beg], inputs=[m, d, beg]) - wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d, cfrc]) + wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d]) +@event_scope def transmission(m: Model, d: Data): """Computes actuator/transmission lengths and moments.""" if not m.nu: return d - @wp.kernel + @kernel def _transmission( m: Model, d: Data, @@ -502,15 +536,16 @@ def _transmission( ) +@event_scope def com_vel(m: Model, d: Data): """Computes cvel, cdof_dot.""" - @wp.kernel + @kernel def _root(d: Data): worldid, elementid = wp.tid() d.cvel[worldid, 0][elementid] = 0.0 - @wp.kernel + @kernel def _level(m: Model, d: Data, leveladr: int): worldid, nodeid = wp.tid() bodyid = m.body_tree[leveladr + nodeid] @@ -576,26 +611,26 @@ def _solve_LD_sparse( ): """Computes sparse backsubstitution: x = inv(L'*D*L)*y""" - @wp.kernel + @kernel def x_acc_up(m: Model, L: array3df, x: array2df, leveladr: int): worldid, nodeid = wp.tid() update = m.qLD_update_tree[leveladr + nodeid] i, k, Madr_ki = update[0], update[1], update[2] wp.atomic_sub(x[worldid], i, L[worldid, 0, Madr_ki] * x[worldid, k]) - @wp.kernel + @kernel def qLDiag_mul(D: array2df, x: array2df): worldid, dofid = wp.tid() x[worldid, dofid] *= D[worldid, dofid] - @wp.kernel + @kernel def x_acc_down(m: Model, L: array3df, x: array2df, leveladr: int): worldid, nodeid = wp.tid() update = m.qLD_update_tree[leveladr + nodeid] i, k, Madr_ki = update[0], update[1], update[2] wp.atomic_sub(x[worldid], k, L[worldid, 0, Madr_ki] * x[worldid, i]) - wp.copy(x, y) + kernel_copy(x, y) qLD_update_treeadr = m.qLD_update_treeadr.numpy() @@ -623,7 +658,7 @@ def _solve_LD_dense(m: Model, d: Data, L: array3df, x: array2df, y: array2df): block_dim = 32 def tile_cho_solve(adr: int, size: int, tilesize: int): - @wp.kernel + @kernel def cho_solve(m: Model, L: array3df, x: array2df, y: array2df, leveladr: int): worldid, nodeid = wp.tid() dofid = m.qLD_tile[leveladr + nodeid] @@ -655,6 +690,45 @@ def solve_LD(m: Model, d: Data, L: array3df, D: array2df, x: array2df, y: array2 _solve_LD_dense(m, d, L, x, y) +@event_scope def solve_m(m: Model, d: Data, x: array2df, y: array2df): """Computes backsubstitution: x = qLD * y.""" solve_LD(m, d, d.qLD, d.qLDiagInv, x, y) + + +def _factor_solve_i_dense(m: Model, d: Data, M: array3df, x: array2df, y: array2df): + # TODO(team): develop heuristic for block dim, or make configurable + block_dim = 32 + + def tile_cholesky(adr: int, size: int, tilesize: int): + @kernel(module="unique") + def cholesky(m: Model, leveladr: int, M: array3df, x: array2df, y: array2df): + worldid, nodeid = wp.tid() + dofid = m.qLD_tile[leveladr + nodeid] + M_tile = wp.tile_load( + M[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) + ) + y_slice = wp.tile_load(y[worldid], shape=(tilesize,), offset=(dofid,)) + + L_tile = wp.tile_cholesky(M_tile) + x_slice = wp.tile_cholesky_solve(L_tile, y_slice) + wp.tile_store(x[worldid], x_slice, offset=(dofid,)) + + wp.launch_tiled( + cholesky, dim=(d.nworld, size), inputs=[m, adr, M, x, y], block_dim=block_dim + ) + + qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() + + for i in range(len(qLD_tileadr)): + beg = qLD_tileadr[i] + end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1] + tile_cholesky(beg, end - beg, int(qLD_tilesize[i])) + + +def factor_solve_i(m, d, M, L, D, x, y): + if m.opt.is_sparse: + _factor_i_sparse(m, d, M, L, D) + _solve_LD_sparse(m, d, L, D, x, y) + else: + _factor_solve_i_dense(m, d, M, x, y) diff --git a/mujoco/mjx/_src/smooth_test.py b/mujoco_warp/_src/smooth_test.py similarity index 67% rename from mujoco/mjx/_src/smooth_test.py rename to mujoco_warp/_src/smooth_test.py index e0fe3122..ed40e91d 100644 --- a/mujoco/mjx/_src/smooth_test.py +++ b/mujoco_warp/_src/smooth_test.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,18 +15,19 @@ """Tests for smooth dynamics functions.""" -from absl.testing import absltest -from absl.testing import parameterized import mujoco -from mujoco import mjx import numpy as np import warp as wp +from absl.testing import absltest +from absl.testing import parameterized -wp.config.verify_cuda = True +import mujoco_warp as mjwarp from . import test_util -# tolerance for difference between MuJoCo and mjWarp smooth calculations - mostly +wp.config.verify_cuda = True + +# tolerance for difference between MuJoCo and MJWarp smooth calculations - mostly # due to float precision _TOLERANCE = 5e-5 @@ -45,12 +46,19 @@ def test_kinematics(self): for arr in (d.xanchor, d.xaxis, d.xquat, d.xpos): arr.zero_() - mjx.kinematics(m, d) + mjwarp.kinematics(m, d) _assert_eq(d.xanchor.numpy()[0], mjd.xanchor, "xanchor") _assert_eq(d.xaxis.numpy()[0], mjd.xaxis, "xaxis") - _assert_eq(d.xquat.numpy()[0], mjd.xquat, "xquat") _assert_eq(d.xpos.numpy()[0], mjd.xpos, "xpos") + _assert_eq(d.xquat.numpy()[0], mjd.xquat, "xquat") + _assert_eq(d.xmat.numpy()[0], mjd.xmat.reshape((-1, 3, 3)), "xmat") + _assert_eq(d.xipos.numpy()[0], mjd.xipos, "xipos") + _assert_eq(d.ximat.numpy()[0], mjd.ximat.reshape((-1, 3, 3)), "ximat") + _assert_eq(d.geom_xpos.numpy()[0], mjd.geom_xpos, "geom_xpos") + _assert_eq(d.geom_xmat.numpy()[0], mjd.geom_xmat.reshape((-1, 3, 3)), "geom_xmat") + _assert_eq(d.site_xpos.numpy()[0], mjd.site_xpos, "site_xpos") + _assert_eq(d.site_xmat.numpy()[0], mjd.site_xmat.reshape((-1, 3, 3)), "site_xmat") def test_com_pos(self): """Tests com_pos.""" @@ -59,51 +67,49 @@ def test_com_pos(self): for arr in (d.subtree_com, d.cinert, d.cdof): arr.zero_() - mjx.com_pos(m, d) + mjwarp.com_pos(m, d) _assert_eq(d.subtree_com.numpy()[0], mjd.subtree_com, "subtree_com") _assert_eq(d.cinert.numpy()[0], mjd.cinert, "cinert") _assert_eq(d.cdof.numpy()[0], mjd.cdof, "cdof") - def test_crb(self): + @parameterized.parameters(True, False) + def test_crb(self, sparse: bool): """Tests crb.""" - _, mjd, m, d = test_util.fixture("pendula.xml") + mjm, mjd, m, d = test_util.fixture("pendula.xml", sparse=sparse) d.crb.zero_() - mjx.crb(m, d) + mjwarp.crb(m, d) _assert_eq(d.crb.numpy()[0], mjd.crb, "crb") - _assert_eq(d.qM.numpy()[0, 0], mjd.qM, "qM") - def test_factor_m_sparse(self): - """Tests factor_m (sparse).""" - _, mjd, m, d = test_util.fixture("pendula.xml", sparse=True) + if sparse: + _assert_eq(d.qM.numpy()[0, 0], mjd.qM, "qM") + else: + qM = np.zeros((mjm.nv, mjm.nv)) + mujoco.mj_fullM(mjm, qM, mjd.qM) + _assert_eq(d.qM.numpy()[0], qM, "qM") + @parameterized.parameters(True, False) + def test_factor_m(self, sparse: bool): + """Tests factor_m.""" + _, mjd, m, d = test_util.fixture("pendula.xml", sparse=sparse) + + qLD = d.qLD.numpy()[0].copy() for arr in (d.qLD, d.qLDiagInv): arr.zero_() - mjx.factor_m(m, d) - _assert_eq(d.qLD.numpy()[0, 0], mjd.qLD, "qLD (sparse)") - _assert_eq(d.qLDiagInv.numpy()[0], mjd.qLDiagInv, "qLDiagInv") + mjwarp.factor_m(m, d) - def test_factor_m_dense(self): - """Tests MJX factor_m (dense).""" - # TODO(team): switch this to pendula.xml and merge with above test - # after mmacklin's tile_cholesky fixes are in - _, mjd, m, d = test_util.fixture("humanoid/humanoid.xml", sparse=False) - - qLD = d.qLD.numpy()[0].copy() - d.qLD.zero_() - - mjx.factor_m(m, d) - _assert_eq(d.qLD.numpy()[0], qLD, "qLD (dense)") + if sparse: + _assert_eq(d.qLD.numpy()[0, 0], mjd.qLD, "qLD (sparse)") + _assert_eq(d.qLDiagInv.numpy()[0], mjd.qLDiagInv, "qLDiagInv") + else: + _assert_eq(d.qLD.numpy()[0], qLD, "qLD (dense)") @parameterized.parameters(True, False) def test_solve_m(self, sparse: bool): """Tests solve_m.""" - # TODO(team): switch this to pendula.xml and merge with above test - # after mmacklin's tile_cholesky fixes are in - fname = "pendula.xml" if sparse else "humanoid/humanoid.xml" - mjm, mjd, m, d = test_util.fixture(fname, sparse=sparse) + mjm, mjd, m, d = test_util.fixture("pendula.xml", sparse=sparse) qfrc_smooth = np.tile(mjd.qfrc_smooth, (1, 1)) qacc_smooth = np.zeros( @@ -117,7 +123,7 @@ def test_solve_m(self, sparse: bool): d.qacc_smooth.zero_() - mjx.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) + mjwarp.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) _assert_eq(d.qacc_smooth.numpy()[0], qacc_smooth[0], "qacc_smooth") def test_rne(self): @@ -126,9 +132,11 @@ def test_rne(self): d.qfrc_bias.zero_() - mjx.rne(m, d) + mjwarp.rne(m, d) _assert_eq(d.qfrc_bias.numpy()[0], mjd.qfrc_bias, "qfrc_bias") + # TODO(team): test DisableBit.GRAVITY + def test_com_vel(self): """Tests com_vel.""" _, mjd, m, d = test_util.fixture("pendula.xml") @@ -136,7 +144,7 @@ def test_com_vel(self): for arr in (d.cvel, d.cdof_dot): arr.zero_() - mjx.com_vel(m, d) + mjwarp.com_vel(m, d) _assert_eq(d.cvel.numpy()[0], mjd.cvel, "cvel") _assert_eq(d.cdof_dot.numpy()[0], mjd.cdof_dot, "cdof_dot") @@ -156,7 +164,7 @@ def test_transmission(self): mjd.moment_colind, ) - mjx._src.smooth.transmission(m, d) + mjwarp._src.smooth.transmission(m, d) _assert_eq(d.actuator_length.numpy()[0], mjd.actuator_length, "actuator_length") _assert_eq(d.actuator_moment.numpy()[0], actuator_moment, "actuator_moment") diff --git a/mujoco_warp/_src/solver.py b/mujoco_warp/_src/solver.py new file mode 100644 index 00000000..9adb2770 --- /dev/null +++ b/mujoco_warp/_src/solver.py @@ -0,0 +1,919 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp + +from . import smooth +from . import support +from . import types +from .warp_util import event_scope +from .warp_util import kernel +from .warp_util import kernel_copy + + +def _create_context(m: types.Model, d: types.Data, grad: bool = True): + @kernel + def _init_context(d: types.Data): + worldid = wp.tid() + d.efc.cost[worldid] = wp.inf + d.efc.solver_niter[worldid] = 0 + d.efc.done[worldid] = False + if grad: + d.efc.search_dot[worldid] = 0.0 + + @kernel + def _jaref(m: types.Model, d: types.Data): + efcid, dofid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + wp.atomic_add( + d.efc.Jaref, + efcid, + d.efc.J[efcid, dofid] * d.qacc[worldid, dofid] - d.efc.aref[efcid] / float(m.nv), + ) + + @kernel + def _search(d: types.Data): + worldid, dofid = wp.tid() + search = -1.0 * d.efc.Mgrad[worldid, dofid] + d.efc.search[worldid, dofid] = search + wp.atomic_add(d.efc.search_dot, worldid, search * search) + + wp.launch(_init_context, dim=(d.nworld), inputs=[d]) + + # jaref = d.efc_J @ d.qacc - d.efc_aref + d.efc.Jaref.zero_() + + wp.launch(_jaref, dim=(d.njmax, m.nv), inputs=[m, d]) + + # Ma = qM @ qacc + support.mul_m(m, d, d.efc.Ma, d.qacc, d.efc.done) + + _update_constraint(m, d) + if grad: + _update_gradient(m, d) + + # search = -Mgrad + wp.launch(_search, dim=(d.nworld, m.nv), inputs=[d]) + + +def _update_constraint(m: types.Model, d: types.Data): + @kernel + def _init_cost(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.prev_cost[worldid] = d.efc.cost[worldid] + d.efc.cost[worldid] = 0.0 + d.efc.gauss[worldid] = 0.0 + + @kernel + def _efc_kernel(d: types.Data): + efcid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + Jaref = d.efc.Jaref[efcid] + efc_D = d.efc.D[efcid] + + # TODO(team): active and conditionally active constraints + active = int(Jaref < 0.0) + d.efc.active[efcid] = active + + if active: + # efc_force = -efc_D * Jaref * active + d.efc.force[efcid] = -1.0 * efc_D * Jaref + + # cost = 0.5 * sum(efc_D * Jaref * Jaref * active)) + wp.atomic_add(d.efc.cost, worldid, 0.5 * efc_D * Jaref * Jaref) + else: + d.efc.force[efcid] = 0.0 + + @kernel + def _zero_qfrc_constraint(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.qfrc_constraint[worldid, dofid] = 0.0 + + @kernel + def _qfrc_constraint(d: types.Data): + dofid, efcid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + wp.atomic_add( + d.qfrc_constraint[worldid], + dofid, + d.efc.J[efcid, dofid] * d.efc.force[efcid], + ) + + @kernel + def _gauss(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + gauss_cost = ( + 0.5 + * (d.efc.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid]) + * (d.qacc[worldid, dofid] - d.qacc_smooth[worldid, dofid]) + ) + wp.atomic_add(d.efc.gauss, worldid, gauss_cost) + wp.atomic_add(d.efc.cost, worldid, gauss_cost) + + wp.launch(_init_cost, dim=(d.nworld), inputs=[d]) + + wp.launch(_efc_kernel, dim=(d.njmax,), inputs=[d]) + + # qfrc_constraint = efc_J.T @ efc_force + wp.launch(_zero_qfrc_constraint, dim=(d.nworld, m.nv), inputs=[d]) + + wp.launch(_qfrc_constraint, dim=(m.nv, d.njmax), inputs=[d]) + + # gauss = 0.5 * (Ma - qfrc_smooth).T @ (qacc - qacc_smooth) + + wp.launch(_gauss, dim=(d.nworld, m.nv), inputs=[d]) + + +def _update_gradient(m: types.Model, d: types.Data): + TILE = m.nv + ITERATIONS = m.opt.iterations + + @kernel + def _zero_grad_dot(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.grad_dot[worldid] = 0.0 + + @kernel + def _grad(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + grad = ( + d.efc.Ma[worldid, dofid] + - d.qfrc_smooth[worldid, dofid] + - d.qfrc_constraint[worldid, dofid] + ) + d.efc.grad[worldid, dofid] = grad + wp.atomic_add(d.efc.grad_dot, worldid, grad * grad) + + if m.opt.is_sparse: + + @kernel + def _zero_h_lower(m: types.Model, d: types.Data): + # TODO(team): static m? + worldid, elementid = wp.tid() + + if ITERATIONS > 1: + if d.efc.done[worldid]: + return + + rowid = m.dof_tri_row[elementid] + colid = m.dof_tri_col[elementid] + d.efc.h[worldid, rowid, colid] = 0.0 + + @kernel + def _set_h_qM_lower_sparse(m: types.Model, d: types.Data): + # TODO(team): static m? + worldid, elementid = wp.tid() + + if ITERATIONS > 1: + if d.efc.done[worldid]: + return + + i = m.qM_fullm_i[elementid] + j = m.qM_fullm_j[elementid] + d.efc.h[worldid, i, j] = d.qM[worldid, 0, elementid] + + else: + + @kernel + def _copy_lower_triangle(m: types.Model, d: types.Data): + # TODO(team): static m? + worldid, elementid = wp.tid() + + if ITERATIONS > 1: + if d.efc.done[worldid]: + return + + rowid = m.dof_tri_row[elementid] + colid = m.dof_tri_col[elementid] + d.efc.h[worldid, rowid, colid] = d.qM[worldid, rowid, colid] + + @kernel + def _JTDAJ(m: types.Model, d: types.Data): + # TODO(team): static m? + efcid, elementid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if ITERATIONS > 1: + if d.efc.done[worldid]: + return + + dofi = m.dof_tri_row[elementid] + dofj = m.dof_tri_col[elementid] + + efc_D = d.efc.D[efcid] + active = d.efc.active[efcid] + if efc_D == 0.0 or active == 0: + return + + # TODO(team): sparse efc_J + wp.atomic_add( + d.efc.h[worldid, dofi], + dofj, + d.efc.J[efcid, dofi] * d.efc.J[efcid, dofj] * efc_D, + ) + + @kernel + def _cholesky(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + mat_tile = wp.tile_load(d.efc.h[worldid], shape=(TILE, TILE)) + fact_tile = wp.tile_cholesky(mat_tile) + input_tile = wp.tile_load(d.efc.grad[worldid], shape=TILE) + output_tile = wp.tile_cholesky_solve(fact_tile, input_tile) + wp.tile_store(d.efc.Mgrad[worldid], output_tile) + + # grad = Ma - qfrc_smooth - qfrc_constraint + wp.launch(_zero_grad_dot, dim=(d.nworld), inputs=[d]) + + wp.launch(_grad, dim=(d.nworld, m.nv), inputs=[d]) + + if m.opt.solver == types.SolverType.CG: + smooth.solve_m(m, d, d.efc.Mgrad, d.efc.grad) + elif m.opt.solver == types.SolverType.NEWTON: + # h = qM + (efc_J.T * efc_D * active) @ efc_J + if m.opt.is_sparse: + wp.launch(_zero_h_lower, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d]) + + wp.launch( + _set_h_qM_lower_sparse, dim=(d.nworld, m.qM_fullm_i.size), inputs=[m, d] + ) + else: + wp.launch(_copy_lower_triangle, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d]) + + wp.launch(_JTDAJ, dim=(d.njmax, m.dof_tri_row.size), inputs=[m, d]) + + wp.launch_tiled(_cholesky, dim=(d.nworld,), inputs=[d], block_dim=32) + + +@wp.func +def _rescale(m: types.Model, value: float) -> float: + return value / (m.stat.meaninertia * float(wp.max(1, m.nv))) + + +@wp.func +def _in_bracket(x: wp.vec3, y: wp.vec3) -> bool: + return (x[1] < y[1] and y[1] < 0.0) or (x[1] > y[1] and y[1] > 0.0) + + +@wp.func +def _eval_pt(quad: wp.vec3, alpha: wp.float32) -> wp.vec3: + return wp.vec3( + alpha * alpha * quad[2] + alpha * quad[1] + quad[0], + 2.0 * alpha * quad[2] + quad[1], + 2.0 * quad[2], + ) + + +@wp.func +def _safe_div(x: wp.float32, y: wp.float32) -> wp.float32: + return x / wp.where(y != 0.0, y, types.MJ_MINVAL) + + +def _linesearch_iterative(m: types.Model, d: types.Data): + ITERATIONS = m.opt.iterations + + @kernel + def _gtol(m: types.Model, d: types.Data): + # TODO(team): static m? + worldid = wp.tid() + + if m.opt.iterations > 1: + if d.efc.done[worldid]: + return + + snorm = wp.math.sqrt(d.efc.search_dot[worldid]) + scale = m.stat.meaninertia * wp.float(wp.max(1, m.nv)) + d.efc.gtol[worldid] = m.opt.tolerance * m.opt.ls_tolerance * snorm * scale + + @kernel + def _zero_jv(d: types.Data): + efcid = wp.tid() + + if d.efc.done[d.efc.worldid[efcid]]: + return + + d.efc.jv[efcid] = 0.0 + + @kernel + def _jv(d: types.Data): + efcid, dofid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + j = d.efc.J[efcid, dofid] + search = d.efc.search[d.efc.worldid[efcid], dofid] + wp.atomic_add(d.efc.jv, efcid, j * search) + + @kernel + def _zero_quad_gauss(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.quad_gauss[worldid] = wp.vec3(0.0) + + @kernel + def _init_quad_gauss(m: types.Model, d: types.Data): + # TODO(team): static m? + worldid, dofid = wp.tid() + + if ITERATIONS > 1: + if d.efc.done[worldid]: + return + + search = d.efc.search[worldid, dofid] + quad_gauss = wp.vec3() + quad_gauss[0] = d.efc.gauss[worldid] / float(m.nv) + quad_gauss[1] = search * (d.efc.Ma[worldid, dofid] - d.qfrc_smooth[worldid, dofid]) + quad_gauss[2] = 0.5 * search * d.efc.mv[worldid, dofid] + wp.atomic_add(d.efc.quad_gauss, worldid, quad_gauss) + + @kernel + def _init_quad(d: types.Data): + efcid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + Jaref = d.efc.Jaref[efcid] + jv = d.efc.jv[efcid] + efc_D = d.efc.D[efcid] + quad = wp.vec3() + quad[0] = 0.5 * Jaref * Jaref * efc_D + quad[1] = jv * Jaref * efc_D + quad[2] = 0.5 * jv * jv * efc_D + d.efc.quad[efcid] = quad + + @kernel + def _init_p0_gauss(p0: wp.array(dtype=wp.vec3), d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + quad = d.efc.quad_gauss[worldid] + p0[worldid] = wp.vec3(quad[0], quad[1], 2.0 * quad[2]) + + @kernel + def _init_p0(p0: wp.array(dtype=wp.vec3), d: types.Data): + efcid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + # TODO(team): active and conditionally active constraints: + if d.efc.Jaref[efcid] >= 0.0: + return + + quad = d.efc.quad[efcid] + wp.atomic_add(p0, worldid, wp.vec3(quad[0], quad[1], 2.0 * quad[2])) + + @kernel + def _init_lo_gauss( + p0: wp.array(dtype=wp.vec3), + lo: wp.array(dtype=wp.vec3), + lo_alpha: wp.array(dtype=wp.float32), + d: types.Data, + ): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + pp0 = p0[worldid] + alpha = -_safe_div(pp0[1], pp0[2]) + lo[worldid] = _eval_pt(d.efc.quad_gauss[worldid], alpha) + lo_alpha[worldid] = alpha + + @kernel + def _init_lo( + lo: wp.array(dtype=wp.vec3), + lo_alpha: wp.array(dtype=wp.float32), + d: types.Data, + ): + efcid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + alpha = lo_alpha[worldid] + + # TODO(team): active and conditionally active constraints + if d.efc.Jaref[efcid] + alpha * d.efc.jv[efcid] < 0.0: + wp.atomic_add(lo, worldid, _eval_pt(d.efc.quad[efcid], alpha)) + + @kernel + def _init_bounds( + p0: wp.array(dtype=wp.vec3), + lo: wp.array(dtype=wp.vec3), + lo_alpha: wp.array(dtype=wp.float32), + hi: wp.array(dtype=wp.vec3), + hi_alpha: wp.array(dtype=wp.float32), + d: types.Data, + ): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + pp0 = p0[worldid] + plo = lo[worldid] + plo_alpha = lo_alpha[worldid] + lo_less = plo[1] < pp0[1] + lo[worldid] = wp.where(lo_less, plo, pp0) + lo_alpha[worldid] = wp.where(lo_less, plo_alpha, 0.0) + hi[worldid] = wp.where(lo_less, pp0, plo) + hi_alpha[worldid] = wp.where(lo_less, 0.0, plo_alpha) + + @kernel + def _next_alpha_gauss( + done: wp.array(dtype=bool), + lo: wp.array(dtype=wp.vec3), + lo_alpha: wp.array(dtype=wp.float32), + hi: wp.array(dtype=wp.vec3), + hi_alpha: wp.array(dtype=wp.float32), + lo_next: wp.array(dtype=wp.vec3), + lo_next_alpha: wp.array(dtype=wp.float32), + hi_next: wp.array(dtype=wp.vec3), + hi_next_alpha: wp.array(dtype=wp.float32), + mid: wp.array(dtype=wp.vec3), + mid_alpha: wp.array(dtype=wp.float32), + d: types.Data, + ): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + if done[worldid]: + return + + quad = d.efc.quad_gauss[worldid] + + plo = lo[worldid] + plo_alpha = lo_alpha[worldid] + plo_next_alpha = plo_alpha - _safe_div(plo[1], plo[2]) + lo_next[worldid] = _eval_pt(quad, plo_next_alpha) + lo_next_alpha[worldid] = plo_next_alpha + + phi = hi[worldid] + phi_alpha = hi_alpha[worldid] + phi_next_alpha = phi_alpha - _safe_div(phi[1], phi[2]) + hi_next[worldid] = _eval_pt(quad, phi_next_alpha) + hi_next_alpha[worldid] = phi_next_alpha + + pmid_alpha = 0.5 * (plo_alpha + phi_alpha) + mid[worldid] = _eval_pt(quad, pmid_alpha) + mid_alpha[worldid] = pmid_alpha + + @kernel + def _next_quad( + done: wp.array(dtype=bool), + lo_next: wp.array(dtype=wp.vec3), + lo_next_alpha: wp.array(dtype=wp.float32), + hi_next: wp.array(dtype=wp.vec3), + hi_next_alpha: wp.array(dtype=wp.float32), + mid: wp.array(dtype=wp.vec3), + mid_alpha: wp.array(dtype=wp.float32), + d: types.Data, + ): + efcid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + if done[worldid]: + return + + quad = d.efc.quad[efcid] + jaref = d.efc.Jaref[efcid] + jv = d.efc.jv[efcid] + + alpha = lo_next_alpha[worldid] + # TODO(team): active and conditionally active constraints + if jaref + alpha * jv < 0.0: + wp.atomic_add(lo_next, worldid, _eval_pt(quad, alpha)) + + alpha = hi_next_alpha[worldid] + # TODO(team): active and conditionally active constraints + if jaref + alpha * jv < 0.0: + wp.atomic_add(hi_next, worldid, _eval_pt(quad, alpha)) + + alpha = mid_alpha[worldid] + # TODO(team): active and conditionally active constraints + if jaref + alpha * jv < 0.0: + wp.atomic_add(mid, worldid, _eval_pt(quad, alpha)) + + @kernel + def _swap( + done: wp.array(dtype=bool), + p0: wp.array(dtype=wp.vec3), + lo: wp.array(dtype=wp.vec3), + lo_alpha: wp.array(dtype=wp.float32), + hi: wp.array(dtype=wp.vec3), + hi_alpha: wp.array(dtype=wp.float32), + lo_next: wp.array(dtype=wp.vec3), + lo_next_alpha: wp.array(dtype=wp.float32), + hi_next: wp.array(dtype=wp.vec3), + hi_next_alpha: wp.array(dtype=wp.float32), + mid: wp.array(dtype=wp.vec3), + mid_alpha: wp.array(dtype=wp.float32), + d: types.Data, + ): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + if done[worldid]: + return + + plo = lo[worldid] + plo_alpha = lo_alpha[worldid] + phi = hi[worldid] + phi_alpha = hi_alpha[worldid] + plo_next = lo_next[worldid] + plo_next_alpha = lo_next_alpha[worldid] + phi_next = hi_next[worldid] + phi_next_alpha = hi_next_alpha[worldid] + pmid = mid[worldid] + pmid_alpha = mid_alpha[worldid] + + # swap lo: + swap_lo_lo_next = _in_bracket(plo, plo_next) + plo = wp.where(swap_lo_lo_next, plo_next, plo) + plo_alpha = wp.where(swap_lo_lo_next, plo_next_alpha, plo_alpha) + swap_lo_mid = _in_bracket(plo, pmid) + plo = wp.where(swap_lo_mid, pmid, plo) + plo_alpha = wp.where(swap_lo_mid, pmid_alpha, plo_alpha) + swap_lo_hi_next = _in_bracket(plo, phi_next) + plo = wp.where(swap_lo_hi_next, phi_next, plo) + plo_alpha = wp.where(swap_lo_hi_next, phi_next_alpha, plo_alpha) + lo[worldid] = plo + lo_alpha[worldid] = plo_alpha + swap_lo = swap_lo_lo_next or swap_lo_mid or swap_lo_hi_next + + # swap hi: + swap_hi_hi_next = _in_bracket(phi, phi_next) + phi = wp.where(swap_hi_hi_next, phi_next, phi) + phi_alpha = wp.where(swap_hi_hi_next, phi_next_alpha, phi_alpha) + swap_hi_mid = _in_bracket(phi, pmid) + phi = wp.where(swap_hi_mid, pmid, phi) + phi_alpha = wp.where(swap_hi_mid, pmid_alpha, phi_alpha) + swap_hi_lo_next = _in_bracket(phi, plo_next) + phi = wp.where(swap_hi_lo_next, plo_next, phi) + phi_alpha = wp.where(swap_hi_lo_next, plo_next_alpha, phi_alpha) + hi[worldid] = phi + hi_alpha[worldid] = phi_alpha + swap_hi = swap_hi_hi_next or swap_hi_mid or swap_hi_lo_next + + # if we did not adjust the interval, we are done + # also done if either low or hi slope is nearly flat + gtol = d.efc.gtol[worldid] + done[worldid] = ( + (not swap_lo and not swap_hi) + or (plo[1] < 0 and plo[1] > -gtol) + or (phi[1] > 0 and phi[1] < gtol) + ) + + # update alpha if we have an improvement + pp0 = p0[worldid] + alpha = 0.0 + improved = plo[0] < pp0[0] or phi[0] < pp0[0] + plo_better = plo[0] < phi[0] + alpha = wp.where(improved and plo_better, plo_alpha, alpha) + alpha = wp.where(improved and not plo_better, phi_alpha, alpha) + d.efc.alpha[worldid] = alpha + + @kernel + def _qacc_ma(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + alpha = d.efc.alpha[worldid] + d.qacc[worldid, dofid] += alpha * d.efc.search[worldid, dofid] + d.efc.Ma[worldid, dofid] += alpha * d.efc.mv[worldid, dofid] + + @kernel + def _jaref(d: types.Data): + efcid = wp.tid() + + if efcid >= min(d.nefc[0], d.njmax): + return + + worldid = d.efc.worldid[efcid] + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.Jaref[efcid] += d.efc.alpha[worldid] * d.efc.jv[efcid] + + wp.launch(_gtol, dim=(d.nworld,), inputs=[m, d]) + + # mv = qM @ search + support.mul_m(m, d, d.efc.mv, d.efc.search, d.efc.done) + + # jv = efc_J @ search + # TODO(team): is there a better way of doing batched matmuls with dynamic array sizes? + wp.launch(_zero_jv, dim=(d.njmax), inputs=[d]) + + wp.launch(_jv, dim=(d.njmax, m.nv), inputs=[d]) + + # prepare quadratics + # quad_gauss = [gauss, search.T @ Ma - search.T @ qfrc_smooth, 0.5 * search.T @ mv] + wp.launch(_zero_quad_gauss, dim=(d.nworld), inputs=[d]) + + wp.launch(_init_quad_gauss, dim=(d.nworld, m.nv), inputs=[m, d]) + + # quad = [0.5 * Jaref * Jaref * efc_D, jv * Jaref * efc_D, 0.5 * jv * jv * efc_D] + + wp.launch(_init_quad, dim=(d.njmax), inputs=[d]) + + # linesearch points + done = d.efc.ls_done + done.zero_() + p0 = d.efc.p0 + lo = d.efc.lo + lo_alpha = d.efc.lo_alpha + hi = d.efc.hi + hi_alpha = d.efc.hi_alpha + lo_next = d.efc.lo_next + lo_next_alpha = d.efc.lo_next_alpha + hi_next = d.efc.hi_next + hi_next_alpha = d.efc.hi_next_alpha + mid = d.efc.mid + mid_alpha = d.efc.mid_alpha + + # initialize interval + + wp.launch(_init_p0_gauss, dim=(d.nworld,), inputs=[p0, d]) + + wp.launch(_init_p0, dim=(d.njmax,), inputs=[p0, d]) + + wp.launch(_init_lo_gauss, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, d]) + + wp.launch(_init_lo, dim=(d.njmax,), inputs=[lo, lo_alpha, d]) + + # set the lo/hi interval bounds + + wp.launch(_init_bounds, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, hi, hi_alpha, d]) + + for _ in range(m.opt.ls_iterations): + # note: we always launch ls_iterations kernels, but the kernels may early exit if done is true + # this allows us to preserve cudagraph requirements (no dynamic kernel launching) at the expense + # of extra launches + inputs = [done, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next] + inputs += [hi_next_alpha, mid, mid_alpha, d] + wp.launch(_next_alpha_gauss, dim=(d.nworld,), inputs=inputs) + + inputs = [done, lo_next, lo_next_alpha, hi_next, hi_next_alpha, mid, mid_alpha] + inputs += [d] + wp.launch(_next_quad, dim=(d.njmax,), inputs=inputs) + + inputs = [done, p0, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next] + inputs += [hi_next_alpha, mid, mid_alpha, d] + wp.launch(_swap, dim=(d.nworld,), inputs=inputs) + + wp.launch(_qacc_ma, dim=(d.nworld, m.nv), inputs=[d]) + + wp.launch(_jaref, dim=(d.njmax,), inputs=[d]) + + +@event_scope +def solve(m: types.Model, d: types.Data): + """Finds forces that satisfy constraints.""" + ITERATIONS = m.opt.iterations + + @kernel + def _zero_search_dot(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.search_dot[worldid] = 0.0 + + @kernel + def _search_update(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + search = -1.0 * d.efc.Mgrad[worldid, dofid] + + if wp.static(m.opt.solver == types.SolverType.CG): + search += d.efc.beta[worldid] * d.efc.search[worldid, dofid] + + d.efc.search[worldid, dofid] = search + wp.atomic_add(d.efc.search_dot, worldid, search * search) + + @kernel + def _done(m: types.Model, d: types.Data, solver_niter: int): + # TODO(team): static m? + worldid = wp.tid() + + if ITERATIONS > 1: + if d.efc.done[worldid]: + return + + improvement = _rescale(m, d.efc.prev_cost[worldid] - d.efc.cost[worldid]) + gradient = _rescale(m, wp.math.sqrt(d.efc.grad_dot[worldid])) + d.efc.done[worldid] = (improvement < m.opt.tolerance) or ( + gradient < m.opt.tolerance + ) + + if m.opt.solver == types.SolverType.CG: + + @kernel + def _prev_grad_Mgrad(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.prev_grad[worldid, dofid] = d.efc.grad[worldid, dofid] + d.efc.prev_Mgrad[worldid, dofid] = d.efc.Mgrad[worldid, dofid] + + @kernel + def _zero_beta_num_den(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.beta_num[worldid] = 0.0 + d.efc.beta_den[worldid] = 0.0 + + @kernel + def _beta_num_den(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + prev_Mgrad = d.efc.prev_Mgrad[worldid][dofid] + wp.atomic_add( + d.efc.beta_num, + worldid, + d.efc.grad[worldid, dofid] * (d.efc.Mgrad[worldid, dofid] - prev_Mgrad), + ) + wp.atomic_add( + d.efc.beta_den, worldid, d.efc.prev_grad[worldid, dofid] * prev_Mgrad + ) + + @kernel + def _beta(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.beta[worldid] = wp.max( + 0.0, d.efc.beta_num[worldid] / wp.max(types.MJ_MINVAL, d.efc.beta_den[worldid]) + ) + + # warmstart + kernel_copy(d.qacc, d.qacc_warmstart) + + _create_context(m, d, grad=True) + + for i in range(m.opt.iterations): + _linesearch_iterative(m, d) + + if m.opt.solver == types.SolverType.CG: + wp.launch(_prev_grad_Mgrad, dim=(d.nworld, m.nv), inputs=[d]) + + _update_constraint(m, d) + _update_gradient(m, d) + + # polak-ribiere + if m.opt.solver == types.SolverType.CG: + wp.launch(_zero_beta_num_den, dim=(d.nworld), inputs=[d]) + + wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[d]) + + wp.launch(_beta, dim=(d.nworld,), inputs=[d]) + + wp.launch(_zero_search_dot, dim=(d.nworld), inputs=[d]) + + wp.launch(_search_update, dim=(d.nworld, m.nv), inputs=[d]) + + wp.launch(_done, dim=(d.nworld,), inputs=[m, d, i]) + + kernel_copy(d.qacc_warmstart, d.qacc) diff --git a/mujoco/mjx/_src/solver_test.py b/mujoco_warp/_src/solver_test.py similarity index 71% rename from mujoco/mjx/_src/solver_test.py rename to mujoco_warp/_src/solver_test.py index e3402858..dd29a2b6 100644 --- a/mujoco/mjx/_src/solver_test.py +++ b/mujoco_warp/_src/solver_test.py @@ -1,16 +1,32 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + """Tests for solver functions.""" +import mujoco +import numpy as np +import warp as wp from absl.testing import absltest from absl.testing import parameterized from etils import epath -import mujoco -from . import io -from . import smooth + +import mujoco_warp as mjwarp + from . import solver -import numpy as np -import warp as wp -# tolerance for difference between MuJoCo and MJX smooth calculations - mostly +# tolerance for difference between MuJoCo and MJWarp smooth calculations - mostly # due to float precision _TOLERANCE = 5e-3 @@ -34,7 +50,7 @@ def _load( njmax: int = 512, keyframe: int = 0, ): - path = epath.resource_path("mujoco.mjx") / "test_data" / fname + path = epath.resource_path("mujoco_warp") / "test_data" / fname mjm = mujoco.MjModel.from_xml_path(path.as_posix()) mjm.opt.jacobian = is_sparse mjm.opt.iterations = iterations @@ -45,20 +61,21 @@ def _load( mjd = mujoco.MjData(mjm) mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe) mujoco.mj_step(mjm, mjd) - m = io.put_model(mjm) - d = io.put_data(mjm, mjd, nworld=nworld, njmax=njmax) + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd, nworld=nworld, njmax=njmax) return mjm, mjd, m, d @parameterized.parameters( - (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_CG, 25, 5), - (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4), + (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_CG, 25, 5, False), + (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4, False), + (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4, True), ) - def test_solve(self, cone, solver_, iterations, ls_iterations): - """Tests MJX solve.""" + def test_solve(self, cone, solver_, iterations, ls_iterations, sparse): + """Tests solve.""" for keyframe in range(3): mjm, mjd, m, d = self._load( "humanoid/humanoid.xml", - is_sparse=False, + is_sparse=sparse, cone=cone, solver_=solver_, iterations=iterations, @@ -75,42 +92,41 @@ def cost(qacc): mj_cost = cost(mjd.qacc) - ctx = solver._context(m, d) - solver._create_context(ctx, m, d) + solver._create_context(m, d) - mjx_cost = ctx.cost.numpy()[0] - ctx.gauss.numpy()[0] + mjwarp_cost = d.efc.cost.numpy()[0] - d.efc.gauss.numpy()[0] - _assert_eq(mjx_cost, mj_cost, name="cost") + _assert_eq(mjwarp_cost, mj_cost, name="cost") qacc_warmstart = mjd.qacc_warmstart.copy() mujoco.mj_forward(mjm, mjd) mjd.qacc_warmstart = qacc_warmstart - m = io.put_model(mjm) - d = io.put_data(mjm, mjd, njmax=mjd.nefc) + m = mjwarp.put_model(mjm) + d = mjwarp.put_data(mjm, mjd, njmax=mjd.nefc) d.qacc.zero_() d.qfrc_constraint.zero_() - d.efc_force.zero_() + d.efc.force.zero_() if solver_ == mujoco.mjtSolver.mjSOL_CG: - smooth.factor_m(m, d) - solver.solve(m, d) + mjwarp.factor_m(m, d) + mjwarp.solve(m, d) mj_cost = cost(mjd.qacc) - mjx_cost = cost(d.qacc.numpy()[0]) - self.assertLessEqual(mjx_cost, mj_cost * 1.025) + mjwarp_cost = cost(d.qacc.numpy()[0]) + self.assertLessEqual(mjwarp_cost, mj_cost * 1.025) if m.opt.solver == mujoco.mjtSolver.mjSOL_NEWTON: _assert_eq(d.qacc.numpy()[0], mjd.qacc, "qacc") _assert_eq(d.qfrc_constraint.numpy()[0], mjd.qfrc_constraint, "qfrc_constraint") - _assert_eq(d.efc_force.numpy()[: mjd.nefc], mjd.efc_force, "efc_force") + _assert_eq(d.efc.force.numpy()[: mjd.nefc], mjd.efc_force, "efc_force") @parameterized.parameters( (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_CG, 25, 5), (mujoco.mjtCone.mjCONE_PYRAMIDAL, mujoco.mjtSolver.mjSOL_NEWTON, 2, 4), ) def test_solve_batch(self, cone, solver_, iterations, ls_iterations): - """Tests MJX solve.""" + """Tests solve (batch).""" mjm0, mjd0, _, _ = self._load( "humanoid/humanoid.xml", is_sparse=False, @@ -163,7 +179,7 @@ def test_solve_batch(self, cone, solver_, iterations, ls_iterations): njmax=2 * nefc_active, ) - d.nefc_total = wp.array([nefc_active], dtype=wp.int32, ndim=1) + d.nefc = wp.array([nefc_active], dtype=wp.int32, ndim=1) nefc_fill = d.njmax - nefc_active @@ -226,25 +242,25 @@ def test_solve_batch(self, cone, solver_, iterations, ls_iterations): d.qM = wp.from_numpy(qM, dtype=wp.float32) d.qacc_smooth = wp.from_numpy(qacc_smooth, dtype=wp.float32) d.qfrc_smooth = wp.from_numpy(qfrc_smooth, dtype=wp.float32) - d.efc_J = wp.from_numpy(efc_J_fill, dtype=wp.float32) - d.efc_D = wp.from_numpy(efc_D_fill, dtype=wp.float32) - d.efc_aref = wp.from_numpy(efc_aref_fill, dtype=wp.float32) - d.efc_worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) + d.efc.J = wp.from_numpy(efc_J_fill, dtype=wp.float32) + d.efc.D = wp.from_numpy(efc_D_fill, dtype=wp.float32) + d.efc.aref = wp.from_numpy(efc_aref_fill, dtype=wp.float32) + d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) if solver_ == mujoco.mjtSolver.mjSOL_CG: - m0 = io.put_model(mjm0) - d0 = io.put_data(mjm0, mjd0) - smooth.factor_m(m0, d0) + m0 = mjwarp.put_model(mjm0) + d0 = mjwarp.put_data(mjm0, mjd0) + mjwarp.factor_m(m0, d0) qLD0 = d0.qLD.numpy() - m1 = io.put_model(mjm1) - d1 = io.put_data(mjm1, mjd1) - smooth.factor_m(m1, d1) + m1 = mjwarp.put_model(mjm1) + d1 = mjwarp.put_data(mjm1, mjd1) + mjwarp.factor_m(m1, d1) qLD1 = d1.qLD.numpy() - m2 = io.put_model(mjm2) - d2 = io.put_data(mjm2, mjd2) - smooth.factor_m(m2, d2) + m2 = mjwarp.put_model(mjm2) + d2 = mjwarp.put_data(mjm2, mjd2) + mjwarp.factor_m(m2, d2) qLD2 = d2.qLD.numpy() qLD = np.vstack([qLD0, qLD1, qLD2]) @@ -252,7 +268,7 @@ def test_solve_batch(self, cone, solver_, iterations, ls_iterations): d.qacc.zero_() d.qfrc_constraint.zero_() - d.efc_force.zero_() + d.efc.force.zero_() solver.solve(m, d) def cost(m, d, qacc): @@ -263,16 +279,16 @@ def cost(m, d, qacc): return cost mj_cost0 = cost(mjm0, mjd0, mjd0.qacc) - mjx_cost0 = cost(mjm0, mjd0, d.qacc.numpy()[0]) - self.assertLessEqual(mjx_cost0, mj_cost0 * 1.025) + mjwarp_cost0 = cost(mjm0, mjd0, d.qacc.numpy()[0]) + self.assertLessEqual(mjwarp_cost0, mj_cost0 * 1.025) mj_cost1 = cost(mjm1, mjd1, mjd1.qacc) - mjx_cost1 = cost(mjm1, mjd1, d.qacc.numpy()[1]) - self.assertLessEqual(mjx_cost1, mj_cost1 * 1.025) + mjwarp_cost1 = cost(mjm1, mjd1, d.qacc.numpy()[1]) + self.assertLessEqual(mjwarp_cost1, mj_cost1 * 1.025) mj_cost2 = cost(mjm2, mjd2, mjd2.qacc) - mjx_cost2 = cost(mjm2, mjd2, d.qacc.numpy()[2]) - self.assertLessEqual(mjx_cost2, mj_cost2 * 1.025) + mjwarp_cost2 = cost(mjm2, mjd2, d.qacc.numpy()[2]) + self.assertLessEqual(mjwarp_cost2, mj_cost2 * 1.025) if m.opt.solver == mujoco.mjtSolver.mjSOL_NEWTON: _assert_eq(d.qacc.numpy()[0], mjd0.qacc, "qacc0") @@ -284,17 +300,17 @@ def cost(m, d, qacc): _assert_eq(d.qfrc_constraint.numpy()[2], mjd2.qfrc_constraint, "qfrc_constraint2") _assert_eq( - d.efc_force.numpy()[: mjd0.nefc], + d.efc.force.numpy()[: mjd0.nefc], mjd0.efc_force, "efc_force0", ) _assert_eq( - d.efc_force.numpy()[mjd0.nefc : mjd0.nefc + mjd1.nefc], + d.efc.force.numpy()[mjd0.nefc : mjd0.nefc + mjd1.nefc], mjd1.efc_force, "efc_force1", ) _assert_eq( - d.efc_force.numpy()[mjd0.nefc + mjd1.nefc : mjd0.nefc + mjd1.nefc + mjd2.nefc], + d.efc.force.numpy()[mjd0.nefc + mjd1.nefc : mjd0.nefc + mjd1.nefc + mjd2.nefc], mjd2.efc_force, "efc_force2", ) diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py new file mode 100644 index 00000000..80e4cab9 --- /dev/null +++ b/mujoco_warp/_src/support.py @@ -0,0 +1,202 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any + +import mujoco +import warp as wp + +from .types import Data +from .types import Model +from .types import array2df +from .types import array3df +from .warp_util import event_scope +from .warp_util import kernel + + +def is_sparse(m: mujoco.MjModel): + if m.opt.jacobian == mujoco.mjtJacobian.mjJAC_AUTO: + return m.nv >= 60 + return m.opt.jacobian == mujoco.mjtJacobian.mjJAC_SPARSE + + +@event_scope +def mul_m( + m: Model, + d: Data, + res: wp.array(ndim=2, dtype=wp.float32), + vec: wp.array(ndim=2, dtype=wp.float32), + skip: wp.array(ndim=1, dtype=bool), +): + """Multiply vector by inertia matrix.""" + + if not m.opt.is_sparse: + + def tile_mul(adr: int, size: int, tilesize: int): + # TODO(team): speed up kernel compile time (14s on 2023 Macbook Pro) + @kernel + def mul( + m: Model, + d: Data, + leveladr: int, + res: array3df, + vec: array3df, + skip: wp.array(ndim=1, dtype=bool), + ): + worldid, nodeid = wp.tid() + + if skip[worldid]: + return + + dofid = m.qLD_tile[leveladr + nodeid] + qM_tile = wp.tile_load( + d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) + ) + vec_tile = wp.tile_load(vec[worldid], shape=(tilesize, 1), offset=(dofid, 0)) + res_tile = wp.tile_zeros(shape=(tilesize, 1), dtype=wp.float32) + wp.tile_matmul(qM_tile, vec_tile, res_tile) + wp.tile_store(res[worldid], res_tile, offset=(dofid, 0)) + + wp.launch_tiled( + mul, + dim=(d.nworld, size), + inputs=[ + m, + d, + adr, + res.reshape(res.shape + (1,)), + vec.reshape(vec.shape + (1,)), + skip, + ], + # TODO(team): develop heuristic for block dim, or make configurable + block_dim=32, + ) + + qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() + + for i in range(len(qLD_tileadr)): + beg = qLD_tileadr[i] + end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1] + tile_mul(beg, end - beg, int(qLD_tilesize[i])) + + else: + + @kernel + def _mul_m_sparse_diag( + m: Model, + d: Data, + res: wp.array(ndim=2, dtype=wp.float32), + vec: wp.array(ndim=2, dtype=wp.float32), + skip: wp.array(ndim=1, dtype=bool), + ): + worldid, dofid = wp.tid() + + if skip[worldid]: + return + + res[worldid, dofid] = d.qM[worldid, 0, m.dof_Madr[dofid]] * vec[worldid, dofid] + + @kernel + def _mul_m_sparse_ij( + m: Model, + d: Data, + res: wp.array(ndim=2, dtype=wp.float32), + vec: wp.array(ndim=2, dtype=wp.float32), + skip: wp.array(ndim=1, dtype=bool), + ): + worldid, elementid = wp.tid() + + if skip[worldid]: + return + + i = m.qM_mulm_i[elementid] + j = m.qM_mulm_j[elementid] + madr_ij = m.qM_madr_ij[elementid] + + qM = d.qM[worldid, 0, madr_ij] + + wp.atomic_add(res[worldid], i, qM * vec[worldid, j]) + wp.atomic_add(res[worldid], j, qM * vec[worldid, i]) + + wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec, skip]) + + wp.launch( + _mul_m_sparse_ij, dim=(d.nworld, m.qM_madr_ij.size), inputs=[m, d, res, vec, skip] + ) + + +@event_scope +def xfrc_accumulate(m: Model, d: Data, qfrc: array2df): + @wp.kernel + def _accumulate(m: Model, d: Data, qfrc: array2df): + worldid, dofid = wp.tid() + cdof = d.cdof[worldid, dofid] + rotational_cdof = wp.vec3(cdof[0], cdof[1], cdof[2]) + jac = wp.spatial_vector(cdof[3], cdof[4], cdof[5], cdof[0], cdof[1], cdof[2]) + + dof_bodyid = m.dof_bodyid[dofid] + accumul = float(0.0) + + for bodyid in range(dof_bodyid, m.nbody): + # any body that is in the subtree of dof_bodyid is part of the jacobian + parentid = bodyid + while parentid != 0 and parentid != dof_bodyid: + parentid = m.body_parentid[parentid] + if parentid == 0: + continue # body is not part of the subtree + offset = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] + cross_term = wp.cross(rotational_cdof, offset) + accumul += wp.dot(jac, d.xfrc_applied[worldid, bodyid]) + wp.dot( + cross_term, + wp.vec3( + d.xfrc_applied[worldid, bodyid][0], + d.xfrc_applied[worldid, bodyid][1], + d.xfrc_applied[worldid, bodyid][2], + ), + ) + + qfrc[worldid, dofid] += accumul + + wp.launch(kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc]) + + +@wp.func +def bisection(x: wp.array(dtype=int), v: int, a_: int, b_: int) -> int: + # Binary search for the largest index i such that x[i] <= v + # x is a sorted array + # a and b are the start and end indices within x to search + a = int(a_) + b = int(b_) + c = int(0) + while b - a > 1: + c = (a + b) // 2 + if x[c] <= v: + a = c + else: + b = c + c = a + if c != b and x[b] <= v: + c = b + return c + + +@wp.func +def mat33_from_rows(a: wp.vec3, b: wp.vec3, c: wp.vec3): + return wp.mat33(a, b, c) + + +@wp.func +def mat33_from_cols(a: wp.vec3, b: wp.vec3, c: wp.vec3): + return wp.mat33(a.x, b.x, c.x, a.y, b.y, c.y, a.z, b.z, c.z) diff --git a/mujoco/mjx/_src/support_test.py b/mujoco_warp/_src/support_test.py similarity index 71% rename from mujoco/mjx/_src/support_test.py rename to mujoco_warp/_src/support_test.py index 034f0e6d..799892c8 100644 --- a/mujoco/mjx/_src/support_test.py +++ b/mujoco_warp/_src/support_test.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,18 +15,19 @@ """Tests for support functions.""" -from absl.testing import absltest -from absl.testing import parameterized import mujoco -from mujoco import mjx import numpy as np import warp as wp +from absl.testing import absltest +from absl.testing import parameterized + +import mujoco_warp as mjwarp + from . import test_util -from .support import xfrc_accumulate wp.config.verify_cuda = True -# tolerance for difference between MuJoCo and mjWarp support calculations - mostly +# tolerance for difference between MuJoCo and MJWarp support calculations - mostly # due to float precision _TOLERANCE = 5e-5 @@ -49,7 +50,8 @@ def test_mul_m(self, sparse): res = wp.zeros((1, mjm.nv), dtype=wp.float32) vec = wp.from_numpy(np.expand_dims(mj_vec, axis=0), dtype=wp.float32) - mjx.mul_m(m, d, res, vec) + skip = wp.zeros((d.nworld), dtype=bool) + mjwarp.mul_m(m, d, res, vec, skip) _assert_eq(res.numpy()[0], mj_res, f"mul_m ({'sparse' if sparse else 'dense'})") @@ -59,23 +61,30 @@ def test_xfrc_accumulated(self): mjm, mjd, m, d = test_util.fixture("pendula.xml") xfrc = np.random.randn(*d.xfrc_applied.numpy().shape) d.xfrc_applied = wp.from_numpy(xfrc, dtype=wp.spatial_vector) - qfrc = xfrc_accumulate(m, d) + qfrc = wp.zeros((1, mjm.nv), dtype=wp.float32) + mjwarp.xfrc_accumulate(m, d, qfrc) qfrc_expected = np.zeros(m.nv) xfrc = xfrc[0] - mjd.xfrc_applied[:] = xfrc for i in range(1, m.nbody): mujoco.mj_applyFT( - mjm, - mjd, - mjd.xfrc_applied[i, :3], - mjd.xfrc_applied[i, 3:], - mjd.xipos[i], - i, - qfrc_expected, + mjm, mjd, xfrc[i, :3], xfrc[i, 3:], mjd.xipos[i], i, qfrc_expected ) np.testing.assert_almost_equal(qfrc.numpy()[0], qfrc_expected, 6) + def test_make_put_data(self): + """Tests that make_put_data and put_data are producing the same shapes for all warp arrays.""" + mjm, mjd, m, d = test_util.fixture("pendula.xml") + md = mjwarp.make_data(mjm) + + # same number of fields + self.assertEqual(len(d.__dict__), len(md.__dict__)) + + # test shapes for all arrays + for attr, val in md.__dict__.items(): + if isinstance(val, wp.array): + self.assertEqual(val.shape, getattr(d, attr).shape) + if __name__ == "__main__": wp.init() diff --git a/mujoco_warp/_src/test_collision_axis_tiled.py b/mujoco_warp/_src/test_collision_axis_tiled.py new file mode 100644 index 00000000..ed9182fe --- /dev/null +++ b/mujoco_warp/_src/test_collision_axis_tiled.py @@ -0,0 +1,142 @@ +import itertools + +import numpy as np +import warp as wp +from absl.testing import absltest + +from .collision_functions import collision_axis_tiled, Box + +BOX_BLOCK_DIM = 32 + +@wp.kernel +def test_collision_axis_tiled_kernel( + a: Box, + b: Box, + R: wp.mat33, + best_axis: wp.array(dtype=wp.vec3), + best_sign: wp.array(dtype=wp.int32), + best_idx: wp.array(dtype=wp.int32), +): + bid, axis_idx = wp.tid() + axis_out, sign_out, idx_out = collision_axis_tiled(a, b, R, axis_idx) + if axis_idx > 0: + return + best_axis[bid] = axis_out + best_sign[bid] = sign_out + best_idx[bid] = idx_out + + +class TestCollisionAxisTiled(absltest.TestCase): + """Tests the collision_axis_tiled function.""" + + def test_collision_axis_tiled_vf(self): + """Tests the collision_axis_tiled function.""" + vert = np.array(list(itertools.product((-1, 1), (-1, 1), (-1, 1))), dtype=float) + + dims_a = np.array([1.1, 0.8, 1.8]) + dims_b = np.array([0.8, 0.3, 1.1]) + + # shift in yz plane + shift_b = np.array([0, 0.2, -0.4]) + + separation = 0.13 + + s = 0.5 * 2**0.5 + rx = np.array([1, 0, 0, 0, s, s, 0, -s, s]).reshape((3, 3)) + ry = np.array([s, 0, s, 0, 1, 0, -s, 0, s]).reshape((3, 3)) + # Rotate vert 0 towards negative x at origin + R_a = ry @ rx.T + t_a = (-dims_a * vert[0]) @ R_a.T + + R_b = np.eye(3) + t_b = -np.array([dims_b[0] + separation, 0, 0]) + shift_b + + vert_a = (dims_a * vert) @ R_a.T + t_a + vert_b = dims_b * vert + t_b + + R_atob = R_b.T @ R_a + t_atob = R_b.T @ (t_a - t_b) + a = Box(vert_a.ravel()) + b = Box(vert_b.ravel()) + + R = wp.mat33(R_atob) + best_axis = wp.empty(1, dtype=wp.vec3) + best_sign = wp.empty(1, dtype=wp.int32) + best_idx = wp.empty(1, dtype=wp.int32) + + wp.launch_tiled( + kernel=test_collision_axis_tiled_kernel, + dim=1, + inputs=[a, b, R], + outputs=[best_axis, best_sign, best_idx], + block_dim=BOX_BLOCK_DIM, + ) + expected_axis = np.array([-1, 0, 0]) + expected_sign = np.sign(best_axis.numpy().dot(expected_axis)) + + np.testing.assert_array_equal( + np.cross(best_axis.numpy()[0], expected_axis), np.zeros(3) + ) + # np.testing.assert_array_equal(best_idx.numpy()[0], 8) + np.testing.assert_array_equal(best_sign.numpy()[0], expected_sign) + + def test_collision_axis_tiled_ee(self): + """Tests the collision_axis_tiled function.""" + # edge on edge + vert = np.array(list(itertools.product((-1, 1), (-1, 1), (-1, 1))), dtype=float) + + dims_a = np.array([1.1, 0.8, 1.8]) + dims_b = np.array([0.8, 0.3, 1.1]) + + dims_a = dims_b = np.ones(3) + + # shift + + separation = 0.13 + + s = 0.5 * 2**0.5 + rx = np.array([1, 0, 0, 0, s, s, 0, -s, s]).reshape((3, 3)) + ry = np.array([s, 0, s, 0, 1, 0, -s, 0, s]).reshape((3, 3)) + + # Rotate vert 0 towards negative x at origin + R_a = ry @ rx + t_a = (-dims_a * 0.5 * (vert[2] + vert[6])) @ R_a.T + # t_a = np.array([0, 0, 2**0.5]) + + R_b = np.eye(3) + t_b = (-dims_b * vert[0]) @ R_b.T + t_b = np.array([-1, 0.0, -1]) + t_b -= np.array([s, 0, s]) * 0.2 + + vert_a = (dims_a * vert) @ R_a.T + t_a + vert_b = dims_b * vert @ R_b.T + t_b + + R_atob = R_b.T @ R_a + t_atob = R_b.T @ (t_a - t_b) + a = Box(vert_a.ravel()) + b = Box(vert_b.ravel()) + + R = wp.mat33(R_atob) + best_axis = wp.empty(1, dtype=wp.vec3) + best_sign = wp.empty(1, dtype=wp.int32) + best_idx = wp.empty(1, dtype=wp.int32) + + wp.launch_tiled( + kernel=test_collision_axis_tiled_kernel, + dim=1, + inputs=[a, b, R], + outputs=[best_axis, best_sign, best_idx], + block_dim=BOX_BLOCK_DIM, + ) + expected_axis = np.array([-1, 0, -1]) + expected_sign = np.sign(best_axis.numpy().dot(expected_axis)) + + np.testing.assert_array_equal( + np.cross(best_axis.numpy()[0], expected_axis), np.zeros(3) + ) + # np.testing.assert_array_equal(best_idx.numpy()[0], 8) + np.testing.assert_array_equal(best_sign.numpy()[0], expected_sign) + + +if __name__ == "__main__": + absltest.main() diff --git a/mujoco/mjx/_src/test_util.py b/mujoco_warp/_src/test_util.py similarity index 51% rename from mujoco/mjx/_src/test_util.py rename to mujoco_warp/_src/test_util.py index aba4653e..b65de415 100644 --- a/mujoco/mjx/_src/test_util.py +++ b/mujoco_warp/_src/test_util.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Physics-Next Project Developers +# Copyright 2025 The Newton Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,19 +18,20 @@ import time from typing import Callable, Tuple -from etils import epath +import mujoco import numpy as np import warp as wp - -import mujoco +from etils import epath from . import io from . import types +from . import warp_util def fixture(fname: str, keyframe: int = -1, sparse: bool = True): - path = epath.resource_path("mujoco.mjx") / "test_data" / fname + path = epath.resource_path("mujoco_warp") / "test_data" / fname mjm = mujoco.MjModel.from_xml_path(path.as_posix()) + mjm.opt.jacobian = sparse mjd = mujoco.MjData(mjm) if keyframe > -1: mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe) @@ -38,55 +39,80 @@ def fixture(fname: str, keyframe: int = -1, sparse: bool = True): mjd.qvel = np.random.uniform(-0.01, 0.01, mjm.nv) mujoco.mj_step(mjm, mjd, 3) # let dynamics get state significantly non-zero mujoco.mj_forward(mjm, mjd) - mjm.opt.jacobian = sparse m = io.put_model(mjm) d = io.put_data(mjm, mjd) return mjm, mjd, m, d +def _sum(stack1, stack2): + ret = {} + for k in stack1: + times1, sub_stack1 = stack1[k] + times2, sub_stack2 = stack2[k] + times = [t1 + t2 for t1, t2 in zip(times1, times2)] + ret[k] = (times, _sum(sub_stack1, sub_stack2)) + return ret + + def benchmark( fn: Callable[[types.Model, types.Data], None], - m: mujoco.MjModel, + mjm: mujoco.MjModel, + mjd: mujoco.MjData, nstep: int = 1000, batch_size: int = 1024, - unroll_steps: int = 1, solver: str = "cg", iterations: int = 1, ls_iterations: int = 4, - nefc_total: int = 0, -) -> Tuple[float, float, int]: + nconmax: int = -1, + njmax: int = -1, + event_trace: bool = False, + measure_alloc: bool = False, +) -> Tuple[float, float, dict, int, list, list]: """Benchmark a model.""" if solver == "cg": - m.opt.solver = mujoco.mjtSolver.mjSOL_CG + mjm.opt.solver = mujoco.mjtSolver.mjSOL_CG elif solver == "newton": - m.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON + mjm.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON - m.opt.iterations = iterations - m.opt.ls_iterations = ls_iterations + mjm.opt.iterations = iterations + mjm.opt.ls_iterations = ls_iterations - mx = io.put_model(m) - dx = io.make_data(m, nworld=batch_size, njmax=nefc_total) - dx.nefc_total = wp.array([nefc_total], dtype=wp.int32, ndim=1) + m = io.put_model(mjm) + d = io.put_data(mjm, mjd, nworld=batch_size, nconmax=nconmax, njmax=njmax) - wp.clear_kernel_cache() jit_beg = time.perf_counter() - fn(mx, dx) - fn(mx, dx) # double warmup to work around issues with compilation during graph capture - jit_end = time.perf_counter() - jit_duration = jit_end - jit_beg - wp.synchronize() - # capture the whole smooth.kinematic() function as a CUDA graph - with wp.ScopedCapture() as capture: - fn(mx, dx) - graph = capture.graph + fn(m, d) + # double warmup to work around issues with compilation during graph capture: + fn(m, d) - run_beg = time.perf_counter() - for _ in range(nstep): - wp.capture_launch(graph) + jit_end = time.perf_counter() + jit_duration = jit_end - jit_beg wp.synchronize() - run_end = time.perf_counter() - run_duration = run_end - run_beg - - return jit_duration, run_duration, batch_size * nstep + trace = {} + ncon = [] + nefc = [] + + with warp_util.EventTracer(enabled=event_trace) as tracer: + # capture the whole function as a CUDA graph + with wp.ScopedCapture() as capture: + fn(m, d) + graph = capture.graph + + run_beg = time.perf_counter() + for _ in range(nstep): + wp.capture_launch(graph) + if trace: + trace = _sum(trace, tracer.trace()) + else: + trace = tracer.trace() + if measure_alloc: + wp.synchronize() + ncon.append(d.ncon.numpy()[0]) + nefc.append(d.nefc.numpy()[0]) + wp.synchronize() + run_end = time.perf_counter() + run_duration = run_end - run_beg + + return jit_duration, run_duration, trace, batch_size * nstep, ncon, nefc diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py new file mode 100644 index 00000000..9ab19a83 --- /dev/null +++ b/mujoco_warp/_src/types.py @@ -0,0 +1,757 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import enum + +import mujoco +import warp as wp + +MJ_MINVAL = mujoco.mjMINVAL +MJ_MINIMP = mujoco.mjMINIMP # minimum constraint impedance +MJ_MAXIMP = mujoco.mjMAXIMP # maximum constraint impedance +MJ_NREF = mujoco.mjNREF +MJ_NIMP = mujoco.mjNIMP + + +class DisableBit(enum.IntFlag): + """Disable default feature bitflags. + + Members: + CONSTRAINT: entire constraint solver + LIMIT: joint and tendon limit constraints + CONTACT: contact constraints + PASSIVE: passive forces + GRAVITY: gravitational forces + CLAMPCTRL: clamp control to specified range + ACTUATION: apply actuation forces + REFSAFE: integrator safety: make ref[0]>=2*timestep + EULERDAMP: implicit damping for Euler integration + FILTERPARENT: disable collisions between parent and child bodies + """ + + CONSTRAINT = mujoco.mjtDisableBit.mjDSBL_CONSTRAINT + LIMIT = mujoco.mjtDisableBit.mjDSBL_LIMIT + CONTACT = mujoco.mjtDisableBit.mjDSBL_CONTACT + PASSIVE = mujoco.mjtDisableBit.mjDSBL_PASSIVE + GRAVITY = mujoco.mjtDisableBit.mjDSBL_GRAVITY + CLAMPCTRL = mujoco.mjtDisableBit.mjDSBL_CLAMPCTRL + ACTUATION = mujoco.mjtDisableBit.mjDSBL_ACTUATION + REFSAFE = mujoco.mjtDisableBit.mjDSBL_REFSAFE + EULERDAMP = mujoco.mjtDisableBit.mjDSBL_EULERDAMP + FILTERPARENT = mujoco.mjtDisableBit.mjDSBL_FILTERPARENT + # unsupported: EQUALITY, FRICTIONLOSS, MIDPHASE, WARMSTART, SENSOR + + +class TrnType(enum.IntEnum): + """Type of actuator transmission. + + Members: + JOINT: force on joint + JOINTINPARENT: force on joint, expressed in parent frame + """ + + JOINT = mujoco.mjtTrn.mjTRN_JOINT + JOINTINPARENT = mujoco.mjtTrn.mjTRN_JOINTINPARENT + # unsupported: SITE, TENDON, SLIDERCRANK, BODY + + +class DynType(enum.IntEnum): + """Type of actuator dynamics. + + Members: + NONE: no internal dynamics; ctrl specifies force + FILTEREXACT: linear filter: da/dt = (u-a) / tau, with exact integration + """ + + NONE = mujoco.mjtDyn.mjDYN_NONE + FILTEREXACT = mujoco.mjtDyn.mjDYN_FILTEREXACT + # unsupported: INTEGRATOR, FILTER, MUSCLE, USER + + +class GainType(enum.IntEnum): + """Type of actuator gain. + + Members: + FIXED: fixed gain + AFFINE: const + kp*length + kv*velocity + """ + + FIXED = mujoco.mjtGain.mjGAIN_FIXED + AFFINE = mujoco.mjtGain.mjGAIN_AFFINE + # unsupported: MUSCLE, USER + + +class BiasType(enum.IntEnum): + """Type of actuator bias. + + Members: + NONE: no bias + AFFINE: const + kp*length + kv*velocity + """ + + NONE = mujoco.mjtBias.mjBIAS_NONE + AFFINE = mujoco.mjtBias.mjBIAS_AFFINE + # unsupported: MUSCLE, USER + + +class JointType(enum.IntEnum): + """Type of degree of freedom. + + Members: + FREE: global position and orientation (quat) (7,) + BALL: orientation (quat) relative to parent (4,) + SLIDE: sliding distance along body-fixed axis (1,) + HINGE: rotation angle (rad) around body-fixed axis (1,) + """ + + FREE = mujoco.mjtJoint.mjJNT_FREE + BALL = mujoco.mjtJoint.mjJNT_BALL + SLIDE = mujoco.mjtJoint.mjJNT_SLIDE + HINGE = mujoco.mjtJoint.mjJNT_HINGE + + def dof_width(self) -> int: + return {0: 6, 1: 3, 2: 1, 3: 1}[self.value] + + def qpos_width(self) -> int: + return {0: 7, 1: 4, 2: 1, 3: 1}[self.value] + + +class ConeType(enum.IntEnum): + """Type of friction cone. + + Members: + PYRAMIDAL: pyramidal + """ + + PYRAMIDAL = mujoco.mjtCone.mjCONE_PYRAMIDAL + # unsupported: ELLIPTIC + + +class GeomType(enum.IntEnum): + """Type of geometry. + + Members: + PLANE: plane + SPHERE: sphere + CAPSULE: capsule + BOX: box + """ + + PLANE = mujoco.mjtGeom.mjGEOM_PLANE + SPHERE = mujoco.mjtGeom.mjGEOM_SPHERE + CAPSULE = mujoco.mjtGeom.mjGEOM_CAPSULE + BOX = mujoco.mjtGeom.mjGEOM_BOX + # unsupported: HFIELD, ELLIPSOID, CYLINDER, MESH, + # NGEOMTYPES, ARROW*, LINE, SKIN, LABEL, NONE + + +class SolverType(enum.IntEnum): + """Constraint solver algorithm. + + Members: + CG: Conjugate gradient (primal) + NEWTON: Newton (primal) + """ + + CG = mujoco.mjtSolver.mjSOL_CG + NEWTON = mujoco.mjtSolver.mjSOL_NEWTON + + +class vec5f(wp.types.vector(length=5, dtype=wp.float32)): + pass + + +class vec10f(wp.types.vector(length=10, dtype=wp.float32)): + pass + + +vec5 = vec5f +vec10 = vec10f +array2df = wp.array2d(dtype=wp.float32) +array3df = wp.array3d(dtype=wp.float32) + + +@wp.struct +class Option: + """Physics options. + + Attributes: + timestep: simulation timestep + impratio: ratio of friction-to-normal contact impedance + tolerance: main solver tolerance + ls_tolerance: CG/Newton linesearch tolerance + gravity: gravitational acceleration + integrator: integration mode (mjtIntegrator) + cone: type of friction cone (mjtCone) + solver: solver algorithm (mjtSolver) + iterations: number of main solver iterations + ls_iterations: maximum number of CG/Newton linesearch iterations + disableflags: bit flags for disabling standard features + is_sparse: whether to use sparse representations + """ + + timestep: float + impratio: wp.float32 + tolerance: float + ls_tolerance: float + gravity: wp.vec3 + integrator: int + cone: int + solver: int + iterations: int + ls_iterations: int + disableflags: int + is_sparse: bool + + +@wp.struct +class Statistic: + """Model statistics (in qpos0). + + Attributes: + meaninertia: mean diagonal inertia + """ + + meaninertia: float + + +@wp.struct +class Constraint: + """Constraint data. + + Attributes: + worldid: world id (njmax,) + J: constraint Jacobian (njmax, nv) + pos: constraint position (equality, contact) (njmax,) + margin: inclusion margin (contact) (njmax,) + D: constraint mass (njmax,) + aref: reference pseudo-acceleration (njmax,) + force: constraint force in constraint space (njmax,) + Jaref: Jac*qacc - aref (njmax,) + Ma: M*qacc (nworld, nv) + grad: gradient of master cost (nworld, nv) + grad_dot: dot(grad, grad) (nworld,) + Mgrad: M / grad (nworld, nv) + search: linesearch vector (nworld, nv) + search_dot: dot(search, search) (nworld,) + gauss: gauss Cost (nworld,) + cost: constraint + Gauss cost (nworld,) + prev_cost: cost from previous iter (nworld,) + solver_niter: number of solver iterations (nworld,) + active: active (quadratic) constraints (njmax,) + gtol: linesearch termination tolerance (nworld,) + mv: qM @ search (nworld, nv) + jv: efc_J @ search (njmax,) + quad: quadratic cost coefficients (njmax, 3) + quad_gauss: quadratic cost gauss coefficients (nworld, 3) + h: cone hessian (nworld, nv, nv) + alpha: line search step size (nworld,) + prev_grad: previous grad (nworld, nv) + prev_Mgrad: previous Mgrad (nworld, nv) + beta: polak-ribiere beta (nworld,) + beta_num: numerator of beta (nworld,) + beta_den: denominator of beta (nworld,) + done: solver done (nworld,) + ls_done: linesearch done (nworld,) + p0: initial point (nworld, 3) + lo: low point bounding the line search interval (nworld, 3) + lo_alpha: alpha for low point (nworld,) + hi: high point bounding the line search interval (nworld, 3) + hi_alpha: alpha for high point (nworld,) + lo_next: next low point (nworld, 3) + lo_next_alpha: alpha for next low point (nworld,) + hi_next: next high point (nworld, 3) + hi_next_alpha: alpha for next high point (nworld,) + mid: loss at mid_alpha (nworld, 3) + mid_alpha: midpoint between lo_alpha and hi_alpha (nworld,) + """ + + worldid: wp.array(dtype=wp.int32, ndim=1) + J: wp.array(dtype=wp.float32, ndim=2) + pos: wp.array(dtype=wp.float32, ndim=1) + margin: wp.array(dtype=wp.float32, ndim=1) + D: wp.array(dtype=wp.float32, ndim=1) + aref: wp.array(dtype=wp.float32, ndim=1) + force: wp.array(dtype=wp.float32, ndim=1) + Jaref: wp.array(dtype=wp.float32, ndim=1) + Ma: wp.array(dtype=wp.float32, ndim=2) + grad: wp.array(dtype=wp.float32, ndim=2) + grad_dot: wp.array(dtype=wp.float32, ndim=1) + Mgrad: wp.array(dtype=wp.float32, ndim=2) + search: wp.array(dtype=wp.float32, ndim=2) + search_dot: wp.array(dtype=wp.float32, ndim=1) + gauss: wp.array(dtype=wp.float32, ndim=1) + cost: wp.array(dtype=wp.float32, ndim=1) + prev_cost: wp.array(dtype=wp.float32, ndim=1) + solver_niter: wp.array(dtype=wp.int32, ndim=1) + active: wp.array(dtype=wp.int32, ndim=1) + gtol: wp.array(dtype=wp.float32, ndim=1) + mv: wp.array(dtype=wp.float32, ndim=2) + jv: wp.array(dtype=wp.float32, ndim=1) + quad: wp.array(dtype=wp.vec3f, ndim=1) + quad_gauss: wp.array(dtype=wp.vec3f, ndim=1) + h: wp.array(dtype=wp.float32, ndim=3) + alpha: wp.array(dtype=wp.float32, ndim=1) + prev_grad: wp.array(dtype=wp.float32, ndim=2) + prev_Mgrad: wp.array(dtype=wp.float32, ndim=2) + beta: wp.array(dtype=wp.float32, ndim=1) + beta_num: wp.array(dtype=wp.float32, ndim=1) + beta_den: wp.array(dtype=wp.float32, ndim=1) + done: wp.array(dtype=bool, ndim=1) + # linesearch + ls_done: wp.array(dtype=bool, ndim=1) + p0: wp.array(dtype=wp.vec3, ndim=1) + lo: wp.array(dtype=wp.vec3, ndim=1) + lo_alpha: wp.array(dtype=wp.float32, ndim=1) + hi: wp.array(dtype=wp.vec3, ndim=1) + hi_alpha: wp.array(dtype=wp.float32, ndim=1) + lo_next: wp.array(dtype=wp.vec3, ndim=1) + lo_next_alpha: wp.array(dtype=wp.float32, ndim=1) + hi_next: wp.array(dtype=wp.vec3, ndim=1) + hi_next_alpha: wp.array(dtype=wp.float32, ndim=1) + mid: wp.array(dtype=wp.vec3, ndim=1) + mid_alpha: wp.array(dtype=wp.float32, ndim=1) + + +@wp.struct +class Model: + """Model definition and parameters. + + Attributes: + nq: number of generalized coordinates = dim () + nv: number of degrees of freedom = dim () + nu: number of actuators/controls = dim () + na: number of activation states = dim () + nbody: number of bodies () + njnt: number of joints () + ngeom: number of geoms () + nsite: number of sites () + nexclude: number of excluded geom pairs () + nmocap: number of mocap bodies () + nM: number of non-zeros in sparse inertia matrix () + opt: physics options + stat: model statistics + qpos0: qpos values at default pose (nq,) + qpos_spring: reference pose for springs (nq,) + body_tree: BFS ordering of body ids + body_treeadr: starting index of each body tree level + actuator_moment_offset_nv: tiling configuration + actuator_moment_offset_nu: tiling configuration + actuator_moment_tileadr: tiling configuration + actuator_moment_tilesize_nv: tiling configuration + actuator_moment_tilesize_nu: tiling configuration + qM_fullm_i: sparse mass matrix addressing + qM_fullm_j: sparse mass matrix addressing + qM_mulm_i: sparse mass matrix addressing + qM_mulm_j: sparse mass matrix addressing + qM_madr_ij: sparse mass matrix addressing + qLD_update_tree: dof tree ordering for qLD updates + qLD_update_treeadr: index of each dof tree level + qLD_tile: tiling configuration + qLD_tileadr: tiling configuration + qLD_tilesize: tiling configuration + body_parentid: id of body's parent (nbody,) + body_rootid: id of root above body (nbody,) + body_weldid: id of body that this body is welded to (nbody,) + body_mocapid: id of mocap data; -1: none (nbody,) + body_jntnum: number of joints for this body (nbody,) + body_jntadr: start addr of joints; -1: no joints (nbody,) + body_dofnum: number of motion degrees of freedom (nbody,) + body_dofadr: start addr of dofs; -1: no dofs (nbody,) + body_geomnum: number of geoms (nbody,) + body_geomadr: start addr of geoms; -1: no geoms (nbody,) + body_pos: position offset rel. to parent body (nbody, 3) + body_quat: orientation offset rel. to parent body (nbody, 4) + body_ipos: local position of center of mass (nbody, 3) + body_iquat: local orientation of inertia ellipsoid (nbody, 4) + body_mass: mass (nbody,) + subtree_mass: mass of subtree (nbody,) + body_inertia: diagonal inertia in ipos/iquat frame (nbody, 3) + body_invweight0: mean inv inert in qpos0 (trn, rot) (nbody, 2) + body_contype: OR over all geom contypes (nbody,) + body_conaffinity: OR over all geom conaffinities (nbody,) + jnt_type: type of joint (mjtJoint) (njnt,) + jnt_qposadr: start addr in 'qpos' for joint's data (njnt,) + jnt_dofadr: start addr in 'qvel' for joint's data (njnt,) + jnt_bodyid: id of joint's body (njnt,) + jnt_limited: does joint have limits (njnt,) + jnt_actfrclimited: does joint have actuator force limits (njnt,) + jnt_solref: constraint solver reference: limit (njnt, mjNREF) + jnt_solimp: constraint solver impedance: limit (njnt, mjNIMP) + jnt_pos: local anchor position (njnt, 3) + jnt_axis: local joint axis (njnt, 3) + jnt_stiffness: stiffness coefficient (njnt,) + jnt_range: joint limits (njnt, 2) + jnt_actfrcrange: range of total actuator force (njnt, 2) + jnt_margin: min distance for limit detection (njnt,) + jnt_limited_slide_hinge_adr: limited/slide/hinge jntadr + dof_bodyid: id of dof's body (nv,) + dof_jntid: id of dof's joint (nv,) + dof_parentid: id of dof's parent; -1: none (nv,) + dof_Madr: dof address in M-diagonal (nv,) + dof_armature: dof armature inertia/mass (nv,) + dof_damping: damping coefficient (nv,) + dof_invweight0: diag. inverse inertia in qpos0 (nv,) + dof_tri_row: np.tril_indices (mjm.nv)[0] + dof_tri_col: np.tril_indices (mjm.nv)[1] + geom_type: geometric type (mjtGeom) (ngeom,) + geom_contype: geom contact type (ngeom,) + geom_conaffinity: geom contact affinity (ngeom,) + geom_condim: contact dimensionality (1, 3, 4, 6) (ngeom,) + geom_bodyid: id of geom's body (ngeom,) + geom_dataid: id of geom's mesh/hfield; -1: none (ngeom,) + geom_priority: geom contact priority (ngeom,) + geom_solmix: mixing coef for solref/imp in geom pair (ngeom,) + geom_solref: constraint solver reference: contact (ngeom, mjNREF) + geom_solimp: constraint solver impedance: contact (ngeom, mjNIMP) + geom_size: geom-specific size parameters (ngeom, 3) + geom_aabb: bounding box, (center, size) (ngeom, 6) + geom_rbound: radius of bounding sphere (ngeom,) + geom_pos: local position offset rel. to body (ngeom, 3) + geom_quat: local orientation offset rel. to body (ngeom, 4) + geom_friction: friction for (slide, spin, roll) (ngeom, 3) + geom_margin: detect contact if dist dict: + global _STACK + + if _STACK is None: + return {} + + ret = {} + + for k, v in _STACK.items(): + events, sub_stack = v + # push into next level of stack + saved_stack, _STACK = _STACK, sub_stack + sub_trace = self.trace() + # pop! + _STACK = saved_stack + events = tuple(wp.get_event_elapsed_time(beg, end) for beg, end in events) + ret[k] = (events, sub_trace) + + return ret + + def __exit__(self, type, value, traceback): + global _STACK + _STACK = None + + +def event_scope(fn, name: str = ""): + name = name or getattr(fn, "__name__") + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + global _STACK + if _STACK is None: + return fn(*args, **kwargs) + # push into next level of stack + saved_stack, _STACK = _STACK, {} + beg = wp.Event(enable_timing=True) + end = wp.Event(enable_timing=True) + wp.record_event(beg) + res = fn(*args, **kwargs) + wp.record_event(end) + # pop back up to current level + sub_stack, _STACK = _STACK, saved_stack + # append events + events, _ = _STACK.get(name, ((), None)) + _STACK[name] = (events + ((beg, end),), sub_stack) + return res + + return wrapper + + +# @kernel decorator to automatically set up modules based on nested +# function names +def kernel( + f: Optional[Callable] = None, + *, + enable_backward: Optional[bool] = None, + module: Optional[Module] = None, +): + """ + Decorator to register a Warp kernel from a Python function. + The function must be defined with type annotations for all arguments. + The function must not return anything. + + Example:: + + @kernel + def my_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)): + tid = wp.tid() + b[tid] = a[tid] + 1.0 + + + @kernel(enable_backward=False) + def my_kernel_no_backward(a: wp.array(dtype=float, ndim=2), x: float): + # the backward pass will not be generated + i, j = wp.tid() + a[i, j] = x + + + @kernel(module="unique") + def my_kernel_unique_module(a: wp.array(dtype=float), b: wp.array(dtype=float)): + # the kernel will be registered in new unique module created just for this + # kernel and its dependent functions and structs + tid = wp.tid() + b[tid] = a[tid] + 1.0 + + Args: + f: The function to be registered as a kernel. + enable_backward: If False, the backward pass will not be generated. + module: The :class:`warp.context.Module` to which the kernel belongs. Alternatively, if a string `"unique"` is provided, the kernel is assigned to a new module named after the kernel name and hash. If None, the module is inferred from the function's module. + + Returns: + The registered kernel. + """ + if module is None: + # create a module name based on the name of the nested function + # get the qualified name, e.g. "main..nested_kernel" + qualname = f.__qualname__ + parts = [part for part in qualname.split(".") if part != ""] + outer_functions = parts[:-1] + module = get_module(".".join([f.__module__] + outer_functions)) + + return wp.kernel(f, enable_backward=enable_backward, module=module) + + +@wp.kernel +def _copy_2df(dest: types.array2df, src: types.array2df): + i, j = wp.tid() + dest[i, j] = src[i, j] + + +@wp.kernel +def _copy_3df(dest: types.array3df, src: types.array3df): + i, j, k = wp.tid() + dest[i, j, k] = src[i, j, k] + + +@wp.kernel +def _copy_2dvec10f( + dest: wp.array2d(dtype=types.vec10f), src: wp.array2d(dtype=types.vec10f) +): + i, j = wp.tid() + dest[i, j] = src[i, j] + + +@wp.kernel +def _copy_2dvec3f(dest: wp.array2d(dtype=wp.vec3f), src: wp.array2d(dtype=wp.vec3f)): + i, j = wp.tid() + dest[i, j] = src[i, j] + + +@wp.kernel +def _copy_2dmat33f(dest: wp.array2d(dtype=wp.mat33f), src: wp.array2d(dtype=wp.mat33f)): + i, j = wp.tid() + dest[i, j] = src[i, j] + + +@wp.kernel +def _copy_2dspatialf( + dest: wp.array2d(dtype=wp.spatial_vector), src: wp.array2d(dtype=wp.spatial_vector) +): + i, j = wp.tid() + dest[i, j] = src[i, j] + + +# TODO(team): remove kernel_copy once wp.copy is supported in cuda subgraphs + + +def kernel_copy(dest: wp.array, src: wp.array): + if src.shape != dest.shape: + raise ValueError("only same shape copying allowed") + + if src.dtype != dest.dtype: + if (src.dtype, dest.dtype) not in ( + (wp.float32, np.float32), + (np.float32, wp.float32), + (wp.int32, np.int32), + (np.int32, wp.int32), + ): + raise ValueError(f"only same dtype copying allowed: {src.dtype} != {dest.dtype}") + + if src.ndim == 2 and src.dtype == wp.float32: + kernel = _copy_2df + elif src.ndim == 3 and src.dtype == wp.float32: + kernel = _copy_3df + elif src.ndim == 2 and src.dtype == wp.vec3f: + kernel = _copy_2dvec3f + elif src.ndim == 2 and src.dtype == wp.mat33f: + kernel = _copy_2dmat33f + elif src.ndim == 2 and src.dtype == types.vec10f: + kernel = _copy_2dvec10f + elif src.ndim == 2 and src.dtype == wp.spatial_vector: + kernel = _copy_2dspatialf + else: + raise NotImplementedError("copy not supported for these array types") + + wp.launch(kernel=kernel, dim=src.shape, inputs=[dest, src]) diff --git a/mujoco_warp/test_data/collision.xml b/mujoco_warp/test_data/collision.xml new file mode 100644 index 00000000..5cc041e2 --- /dev/null +++ b/mujoco_warp/test_data/collision.xml @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mujoco/mjx/test_data/constraints.xml b/mujoco_warp/test_data/constraints.xml similarity index 100% rename from mujoco/mjx/test_data/constraints.xml rename to mujoco_warp/test_data/constraints.xml diff --git a/mujoco/mjx/test_data/humanoid/humanoid.xml b/mujoco_warp/test_data/humanoid/humanoid.xml similarity index 96% rename from mujoco/mjx/test_data/humanoid/humanoid.xml rename to mujoco_warp/test_data/humanoid/humanoid.xml index 168310bc..f4152998 100644 --- a/mujoco/mjx/test_data/humanoid/humanoid.xml +++ b/mujoco_warp/test_data/humanoid/humanoid.xml @@ -14,7 +14,7 @@ --> - @@ -40,7 +40,7 @@ - + @@ -186,15 +186,6 @@ - - - - - - - - -