diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 36c484f4e..2f2be3afa 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -35,6 +35,10 @@ from mujoco_warp._src.collision_primitive import primitive_narrowphase as primitive_narrowphase from mujoco_warp._src.collision_sdf import sdf_narrowphase as sdf_narrowphase from mujoco_warp._src.constraint import make_constraint as make_constraint +from mujoco_warp._src.delay import init_ctrl_history as init_ctrl_history +from mujoco_warp._src.delay import init_sensor_history as init_sensor_history +from mujoco_warp._src.delay import read_ctrl as read_ctrl +from mujoco_warp._src.delay import read_sensor as read_sensor from mujoco_warp._src.derivative import deriv_smooth_vel as deriv_smooth_vel from mujoco_warp._src.forward import euler as euler from mujoco_warp._src.forward import forward as forward diff --git a/mujoco_warp/_src/delay.py b/mujoco_warp/_src/delay.py new file mode 100644 index 000000000..8aa520b09 --- /dev/null +++ b/mujoco_warp/_src/delay.py @@ -0,0 +1,837 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp + +from mujoco_warp._src.types import MJ_MINVAL +from mujoco_warp._src.types import Data +from mujoco_warp._src.types import Model + +wp.set_module_options({"enable_backward": False}) + + +@wp.func +def _history_physical_index(cursor: int, n: int, logical: int) -> int: + """Convert logical index (0=oldest, n-1=newest) to physical index.""" + return (cursor + 1 + logical) % n + + +@wp.func +def _history_find_index( + # In: + buf: wp.array2d(dtype=float), + worldid: int, + buf_offset: int, + n: int, + cursor: int, + t: float, +) -> int: + """Find logical index i such that times[i-1] < t <= times[i]. + + Returns 0 if t <= times[oldest], n if t > times[newest]. + Uses linear search on the circular buffer (buffers are small, typically 2-10). + """ + times_offset = buf_offset + 2 + + oldest_phys = _history_physical_index(cursor, n, 0) + newest_phys = _history_physical_index(cursor, n, n - 1) + t_oldest = buf[worldid, times_offset + oldest_phys] + t_newest = buf[worldid, times_offset + newest_phys] + + # before or at first element + if t <= t_oldest: + return 0 + + # after last element + if t > t_newest: + return n + + # linear search: find smallest logical i such that times[phys(i)] >= t + result = n + for k in range(n): + phys_k = _history_physical_index(cursor, n, k) + if buf[worldid, times_offset + phys_k] >= t: + result = k + break + + return result + + +@wp.func +def _history_read_scalar( + # In: + buf: wp.array2d(dtype=float), + worldid: int, + buf_offset: int, + n: int, + t: float, + interp: int, +) -> float: + """Read a scalar value from history buffer at time t. + + interp: 0=zero-order-hold, 1=linear interpolation + """ + cursor = int(buf[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + oldest_phys = _history_physical_index(cursor, n, 0) + newest_phys = _history_physical_index(cursor, n, n - 1) + t_oldest = buf[worldid, times_offset + oldest_phys] + t_newest = buf[worldid, times_offset + newest_phys] + + # extrapolate before oldest + if t <= t_oldest + MJ_MINVAL: + return buf[worldid, values_offset + oldest_phys] + + # extrapolate after newest + if t >= t_newest - MJ_MINVAL: + return buf[worldid, values_offset + newest_phys] + + # find bracketing index + i = _history_find_index(buf, worldid, buf_offset, n, cursor, t) + phys_i = _history_physical_index(cursor, n, i) + + # exact match + if wp.abs(t - buf[worldid, times_offset + phys_i]) < MJ_MINVAL: + return buf[worldid, values_offset + phys_i] + + phys_lo = _history_physical_index(cursor, n, i - 1) + phys_hi = phys_i + + # zero-order hold + if interp == 0: + return buf[worldid, values_offset + phys_lo] + + # linear interpolation + dt = buf[worldid, times_offset + phys_hi] - buf[worldid, times_offset + phys_lo] + alpha = (t - buf[worldid, times_offset + phys_lo]) / dt + v_lo = buf[worldid, values_offset + phys_lo] + v_hi = buf[worldid, values_offset + phys_hi] + return v_lo + alpha * (v_hi - v_lo) + + +@wp.func +def _history_read_vector( + # In: + adr: int, + buf: wp.array2d(dtype=float), + worldid: int, + buf_offset: int, + n: int, + dim: int, + t: float, + interp: int, + # Data out: + sensordata_out: wp.array2d(dtype=float), +) -> int: + """Read a vector value from history buffer at time t into sensordata. + + Returns 1 on success (value written to sensordata). + interp: 0=zero-order-hold, 1=linear interpolation + """ + cursor = int(buf[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + oldest_phys = _history_physical_index(cursor, n, 0) + newest_phys = _history_physical_index(cursor, n, n - 1) + t_oldest = buf[worldid, times_offset + oldest_phys] + t_newest = buf[worldid, times_offset + newest_phys] + + # extrapolate before oldest: copy oldest + if t <= t_oldest + MJ_MINVAL: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + oldest_phys * dim + d] + return 1 + + # extrapolate after newest: copy newest + if t >= t_newest - MJ_MINVAL: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + newest_phys * dim + d] + return 1 + + # find bracketing index + i = _history_find_index(buf, worldid, buf_offset, n, cursor, t) + phys_i = _history_physical_index(cursor, n, i) + + # exact match + if wp.abs(t - buf[worldid, times_offset + phys_i]) < MJ_MINVAL: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + phys_i * dim + d] + return 1 + + phys_lo = _history_physical_index(cursor, n, i - 1) + phys_hi = phys_i + + # zero-order hold + if interp == 0: + for d in range(dim): + sensordata_out[worldid, adr + d] = buf[worldid, values_offset + phys_lo * dim + d] + return 1 + + # linear interpolation + dt = buf[worldid, times_offset + phys_hi] - buf[worldid, times_offset + phys_lo] + alpha = (t - buf[worldid, times_offset + phys_lo]) / dt + for d in range(dim): + v_lo = buf[worldid, values_offset + phys_lo * dim + d] + v_hi = buf[worldid, values_offset + phys_hi * dim + d] + sensordata_out[worldid, adr + d] = v_lo + alpha * (v_hi - v_lo) + return 1 + + +@wp.func +def _history_insert_scalar( + # In: + worldid: int, + buf_offset: int, + n: int, + t: float, + value: float, + # Out: + buf_out: wp.array2d(dtype=float), +): + """Insert a scalar value into history buffer at time t.""" + cursor = int(buf_out[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + i = _history_find_index(buf_out, worldid, buf_offset, n, cursor, t) + + # exact match + if i < n: + phys_i = _history_physical_index(cursor, n, i) + if wp.abs(t - buf_out[worldid, times_offset + phys_i]) < MJ_MINVAL: + buf_out[worldid, values_offset + phys_i] = value + return + + # older than oldest: replace oldest + if i == 0: + oldest_phys = _history_physical_index(cursor, n, 0) + buf_out[worldid, times_offset + oldest_phys] = t + buf_out[worldid, values_offset + oldest_phys] = value + return + + # newer than newest: advance cursor + if i == n: + cursor = (cursor + 1) % n + buf_out[worldid, buf_offset + 1] = float(cursor) + buf_out[worldid, times_offset + cursor] = t + buf_out[worldid, values_offset + cursor] = value + return + + # out-of-order: shift [1, i-1] left, insert at i-1 + for j in range(i - 1): + src_phys = _history_physical_index(cursor, n, j + 1) + dst_phys = _history_physical_index(cursor, n, j) + buf_out[worldid, times_offset + dst_phys] = buf_out[worldid, times_offset + src_phys] + buf_out[worldid, values_offset + dst_phys] = buf_out[worldid, values_offset + src_phys] + insert_phys = _history_physical_index(cursor, n, i - 1) + buf_out[worldid, times_offset + insert_phys] = t + buf_out[worldid, values_offset + insert_phys] = value + + +@wp.func +def _history_insert_vector( + # In: + worldid: int, + buf_offset: int, + n: int, + dim: int, + t: float, + src: wp.array2d(dtype=float), + src_adr: int, + # Out: + buf_out: wp.array2d(dtype=float), +): + """Insert a vector value from src[worldid, src_adr:src_adr+dim] into history buffer at time t.""" + cursor = int(buf_out[worldid, buf_offset + 1]) + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + n + + i = _history_find_index(buf_out, worldid, buf_offset, n, cursor, t) + + slot_phys = -1 + + # exact match + if i < n: + phys_i = _history_physical_index(cursor, n, i) + if wp.abs(t - buf_out[worldid, times_offset + phys_i]) < MJ_MINVAL: + slot_phys = phys_i + + if slot_phys < 0: + if i == 0: + # older than oldest: replace oldest + slot_phys = _history_physical_index(cursor, n, 0) + buf_out[worldid, times_offset + slot_phys] = t + elif i == n: + # newer than newest: advance cursor + cursor = (cursor + 1) % n + buf_out[worldid, buf_offset + 1] = float(cursor) + slot_phys = cursor + buf_out[worldid, times_offset + slot_phys] = t + else: + # out-of-order: shift [1, i-1] left, insert at i-1 + for j in range(i - 1): + src_phys = _history_physical_index(cursor, n, j + 1) + dst_phys = _history_physical_index(cursor, n, j) + buf_out[worldid, times_offset + dst_phys] = buf_out[worldid, times_offset + src_phys] + for d in range(dim): + buf_out[worldid, values_offset + dst_phys * dim + d] = buf_out[worldid, values_offset + src_phys * dim + d] + slot_phys = _history_physical_index(cursor, n, i - 1) + buf_out[worldid, times_offset + slot_phys] = t + + # copy values + for d in range(dim): + buf_out[worldid, values_offset + slot_phys * dim + d] = src[worldid, src_adr + d] + + +@wp.kernel +def _read_ctrl_delayed_kernel( + # Model: + actuator_history: wp.array(dtype=wp.vec2i), + actuator_historyadr: wp.array(dtype=int), + actuator_delay: wp.array(dtype=float), + # Data in: + time_in: wp.array(dtype=float), + history_in: wp.array2d(dtype=float), + ctrl_in: wp.array2d(dtype=float), + # Data out: + ctrl_out: wp.array2d(dtype=float), +): + """Read delayed ctrl for each actuator.""" + worldid, uid = wp.tid() + + hist = actuator_history[uid] + nsample = hist[0] + + if nsample == 0: + # no delay: direct copy + ctrl_out[worldid, uid] = ctrl_in[worldid, uid] + else: + interp = hist[1] + delay = actuator_delay[uid] + buf_offset = actuator_historyadr[uid] + t = time_in[worldid] - delay + ctrl_out[worldid, uid] = _history_read_scalar(history_in, worldid, buf_offset, nsample, t, interp) + + +@wp.kernel +def _insert_ctrl_history_kernel( + # Model: + actuator_history: wp.array(dtype=wp.vec2i), + actuator_historyadr: wp.array(dtype=int), + # Data in: + time_in: wp.array(dtype=float), + ctrl_in: wp.array2d(dtype=float), + # Data out: + history_out: wp.array2d(dtype=float), +): + """Insert current ctrl into history buffers.""" + worldid, uid = wp.tid() + + hist = actuator_history[uid] + nsample = hist[0] + if nsample == 0: + return + + buf_offset = actuator_historyadr[uid] + t = time_in[worldid] + value = ctrl_in[worldid, uid] + _history_insert_scalar(worldid, buf_offset, nsample, t, value, history_out) + + +@wp.kernel +def _insert_sensor_history_stage( + # Model: + sensor_dim: wp.array(dtype=int), + sensor_adr: wp.array(dtype=int), + sensor_history: wp.array(dtype=wp.vec2i), + sensor_historyadr: wp.array(dtype=int), + sensor_delay: wp.array(dtype=float), + sensor_interval: wp.array(dtype=wp.vec2), + # Data in: + time_in: wp.array(dtype=float), + sensordata_in: wp.array2d(dtype=float), + # In: + sensor_ids: wp.array(dtype=int), + # Data out: + history_out: wp.array2d(dtype=float), +): + """Insert current sensor values into history buffers for specific sensor IDs.""" + worldid, idx = wp.tid() + sid = sensor_ids[idx] + + hist = sensor_history[sid] + nsample = hist[0] + if nsample == 0: + return + + buf_offset = sensor_historyadr[sid] + dim = sensor_dim[sid] + interval_val = sensor_interval[sid] + period = interval_val[0] + t = time_in[worldid] + + if period > 0.0: + # interval mode: check if condition is satisfied + time_prev = history_out[worldid, buf_offset] # user slot stores time_prev + if time_prev + period <= t: + # advance time_prev by exact period + history_out[worldid, buf_offset] = time_prev + period + # insert sensor value + _history_insert_vector(worldid, buf_offset, nsample, dim, t, sensordata_in, sensor_adr[sid], history_out) + else: + _history_insert_vector(worldid, buf_offset, nsample, dim, t, sensordata_in, sensor_adr[sid], history_out) + + +@wp.kernel +def _apply_sensor_delay_kernel( + # Model: + sensor_dim: wp.array(dtype=int), + sensor_adr: wp.array(dtype=int), + sensor_history: wp.array(dtype=wp.vec2i), + sensor_historyadr: wp.array(dtype=int), + sensor_delay: wp.array(dtype=float), + sensor_interval: wp.array(dtype=wp.vec2), + # Data in: + time_in: wp.array(dtype=float), + history_in: wp.array2d(dtype=float), + # In: + sensor_ids: wp.array(dtype=int), + # Data out: + sensordata_out: wp.array2d(dtype=float), +): + """Apply delay/interval logic for sensors after computation.""" + worldid, idx = wp.tid() + sid = sensor_ids[idx] + + hist = sensor_history[sid] + nsample = hist[0] + if nsample <= 0: + return + + delay = sensor_delay[sid] + dim = sensor_dim[sid] + interp = hist[1] + buf_offset = sensor_historyadr[sid] + t = time_in[worldid] + + if delay > 0.0: + # delay > 0: read delayed value from buffer + _history_read_vector(sensor_adr[sid], history_in, worldid, buf_offset, nsample, dim, t - delay, interp, sensordata_out) + else: + # interval-only (delay == 0, interval > 0): check interval condition + interval_val = sensor_interval[sid] + period = interval_val[0] + if period > 0.0: + time_prev = history_in[worldid, buf_offset] # user slot + if time_prev + period > t: + # interval condition not satisfied: read from buffer + _history_read_vector(sensor_adr[sid], history_in, worldid, buf_offset, nsample, dim, t, interp, sensordata_out) + # else: interval condition satisfied, keep computed value + + +def read_ctrl_delayed(m: Model, d: Data, ctrl: wp.array2d(dtype=float)): + """Read delayed ctrl values for all actuators.""" + if m.nhistory == 0: + wp.copy(ctrl, d.ctrl) + return + + wp.launch( + _read_ctrl_delayed_kernel, + dim=(d.nworld, m.nu), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + m.actuator_delay, + d.time, + d.history, + d.ctrl, + ], + outputs=[ctrl], + ) + + +def insert_ctrl_history(m: Model, d: Data): + """Insert current ctrl values into history buffers.""" + if m.nhistory == 0 or m.nu == 0: + return + + wp.launch( + _insert_ctrl_history_kernel, + dim=(d.nworld, m.nu), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + d.time, + d.ctrl, + ], + outputs=[d.history], + ) + + +def apply_sensor_delay(m: Model, d: Data, sensorid: wp.array(dtype=int)): + """Apply delay/interval logic for given sensors after computation. + + Also inserts current (undelayed) sensor values into history buffers first. + This must be called AFTER sensor computation and BEFORE time advance. + """ + if m.nhistory == 0 or sensorid.shape[0] == 0: + return + + # First, insert current (fresh) sensor values into history buffers + wp.launch( + _insert_sensor_history_stage, + dim=(d.nworld, sensorid.shape[0]), + inputs=[ + m.sensor_dim, + m.sensor_adr, + m.sensor_history, + m.sensor_historyadr, + m.sensor_delay, + m.sensor_interval, + d.time, + d.sensordata, + sensorid, + ], + outputs=[d.history], + ) + + # Then, overwrite sensordata with delayed values + wp.launch( + _apply_sensor_delay_kernel, + dim=(d.nworld, sensorid.shape[0]), + inputs=[ + m.sensor_dim, + m.sensor_adr, + m.sensor_history, + m.sensor_historyadr, + m.sensor_delay, + m.sensor_interval, + d.time, + d.history, + sensorid, + ], + outputs=[d.sensordata], + ) + + +@wp.kernel +def _read_ctrl_kernel( + # Model: + actuator_history: wp.array(dtype=wp.vec2i), + actuator_historyadr: wp.array(dtype=int), + actuator_delay: wp.array(dtype=float), + # Data in: + time_in: wp.array(dtype=float), + history_in: wp.array2d(dtype=float), + ctrl_in: wp.array2d(dtype=float), + # In: + uid: int, + interp: int, + # Out: + result_out: wp.array(dtype=float), +): + """Read delayed ctrl for 1 actuator across all worlds.""" + worldid = wp.tid() + + hist = actuator_history[uid] + nsample = hist[0] + + if nsample == 0: + result_out[worldid] = ctrl_in[worldid, uid] + else: + interp_val = interp + if interp_val < 0: + interp_val = hist[1] + delay = actuator_delay[uid] + buf_offset = actuator_historyadr[uid] + t = time_in[worldid] - delay + result_out[worldid] = _history_read_scalar(history_in, worldid, buf_offset, nsample, t, interp_val) + + +def read_ctrl( + m: Model, + d: Data, + ctrlid: int, + time: wp.array(dtype=float), + interp: int, + result: wp.array2d(dtype=float), +): + """Read delayed ctrl for 1 actuator across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + ctrlid: actuator index. + time: query time per world (nworld,). + interp: interpolation order (-1=model default, 0=ZOH, 1=linear). + result: output buffer (nworld,). + """ + wp.launch( + _read_ctrl_kernel, + dim=(d.nworld,), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + m.actuator_delay, + time, + d.history, + d.ctrl, + ctrlid, + interp, + ], + outputs=[result], + ) + + +@wp.kernel +def _read_sensor_kernel( + # Model: + sensor_dim: wp.array(dtype=int), + sensor_adr: wp.array(dtype=int), + sensor_history: wp.array(dtype=wp.vec2i), + sensor_historyadr: wp.array(dtype=int), + sensor_delay: wp.array(dtype=float), + # Data in: + time_in: wp.array(dtype=float), + history_in: wp.array2d(dtype=float), + sensordata_in: wp.array2d(dtype=float), + # In: + sid: int, + interp: int, + # Out: + result_out: wp.array2d(dtype=float), +): + """Read delayed sensor for 1 sensor across all worlds.""" + worldid = wp.tid() + + hist = sensor_history[sid] + nsample = hist[0] + dim = sensor_dim[sid] + adr = sensor_adr[sid] + + if nsample == 0: + for i in range(dim): + result_out[worldid, i] = sensordata_in[worldid, adr + i] + else: + interp_val = interp + if interp_val < 0: + interp_val = hist[1] + delay = sensor_delay[sid] + buf_offset = sensor_historyadr[sid] + t = time_in[worldid] - delay + _history_read_vector( + adr, + history_in, + worldid, + buf_offset, + nsample, + dim, + t, + interp_val, + result_out, + ) + + +def read_sensor( + m: Model, + d: Data, + sensorid: int, + time: wp.array(dtype=float), + interp: int, + result: wp.array2d(dtype=float), +): + """Read delayed sensor for 1 sensor across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + sensorid: sensor index. + time: query time per world (nworld,). + interp: interpolation order (-1=model default, 0=ZOH, 1=linear). + result: output buffer (nworld, dim). + """ + wp.launch( + _read_sensor_kernel, + dim=(d.nworld,), + inputs=[ + m.sensor_dim, + m.sensor_adr, + m.sensor_history, + m.sensor_historyadr, + m.sensor_delay, + time, + d.history, + d.sensordata, + sensorid, + interp, + ], + outputs=[result], + ) + + +@wp.kernel +def _init_ctrl_history_kernel( + # kernel_analyzer: off + # Model: + actuator_history: wp.array(dtype=wp.vec2i), + actuator_historyadr: wp.array(dtype=int), + # In: + ctrlid: int, + times: wp.array(dtype=float), + values: wp.array2d(dtype=float), + has_times: int, + # Data out: + history_out: wp.array2d(dtype=float), + # kernel_analyzer: on +): + """Initialize history buffer for 1 actuator across all worlds.""" + worldid = wp.tid() + + nsample = actuator_history[ctrlid][0] + buf_offset = actuator_historyadr[ctrlid] + + # preserve user slot + user = history_out[worldid, buf_offset] + + # cursor = 0 (samples in order, newest at index nsample-1) + history_out[worldid, buf_offset + 1] = float(nsample - 1) + + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + nsample + + for i in range(nsample): + if has_times != 0: + history_out[worldid, times_offset + i] = times[i] + history_out[worldid, values_offset + i] = values[worldid, i] + + # restore user slot + history_out[worldid, buf_offset] = user + + +def init_ctrl_history( + m: Model, + d: Data, + ctrlid: int, + times: wp.array(dtype=float), + values: wp.array2d(dtype=float), +): + """Initialize history buffer for 1 actuator across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + ctrlid: actuator index. + times: timestamps or None (nworld,). + values: ctrl values (nworld, nsample). + """ + has_times = 0 if times is None else 1 + if times is None: + times = wp.empty(0, dtype=float) + + wp.launch( + _init_ctrl_history_kernel, + dim=(d.nworld,), + inputs=[ + m.actuator_history, + m.actuator_historyadr, + ctrlid, + times, + values, + has_times, + ], + outputs=[d.history], + ) + + +# kernel_analyzer: off +@wp.kernel +def _init_sensor_history_kernel( + # Model: + sensor_history: wp.array(dtype=wp.vec2i), + sensor_historyadr: wp.array(dtype=int), + sensor_dim_arr: wp.array(dtype=int), + # In: + sensorid: int, + times: wp.array(dtype=float), + values: wp.array2d(dtype=float), + phase: wp.array(dtype=float), + has_times: int, + # Data out: + history_out: wp.array2d(dtype=float), +): + # kernel_analyzer: on + """Initialize history buffer for 1 sensor across all worlds.""" + worldid = wp.tid() + + nsample = sensor_history[sensorid][0] + dim = sensor_dim_arr[sensorid] + buf_offset = sensor_historyadr[sensorid] + + # set user slot (phase = last computation time for interval sensors) + history_out[worldid, buf_offset] = phase[worldid] + + # cursor = 0 (samples in order, newest at index nsample-1) + history_out[worldid, buf_offset + 1] = float(nsample - 1) + + times_offset = buf_offset + 2 + values_offset = buf_offset + 2 + nsample + + for i in range(nsample): + if has_times != 0: + history_out[worldid, times_offset + i] = times[i] + for j in range(dim): + history_out[worldid, values_offset + i * dim + j] = values[worldid, i * dim + j] + + +def init_sensor_history( + m: Model, + d: Data, + sensorid: int, + times: wp.array(dtype=float), + values: wp.array2d(dtype=float), + phase: wp.array(dtype=float), +): + """Initialize history buffer for 1 sensor across all worlds. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + sensorid: sensor index. + times: timestamps or None (nworld,). + values: sensor values (nworld, nsample * dim). + phase: user slot value per world (nworld,). + """ + has_times = 0 if times is None else 1 + if times is None: + times = wp.empty(0, dtype=float) + + wp.launch( + _init_sensor_history_kernel, + dim=(d.nworld,), + inputs=[ + m.sensor_history, + m.sensor_historyadr, + m.sensor_dim, + sensorid, + times, + values, + phase, + has_times, + ], + outputs=[d.history], + ) diff --git a/mujoco_warp/_src/delay_test.py b/mujoco_warp/_src/delay_test.py new file mode 100644 index 000000000..2f7a210c1 --- /dev/null +++ b/mujoco_warp/_src/delay_test.py @@ -0,0 +1,471 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for actuator and sensor delay.""" + +import mujoco +import numpy as np +import warp as wp +from absl.testing import absltest + +from mujoco_warp import test_data +from mujoco_warp._src import delay +from mujoco_warp._src import forward + +_TOLERANCE = 1e-8 + + +class ActuatorDelayTest(absltest.TestCase): + """Test actuator delay, mirroring MuJoCo C tests.""" + + def test_actuator_delay(self): + """Test basic actuator delay with ZOH (default interp). + + delay=0.02, timestep=0.01, nsample=2. + Ctrl set to 10 should arrive after 2 timesteps. + """ + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + # delay = 0.02, timestep = 0.01 => 2-step delay + self.assertEqual(mjm.actuator_history[0, 0], 2) + + # set ctrl=10 + d.ctrl.numpy()[:] = 10.0 + wp.copy(d.ctrl, wp.array(np.full((1, 1), 10.0), dtype=float)) + + # step 1: delayed ctrl still 0 + forward.step(m, d) + act_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(act_force, 0.0, atol=_TOLERANCE, err_msg="step 0") + + # step 2: still 0 + forward.step(m, d) + act_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(act_force, 0.0, atol=_TOLERANCE, err_msg="step 1") + + # step 3: now ctrl=10 arrives + forward.step(m, d) + act_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(act_force, 10.0, atol=_TOLERANCE, err_msg="step 2") + + def test_actuator_delay_linear_interp(self): + """Test actuator delay with linear interpolation. + + delay=0.015, timestep=0.01, nsample=3, interp=linear. + """ + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertEqual(mjm.actuator_history[0, 0], 3) + self.assertEqual(mjm.actuator_history[0, 1], 1) # interp=linear + + # step 0: ctrl=10, expected force=0 + wp.copy(d.ctrl, wp.array(np.full((1, 1), 10.0), dtype=float)) + forward.step(m, d) + act_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(act_force, 0.0, atol=_TOLERANCE, err_msg="step 0") + + # step 1: ctrl=20, expected force=5 + wp.copy(d.ctrl, wp.array(np.full((1, 1), 20.0), dtype=float)) + forward.step(m, d) + act_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(act_force, 5.0, atol=_TOLERANCE, err_msg="step 1") + + # step 2: ctrl=30, expected force=15 + wp.copy(d.ctrl, wp.array(np.full((1, 1), 30.0), dtype=float)) + forward.step(m, d) + act_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(act_force, 15.0, atol=_TOLERANCE, err_msg="step 2") + + +class SensorDelayTest(absltest.TestCase): + """Test sensor delay, mirroring MuJoCo C tests.""" + + def test_sensor_delay(self): + """Test basic sensor delay with ZOH. + + delay=0.02, timestep=0.01, nsample=3. + """ + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertEqual(mjm.sensor_history[0, 0], 3) + np.testing.assert_allclose(mjm.sensor_delay[0], 0.02, atol=1e-10) + + # step 0: qpos=10 + wp.copy(d.qpos, wp.array(np.full((1, 1), 10.0), dtype=float)) + forward.step(m, d) + sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(sdata, 0.0, atol=_TOLERANCE, err_msg="step 0") + + # step 1: qpos=20 + wp.copy(d.qpos, wp.array(np.full((1, 1), 20.0), dtype=float)) + forward.step(m, d) + sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(sdata, 0.0, atol=_TOLERANCE, err_msg="step 1") + + # step 2: qpos=30, expecting value from step 0 (delay=2 steps) + wp.copy(d.qpos, wp.array(np.full((1, 1), 30.0), dtype=float)) + forward.step(m, d) + sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(sdata, 10.0, atol=_TOLERANCE, err_msg="step 2") + + # step 3: qpos=40, expecting value from step 1 + wp.copy(d.qpos, wp.array(np.full((1, 1), 40.0), dtype=float)) + forward.step(m, d) + sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(sdata, 20.0, atol=_TOLERANCE, err_msg="step 3") + + def test_sensor_delay_linear_interp(self): + """Test sensor delay with linear interpolation. + + delay=0.015, timestep=0.01, nsample=3, interp=linear. + """ + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + self.assertEqual(mjm.sensor_history[0, 0], 3) + self.assertEqual(mjm.sensor_history[0, 1], 1) + + # step 0: qpos=10, expected=0 + wp.copy(d.qpos, wp.array(np.full((1, 1), 10.0), dtype=float)) + forward.step(m, d) + sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(sdata, 0.0, atol=_TOLERANCE, err_msg="step 0") + + # step 1: qpos=20, expected=5 + wp.copy(d.qpos, wp.array(np.full((1, 1), 20.0), dtype=float)) + forward.step(m, d) + sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(sdata, 5.0, atol=_TOLERANCE, err_msg="step 1") + + # step 2: qpos=30, expected=15 + wp.copy(d.qpos, wp.array(np.full((1, 1), 30.0), dtype=float)) + forward.step(m, d) + sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(sdata, 15.0, atol=_TOLERANCE, err_msg="step 2") + + +class MujocoReferenceTest(absltest.TestCase): + """Test that MuJoCo Warp matches MuJoCo C reference for delays. + + Uses MuJoCo C as ground truth and compares outputs. + """ + + def test_actuator_delay_reference(self): + """Compare actuator delay against MuJoCo C reference.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + # set ctrl=10 in both + mjd.ctrl[0] = 10.0 + wp.copy(d.ctrl, wp.array(np.full((1, 1), 10.0), dtype=float)) + + # step both 4 times and compare + for i in range(4): + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + mj_force = mjd.actuator_force[0] + warp_force = d.actuator_force.numpy()[0, 0] + np.testing.assert_allclose(warp_force, mj_force, atol=_TOLERANCE, err_msg=f"actuator_force mismatch at step {i}") + + def test_sensor_delay_reference(self): + """Compare sensor delay against MuJoCo C reference.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + for i in range(5): + qpos_val = float((i + 1) * 10) + mjd.qpos[0] = qpos_val + wp.copy(d.qpos, wp.array(np.full((1, 1), qpos_val), dtype=float)) + + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + mj_sdata = mjd.sensordata[0] + warp_sdata = d.sensordata.numpy()[0, 0] + np.testing.assert_allclose(warp_sdata, mj_sdata, atol=_TOLERANCE, err_msg=f"sensordata mismatch at step {i}") + + +class PublicAPITest(absltest.TestCase): + """Test public delay API functions against MuJoCo C reference.""" + + def test_read_ctrl(self): + """Test read_ctrl matches mj_readCtrl.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + # step both with ctrl=10, then ctrl=20 + for ctrl_val in [10.0, 20.0, 30.0]: + mjd.ctrl[0] = ctrl_val + wp.copy(d.ctrl, wp.array(np.full((1, 1), ctrl_val), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + # compare read_ctrl at current time + time_arr = d.time + warp_result = wp.empty(d.nworld, dtype=float) + delay.read_ctrl(m, d, 0, time_arr, interp=-1, result=warp_result) + mj_result = mujoco.mj_readCtrl(mjm, mjd, 0, mjd.time, -1) + np.testing.assert_allclose( + warp_result.numpy()[0], + mj_result, + atol=_TOLERANCE, + err_msg="read_ctrl mismatch", + ) + + # compare with explicit interp=0 (ZOH) + warp_result_zoh = wp.empty(d.nworld, dtype=float) + delay.read_ctrl(m, d, 0, time_arr, interp=0, result=warp_result_zoh) + mj_result_zoh = mujoco.mj_readCtrl(mjm, mjd, 0, mjd.time, 0) + np.testing.assert_allclose( + warp_result_zoh.numpy()[0], + mj_result_zoh, + atol=_TOLERANCE, + err_msg="read_ctrl ZOH mismatch", + ) + + def test_read_sensor(self): + """Test read_sensor matches mj_readSensor.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + for i in range(4): + qpos_val = float((i + 1) * 10) + mjd.qpos[0] = qpos_val + wp.copy(d.qpos, wp.array(np.full((1, 1), qpos_val), dtype=float)) + mujoco.mj_step(mjm, mjd) + forward.step(m, d) + + # compare read_sensor at current time + dim = mjm.sensor_dim[0] + time_arr = d.time + result = wp.empty((d.nworld, dim), dtype=float) + delay.read_sensor(m, d, 0, time_arr, interp=-1, result=result) + + mj_result_buf = np.zeros(dim) + ptr = mujoco.mj_readSensor(mjm, mjd, 0, mjd.time, mj_result_buf, -1) + mj_val = ptr if ptr is not None else mj_result_buf + + np.testing.assert_allclose( + result.numpy()[0], + mj_val, + atol=_TOLERANCE, + err_msg="read_sensor mismatch", + ) + + def test_init_ctrl_history(self): + """Test init_ctrl_history sets buffer correctly.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + nsample = mjm.actuator_history[0, 0] + + # initialize with custom values + custom_times = np.array([0.1, 0.2, 0.3]) + custom_values = np.array([100.0, 200.0, 300.0]) + times_wp = wp.array(custom_times, dtype=float) + values_wp = wp.array(custom_values.reshape(1, -1), dtype=float) + delay.init_ctrl_history(m, d, 0, times_wp, values_wp) + + # also init MuJoCo C side + mujoco.mj_initCtrlHistory(mjm, mjd, 0, custom_times, custom_values) + + # read at a time in the buffer + query_time = 0.23 # between samples → ZOH should return value at t=0.2 + time_arr = wp.array([query_time], dtype=float) + warp_result = wp.empty(d.nworld, dtype=float) + delay.read_ctrl(m, d, 0, time_arr, interp=0, result=warp_result) + mj_result = mujoco.mj_readCtrl(mjm, mjd, 0, query_time, 0) + np.testing.assert_allclose( + warp_result.numpy()[0], + mj_result, + atol=_TOLERANCE, + err_msg="init_ctrl_history read mismatch", + ) + + def test_init_sensor_history(self): + """Test init_sensor_history sets buffer correctly.""" + xml = """ + + + """ + mjm, mjd, m, d = test_data.fixture(xml=xml) + + nsample = mjm.sensor_history[0, 0] + dim = mjm.sensor_dim[0] + + # initialize with custom values + custom_times = np.array([0.1, 0.2, 0.3]) + custom_values = np.array([100.0, 200.0, 300.0]) + phase = 0.05 + + times_wp = wp.array(custom_times, dtype=float) + values_wp = wp.array(custom_values.reshape(1, -1), dtype=float) + phase_wp = wp.array([phase], dtype=float) + delay.init_sensor_history(m, d, 0, times_wp, values_wp, phase=phase_wp) + + # also init MuJoCo C side + mujoco.mj_initSensorHistory(mjm, mjd, 0, custom_times, custom_values, phase) + + # read at a time in the buffer + query_time = 0.23 + time_arr = wp.array([query_time], dtype=float) + result = wp.empty((1, dim), dtype=float) + delay.read_sensor(m, d, 0, time_arr, interp=0, result=result) + + mj_result_buf = np.zeros(dim) + ptr = mujoco.mj_readSensor(mjm, mjd, 0, query_time, mj_result_buf, 0) + mj_val = ptr if ptr is not None else mj_result_buf + + np.testing.assert_allclose( + result.numpy()[0], + mj_val, + atol=_TOLERANCE, + err_msg="init_sensor_history read mismatch", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 032070be8..99952558a 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -19,6 +19,7 @@ from mujoco_warp._src import collision_driver from mujoco_warp._src import constraint +from mujoco_warp._src import delay from mujoco_warp._src import derivative from mujoco_warp._src import math from mujoco_warp._src import passive @@ -262,6 +263,9 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None) outputs=[d.qpos], ) + # advance history buffers before time advance + delay.insert_ctrl_history(m, d) + wp.launch( _next_time, dim=d.nworld, @@ -793,6 +797,13 @@ def fwd_actuation(m: Model, d: Data): d.qfrc_actuator.zero_() return + # read delayed ctrl (or direct copy if no delay) + if m.nhistory > 0: + ctrl = wp.empty((d.nworld, m.nu), dtype=float) + delay.read_ctrl_delayed(m, d, ctrl) + else: + ctrl = d.ctrl + wp.launch( _actuator_force, dim=(d.nworld, m.nu), @@ -817,7 +828,7 @@ def fwd_actuation(m: Model, d: Data): m.actuator_acc0, m.actuator_lengthrange, d.act, - d.ctrl, + ctrl, d.actuator_length, d.actuator_velocity, m.opt.disableflags & DisableBit.CLAMPCTRL, diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 85e9b3498..6ab7a8015 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -17,6 +17,7 @@ import warp as wp +from mujoco_warp._src import delay from mujoco_warp._src import math from mujoco_warp._src import ray from mujoco_warp._src import smooth @@ -898,6 +899,10 @@ def sensor_pos(m: Model, d: Data): ], ) + # apply sensor delay/interval for position sensors + delay.apply_sensor_delay(m, d, m.sensor_pos_adr) + delay.apply_sensor_delay(m, d, m.sensor_limitpos_adr) + @wp.func def _velocimeter( @@ -1437,6 +1442,10 @@ def sensor_vel(m: Model, d: Data): ], ) + # apply sensor delay/interval for velocity sensors + delay.apply_sensor_delay(m, d, m.sensor_vel_adr) + delay.apply_sensor_delay(m, d, m.sensor_limitvel_adr) + @wp.func def _accelerometer( @@ -2616,6 +2625,10 @@ def sensor_acc(m: Model, d: Data): ], ) + # apply sensor delay/interval for acceleration sensors + delay.apply_sensor_delay(m, d, m.sensor_acc_adr) + delay.apply_sensor_delay(m, d, m.sensor_limitfrc_adr) + @wp.kernel def _energy_pos_zero( diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 08b673be7..5d09b8789 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -814,6 +814,7 @@ class Model: nplugin: number of plugin instances ngravcomp: number of bodies with nonzero gravcomp nsensordata: number of elements in sensor data vector + nhistory: number of history buffer entries opt: physics options stat: model statistics qpos0: qpos values at default pose (*, nq) @@ -1017,6 +1018,9 @@ class Model: actuator_trnid: transmission id: joint, tendon, site (nu, 2) actuator_actadr: first activation address; -1: stateless (nu,) actuator_actnum: number of activation variables (nu,) + actuator_history: history buffer sizes (nu, 2) + actuator_historyadr: history buffer address (nu,) + actuator_delay: delay in seconds (nu,) actuator_ctrllimited: is control limited (nu,) actuator_forcelimited: is force limited (nu,) actuator_actlimited: is activation limited (nu,) @@ -1041,6 +1045,10 @@ class Model: sensor_dim: number of scalar outputs (nsensor,) sensor_adr: address in sensor array (nsensor,) sensor_cutoff: cutoff for real and positive; 0: ignore (nsensor,) + sensor_history: history buffer sizes (nsensor, 2) + sensor_historyadr: history buffer address (nsensor,) + sensor_delay: delay in seconds (nsensor,) + sensor_interval: sensor interval and phase (nsensor, 2) plugin: globally registered plugin slot number (nplugin,) plugin_attr: config attributes of geom plugin (nplugin, 3) M_rownnz: number of non-zeros in each row of qM (nv,) @@ -1177,6 +1185,7 @@ class Model: nplugin: int ngravcomp: int nsensordata: int + nhistory: int opt: Option stat: Statistic qpos0: array("*", "nq", float) @@ -1380,6 +1389,9 @@ class Model: actuator_trnid: array("nu", wp.vec2i) actuator_actadr: array("nu", int) actuator_actnum: array("nu", int) + actuator_history: array("nu", wp.vec2i) + actuator_historyadr: array("nu", int) + actuator_delay: array("nu", float) actuator_ctrllimited: array("nu", bool) actuator_forcelimited: array("nu", bool) actuator_actlimited: array("nu", bool) @@ -1404,6 +1416,10 @@ class Model: sensor_dim: array("nsensor", int) sensor_adr: array("nsensor", int) sensor_cutoff: array("nsensor", float) + sensor_history: array("nsensor", wp.vec2i) + sensor_historyadr: array("nsensor", int) + sensor_delay: array("nsensor", float) + sensor_interval: array("nsensor", wp.vec2) plugin: array("nplugin", int) plugin_attr: array("nplugin", wp.vec3f) M_rownnz: array("nv", int) @@ -1588,6 +1604,7 @@ class Data: qpos: position (nworld, nq) qvel: velocity (nworld, nv) act: actuator activation (nworld, na) + history: history buffer for delays (nworld, nhistory) qacc_warmstart: acceleration used for warmstart (nworld, nv) ctrl: control (nworld, nu) qfrc_applied: applied generalized force (nworld, nv) @@ -1679,6 +1696,7 @@ class Data: qpos: array("nworld", "nq", float) qvel: array("nworld", "nv", float) act: array("nworld", "na", float) + history: array("nworld", "nhistory", float) qacc_warmstart: array("nworld", "nv", float) ctrl: array("nworld", "nu", float) qfrc_applied: array("nworld", "nv", float)