diff --git a/contrib/render.py b/contrib/render.py index a70aa4e0b..fa0cb6af7 100644 --- a/contrib/render.py +++ b/contrib/render.py @@ -18,7 +18,7 @@ Usage: mjwarp-render [flags] Example: - mjwarp-render benchmark/humanoid/humanoid.xml --nworld=1 --cam=0 --width=512 --height=512 + mjwarp-render benchmarks/humanoid/humanoid.xml --nworld=1 --cam=0 --width=512 --height=512 """ import sys @@ -42,6 +42,7 @@ _HEIGHT = flags.DEFINE_integer("height", 512, "render height (pixels)") _RENDER_RGB = flags.DEFINE_bool("rgb", True, "render RGB image") _RENDER_DEPTH = flags.DEFINE_bool("depth", True, "render depth image") +_RENDER_SEG = flags.DEFINE_bool("seg", False, "render segmentation image") _USE_TEXTURES = flags.DEFINE_bool("textures", True, "use textures") _USE_SHADOWS = flags.DEFINE_bool("shadows", False, "use shadows") _DEVICE = flags.DEFINE_string("device", None, "override the default Warp device") @@ -207,6 +208,7 @@ def _main(argv: Sequence[str]): (render_width, render_height), _RENDER_RGB.value, _RENDER_DEPTH.value, + _RENDER_SEG.value, _USE_TEXTURES.value, _USE_SHADOWS.value, enabled_geom_groups=[0, 1, 2], diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 6de859326..a8e86b02c 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -46,6 +46,7 @@ from mujoco_warp._src.forward import rungekutta4 as rungekutta4 from mujoco_warp._src.forward import step1 as step1 from mujoco_warp._src.forward import step2 as step2 +from mujoco_warp._src.grad import COLLISION_GRAD_FIELDS as COLLISION_GRAD_FIELDS from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS from mujoco_warp._src.grad import SOLVER_GRAD_FIELDS as SOLVER_GRAD_FIELDS from mujoco_warp._src.grad import diff_forward as diff_forward diff --git a/mujoco_warp/_src/adjoint.py b/mujoco_warp/_src/adjoint.py index 10d79bc2c..70f8cf937 100644 --- a/mujoco_warp/_src/adjoint.py +++ b/mujoco_warp/_src/adjoint.py @@ -6,6 +6,8 @@ Import this module via ``grad.py`` dont import it directly """ +import os + import warp as wp from mujoco_warp._src import math @@ -15,6 +17,110 @@ from mujoco_warp._src.block_cholesky import create_blocked_cholesky_solve_func from mujoco_warp._src.warp_util import cache_kernel +# --------------------------------------------------------------------------- +# Phase 3: efc-level gradient kernels for collision chain +# --------------------------------------------------------------------------- + + +@wp.kernel +def _efc_J_grad_kernel( + # In: + v: wp.array2d(dtype=float), + efc_force: wp.array2d(dtype=float), + nefc: wp.array(dtype=int), + nv: int, + njmax: int, + # Out: + efc_J_grad_out: wp.array3d(dtype=float), +): + """Compute adj_efc_J[i, j] = v[j] * efc_force[i]. + + From KKT: F(qacc) = M*qacc - qfrc_smooth - J^T*f = 0 + The derivative of J^T*f w.r.t. J[i,j] is f[i] * delta, and the + adjoint vector v gives the sensitivity: adj_J[i,j] = v[j] * f[i]. + """ + worldid, efcid, dofid = wp.tid() + if efcid < nefc[worldid] and dofid < nv: + efc_J_grad_out[worldid, efcid, dofid] = v[worldid, dofid] * efc_force[worldid, efcid] + + +@wp.kernel +def _efc_pos_grad_kernel( + # In: + efc_aref_grad: wp.array2d(dtype=float), + contact_solref: wp.array(dtype=wp.vec2), + contact_solimp: wp.array(dtype=types.vec5), + contact_includemargin: wp.array(dtype=float), + contact_dist: wp.array(dtype=float), + contact_efc_address: wp.array2d(dtype=int), + contact_worldid: wp.array(dtype=int), + contact_type: wp.array(dtype=int), + nacon: wp.array(dtype=int), + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + # Out: + efc_pos_grad_out: wp.array2d(dtype=float), +): + """Compute adj_efc_pos from adj_efc_aref. + + From efc_aref = -k * imp * pos - b * vel, d(aref)/d(pos) = -k*imp. + So adj_efc_pos = adj_efc_aref * (-k * imp). + We iterate over contacts and their first dimension (normal direction). + """ + conid = wp.tid() + if conid >= nacon[0]: + return + if not (contact_type[conid] & 1): # ContactType.CONSTRAINT + return + + efcid = contact_efc_address[conid, 0] + if efcid < 0: + return + + worldid = contact_worldid[conid] + timestep = opt_timestep[worldid % opt_timestep.shape[0]] + + solref = contact_solref[conid] + solimp = contact_solimp[conid] + includemargin = contact_includemargin[conid] + pos_val = contact_dist[conid] - includemargin + + # Recompute k and imp (same as _efc_row) + timeconst = solref[0] + dampratio = solref[1] + dmin = solimp[0] + dmax = solimp[1] + width = solimp[2] + mid = solimp[3] + power = solimp[4] + + if not (opt_disableflags & types.DisableBit.REFSAFE): + timeconst = wp.max(timeconst, 2.0 * timestep) + + dmin = wp.clamp(dmin, types.MJ_MINIMP, types.MJ_MAXIMP) + dmax = wp.clamp(dmax, types.MJ_MINIMP, types.MJ_MAXIMP) + width = wp.max(types.MJ_MINVAL, width) + mid = wp.clamp(mid, types.MJ_MINIMP, types.MJ_MAXIMP) + power = wp.max(1.0, power) + + dmax_sq = dmax * dmax + k = 1.0 / (dmax_sq * timeconst * timeconst * dampratio * dampratio) + k = wp.where(solref[0] <= 0.0, -solref[0] / dmax_sq, k) + + imp_x = wp.abs(pos_val) / width + imp_a = (1.0 / wp.pow(mid, power - 1.0)) * wp.pow(imp_x, power) + imp_b = 1.0 - (1.0 / wp.pow(1.0 - mid, power - 1.0)) * wp.pow(1.0 - imp_x, power) + imp_y = wp.where(imp_x < mid, imp_a, imp_b) + imp = dmin + imp_y * (dmax - dmin) + imp = wp.clamp(imp, dmin, dmax) + imp = wp.where(imp_x > 1.0, dmax, imp) + + # d(aref)/d(pos) = -k * imp + daref_dpos = -k * imp + + adj_aref = efc_aref_grad[worldid, efcid] + efc_pos_grad_out[worldid, efcid] = adj_aref * daref_dpos + @wp.func_grad(math.quat_integrate) def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat): @@ -129,6 +235,17 @@ def _copy_grad_kernel( dst[worldid, dofid] = src[worldid, dofid] +@wp.kernel +def _accumulate_grad_kernel( + # In: + src: wp.array2d(dtype=float), + # Out: + dst: wp.array2d(dtype=float), +): + worldid, dofid = wp.tid() + dst[worldid, dofid] = dst[worldid, dofid] + src[worldid, dofid] + + @cache_kernel def _adjoint_cholesky_tile(nv: int): @wp.kernel(module="unique", enable_backward=False) @@ -182,9 +299,7 @@ def kernel( out: wp.array3d(dtype=float), ): worldid = wp.tid() - wp.static(create_blocked_cholesky_func(tile_size))( - H[worldid], nv_runtime, hfactor_tmp[worldid] - ) + wp.static(create_blocked_cholesky_func(tile_size))(H[worldid], nv_runtime, hfactor_tmp[worldid]) wp.static(create_blocked_cholesky_solve_func(tile_size, matrix_size))( hfactor_tmp[worldid], b[worldid], nv_runtime, out[worldid] ) @@ -219,9 +334,7 @@ def _solve_hessian_system(m: types.Model, d: types.Data, b, out): if d.solver_hfactor.shape[1] > 0: # Solve-only using stored Cholesky factor wp.launch_tiled( - _adjoint_cholesky_blocked( - types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad - ), + _adjoint_cholesky_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad), dim=d.nworld, inputs=[d.solver_hfactor, b_3d, m.nv], outputs=[out_3d], @@ -237,13 +350,9 @@ def _solve_hessian_system(m: types.Model, d: types.Data, b, out): inputs=[m.nv], outputs=[d.solver_h], ) - hfactor_tmp = wp.zeros( - (d.nworld, m.nv_pad, m.nv_pad), dtype=float - ) + hfactor_tmp = wp.zeros((d.nworld, m.nv_pad, m.nv_pad), dtype=float) wp.launch_tiled( - _adjoint_cholesky_full_blocked( - types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad - ), + _adjoint_cholesky_full_blocked(types.TILE_SIZE_JTDAJ_DENSE, m.nv_pad), dim=d.nworld, inputs=[d.solver_h, b_3d, m.nv, hfactor_tmp], outputs=[out_3d], @@ -251,39 +360,106 @@ def _solve_hessian_system(m: types.Model, d: types.Data, b, out): ) -def solver_implicit_adjoint(m: types.Model, d: types.Data): +def solver_implicit_adjoint(m: types.Model, d: types.Data, qacc_array=None, qacc_smooth_ref=None): """Implicit differentiation adjoint for constraint solver. - Called during tape backward. Reads d.qacc.grad (set by downstream), - solves H*v = adj_qacc, writes d.qacc_smooth.grad = M*v. + Called during tape backward. Reads qacc_array.grad (set by downstream + integrator adjoint), solves H*v = adj_qacc, accumulates into + qacc_smooth_ref.grad += M*v. + + Args: + qacc_array: The array whose .grad contains the incoming adjoint. + Defaults to d.qacc when called from diff_forward(). + Integrators pass their local qacc array when it differs + from d.qacc (e.g. euler with implicit damping). + qacc_smooth_ref: The qacc_smooth array whose .grad receives the + accumulated adjoint. Captured at record time for + correct gradient isolation when intermediate arrays + are cloned between substeps. Defaults to d.qacc_smooth. """ nv = m.nv if nv == 0: return + if qacc_array is None: + qacc_array = d.qacc + + if qacc_smooth_ref is None: + qacc_smooth_ref = d.qacc_smooth + + adj_qacc = qacc_array.grad + if adj_qacc is None: + return + + if os.environ.get("MJW_DEBUG_ADJOINT") == "1": + import torch + + adj_norm = wp.to_torch(adj_qacc).norm().item() + print(f"[adjoint] |adj_qacc|={adj_norm:.6e}, njmax={d.njmax}") + if d.njmax == 0: - # Solver was identity (qacc = qacc_smooth), copy adjoint through + # Solver was identity (qacc = qacc_smooth), accumulate adjoint through wp.launch( - _copy_grad_kernel, + _accumulate_grad_kernel, dim=(d.nworld, nv), - inputs=[d.qacc.grad], - outputs=[d.qacc_smooth.grad], + inputs=[adj_qacc], + outputs=[qacc_smooth_ref.grad], ) return if m.opt.solver != types.SolverType.NEWTON: # CG solver: no Hessian stored, fall back to identity wp.launch( - _copy_grad_kernel, + _accumulate_grad_kernel, dim=(d.nworld, nv), - inputs=[d.qacc.grad], - outputs=[d.qacc_smooth.grad], + inputs=[adj_qacc], + outputs=[qacc_smooth_ref.grad], ) return # Solve H * v = adj_qacc v = wp.zeros((d.nworld, m.nv_pad), dtype=float) - _solve_hessian_system(m, d, d.qacc.grad, v) + _solve_hessian_system(m, d, adj_qacc, v) + + # adj_qacc_smooth += M * v (accumulate, not overwrite) + tmp = wp.zeros((d.nworld, m.nv_pad), dtype=float) + support.mul_m(m, d, tmp, v) + wp.launch( + _accumulate_grad_kernel, + dim=(d.nworld, nv), + inputs=[tmp], + outputs=[qacc_smooth_ref.grad], + ) - # adj_qacc_smooth = M * v - support.mul_m(m, d, d.qacc_smooth.grad, v) + # Phase 3: compute efc-level gradients for collision chain + if d.njmax > 0: + efc_J = d.efc.J + if hasattr(efc_J, "grad") and efc_J.grad is not None: + wp.launch( + _efc_J_grad_kernel, + dim=(d.nworld, d.njmax_pad, m.nv_pad), + inputs=[v, d.efc.force, d.nefc, m.nv, d.njmax], + outputs=[efc_J.grad], + ) + + efc_aref = d.efc.aref + efc_pos = d.efc.pos + if hasattr(efc_aref, "grad") and efc_aref.grad is not None and hasattr(efc_pos, "grad") and efc_pos.grad is not None: + wp.launch( + _efc_pos_grad_kernel, + dim=d.naconmax, + inputs=[ + efc_aref.grad, + d.contact.solref, + d.contact.solimp, + d.contact.includemargin, + d.contact.dist, + d.contact.efc_address, + d.contact.worldid, + d.contact.type, + d.nacon, + m.opt.timestep, + m.opt.disableflags, + ], + outputs=[efc_pos.grad], + ) diff --git a/mujoco_warp/_src/bvh.py b/mujoco_warp/_src/bvh.py index 0efdc1344..60c4093cf 100644 --- a/mujoco_warp/_src/bvh.py +++ b/mujoco_warp/_src/bvh.py @@ -189,12 +189,12 @@ def _compute_bvh_bounds( upper_out: wp.array(dtype=wp.vec3), group_out: wp.array(dtype=int), ): - world_id, geom_local_id = wp.tid() + worldid, geom_local_id = wp.tid() geom_id = enabled_geom_ids[geom_local_id] - pos = geom_xpos_in[world_id, geom_id] - rot = geom_xmat_in[world_id, geom_id] - size = geom_size[world_id % geom_size.shape[0], geom_id] + pos = geom_xpos_in[worldid, geom_id] + rot = geom_xmat_in[worldid, geom_id] + size = geom_size[worldid % geom_size.shape[0], geom_id] type = geom_type[geom_id] # TODO: Investigate branch elimination with static loop unrolling @@ -218,9 +218,9 @@ def _compute_bvh_bounds( hfield_center = pos + rot[:, 2] * size[2] lower_bound, upper_bound = _compute_box_bounds(hfield_center, rot, size) - lower_out[world_id * bvh_ngeom + geom_local_id] = lower_bound - upper_out[world_id * bvh_ngeom + geom_local_id] = upper_bound - group_out[world_id * bvh_ngeom + geom_local_id] = world_id + lower_out[worldid * bvh_ngeom + geom_local_id] = lower_bound + upper_out[worldid * bvh_ngeom + geom_local_id] = upper_bound + group_out[worldid * bvh_ngeom + geom_local_id] = worldid @wp.kernel @@ -235,14 +235,70 @@ def compute_bvh_group_roots( group_root_out[tid] = root +@wp.kernel +def _compute_flex_bvh_bounds( + # Model: + flex_vertadr: wp.array(dtype=int), + flex_vertnum: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), + # Data in: + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + # In: + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + bvh_ngeom: int, + total_bvh_size: int, + # Out: + lower_out: wp.array(dtype=wp.vec3), + upper_out: wp.array(dtype=wp.vec3), + group_out: wp.array(dtype=int), +): + worldid, flexlocalid = wp.tid() + + flex_id = flex_geom_flexid[flexlocalid] + edge_id = flex_geom_edgeid[flexlocalid] + out_idx = worldid * total_bvh_size + bvh_ngeom + flexlocalid + radius = flex_radius[flex_id] + inflate = wp.vec3(radius, radius, radius) + + if edge_id >= 0: # capsule (1D edge) + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flex_id] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + lower_out[out_idx] = wp.min(v0, v1) - inflate + upper_out[out_idx] = wp.max(v0, v1) + inflate + else: # mesh (2D/3D) + vert_adr = flex_vertadr[flex_id] + nvert = flex_vertnum[flex_id] + min_bound = wp.vec3(MJ_MAXVAL, MJ_MAXVAL, MJ_MAXVAL) + max_bound = wp.vec3(-MJ_MAXVAL, -MJ_MAXVAL, -MJ_MAXVAL) + for i in range(nvert): + v = flexvert_xpos_in[worldid, vert_adr + i] + min_bound = wp.min(min_bound, v) + max_bound = wp.max(max_bound, v) + lower_out[out_idx] = min_bound - inflate + upper_out[out_idx] = max_bound + inflate + + group_out[out_idx] = worldid + + def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, nworld: int): """Build a global BVH for all geometries in all worlds.""" + total_bvh_size = rc.bvh_ngeom + rc.bvh_nflexgeom + geom_type = wp.array(mjm.geom_type, dtype=int) geom_dataid = wp.array(mjm.geom_dataid, dtype=int) geom_size = wp.array(np.tile(mjm.geom_size[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xpos = wp.array(np.tile(mjd.geom_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xmat = wp.array(np.tile(mjd.geom_xmat.reshape(mjm.ngeom, 3, 3)[np.newaxis, :, :, :], (nworld, 1, 1, 1)), dtype=wp.mat33) + flex_vertadr = wp.array(mjm.flex_vertadr, dtype=int) + flex_vertnum = wp.array(mjm.flex_vertnum, dtype=int) + flex_edge = wp.array(mjm.flex_edge, dtype=wp.vec2i) + flex_radius = wp.array(mjm.flex_radius, dtype=float) + wp.launch( kernel=_compute_bvh_bounds, dim=(nworld, rc.bvh_ngeom), @@ -252,7 +308,7 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, geom_size, geom_xpos, geom_xmat, - rc.bvh_ngeom, + total_bvh_size, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, @@ -262,6 +318,26 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, ], ) + flexvert_xpos = wp.array(np.tile(mjd.flexvert_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) + wp.launch( + kernel=_compute_flex_bvh_bounds, + dim=(nworld, rc.bvh_nflexgeom), + inputs=[ + flex_vertadr, + flex_vertnum, + flex_edge, + flex_radius, + flexvert_xpos, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, + rc.bvh_ngeom, + total_bvh_size, + rc.lower, + rc.upper, + rc.group, + ], + ) + bvh = wp.Bvh(rc.lower, rc.upper, groups=rc.group, constructor="sah") # BVH handle must be stored to avoid garbage collection @@ -277,6 +353,8 @@ def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): + total_bvh_size = rc.bvh_ngeom + rc.bvh_nflexgeom + wp.launch( kernel=_compute_bvh_bounds, dim=(d.nworld, rc.bvh_ngeom), @@ -286,7 +364,7 @@ def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): m.geom_size, d.geom_xpos, d.geom_xmat, - rc.bvh_ngeom, + total_bvh_size, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, @@ -296,6 +374,26 @@ def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): ], ) + if rc.bvh_nflexgeom > 0: + wp.launch( + kernel=_compute_flex_bvh_bounds, + dim=(d.nworld, rc.bvh_nflexgeom), + inputs=[ + m.flex_vertadr, + m.flex_vertnum, + m.flex_edge, + m.flex_radius, + d.flexvert_xpos, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, + rc.bvh_ngeom, + total_bvh_size, + rc.lower, + rc.upper, + rc.group, + ], + ) + rc.bvh.refit() @@ -500,6 +598,12 @@ def build_hfield_bvh( @wp.kernel def accumulate_flex_vertex_normals( # Model: + nflex: int, + flex_dim: wp.array(dtype=int), + flex_vertadr: wp.array(dtype=int), + flex_elemadr: wp.array(dtype=int), + flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), @@ -509,10 +613,22 @@ def accumulate_flex_vertex_normals( """Accumulate per-vertex normals by summing adjacent face normals.""" worldid, elemid = wp.tid() - elem_base = elemid * 3 - i0 = flex_elem[elem_base + 0] - i1 = flex_elem[elem_base + 1] - i2 = flex_elem[elem_base + 2] + for i in range(nflex): + locid = elemid - flex_elemadr[i] + if locid >= 0 and locid < flex_elemnum[i]: + f = i + break + + if flex_dim[f] == 1 or flex_dim[f] == 3: + return + + local_elemid = elemid - flex_elemadr[f] + elem_adr = flex_elemdataadr[f] + vert_adr = flex_vertadr[f] + elem_base = elem_adr + local_elemid * 3 + i0 = vert_adr + flex_elem[elem_base + 0] + i1 = vert_adr + flex_elem[elem_base + 1] + i2 = vert_adr + flex_elem[elem_base + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] @@ -718,12 +834,11 @@ def _build_flex_3d_shells( @wp.kernel -def _update_flex_face_points( +def _update_flex_2d_face_points( # Model: - nflex: int, - flex_dim: wp.array(dtype=int), flex_vertadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_shelldataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), flex_shell: wp.array(dtype=int), @@ -732,149 +847,150 @@ def _update_flex_face_points( flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: flexvert_norm_in: wp.array2d(dtype=wp.vec3), - flex_elemdataadr: wp.array(dtype=int), - flex_faceadr: wp.array(dtype=int), - flex_workadr: wp.array(dtype=int), - flex_worknum: wp.array(dtype=int), - nfaces: int, + flex_id: int, + nface: int, smooth: bool, # Out: face_point_out: wp.array(dtype=wp.vec3), ): worldid, workid = wp.tid() - # identify which flex this work item belongs to - f = int(0) - locid = int(0) - for i in range(nflex): - locid = workid - flex_workadr[i] - if locid >= 0 and locid < flex_worknum[i]: - f = i - break - - dim = flex_dim[f] - face_offset = flex_faceadr[f] - world_face_offset = worldid * nfaces - vert_adr = flex_vertadr[f] + elem_adr = flex_elemdataadr[flex_id] + vert_adr = flex_vertadr[flex_id] + radius = flex_radius[flex_id] + nelem = flex_elemnum[flex_id] + world_face_offset = worldid * nface - if dim == 2: - radius = flex_radius[f] - elem_count = flex_elemnum[f] - - if locid < elem_count: - # 2D element faces - elemid = locid - elem_adr = flex_elemdataadr[f] - ebase = elem_adr + elemid * 3 - i0 = vert_adr + flex_elem[ebase + 0] - i1 = vert_adr + flex_elem[ebase + 1] - i2 = vert_adr + flex_elem[ebase + 2] - - v0 = flexvert_xpos_in[worldid, i0] - v1 = flexvert_xpos_in[worldid, i1] - v2 = flexvert_xpos_in[worldid, i2] - - # TODO: Use static conditional - if smooth: - n0 = flexvert_norm_in[worldid, i0] - n1 = flexvert_norm_in[worldid, i1] - n2 = flexvert_norm_in[worldid, i2] - else: - face_nrm = wp.cross(v1 - v0, v2 - v0) - face_nrm = wp.normalize(face_nrm) - n0 = face_nrm - n1 = face_nrm - n2 = face_nrm - - p0_pos = v0 + radius * n0 - p1_pos = v1 + radius * n1 - p2_pos = v2 + radius * n2 - - p0_neg = v0 - radius * n0 - p1_neg = v1 - radius * n1 - p2_neg = v2 - radius * n2 - - face_id0 = world_face_offset + face_offset + (2 * elemid) - base0 = face_id0 * 3 - face_point_out[base0 + 0] = p0_pos - face_point_out[base0 + 1] = p1_pos - face_point_out[base0 + 2] = p2_pos - - face_id1 = world_face_offset + face_offset + (2 * elemid + 1) - base1 = face_id1 * 3 - face_point_out[base1 + 0] = p0_neg - face_point_out[base1 + 1] = p1_neg - face_point_out[base1 + 2] = p2_neg - else: - # 2D shell faces - shellid = locid - elem_count - shell_adr = flex_shelldataadr[f] - sbase = shell_adr + 2 * shellid - i0 = vert_adr + flex_shell[sbase + 0] - i1 = vert_adr + flex_shell[sbase + 1] + if workid < nelem: + # 2D element faces + elemid = workid + ebase = elem_adr + elemid * 3 + i0 = vert_adr + flex_elem[ebase + 0] + i1 = vert_adr + flex_elem[ebase + 1] + i2 = vert_adr + flex_elem[ebase + 2] - v0 = flexvert_xpos_in[worldid, i0] - v1 = flexvert_xpos_in[worldid, i1] + v0 = flexvert_xpos_in[worldid, i0] + v1 = flexvert_xpos_in[worldid, i1] + v2 = flexvert_xpos_in[worldid, i2] + # TODO: Use static conditional + if smooth: n0 = flexvert_norm_in[worldid, i0] n1 = flexvert_norm_in[worldid, i1] - - shell_face_offset = face_offset + (2 * elem_count) - face_id0 = world_face_offset + shell_face_offset + (2 * shellid) - base0 = face_id0 * 3 - face_point_out[base0 + 0] = v0 + radius * n0 - face_point_out[base0 + 1] = v1 - radius * n1 - face_point_out[base0 + 2] = v1 + radius * n1 - - face_id1 = world_face_offset + shell_face_offset + (2 * shellid + 1) - base1 = face_id1 * 3 - face_point_out[base1 + 0] = v1 - radius * n1 - face_point_out[base1 + 1] = v0 + radius * n0 - face_point_out[base1 + 2] = v0 - radius * n0 + n2 = flexvert_norm_in[worldid, i2] + else: + face_nrm = wp.cross(v1 - v0, v2 - v0) + face_nrm = wp.normalize(face_nrm) + n0 = face_nrm + n1 = face_nrm + n2 = face_nrm + + p0_pos = v0 + radius * n0 + p1_pos = v1 + radius * n1 + p2_pos = v2 + radius * n2 + + p0_neg = v0 - radius * n0 + p1_neg = v1 - radius * n1 + p2_neg = v2 - radius * n2 + + face_id0 = world_face_offset + (2 * elemid) + base0 = face_id0 * 3 + face_point_out[base0 + 0] = p0_pos + face_point_out[base0 + 1] = p1_pos + face_point_out[base0 + 2] = p2_pos + + face_id1 = world_face_offset + (2 * elemid + 1) + base1 = face_id1 * 3 + face_point_out[base1 + 0] = p0_neg + face_point_out[base1 + 1] = p1_neg + face_point_out[base1 + 2] = p2_neg else: - # 3D shell faces - shellid = locid - shell_adr = flex_shelldataadr[f] - sbase = shell_adr + shellid * 3 + # 2D shell faces + shell_adr = flex_shelldataadr[flex_id] + shellid = workid - nelem + sbase = shell_adr + 2 * shellid i0 = vert_adr + flex_shell[sbase + 0] i1 = vert_adr + flex_shell[sbase + 1] - i2 = vert_adr + flex_shell[sbase + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] - v2 = flexvert_xpos_in[worldid, i2] - face_id = world_face_offset + face_offset + shellid - fbase = face_id * 3 + n0 = flexvert_norm_in[worldid, i0] + n1 = flexvert_norm_in[worldid, i1] - face_point_out[fbase + 0] = v0 - face_point_out[fbase + 1] = v1 - face_point_out[fbase + 2] = v2 + shell_face_offset = 2 * nelem + face_id0 = world_face_offset + shell_face_offset + (2 * shellid) + base0 = face_id0 * 3 + face_point_out[base0 + 0] = v0 + radius * n0 + face_point_out[base0 + 1] = v1 - radius * n1 + face_point_out[base0 + 2] = v1 + radius * n1 + face_id1 = world_face_offset + shell_face_offset + (2 * shellid + 1) + base1 = face_id1 * 3 + face_point_out[base1 + 0] = v1 - radius * n1 + face_point_out[base1 + 1] = v0 + radius * n0 + face_point_out[base1 + 2] = v0 - radius * n0 -def build_flex_bvh( - mjm: mujoco.MjModel, mjd: mujoco.MjData, nworld: int, constructor: str = "sah", leaf_size: int = 2 -) -> tuple[wp.Mesh, wp.array, wp.array, wp.array, wp.array, wp.array, int]: - """Create a Warp mesh BVH from flex data.""" - if (mjm.flex_dim == 1).any(): - raise ValueError("1D Flex objects are not currently supported.") - nflex = mjm.nflex +@wp.kernel +def _update_flex_3d_face_points( + # Model: + flex_vertadr: wp.array(dtype=int), + flex_shelldataadr: wp.array(dtype=int), + flex_shell: wp.array(dtype=int), + # Data in: + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), + # In: + flex_id: int, + nface: int, + # Out: + face_point_out: wp.array(dtype=wp.vec3), +): + worldid, shellid = wp.tid() + + shell_adr = flex_shelldataadr[flex_id] + vert_adr = flex_vertadr[flex_id] + + face_id = worldid * nface + shellid + fbase = face_id * 3 + + sbase = shell_adr + shellid * 3 + i0 = vert_adr + flex_shell[sbase + 0] + i1 = vert_adr + flex_shell[sbase + 1] + i2 = vert_adr + flex_shell[sbase + 2] + + face_point_out[fbase + 0] = flexvert_xpos_in[worldid, i0] + face_point_out[fbase + 1] = flexvert_xpos_in[worldid, i1] + face_point_out[fbase + 2] = flexvert_xpos_in[worldid, i2] + + +def build_flex_bvh( + mjm: mujoco.MjModel, + mjd: mujoco.MjData, + nworld: int, + flex_id: int, + constructor: str = "sah", + leaf_size: int = 2, +) -> tuple[wp.Mesh, wp.array, wp.array, wp.array, int]: + """Create a Warp mesh BVH for a single 2D or 3D flex.""" nflexvert = mjm.nflexvert - nflexelemdata = len(mjm.flex_elem) + flex_dim = wp.array(mjm.flex_dim, dtype=int) + flex_elemadr = wp.array(mjm.flex_elemadr, dtype=int) + flex_elemnum = wp.array(mjm.flex_elemnum, dtype=int) flex_elem = wp.array(mjm.flex_elem, dtype=int) + flex_elemdataadr = wp.array(mjm.flex_elemdataadr, dtype=int) + flex_vertadr = wp.array(mjm.flex_vertadr, dtype=int) flexvert_xpos = wp.array(np.tile(mjd.flexvert_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) - flex_faceadr = [0] - for f in range(nflex): - if mjm.flex_dim[f] == 2: - flex_faceadr.append(flex_faceadr[-1] + 2 * mjm.flex_elemnum[f] + 2 * mjm.flex_shellnum[f]) - elif mjm.flex_dim[f] == 3: - flex_faceadr.append(flex_faceadr[-1] + mjm.flex_shellnum[f]) + dim = int(mjm.flex_dim[flex_id]) + nelem = int(mjm.flex_elemnum[flex_id]) + nshell = int(mjm.flex_shellnum[flex_id]) - nface = int(flex_faceadr[-1]) - flex_faceadr = flex_faceadr[:-1] + if dim == 2: + nface = 2 * nelem + 2 * nshell + else: + nface = nshell face_point = wp.empty(nface * 3 * nworld, dtype=wp.vec3) face_index = wp.empty(nface * 3 * nworld, dtype=wp.int32) @@ -885,8 +1001,8 @@ def build_flex_bvh( wp.launch( kernel=accumulate_flex_vertex_normals, - dim=(nworld, nflexelemdata // 3), - inputs=[flex_elem, flexvert_xpos], + dim=(nworld, mjm.nflexelem), + inputs=[mjm.nflex, flex_dim, flex_vertadr, flex_elemadr, flex_elemnum, flex_elemdataadr, flex_elem, flexvert_xpos], outputs=[flexvert_norm], ) @@ -896,60 +1012,56 @@ def build_flex_bvh( inputs=[flexvert_norm], ) - for f in range(nflex): - dim = mjm.flex_dim[f] - elem_adr = mjm.flex_elemdataadr[f] - nelem = mjm.flex_elemnum[f] - shell_adr = mjm.flex_shelldataadr[f] - nshell = mjm.flex_shellnum[f] - vert_adr = mjm.flex_vertadr[f] + elem_adr = mjm.flex_elemdataadr[flex_id] + shell_adr = mjm.flex_shelldataadr[flex_id] + vert_adr = mjm.flex_vertadr[flex_id] - if dim == 2: - wp.launch( - kernel=_build_flex_2d_elements, - dim=(nworld, nelem), - inputs=[ - flex_elem, - flexvert_xpos, - flexvert_norm, - elem_adr, - vert_adr, - flex_faceadr[f], - mjm.flex_radius[f], - nface, - ], - outputs=[face_point, face_index, group], - ) - - wp.launch( - kernel=_build_flex_2d_sides, - dim=(nworld, nshell), - inputs=[ - flex_shell, - flexvert_xpos, - flexvert_norm, - shell_adr, - vert_adr, - flex_faceadr[f] + 2 * nelem, - mjm.flex_radius[f], - nface, - ], - outputs=[face_point, face_index, group], - ) - elif dim == 3: - wp.launch( - kernel=_build_flex_3d_shells, - dim=(nworld, nshell), - inputs=[ - flex_shell, - flexvert_xpos, - shell_adr, - vert_adr, - flex_faceadr[f], - nface, - ], - outputs=[face_point, face_index, group], - ) + if dim == 2: + wp.launch( + kernel=_build_flex_2d_elements, + dim=(nworld, nelem), + inputs=[ + flex_elem, + flexvert_xpos, + flexvert_norm, + elem_adr, + vert_adr, + 0, # face_offset + mjm.flex_radius[flex_id], + nface, + ], + outputs=[face_point, face_index, group], + ) + + wp.launch( + kernel=_build_flex_2d_sides, + dim=(nworld, nshell), + inputs=[ + flex_shell, + flexvert_xpos, + flexvert_norm, + shell_adr, + vert_adr, + 2 * nelem, # face_offset + mjm.flex_radius[flex_id], + nface, + ], + outputs=[face_point, face_index, group], + ) + elif dim == 3: + wp.launch( + kernel=_build_flex_3d_shells, + dim=(nworld, nshell), + inputs=[ + flex_shell, + flexvert_xpos, + shell_adr, + vert_adr, + 0, # face_offset + nface, + ], + outputs=[face_point, face_index, group], + ) flex_mesh = wp.Mesh( points=face_point, @@ -967,24 +1079,23 @@ def build_flex_bvh( outputs=[group_root], ) - return ( - flex_mesh, - face_point, - group_root, - flex_shell, - flex_faceadr, - nface, - ) + return flex_mesh, group_root def refit_flex_bvh(m: Model, d: Data, rc: RenderContext): - """Refit the flex BVH.""" + """Refit per-flex BVHs.""" flexvert_norm = wp.zeros(d.flexvert_xpos.shape, dtype=wp.vec3) wp.launch( kernel=accumulate_flex_vertex_normals, - dim=(d.nworld, m.nflexelemdata // 3), + dim=(d.nworld, m.nflexelem), inputs=[ + m.nflex, + m.flex_dim, + m.flex_vertadr, + m.flex_elemadr, + m.flex_elemnum, + m.flex_elemdataadr, m.flex_elem, d.flexvert_xpos, ], @@ -993,32 +1104,49 @@ def refit_flex_bvh(m: Model, d: Data, rc: RenderContext): wp.launch( kernel=normalize_vertex_normals, - dim=(d.nworld, m.nflexvert), + dim=(d.nworld, d.flexvert_xpos.shape[1]), inputs=[flexvert_norm], ) - wp.launch( - kernel=_update_flex_face_points, - dim=(d.nworld, rc.flex_nwork), - inputs=[ - m.nflex, - m.flex_dim, - m.flex_vertadr, - m.flex_elemnum, - m.flex_shelldataadr, - m.flex_elem, - m.flex_shell, - m.flex_radius, - d.flexvert_xpos, - flexvert_norm, - rc.flex_elemdataadr, - rc.flex_faceadr, - rc.flex_workadr, - rc.flex_worknum, - rc.flex_nface, - rc.flex_render_smooth, - ], - outputs=[rc.flex_face_point], - ) + for i in range(m.nflex): + if rc.flex_dim_np[i] == 1: + continue + mesh = rc.flex_mesh_registry[i] + nface = mesh.points.shape[0] // (3 * d.nworld) + + if rc.flex_dim_np[i] == 2: + wp.launch( + kernel=_update_flex_2d_face_points, + dim=(d.nworld, nface // 2), + inputs=[ + m.flex_vertadr, + m.flex_elemnum, + m.flex_elemdataadr, + m.flex_shelldataadr, + m.flex_elem, + m.flex_shell, + m.flex_radius, + d.flexvert_xpos, + flexvert_norm, + i, + nface, + rc.flex_render_smooth, + ], + outputs=[mesh.points], + ) + else: + wp.launch( + kernel=_update_flex_3d_face_points, + dim=(d.nworld, nface), + inputs=[ + m.flex_vertadr, + m.flex_shelldataadr, + m.flex_shell, + d.flexvert_xpos, + i, + nface, + ], + outputs=[mesh.points], + ) - rc.flex_mesh.refit() + mesh.refit() diff --git a/mujoco_warp/_src/bvh_test.py b/mujoco_warp/_src/bvh_test.py index ee15d65b5..5241334d3 100644 --- a/mujoco_warp/_src/bvh_test.py +++ b/mujoco_warp/_src/bvh_test.py @@ -33,9 +33,12 @@ def _assert_eq(a, b, name): @dataclasses.dataclass class MinimalRenderContext: bvh_ngeom: int + bvh_nflexgeom: int enabled_geom_ids: wp.array mesh_bounds_size: wp.array hfield_bounds_size: wp.array + flex_geom_flexid: wp.array + flex_geom_edgeid: wp.array lower: wp.array upper: wp.array group: wp.array @@ -53,9 +56,12 @@ def _create_minimal_context(mjm, nworld, enabled_geom_groups=None): return MinimalRenderContext( bvh_ngeom=bvh_ngeom, + bvh_nflexgeom=0, enabled_geom_ids=wp.array(geom_enabled_idx, dtype=int), mesh_bounds_size=wp.zeros(max(mjm.nmesh, 1), dtype=wp.vec3), hfield_bounds_size=wp.zeros(max(mjm.nhfield, 1), dtype=wp.vec3), + flex_geom_flexid=wp.zeros(max(mjm.nflex, 1), dtype=int), + flex_geom_edgeid=wp.zeros(max(mjm.nflex, 1), dtype=int), lower=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), upper=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), group=wp.zeros(nworld * bvh_ngeom, dtype=int), @@ -211,12 +217,18 @@ def test_accumulate_flex_vertex_normals(self): dtype=wp.vec3, ) flex_elem = wp.array([0, 1, 2, 1, 3, 2], dtype=int) + flex_elemdataadr = wp.array([0], dtype=int) + flex_elemadr = wp.array([0], dtype=int) + flex_elemnum = wp.array([len(flex_elem)], dtype=int) + flex_vertadr = wp.array([0], dtype=int) + flex_dim = wp.array([2], dtype=int) + flex_id = 0 flexvert_norm = wp.zeros((nworld, nvert), dtype=wp.vec3) wp.launch( kernel=bvh.accumulate_flex_vertex_normals, dim=(nworld, nelem), - inputs=[flex_elem, flexvert_xpos], + inputs=[1, flex_dim, flex_vertadr, flex_elemadr, flex_elemnum, flex_elemdataadr, flex_elem, flexvert_xpos], outputs=[flexvert_norm], ) @@ -252,7 +264,7 @@ def test_build_flex_bvh(self): mjm, mjd, m, d = test_data.fixture("flex/floppy.xml") - flex_mesh, face_point, group_root, flex_shell, flex_faceadr, nface = bvh.build_flex_bvh(mjm, mjd, 1) + flex_mesh, face_point, flex_shell, group_root, nface = bvh.build_flex_bvh(mjm, mjd, 1, 0) self.assertNotEqual(flex_mesh.id, wp.uint64(0), "flex_mesh id") diff --git a/mujoco_warp/_src/collision_flex.py b/mujoco_warp/_src/collision_flex.py index 3523730b0..13d8b0e0d 100644 --- a/mujoco_warp/_src/collision_flex.py +++ b/mujoco_warp/_src/collision_flex.py @@ -398,6 +398,7 @@ def _flex_narrowphase_dim2( flex_vertadr: wp.array(dtype=int), flex_elemadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elem: wp.array(dtype=int), flex_radius: wp.array(dtype=float), # Data in: @@ -443,7 +444,7 @@ def _flex_narrowphase_dim2( tri_radius = flex_radius[flexid] tri_margin = flex_margin[flexid] - elem_data_idx = elemid * 3 + elem_data_idx = flex_elemdataadr[flexid] + (elemid - flex_elemadr[flexid]) * 3 v0_local = flex_elem[elem_data_idx] v1_local = flex_elem[elem_data_idx + 1] v2_local = flex_elem[elem_data_idx + 2] @@ -709,6 +710,7 @@ def flex_narrowphase(m: Model, d: Data): m.flex_vertadr, m.flex_elemadr, m.flex_elemnum, + m.flex_elemdataadr, m.flex_elem, m.flex_radius, d.geom_xpos, diff --git a/mujoco_warp/_src/collision_smooth.py b/mujoco_warp/_src/collision_smooth.py new file mode 100644 index 000000000..abf370525 --- /dev/null +++ b/mujoco_warp/_src/collision_smooth.py @@ -0,0 +1,789 @@ +# Copyright 2025 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Smooth (differentiable) collision recomputation for autodifferentiation. + +This module provides differentiable replacements for the collision pipeline's +contact geometry and constraint assembly. It runs *after* the discrete pipeline +and overwrites contact.{dist, pos, frame} and efc.{J, pos, D, aref, vel} with +smooth values that Warp can differentiate through. + +Supported geom type pairs: + - plane-sphere, sphere-sphere, sphere-capsule + - capsule-capsule (2 contacts), plane-capsule (2 contacts) + +Unsupported types (box, mesh, convex, etc.) are no-ops that keep discrete +values (zero gradient through those contacts). +""" + +from typing import Tuple + +import warp as wp + +from mujoco_warp._src import support +from mujoco_warp._src import types +from mujoco_warp._src.types import MJ_MINVAL +from mujoco_warp._src.types import DisableBit + +wp.set_module_options({"enable_backward": True}) + + +# ============================================================================ +# Custom types (matching collision_primitive_core.py) +# ============================================================================ + + +class mat23f(wp.types.matrix(shape=(2, 3), dtype=wp.float32)): + pass + + +# ============================================================================ +# Smooth distance functions +# ============================================================================ + + +@wp.func +def smooth_plane_sphere( + plane_normal: wp.vec3, + plane_pos: wp.vec3, + sphere_pos: wp.vec3, + sphere_radius: float, +) -> Tuple[float, wp.vec3]: + """Plane-sphere distance (already smooth).""" + dist = wp.dot(sphere_pos - plane_pos, plane_normal) - sphere_radius + pos = sphere_pos - plane_normal * (sphere_radius + 0.5 * dist) + return dist, pos + + +@wp.func +def smooth_sphere_sphere( + pos1: wp.vec3, + radius1: float, + pos2: wp.vec3, + radius2: float, +) -> Tuple[float, wp.vec3, wp.vec3]: + """Sphere-sphere distance with smooth normalization at coincident centers.""" + dir = pos2 - pos1 + raw_dist = wp.length(dir) + # Smooth normalization: replaces if dist==0 branch + n = dir / wp.max(raw_dist, 1e-8) + dist = raw_dist - (radius1 + radius2) + pos = pos1 + n * (radius1 + 0.5 * dist) + return dist, pos, n + + +@wp.func +def smooth_sphere_capsule( + sphere_pos: wp.vec3, + sphere_radius: float, + capsule_pos: wp.vec3, + capsule_axis: wp.vec3, + capsule_radius: float, + capsule_half_length: float, +) -> Tuple[float, wp.vec3, wp.vec3]: + """Sphere-capsule distance using wp.clamp (subdifferentiable at boundary).""" + segment = capsule_axis * capsule_half_length + seg_start = capsule_pos - segment + seg_end = capsule_pos + segment + + # Closest point on capsule centerline to sphere center + ab = seg_end - seg_start + t = wp.dot(sphere_pos - seg_start, ab) / (wp.dot(ab, ab) + 1e-6) + pt = seg_start + wp.clamp(t, 0.0, 1.0) * ab + + return smooth_sphere_sphere(sphere_pos, sphere_radius, pt, capsule_radius) + + +@wp.func +def smooth_capsule_capsule( + cap1_pos: wp.vec3, + cap1_axis: wp.vec3, + cap1_radius: float, + cap1_half_length: float, + cap2_pos: wp.vec3, + cap2_axis: wp.vec3, + cap2_radius: float, + cap2_half_length: float, + margin: float, +) -> Tuple[wp.vec2, mat23f, mat23f]: + """Capsule-capsule distance returning 2 contacts, regularized for parallel axes.""" + contact_dist = wp.vec2(wp.inf, wp.inf) + contact_pos = mat23f() + contact_normal = mat23f() + + axis1 = cap1_axis * cap1_half_length + axis2 = cap2_axis * cap2_half_length + dif = cap1_pos - cap2_pos + + ma = wp.dot(axis1, axis1) + mb = -wp.dot(axis1, axis2) + mc = wp.dot(axis2, axis2) + u = -wp.dot(axis1, dif) + v = wp.dot(axis2, dif) + det = ma * mc - mb * mb + + # Regularized determinant: smooth handling of near-parallel axes + det_abs = wp.abs(det) + det_sign = wp.where(det >= 0.0, 1.0, -1.0) + det_reg = det_sign * wp.max(det_abs, 1e-10) + + # Blend: use non-parallel path when |det| > threshold, parallel otherwise + # Smooth blending factor + blend_threshold = 1e-8 + alpha = wp.min(det_abs / wp.max(blend_threshold, 1e-15), 1.0) + + # -- Non-parallel path -- + inv_det = 1.0 / det_reg + x1_np = (mc * u - mb * v) * inv_det + x2_np = (ma * v - mb * u) * inv_det + + # Clamp with recomputation (smooth via wp.clamp) + x1_np = wp.clamp(x1_np, -1.0, 1.0) + x2_np = wp.clamp(x2_np, -1.0, 1.0) + + # Re-clamp for consistency + x2_np = wp.clamp((v + mb * x1_np) / wp.max(mc, 1e-10), -1.0, 1.0) + x1_np = wp.clamp((u - mb * x2_np) / wp.max(ma, 1e-10), -1.0, 1.0) + + vec1_np = cap1_pos + axis1 * x1_np + vec2_np = cap2_pos + axis2 * x2_np + dist_np, pos_np, normal_np = smooth_sphere_sphere(vec1_np, cap1_radius, vec2_np, cap2_radius) + + # -- Parallel path: test 4 endpoint pairs, keep first 2 -- + # Endpoint 1: x1 = 1 + vec1_a = cap1_pos + axis1 + x2_a = wp.clamp((v - mb) / wp.max(mc, 1e-10), -1.0, 1.0) + vec2_a = cap2_pos + axis2 * x2_a + dist_a, pos_a, normal_a = smooth_sphere_sphere(vec1_a, cap1_radius, vec2_a, cap2_radius) + + # Endpoint 2: x1 = -1 + vec1_b = cap1_pos - axis1 + x2_b = wp.clamp((v + mb) / wp.max(mc, 1e-10), -1.0, 1.0) + vec2_b = cap2_pos + axis2 * x2_b + dist_b, pos_b, normal_b = smooth_sphere_sphere(vec1_b, cap1_radius, vec2_b, cap2_radius) + + # Endpoint 3: x2 = 1 + vec2_c = cap2_pos + axis2 + x1_c = wp.clamp((u - mb) / wp.max(ma, 1e-10), -1.0, 1.0) + vec1_c = cap1_pos + axis1 * x1_c + dist_c, pos_c, normal_c = smooth_sphere_sphere(vec1_c, cap1_radius, vec2_c, cap2_radius) + + # Endpoint 4: x2 = -1 + vec2_d = cap2_pos - axis2 + x1_d = wp.clamp((u + mb) / wp.max(ma, 1e-10), -1.0, 1.0) + vec1_d = cap1_pos + axis1 * x1_d + dist_d, pos_d, normal_d = smooth_sphere_sphere(vec1_d, cap1_radius, vec2_d, cap2_radius) + + # Sort 4 endpoints by distance, pick best 2 for parallel contacts + # Contact 0: best of all 4 + par_dist0 = dist_a + par_pos0 = pos_a + par_normal0 = normal_a + + if dist_b < par_dist0: + par_dist0 = dist_b + par_pos0 = pos_b + par_normal0 = normal_b + if dist_c < par_dist0: + par_dist0 = dist_c + par_pos0 = pos_c + par_normal0 = normal_c + if dist_d < par_dist0: + par_dist0 = dist_d + par_pos0 = pos_d + par_normal0 = normal_d + + # Contact 1: second best + par_dist1 = wp.inf + par_pos1 = wp.vec3(0.0) + par_normal1 = wp.vec3(1.0, 0.0, 0.0) + + if dist_a <= margin and dist_a != par_dist0: + par_dist1 = dist_a + par_pos1 = pos_a + par_normal1 = normal_a + if dist_b <= margin and dist_b != par_dist0: + if dist_b < par_dist1: + par_dist1 = dist_b + par_pos1 = pos_b + par_normal1 = normal_b + if dist_c <= margin and dist_c != par_dist0: + if dist_c < par_dist1: + par_dist1 = dist_c + par_pos1 = pos_c + par_normal1 = normal_c + if dist_d <= margin and dist_d != par_dist0: + if dist_d < par_dist1: + par_dist1 = dist_d + par_pos1 = pos_d + par_normal1 = normal_d + + # Blend between non-parallel (1 contact) and parallel (2 contacts) + # Non-parallel: contact 0 = np result, contact 1 = inf + # Parallel: contact 0, 1 from sorted endpoints + blend_dist0 = alpha * dist_np + (1.0 - alpha) * par_dist0 + blend_pos0 = alpha * pos_np + (1.0 - alpha) * par_pos0 + blend_normal0 = alpha * normal_np + (1.0 - alpha) * par_normal0 + # Renormalize blended normal + blend_normal0 = blend_normal0 / wp.max(wp.length(blend_normal0), 1e-8) + + # Contact 1: only from parallel path (non-parallel has 1 contact) + blend_dist1 = (1.0 - alpha) * par_dist1 + alpha * wp.inf + + if blend_dist0 <= margin: + contact_dist[0] = blend_dist0 + contact_pos[0] = blend_pos0 + contact_normal[0] = blend_normal0 + + if blend_dist1 <= margin: + contact_dist[1] = blend_dist1 + contact_pos[1] = par_pos1 + contact_normal[1] = par_normal1 + + return contact_dist, contact_pos, contact_normal + + +@wp.func +def smooth_plane_capsule( + plane_normal: wp.vec3, + plane_pos: wp.vec3, + capsule_pos: wp.vec3, + capsule_axis: wp.vec3, + capsule_radius: float, + capsule_half_length: float, +) -> Tuple[wp.vec2, mat23f, wp.mat33]: + """Plane-capsule distance returning 2 contacts (already smooth).""" + n = plane_normal + axis = capsule_axis + segment = axis * capsule_half_length + + # Build contact frame (smooth version matching collision_primitive_core.py) + proj = axis - n * wp.dot(n, axis) + proj_len = wp.length(proj) + b = proj / wp.max(proj_len, 1e-8) + + # Fallback when capsule axis is nearly parallel to plane normal + if proj_len < 0.5: + if -0.5 < n[1] and n[1] < 0.5: + b = wp.vec3(0.0, 1.0, 0.0) + else: + b = wp.vec3(0.0, 0.0, 1.0) + + c = wp.cross(n, b) + frame = wp.mat33(n[0], n[1], n[2], b[0], b[1], b[2], c[0], c[1], c[2]) + + # Two contacts at capsule endpoints + dist1, pos1 = smooth_plane_sphere(n, plane_pos, capsule_pos + segment, capsule_radius) + dist2, pos2 = smooth_plane_sphere(n, plane_pos, capsule_pos - segment, capsule_radius) + + dist = wp.vec2(dist1, dist2) + pos = mat23f(pos1[0], pos1[1], pos1[2], pos2[0], pos2[1], pos2[2]) + + return dist, pos, frame + + +@wp.func +def smooth_make_frame(normal: wp.vec3) -> wp.mat33: + """Construct contact frame from normal with smooth tangent directions.""" + a = normal / wp.max(wp.length(normal), 1e-8) + + # Gram-Schmidt orthogonalization (same as math.orthogonals but using + # wp.where instead of branching on a[1] for smoother gradients) + y = wp.vec3(0.0, 1.0, 0.0) + z = wp.vec3(0.0, 0.0, 1.0) + b = wp.where((-0.5 < a[1]) and (a[1] < 0.5), y, z) + b = b - a * wp.dot(a, b) + b_len = wp.length(b) + b = b / wp.max(b_len, 1e-8) + c = wp.cross(a, b) + + return wp.mat33( + a[0], + a[1], + a[2], + b[0], + b[1], + b[2], + c[0], + c[1], + c[2], + ) + + +# ============================================================================ +# Smooth contact recomputation kernel +# ============================================================================ + + +@wp.kernel +def _smooth_recompute_kernel( + # Model (constants): + geom_type: wp.array(dtype=int), + geom_size: wp.array2d(dtype=wp.vec3), + geom_bodyid: wp.array(dtype=int), + # Data in (differentiable): + geom_xpos_in: wp.array2d(dtype=wp.vec3), + geom_xmat_in: wp.array2d(dtype=wp.mat33), + # Data in (integer, no grad): + nacon_in: wp.array(dtype=int), + contact_geom_in: wp.array(dtype=wp.vec2i), + contact_geomcollisionid_in: wp.array(dtype=int), + contact_worldid_in: wp.array(dtype=int), + # Data out (differentiable): + contact_dist_out: wp.array(dtype=float), + contact_pos_out: wp.array(dtype=wp.vec3), + contact_frame_out: wp.array(dtype=wp.mat33), +): + cid = wp.tid() + + if cid >= nacon_in[0]: + return + + geoms = contact_geom_in[cid] + g1 = geoms[0] + g2 = geoms[1] + + # Skip flex contacts (geom id = -1) + if g1 < 0 or g2 < 0: + return + + worldid = contact_worldid_in[cid] + subcid = contact_geomcollisionid_in[cid] + t1 = geom_type[g1] + t2 = geom_type[g2] + + # Geom poses (differentiable from Phase 1 kinematics) + pos1 = geom_xpos_in[worldid, g1] + pos2 = geom_xpos_in[worldid, g2] + mat1 = geom_xmat_in[worldid, g1] + mat2 = geom_xmat_in[worldid, g2] + + # Geom sizes (model constants — use worldid=0 for batched models) + size_id = worldid % geom_size.shape[0] + size1 = geom_size[size_id, g1] + size2 = geom_size[size_id, g2] + + # Dispatch based on geom type pair + # Geom types: PLANE=0, HFIELD=1, SPHERE=2, CAPSULE=3, ELLIPSOID=4, + # CYLINDER=5, BOX=6, MESH=7, SDF=8 + + handled = False + + # plane-sphere + if t1 == 0 and t2 == 2: + plane_normal = wp.vec3(mat1[0, 2], mat1[1, 2], mat1[2, 2]) + dist, pos = smooth_plane_sphere(plane_normal, pos1, pos2, size2[0]) + frame = smooth_make_frame(plane_normal) + contact_dist_out[cid] = dist + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + handled = True + + # sphere-sphere + if not handled and t1 == 2 and t2 == 2: + dist, pos, normal = smooth_sphere_sphere(pos1, size1[0], pos2, size2[0]) + frame = smooth_make_frame(normal) + contact_dist_out[cid] = dist + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + handled = True + + # sphere-capsule + if not handled and t1 == 2 and t2 == 3: + cap_axis = wp.vec3(mat2[0, 2], mat2[1, 2], mat2[2, 2]) + dist, pos, normal = smooth_sphere_capsule(pos1, size1[0], pos2, cap_axis, size2[0], size2[1]) + frame = smooth_make_frame(normal) + contact_dist_out[cid] = dist + contact_pos_out[cid] = pos + contact_frame_out[cid] = frame + handled = True + + # capsule-capsule (2 contacts via geomcollisionid) + if not handled and t1 == 3 and t2 == 3: + cap1_axis = wp.vec3(mat1[0, 2], mat1[1, 2], mat1[2, 2]) + cap2_axis = wp.vec3(mat2[0, 2], mat2[1, 2], mat2[2, 2]) + dists, positions, normals = smooth_capsule_capsule( + pos1, + cap1_axis, + size1[0], + size1[1], + pos2, + cap2_axis, + size2[0], + size2[1], + 1e10, # large margin so we always compute both contacts + ) + if subcid == 0: + contact_dist_out[cid] = dists[0] + contact_pos_out[cid] = wp.vec3(positions[0, 0], positions[0, 1], positions[0, 2]) + normal0 = wp.vec3(normals[0, 0], normals[0, 1], normals[0, 2]) + contact_frame_out[cid] = smooth_make_frame(normal0) + else: + contact_dist_out[cid] = dists[1] + contact_pos_out[cid] = wp.vec3(positions[1, 0], positions[1, 1], positions[1, 2]) + normal1 = wp.vec3(normals[1, 0], normals[1, 1], normals[1, 2]) + contact_frame_out[cid] = smooth_make_frame(normal1) + handled = True + + # plane-capsule (2 contacts via geomcollisionid) + if not handled and t1 == 0 and t2 == 3: + plane_normal = wp.vec3(mat1[0, 2], mat1[1, 2], mat1[2, 2]) + cap_axis = wp.vec3(mat2[0, 2], mat2[1, 2], mat2[2, 2]) + dists, positions, frame = smooth_plane_capsule(plane_normal, pos1, pos2, cap_axis, size2[0], size2[1]) + if subcid == 0: + contact_dist_out[cid] = dists[0] + contact_pos_out[cid] = wp.vec3(positions[0, 0], positions[0, 1], positions[0, 2]) + else: + contact_dist_out[cid] = dists[1] + contact_pos_out[cid] = wp.vec3(positions[1, 0], positions[1, 1], positions[1, 2]) + contact_frame_out[cid] = frame + handled = True + + # Unsupported types: no-op (keeps discrete values, zero gradient) + + +# ============================================================================ +# Differentiable constraint assembly kernel +# ============================================================================ + + +@wp.func +def _smooth_efc_row( + opt_disableflags: int, + worldid: int, + timestep: float, + efcid: int, + pos_aref: float, + pos_imp: float, + invweight: float, + solref: wp.vec2, + solimp: types.vec5, + margin: float, + vel: float, + # Out: + pos_out: wp.array2d(dtype=float), + D_out: wp.array2d(dtype=float), + aref_out: wp.array2d(dtype=float), + vel_out: wp.array2d(dtype=float), +): + """Smooth reimplementation of _efc_row for differentiable constraint params.""" + timeconst = solref[0] + dampratio = solref[1] + dmin = solimp[0] + dmax = solimp[1] + width = solimp[2] + mid = solimp[3] + power = solimp[4] + + if not (opt_disableflags & DisableBit.REFSAFE): + timeconst = wp.max(timeconst, 2.0 * timestep) + + dmin = wp.clamp(dmin, types.MJ_MINIMP, types.MJ_MAXIMP) + dmax = wp.clamp(dmax, types.MJ_MINIMP, types.MJ_MAXIMP) + width = wp.max(MJ_MINVAL, width) + mid = wp.clamp(mid, types.MJ_MINIMP, types.MJ_MAXIMP) + power = wp.max(1.0, power) + + dmax_sq = dmax * dmax + k = 1.0 / (dmax_sq * timeconst * timeconst * dampratio * dampratio) + b = 2.0 / (dmax * timeconst) + k = wp.where(solref[0] <= 0.0, -solref[0] / dmax_sq, k) + b = wp.where(solref[1] <= 0.0, -solref[1] / dmax, b) + + imp_x = wp.abs(pos_imp) / width + imp_a = (1.0 / wp.pow(mid, power - 1.0)) * wp.pow(imp_x, power) + imp_b = 1.0 - (1.0 / wp.pow(1.0 - mid, power - 1.0)) * wp.pow(1.0 - imp_x, power) + imp_y = wp.where(imp_x < mid, imp_a, imp_b) + imp = dmin + imp_y * (dmax - dmin) + imp = wp.clamp(imp, dmin, dmax) + imp = wp.where(imp_x > 1.0, dmax, imp) + + D_out[worldid, efcid] = 1.0 / wp.max(invweight * (1.0 - imp) / imp, MJ_MINVAL) + vel_out[worldid, efcid] = vel + aref_out[worldid, efcid] = -k * imp * pos_aref - b * vel + pos_out[worldid, efcid] = pos_aref + margin + + +@wp.kernel +def _smooth_contact_to_efc_kernel( + # Model constants: + nv: int, + opt_timestep: wp.array(dtype=float), + opt_disableflags: int, + opt_impratio_invsqrt: wp.array(dtype=float), + body_parentid: wp.array(dtype=int), + body_rootid: wp.array(dtype=int), + body_weldid: wp.array(dtype=int), + body_dofnum: wp.array(dtype=int), + body_dofadr: wp.array(dtype=int), + body_invweight0: wp.array2d(dtype=wp.vec2), + dof_bodyid: wp.array(dtype=int), + dof_parentid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + # Data in (differentiable): + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + qvel_in: wp.array2d(dtype=float), + contact_dist_in: wp.array(dtype=float), + contact_pos_in: wp.array(dtype=wp.vec3), + contact_frame_in: wp.array(dtype=wp.mat33), + contact_friction_in: wp.array(dtype=types.vec5), + contact_includemargin_in: wp.array(dtype=float), + contact_solref_in: wp.array(dtype=wp.vec2), + contact_solimp_in: wp.array(dtype=types.vec5), + # Data in (integer, no grad): + nacon_in: wp.array(dtype=int), + contact_efc_address_in: wp.array2d(dtype=int), + contact_dim_in: wp.array(dtype=int), + contact_worldid_in: wp.array(dtype=int), + contact_geom_in: wp.array(dtype=wp.vec2i), + contact_type_in: wp.array(dtype=int), + njmax_in: int, + # Data out (differentiable): + efc_J_out: wp.array3d(dtype=float), + efc_pos_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), +): + conid, dimid = wp.tid() + + if conid >= nacon_in[0]: + return + + # Only process constraint contacts + if not (contact_type_in[conid] & 1): # ContactType.CONSTRAINT = 1 + return + + condim = contact_dim_in[conid] + if condim == 1 and dimid > 0: + return + elif condim > 1 and dimid >= 2 * (condim - 1): + return + + # Read efc_address — skip if -1 (not active) + efcid = contact_efc_address_in[conid, dimid] + if efcid < 0: + return + if efcid >= njmax_in: + return + + worldid = contact_worldid_in[conid] + timestep = opt_timestep[worldid % opt_timestep.shape[0]] + impratio_invsqrt = opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] + + geom = contact_geom_in[conid] + body1 = geom_bodyid[geom[0]] + body2 = geom_bodyid[geom[1]] + + con_pos = contact_pos_in[conid] + frame = contact_frame_in[conid] + includemargin = contact_includemargin_in[conid] + pos = contact_dist_in[conid] - includemargin + + # Pyramidal invweight computation + body_invweight0_id = worldid % body_invweight0.shape[0] + invweight = body_invweight0[body_invweight0_id, body1][0] + body_invweight0[body_invweight0_id, body2][0] + + fri0 = float(0.0) + frii = float(0.0) + dimid2 = int(0) + if condim > 1: + dimid2 = dimid / 2 + 1 + friction = contact_friction_in[conid] + fri0 = friction[0] + frii = friction[dimid2 - 1] + invweight = invweight + fri0 * fri0 * invweight + invweight = invweight * 2.0 * fri0 * fri0 * impratio_invsqrt * impratio_invsqrt + + Jqvel = float(0.0) + + # Skip fixed bodies + body1 = body_weldid[body1] + body2 = body_weldid[body2] + + da1 = int(body_dofadr[body1] + body_dofnum[body1] - 1) + da2 = int(body_dofadr[body2] + body_dofnum[body2] - 1) + + # Dense Jacobian computation (AD requires dense) + da = wp.max(da1, da2) + dofid = int(nv - 1) + + while True: + if dofid < 0: + break + + if dofid == da: + jac1p, jac1r = support.jac_dof( + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + con_pos, + body1, + dofid, + worldid, + ) + jac2p, jac2r = support.jac_dof( + body_parentid, + body_rootid, + dof_bodyid, + subtree_com_in, + cdof_in, + con_pos, + body2, + dofid, + worldid, + ) + + J = float(0.0) + Ji = float(0.0) + + for xyz in range(3): + jacp_dif = jac2p[xyz] - jac1p[xyz] + J += frame[0, xyz] * jacp_dif + + if condim > 1: + if dimid2 < 3: + Ji += frame[dimid2, xyz] * jacp_dif + else: + Ji += frame[dimid2 - 3, xyz] * (jac2r[xyz] - jac1r[xyz]) + + if condim > 1: + if dimid % 2 == 0: + J += Ji * frii + else: + J -= Ji * frii + + efc_J_out[worldid, efcid, dofid] = J + Jqvel += J * qvel_in[worldid, dofid] + + # Advance tree pointers + if da1 == da: + da1 = dof_parentid[da1] + if da2 == da: + da2 = dof_parentid[da2] + da = wp.max(da1, da2) + dofid -= 1 + else: + efc_J_out[worldid, efcid, dofid] = 0.0 + dofid -= 1 + + # Compute constraint equation row + _smooth_efc_row( + opt_disableflags, + worldid, + timestep, + efcid, + pos, + pos, + invweight, + contact_solref_in[conid], + contact_solimp_in[conid], + includemargin, + Jqvel, + efc_pos_out, + efc_D_out, + efc_aref_out, + efc_vel_out, + ) + + +# ============================================================================ +# Python launchers +# ============================================================================ + + +def smooth_recompute_contacts(m: types.Model, d: types.Data): + """Overwrite contact.{dist, pos, frame} with smooth differentiable values.""" + if d.naconmax == 0: + return + + wp.launch( + _smooth_recompute_kernel, + dim=d.naconmax, + inputs=[ + # Model constants + m.geom_type, + m.geom_size, + m.geom_bodyid, + # Data in (differentiable) + d.geom_xpos, + d.geom_xmat, + # Data in (integer) + d.nacon, + d.contact.geom, + d.contact.geomcollisionid, + d.contact.worldid, + ], + outputs=[ + d.contact.dist, + d.contact.pos, + d.contact.frame, + ], + ) + + +def smooth_contact_to_efc(m: types.Model, d: types.Data): + """Overwrite efc.{J, pos, D, aref, vel} with smooth differentiable values.""" + if d.naconmax == 0 or d.njmax == 0: + return + + wp.launch( + _smooth_contact_to_efc_kernel, + dim=(d.naconmax, m.nmaxpyramid), + inputs=[ + # Model constants + m.nv, + m.opt.timestep, + m.opt.disableflags, + m.opt.impratio_invsqrt, + m.body_parentid, + m.body_rootid, + m.body_weldid, + m.body_dofnum, + m.body_dofadr, + m.body_invweight0, + m.dof_bodyid, + m.dof_parentid, + m.geom_bodyid, + # Data in (differentiable) + d.subtree_com, + d.cdof, + d.qvel, + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.friction, + d.contact.includemargin, + d.contact.solref, + d.contact.solimp, + # Data in (integer) + d.nacon, + d.contact.efc_address, + d.contact.dim, + d.contact.worldid, + d.contact.geom, + d.contact.type, + d.njmax, + ], + outputs=[ + d.efc.J, + d.efc.pos, + d.efc.D, + d.efc.aref, + d.efc.vel, + ], + ) diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py index 575a769d2..5deb1efd2 100644 --- a/mujoco_warp/_src/derivative.py +++ b/mujoco_warp/_src/derivative.py @@ -15,6 +15,7 @@ import warp as wp +from mujoco_warp._src.support import next_act from mujoco_warp._src.types import BiasType from mujoco_warp._src.types import Data from mujoco_warp._src.types import DisableBit @@ -30,18 +31,24 @@ @wp.kernel def _qderiv_actuator_passive_vel( # Model: + opt_timestep: wp.array(dtype=float), actuator_dyntype: wp.array(dtype=int), actuator_gaintype: wp.array(dtype=int), actuator_biastype: wp.array(dtype=int), actuator_actadr: wp.array(dtype=int), actuator_actnum: wp.array(dtype=int), actuator_forcelimited: wp.array(dtype=bool), + actuator_actlimited: wp.array(dtype=bool), + actuator_dynprm: wp.array2d(dtype=vec10f), actuator_gainprm: wp.array2d(dtype=vec10f), actuator_biasprm: wp.array2d(dtype=vec10f), + actuator_actearly: wp.array(dtype=bool), actuator_forcerange: wp.array2d(dtype=wp.vec2), + actuator_actrange: wp.array2d(dtype=wp.vec2), # Data in: act_in: wp.array2d(dtype=float), ctrl_in: wp.array2d(dtype=float), + act_dot_in: wp.array2d(dtype=float), actuator_force_in: wp.array2d(dtype=float), # Out: vel_out: wp.array2d(dtype=float), @@ -76,9 +83,24 @@ def _qderiv_actuator_passive_vel( vel = float(bias) if actuator_dyntype[actid] != DynType.NONE: if gain != 0.0: - act_first = actuator_actadr[actid] - act_last = act_first + actuator_actnum[actid] - 1 - vel += gain * act_in[worldid, act_last] + act_adr = actuator_actadr[actid] + actuator_actnum[actid] - 1 + + # use next activation if actearly is set (matching forward pass) + if actuator_actearly[actid]: + act = next_act( + opt_timestep[worldid % opt_timestep.shape[0]], + actuator_dyntype[actid], + actuator_dynprm[worldid % actuator_dynprm.shape[0], actid], + actuator_actrange[worldid % actuator_actrange.shape[0], actid], + act_in[worldid, act_adr], + act_dot_in[worldid, act_adr], + 1.0, + actuator_actlimited[actid], + ) + else: + act = act_in[worldid, act_adr] + + vel += gain * act else: if gain != 0.0: vel += gain * ctrl_in[worldid, actid] @@ -95,10 +117,9 @@ def _nonzero_mask(x: float) -> float: @wp.kernel -def _qderiv_actuator_passive_actuation_sparse( +def _qderiv_actuator_passive_actuation_dense( # Model: nu: int, - is_sparse: bool, # Data in: moment_rownnz_in: wp.array2d(dtype=int), moment_rowadr_in: wp.array2d(dtype=int), @@ -142,12 +163,63 @@ def _qderiv_actuator_passive_actuation_sparse( qderiv_contrib += moment_i * moment_j * vel - if is_sparse: - qDeriv_out[worldid, 0, elemid] = qderiv_contrib - else: - qDeriv_out[worldid, dofiid, dofjid] = qderiv_contrib - if dofiid != dofjid: - qDeriv_out[worldid, dofjid, dofiid] = qderiv_contrib + qDeriv_out[worldid, dofiid, dofjid] = qderiv_contrib + if dofiid != dofjid: + qDeriv_out[worldid, dofjid, dofiid] = qderiv_contrib + + +@wp.kernel +def _qderiv_actuator_passive_actuation_sparse( + # Model: + M_rownnz: wp.array(dtype=int), + M_rowadr: wp.array(dtype=int), + # Data in: + moment_rownnz_in: wp.array2d(dtype=int), + moment_rowadr_in: wp.array2d(dtype=int), + moment_colind_in: wp.array2d(dtype=int), + actuator_moment_in: wp.array2d(dtype=float), + # In: + vel_in: wp.array2d(dtype=float), + qMj: wp.array(dtype=int), + # Out: + qDeriv_out: wp.array3d(dtype=float), +): + worldid, actid = wp.tid() + + vel = vel_in[worldid, actid] + if vel == 0.0: + return + + rownnz = moment_rownnz_in[worldid, actid] + rowadr = moment_rowadr_in[worldid, actid] + + for i in range(rownnz): + rowadri = rowadr + i + moment_i = actuator_moment_in[worldid, rowadri] + if moment_i == 0.0: + continue + dofi = moment_colind_in[worldid, rowadri] + + for j in range(i + 1): + rowadrj = rowadr + j + moment_j = actuator_moment_in[worldid, rowadrj] + if moment_j == 0.0: + continue + dofj = moment_colind_in[worldid, rowadrj] + + contrib = moment_i * moment_j * vel + + # Search the corresponding elemid + # TODO: This could be precalculated for improved performance + row = dofi + col = dofj + row_startk = M_rowadr[row] - 1 + row_nnz = M_rownnz[row] + for k in range(row_nnz): + row_startk += 1 + if qMj[row_startk] == col: + wp.atomic_add(qDeriv_out[worldid, 0], row_startk, contrib) + break @wp.kernel @@ -268,27 +340,41 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)): _qderiv_actuator_passive_vel, dim=(d.nworld, m.nu), inputs=[ + m.opt.timestep, m.actuator_dyntype, m.actuator_gaintype, m.actuator_biastype, m.actuator_actadr, m.actuator_actnum, m.actuator_forcelimited, + m.actuator_actlimited, + m.actuator_dynprm, m.actuator_gainprm, m.actuator_biasprm, + m.actuator_actearly, m.actuator_forcerange, + m.actuator_actrange, d.act, d.ctrl, + d.act_dot, d.actuator_force, ], outputs=[vel], ) - wp.launch( - _qderiv_actuator_passive_actuation_sparse, - dim=(d.nworld, qMi.size), - inputs=[m.nu, m.is_sparse, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMi, qMj], - outputs=[out], - ) + if m.is_sparse: + wp.launch( + _qderiv_actuator_passive_actuation_sparse, + dim=(d.nworld, m.nu), + inputs=[m.M_rownnz, m.M_rowadr, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMj], + outputs=[out], + ) + else: + wp.launch( + _qderiv_actuator_passive_actuation_dense, + dim=(d.nworld, qMi.size), + inputs=[m.nu, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment, vel, qMi, qMj], + outputs=[out], + ) wp.launch( _qderiv_actuator_passive, dim=(d.nworld, qMi.size), diff --git a/mujoco_warp/_src/derivative_test.py b/mujoco_warp/_src/derivative_test.py index da4d4b3bc..cc4746d29 100644 --- a/mujoco_warp/_src/derivative_test.py +++ b/mujoco_warp/_src/derivative_test.py @@ -209,6 +209,261 @@ def test_step_tendon_serial_chain_no_nan(self): self.assertFalse(np.any(np.isnan(mjd.qpos))) self.assertFalse(np.any(np.isnan(mjd.qvel))) + def test_smooth_vel_sparse_tendon_coupled(self): + """Tests qDeriv kernel with nv > 32 and moment_rownnz > 1. + + Builds a chain of 35 DOFs (forcing sparse path) with a fixed tendon + coupling two joints, producing an actuator with moment_rownnz=2. + """ + # Build a chain long enough to force sparse (nv > 32) + xml = f""" + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + + mjm, mjd, m, d = test_data.fixture( + xml=xml, + keyframe=0, + overrides={"opt.jacobian": mujoco.mjtJacobian.mjJAC_SPARSE}, + ) + + self.assertTrue(m.is_sparse, "Model should use sparse path (nv > 32)") + + mujoco.mj_step(mjm, mjd) + + out_smooth_vel = wp.zeros((1, 1, m.nM), dtype=float) + mjw.deriv_smooth_vel(m, d, out_smooth_vel) + + mjw_out = np.zeros((m.nv, m.nv)) + for elem, (i, j) in enumerate(zip(m.qM_fullm_i.numpy(), m.qM_fullm_j.numpy())): + mjw_out[i, j] = out_smooth_vel.numpy()[0, 0, elem] + mjw_out[j, i] = out_smooth_vel.numpy()[0, 0, elem] + + mj_qDeriv = np.zeros((mjm.nv, mjm.nv)) + mujoco.mju_sparse2dense(mj_qDeriv, mjd.qDeriv, mjm.D_rownnz, mjm.D_rowadr, mjm.D_colind) + + mj_qM = np.zeros((m.nv, m.nv)) + mujoco.mj_fullM(mjm, mj_qM, mjd.qM) + mj_out = mj_qM - mjm.opt.timestep * mj_qDeriv + + self.assertFalse(np.any(np.isnan(mjw_out))) + _assert_eq(mjw_out, mj_out, "qM - dt * qDeriv (sparse tendon coupled)") + + def test_actearly_derivative(self): + """Implicit derivatives should use next activation when actearly is set.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + """, + keyframe=0, + ) + + # both should have same act_dot (ctrl = 1 for integrator dynamics) + _assert_eq(d.act_dot.numpy()[0, 0], d.act_dot.numpy()[0, 1], "act_dot") + + # compute qDeriv using deriv_smooth_vel + out_smooth_vel = wp.zeros(d.qM.shape, dtype=float) + mjw.deriv_smooth_vel(m, d, out_smooth_vel) + mjw_out = out_smooth_vel.numpy()[0, : m.nv, : m.nv] + + # with actearly=true and nonzero act_dot, derivative should differ + # because actearly uses next activation: act + act_dot*dt + # for our model: next_act = 0 + 1*1 = 1, current_act = 0 + # derivative adds gain_vel * act to qDeriv diagonal + # qDeriv = qM - dt * actuator_vel_derivative + # for independent bodies with mass=1: qM diagonal = 1.0 + # actearly=true: vel = gain_vel * next_act = 1 * 1 = 1, out = 1 - 1*1 = 0 + # actearly=false: vel = gain_vel * current_act = 1 * 0 = 0, out = 1 - 1*0 = 1 + self.assertNotAlmostEqual( + mjw_out[0, 0], + mjw_out[1, 1], + msg="actearly=true should use next activation in derivative", + ) + _assert_eq(mjw_out[0, 0], 0.0, "actearly=true: qM - dt*gain_vel*next_act = 1 - 1*1 = 0") + _assert_eq(mjw_out[1, 1], 1.0, "actearly=false: qM - dt*gain_vel*current_act = 1 - 1*0 = 1") + def test_forcerange_clamped_derivative(self): """Implicit integration is more accurate than Euler with active forcerange clamping.""" xml = """ diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 67114e07d..e9df483c7 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -18,6 +18,7 @@ import warp as wp from mujoco_warp._src import collision_driver +from mujoco_warp._src import collision_smooth from mujoco_warp._src import constraint from mujoco_warp._src import derivative from mujoco_warp._src import island @@ -27,6 +28,7 @@ from mujoco_warp._src import smooth from mujoco_warp._src import solver from mujoco_warp._src import util_misc +from mujoco_warp._src.support import next_act from mujoco_warp._src.support import xfrc_accumulate from mujoco_warp._src.types import MJ_MINVAL from mujoco_warp._src.types import BiasType @@ -128,37 +130,6 @@ def _next_velocity( qvel_out[worldid, dofid] = qvel_in[worldid, dofid] + qacc_scale_in * qacc_in[worldid, dofid] * timestep -# TODO(team): kernel analyzer array slice? -@wp.func -def _next_act( - # Model: - opt_timestep: float, # kernel_analyzer: ignore - actuator_dyntype: int, # kernel_analyzer: ignore - actuator_dynprm: vec10f, # kernel_analyzer: ignore - actuator_actrange: wp.vec2, # kernel_analyzer: ignore - # Data In: - act_in: float, # kernel_analyzer: ignore - act_dot_in: float, # kernel_analyzer: ignore - # In: - act_dot_scale: float, - clamp: bool, -) -> float: - # advance actuation - if actuator_dyntype == DynType.FILTEREXACT: - tau = wp.max(MJ_MINVAL, actuator_dynprm[0]) - act = act_in + act_dot_scale * act_dot_in * tau * (1.0 - wp.exp(-opt_timestep / tau)) - elif actuator_dyntype == DynType.USER: - return act_in - else: - act = act_in + act_dot_scale * act_dot_in * opt_timestep - - # clamp to actrange - if clamp: - act = wp.clamp(act, actuator_actrange[0], actuator_actrange[1]) - - return act - - @wp.kernel def _next_activation( # Model: @@ -185,7 +156,7 @@ def _next_activation( actadr = actuator_actadr[uid] actnum = actuator_actnum[uid] for j in range(actadr, actadr + actnum): - act = _next_act( + act = next_act( opt_timestep[opt_timestep_id], actuator_dyntype[uid], actuator_dynprm[actuator_dynprm_id, uid], @@ -365,11 +336,12 @@ def euler(m: Model, d: Data): """Euler integrator, semi-implicit in velocity.""" # integrate damping implicitly if not (m.opt.disableflags & (DisableBit.EULERDAMP | DisableBit.DAMPER)): - qacc = wp.empty((d.nworld, m.nv), dtype=float) + ad_active = d.qpos.requires_grad + qacc = wp.empty((d.nworld, m.nv), dtype=float, requires_grad=ad_active) if m.is_sparse: qM = wp.clone(d.qM) - qLD = wp.empty((d.nworld, 1, m.nC), dtype=float) - qLDiagInv = wp.empty((d.nworld, m.nv), dtype=float) + qLD = wp.empty((d.nworld, 1, m.nC), dtype=float, requires_grad=ad_active) + qLDiagInv = wp.empty((d.nworld, m.nv), dtype=float, requires_grad=ad_active) wp.launch( _euler_damp_qfrc_sparse, dim=(d.nworld, m.nv), @@ -386,8 +358,10 @@ def euler(m: Model, d: Data): outputs=[qacc], block_dim=m.block_dim.euler_dense, ) + _record_solver_adjoint(m, d, qacc_array=qacc) _advance(m, d, qacc) else: + _record_solver_adjoint(m, d, qacc_array=d.qacc) _advance(m, d, d.qacc) @@ -498,14 +472,15 @@ def rungekutta4(m: Model, d: Data): A = [0.5, 0.5, 1.0] # diagonal only B = [1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0] + ad_active = d.qpos.requires_grad qpos_t0 = wp.clone(d.qpos) qvel_t0 = wp.clone(d.qvel) - qvel_rk = wp.zeros((d.nworld, m.nv), dtype=float) - qacc_rk = wp.zeros((d.nworld, m.nv), dtype=float) + qvel_rk = wp.zeros((d.nworld, m.nv), dtype=float, requires_grad=ad_active) + qacc_rk = wp.zeros((d.nworld, m.nv), dtype=float, requires_grad=ad_active) if m.na: act_t0 = wp.clone(d.act) - act_dot_rk = wp.zeros((d.nworld, m.na), dtype=float) + act_dot_rk = wp.zeros((d.nworld, m.na), dtype=float, requires_grad=ad_active) else: act_t0 = None act_dot_rk = None @@ -525,6 +500,7 @@ def rungekutta4(m: Model, d: Data): wp.copy(d.act, act_t0) wp.copy(d.act_dot, act_dot_rk) + _record_solver_adjoint(m, d, qacc_array=qacc_rk) _advance(m, d, qacc_rk, qvel_rk) @@ -532,6 +508,7 @@ def rungekutta4(m: Model, d: Data): def implicit(m: Model, d: Data): """Integrates fully implicit in velocity.""" if ~(m.opt.disableflags | ~(DisableBit.ACTUATION | DisableBit.SPRING | DisableBit.DAMPER)): + ad_active = d.qpos.requires_grad if m.is_sparse: qDeriv = wp.empty((d.nworld, 1, m.nM), dtype=float) qLD = wp.empty((d.nworld, 1, m.nC), dtype=float) @@ -540,10 +517,12 @@ def implicit(m: Model, d: Data): qLD = wp.empty(d.qM.shape, dtype=float) qLDiagInv = wp.empty((d.nworld, m.nv), dtype=float) derivative.deriv_smooth_vel(m, d, qDeriv) - qacc = wp.empty((d.nworld, m.nv), dtype=float) + qacc = wp.empty((d.nworld, m.nv), dtype=float, requires_grad=ad_active) smooth.factor_solve_i(m, d, qDeriv, qLD, qLDiagInv, qacc, d.efc.Ma) + _record_solver_adjoint(m, d, qacc_array=qacc) _advance(m, d, qacc) else: + _record_solver_adjoint(m, d, qacc_array=d.qacc) _advance(m, d, d.qacc) @@ -567,7 +546,15 @@ def fwd_position(m: Model, d: Data, factorize: bool = True): smooth.factor_m(m, d) if m.opt.run_collision_detection: collision_driver.collision(m, d) + # Phase 3: smooth collision recomputation for AD + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad: + collision_smooth.smooth_recompute_contacts(m, d) constraint.make_constraint(m, d) + # Phase 3: differentiable constraint assembly for AD + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad: + collision_smooth.smooth_contact_to_efc(m, d) # TODO(team): remove False after island features are more complete if False and not (m.opt.disableflags & DisableBit.ISLAND): island.island(m, d) @@ -720,7 +707,7 @@ def _actuator_force( if dyntype == DynType.INTEGRATOR or dyntype == DynType.NONE: act = act_in[worldid, act_last] - ctrl_act = _next_act( + ctrl_act = next_act( opt_timestep[worldid % opt_timestep.shape[0]], dyntype, dynprm, @@ -950,11 +937,7 @@ def fwd_actuation(m: Model, d: Data): ) # clone to break input/output aliasing for correct AD; skip when not # recording a backward tape to avoid unnecessary allocation + copy. - qfrc_actuator_in = ( - wp.clone(d.qfrc_actuator) - if d.qfrc_actuator.requires_grad - else d.qfrc_actuator - ) + qfrc_actuator_in = wp.clone(d.qfrc_actuator) if d.qfrc_actuator.requires_grad else d.qfrc_actuator wp.launch( _qfrc_actuator_gravcomp_limits, dim=(d.nworld, m.nv), @@ -1012,10 +995,92 @@ def fwd_acceleration(m: Model, d: Data, factorize: bool = False): else: smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) + # Custom adjoint for M_inv solve on the dense path. + # The tile Cholesky kernels have enable_backward=False, so the tape cannot + # propagate qacc_smooth.grad -> qfrc_smooth.grad automatically. We record + # a callback that performs the VJP: qfrc_smooth.grad += M_inv * qacc_smooth.grad + # (M is symmetric so M_inv^T = M_inv). + _record_fwd_accel_adjoint(m, d) + + +def _record_fwd_accel_adjoint(m: Model, d: Data): + """Record custom adjoint for the M_inv solve in fwd_acceleration. + + On the dense path, _tile_cholesky_factorize_solve has enable_backward=False. + This record_func propagates qacc_smooth.grad -> qfrc_smooth.grad via M_inv, + using the already-factored d.qLD from the forward pass. + + Array references are captured at record time (not through d) so that + intermediate array cloning between substeps routes each substep's adjoint + to the correct .grad memory. + """ + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad and not m.is_sparse: + from mujoco_warp._src.adjoint import _accumulate_grad_kernel + + # Capture current array refs for correct gradient isolation across substeps + qacc_smooth_ref = d.qacc_smooth + qfrc_smooth_ref = d.qfrc_smooth + + def _adjoint(m=m, d=d, qacc_smooth=qacc_smooth_ref, qfrc_smooth=qfrc_smooth_ref): + adj_qacc_smooth = qacc_smooth.grad + if adj_qacc_smooth is None: + return + # qfrc_smooth.grad += M_inv * qacc_smooth.grad + tmp = wp.zeros_like(qfrc_smooth) + smooth.solve_m(m, d, tmp, adj_qacc_smooth) + if qfrc_smooth.grad is None: + qfrc_smooth.grad = tmp + else: + wp.launch( + _accumulate_grad_kernel, + dim=(d.nworld, m.nv), + inputs=[tmp], + outputs=[qfrc_smooth.grad], + ) + + tape.record_func(_adjoint, [qacc_smooth_ref, qfrc_smooth_ref]) + + +def _record_solver_adjoint(m: Model, d: Data, qacc_array=None): + """Record the solver implicit differentiation adjoint on the active tape. + + Args: + qacc_array: The array whose .grad will receive the incoming adjoint from + the integrator backward. Defaults to d.qacc (correct when + the integrator uses d.qacc directly, e.g. eulerdamp disabled). + Integrators that create a local qacc must pass it here. + + Array references are captured at record time so that intermediate array + cloning between substeps routes each substep's adjoint correctly. + """ + tape = wp._src.context.runtime.tape + if tape is not None and d.qpos.requires_grad: + from mujoco_warp._src.adjoint import solver_implicit_adjoint + + if qacc_array is None: + qacc_array = d.qacc + + # Capture qacc_smooth ref at record time for gradient isolation + qacc_smooth_ref = d.qacc_smooth + + tape.record_func( + lambda m=m, d=d, qa=qacc_array, qs=qacc_smooth_ref: solver_implicit_adjoint( + m, d, qacc_array=qa, qacc_smooth_ref=qs + ), + [qacc_array, qacc_smooth_ref], + ) + @event_scope -def forward(m: Model, d: Data): - """Forward dynamics.""" +def forward(m: Model, d: Data, record_solver_adjoint: bool = True): + """Forward dynamics. + + Args: + record_solver_adjoint: If True, record the solver implicit differentiation + adjoint on the tape. Set to False when called from step() since the + integrator records its own adjoint at the correct tape position. + """ energy = m.opt.enableflags & EnableBit.ENERGY fwd_position(m, d, factorize=False) @@ -1042,25 +1107,88 @@ def forward(m: Model, d: Data): solver.solve(m, d) - # Record implicit differentiation adjoint on the active tape - tape = wp._src.context.runtime.tape - if tape is not None and d.qpos.requires_grad: - from mujoco_warp._src.adjoint import solver_implicit_adjoint - - tape.record_func( - lambda m=m, d=d: solver_implicit_adjoint(m, d), - [d.qacc, d.qacc_smooth], - ) + # Record implicit differentiation adjoint on the active tape. + # When called from step(), the integrator handles this instead (at the + # correct tape position between factor_solve_i and _advance). + if record_solver_adjoint: + _record_solver_adjoint(m, d) sensor.sensor_acc(m, d) +def _isolate_intermediates_for_ad(m: Model, d: Data): + """Allocate fresh intermediate arrays for per-substep gradient isolation. + + In tape-all mode (single wp.Tape over multiple step() calls), intermediate + arrays like qfrc_smooth and qacc_smooth are overwritten each substep but + share a single .grad array. This causes backward to accumulate adjoint + contributions from ALL substeps into the same memory (~250,000x amplification + for 16 substeps). + + By allocating fresh arrays at the start of each step(), each substep writes + to its own memory. The tape records operations on these unique arrays, and + backward routes each substep's adjoint to the correct .grad memory. + + Only called when AD is active (d.qpos.requires_grad). + """ + nw = d.nworld + nv = m.nv + nu = m.nu + + # --- Force arrays --- + d.qfrc_smooth = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qacc_smooth = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qfrc_actuator = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.actuator_force = wp.zeros((nw, nu), dtype=float, requires_grad=True) + d.qacc = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qfrc_bias = wp.zeros((nw, nv), dtype=float, requires_grad=True) + d.qfrc_passive = wp.zeros((nw, nv), dtype=float, requires_grad=True) + + # --- Kinematics arrays --- + # These use Warp vector/matrix dtypes (vec3, mat33, etc.) so use + # zeros_like to match the exact dtype and shape from the existing arrays. + d.xpos = wp.zeros_like(d.xpos, requires_grad=True) + d.xmat = wp.zeros_like(d.xmat, requires_grad=True) + d.xipos = wp.zeros_like(d.xipos, requires_grad=True) + d.ximat = wp.zeros_like(d.ximat, requires_grad=True) + d.subtree_com = wp.zeros_like(d.subtree_com, requires_grad=True) + d.cinert = wp.zeros_like(d.cinert, requires_grad=True) + d.cdof = wp.zeros_like(d.cdof, requires_grad=True) + d.cdof_dot = wp.zeros_like(d.cdof_dot, requires_grad=True) + d.cvel = wp.zeros_like(d.cvel, requires_grad=True) + d.crb = wp.zeros_like(d.crb, requires_grad=True) + d.cacc = wp.zeros_like(d.cacc, requires_grad=True) + + # --- Mass matrix --- + # Shapes depend on sparse vs dense; zeros_like handles both. + d.qM = wp.zeros_like(d.qM, requires_grad=True) + d.qLD = wp.zeros_like(d.qLD, requires_grad=True) + d.qLDiagInv = wp.zeros((nw, nv), dtype=float, requires_grad=True) + + # --- Geometry / joint kinematics --- + d.geom_xpos = wp.zeros_like(d.geom_xpos, requires_grad=True) + d.geom_xmat = wp.zeros_like(d.geom_xmat, requires_grad=True) + d.xanchor = wp.zeros_like(d.xanchor, requires_grad=True) + d.xaxis = wp.zeros_like(d.xaxis, requires_grad=True) + d.subtree_linvel = wp.zeros_like(d.subtree_linvel, requires_grad=True) + d.subtree_angmom = wp.zeros_like(d.subtree_angmom, requires_grad=True) + + # --- Actuator arrays --- + d.actuator_velocity = wp.zeros((nw, nu), dtype=float, requires_grad=True) + + @event_scope def step(m: Model, d: Data): """Advance simulation.""" # TODO(team): mj_checkPos # TODO(team): mj_checkVel - forward(m, d) + + # Allocate fresh intermediate arrays when AD is active to prevent + # cross-substep gradient accumulation in tape-all mode. + if d.qpos.requires_grad: + _isolate_intermediates_for_ad(m, d) + + forward(m, d, record_solver_adjoint=False) # TODO(team): mj_checkAcc if m.opt.integrator == IntegratorType.EULER: @@ -1108,15 +1236,8 @@ def step2(m: Model, d: Data): fwd_acceleration(m, d) solver.solve(m, d) - # Record implicit differentiation adjoint on the active tape - tape = wp._src.context.runtime.tape - if tape is not None and d.qpos.requires_grad: - from mujoco_warp._src.adjoint import solver_implicit_adjoint - - tape.record_func( - lambda m=m, d=d: solver_implicit_adjoint(m, d), - [d.qacc, d.qacc_smooth], - ) + # The solver adjoint record_func is handled by the integrator below, + # NOT here — see euler()/implicit() for details. sensor.sensor_acc(m, d) # TODO(team): mj_checkAcc diff --git a/mujoco_warp/_src/forward_test.py b/mujoco_warp/_src/forward_test.py index 3b6a9828b..a8b484621 100644 --- a/mujoco_warp/_src/forward_test.py +++ b/mujoco_warp/_src/forward_test.py @@ -649,6 +649,12 @@ def oscillator(m, d): np.testing.assert_allclose(d.act.numpy()[0, 0], np.cos(2 * np.pi * frequency * t_next), atol=1e-3) np.testing.assert_allclose(d.act.numpy()[0, 1], np.sin(2 * np.pi * frequency * t_next), atol=1e-3) + def test_multiflex(self): + """Tests multiflex model with different flex dimensions.""" + _, _, m, d = test_data.fixture("flex/multiflex.xml") + + mjw.forward(m, d) + if __name__ == "__main__": wp.init() diff --git a/mujoco_warp/_src/grad.py b/mujoco_warp/_src/grad.py index e30fd0f30..c3dc30477 100644 --- a/mujoco_warp/_src/grad.py +++ b/mujoco_warp/_src/grad.py @@ -93,17 +93,37 @@ "sensordata", ) -SOLVER_GRAD_FIELDS: tuple = ( - "qfrc_constraint", +SOLVER_GRAD_FIELDS: tuple = ("qfrc_constraint",) + +COLLISION_GRAD_FIELDS: tuple = ( + # Contact geometry (written by smooth_recompute_contacts) + "contact.dist", + "contact.pos", + "contact.frame", + # Constraint arrays (written by smooth_contact_to_efc) + "efc.J", + "efc.pos", + "efc.D", + "efc.aref", + "efc.vel", ) +def _resolve_field(d: Data, name: str): + """Resolve a field name, supporting dotted paths like 'contact.dist'.""" + if "." in name: + obj_name, field_name = name.split(".", 1) + obj = getattr(d, obj_name, None) + return getattr(obj, field_name, None) if obj else None + return getattr(d, name, None) + + def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None: """Enables gradient tracking on Data arrays.""" if fields is None: fields = SMOOTH_GRAD_FIELDS for name in fields: - arr = getattr(d, name, None) + arr = _resolve_field(d, name) if arr is not None and isinstance(arr, wp.array): arr.requires_grad = True @@ -111,7 +131,7 @@ def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None: def disable_grad(d: Data) -> None: """Disables gradient tracking on all Data arrays.""" for name in SMOOTH_GRAD_FIELDS: - arr = getattr(d, name, None) + arr = _resolve_field(d, name) if arr is not None and isinstance(arr, wp.array): arr.requires_grad = False @@ -132,8 +152,7 @@ def _warn_if_cg_solver(m: Model, d: Data): """Warn if CG solver is used with constraints (gradients will be zero).""" if d.njmax > 0 and m.opt.solver != SolverType.NEWTON: warnings.warn( - "Differentiable solver requires Newton. CG solver " - "gradients through constraints will be zero.", + "Differentiable solver requires Newton. CG solver gradients through constraints will be zero.", stacklevel=3, ) diff --git a/mujoco_warp/_src/grad_test.py b/mujoco_warp/_src/grad_test.py index 5ad2b85c5..c799b660e 100644 --- a/mujoco_warp/_src/grad_test.py +++ b/mujoco_warp/_src/grad_test.py @@ -20,6 +20,7 @@ import mujoco_warp as mjw from mujoco_warp import test_data from mujoco_warp._src import math +from mujoco_warp._src.grad import _resolve_field from mujoco_warp._src.grad import enable_grad # tolerance for AD vs finite-difference comparison @@ -120,6 +121,31 @@ """ +# Freejoint root + hinge child with actuator, for full step gradient test. +_FREE_HINGE_XML = """ + + + + + + + + + + + + + + + + + + + +""" + def _fd_gradient(fn, x_np, eps=1e-3): """Central-difference gradient of scalar fn w.r.t. x_np.""" @@ -160,6 +186,7 @@ class GradSmoothTest(parameterized.TestCase): @parameterized.parameters( ("hinge", _SIMPLE_HINGE_XML), ("slide", _SIMPLE_SLIDE_XML), + ("free", _SIMPLE_FREE_XML), ) def test_kinematics_grad(self, name, xml): """dL/dqpos through kinematics(): loss = sum(xpos).""" @@ -368,6 +395,53 @@ def eval_loss(ctrl_np): err_msg="euler step grad mismatch", ) + @absltest.skipIf( + wp.get_device().is_cuda and wp.get_device().arch < 70, + "tile kernels (cuSolverDx) require sm_70+", + ) + def test_euler_step_grad_free(self): + """Full Euler step gradient for freejoint + hinge model: dL/dctrl.""" + xml = _FREE_HINGE_XML + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.step(m, d) + wp.launch( + _sum_xpos_kernel, + dim=(d.nworld, m.nbody), + inputs=[d.xpos, loss], + ) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + def eval_loss(ctrl_np): + _, _, _, d_fd = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d_fd) + d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) + mjw.step(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch( + _sum_xpos_kernel, + dim=(d_fd.nworld, m.nbody), + inputs=[d_fd.xpos, l], + ) + return l.numpy()[0] + + ctrl_np = mjd.ctrl.copy() + fd_grad = _fd_gradient(eval_loss, ctrl_np) + + np.testing.assert_allclose( + ad_grad, + fd_grad, + atol=_FD_TOL, + rtol=_FD_TOL, + err_msg="euler step grad (freejoint+hinge) mismatch", + ) + @wp.kernel def _quat_integrate_kernel( @@ -813,6 +887,288 @@ def test_make_diff_data_custom_fields(self): self.assertFalse(d.qvel.requires_grad) self.assertFalse(d.ctrl.requires_grad) + def test_enable_backward_module_flags(self): + """Verify enable_backward is set correctly on all AD-relevant modules.""" + from mujoco_warp._src import collision_smooth + from mujoco_warp._src import derivative + from mujoco_warp._src import forward as forward_mod + from mujoco_warp._src import passive + from mujoco_warp._src import smooth + + # Modules that SHOULD have enable_backward=True + for mod in [smooth, forward_mod, passive, derivative, collision_smooth]: + opts = wp.get_module_options(mod) + self.assertTrue( + opts.get("enable_backward", False), + f"{mod.__name__} should have enable_backward=True", + ) + + # Modules that should NOT have enable_backward + from mujoco_warp._src import collision_driver + from mujoco_warp._src import constraint + from mujoco_warp._src import solver + + for mod in [constraint, solver, collision_driver]: + opts = wp.get_module_options(mod) + self.assertFalse( + opts.get("enable_backward", False), + f"{mod.__name__} should have enable_backward=False", + ) + + def test_enable_grad_all_smooth_fields(self): + """All SMOOTH_GRAD_FIELDS are toggled by enable_grad.""" + mjm = mujoco.MjModel.from_xml_string(_SIMPLE_HINGE_XML) + d = mjw.make_data(mjm) + + mjw.enable_grad(d) + for name in mjw.SMOOTH_GRAD_FIELDS: + arr = _resolve_field(d, name) + if arr is not None and isinstance(arr, wp.array): + self.assertTrue( + arr.requires_grad, + f"SMOOTH_GRAD_FIELDS field '{name}' not enabled by enable_grad", + ) + + mjw.disable_grad(d) + for name in mjw.SMOOTH_GRAD_FIELDS: + arr = _resolve_field(d, name) + if arr is not None and isinstance(arr, wp.array): + self.assertFalse( + arr.requires_grad, + f"SMOOTH_GRAD_FIELDS field '{name}' not disabled by disable_grad", + ) + + def test_forward_without_grad_no_error(self): + """Forward pipeline without enable_grad works (no errors, no gradients).""" + mjm, mjd, m, d = test_data.fixture(xml=_SIMPLE_HINGE_XML, keyframe=0) + # Do NOT call enable_grad + mjw.kinematics(m, d) + mjw.com_pos(m, d) + mjw.crb(m, d) + + # Verify no requires_grad is set + self.assertFalse(d.qpos.requires_grad) + self.assertFalse(d.xpos.requires_grad) + + def test_diff_step_produces_nonzero_gradients(self): + """diff_step with enable_grad produces nonzero gradients.""" + mjm, mjd, m, d = test_data.fixture(xml=_SIMPLE_HINGE_XML, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.kinematics(m, d) + mjw.com_pos(m, d) + wp.launch( + _sum_xpos_kernel, + dim=(d.nworld, m.nbody), + inputs=[d.xpos, loss], + ) + tape.backward(loss=loss) + + ad_grad = d.qpos.grad.numpy()[0, : mjm.nq] + # With a non-zero keyframe, kinematics gradients should be nonzero + self.assertTrue( + np.any(np.abs(ad_grad) > 1e-6), + "enable_grad + tape should produce nonzero gradients", + ) + + +# ---- Test models for integrator gradient path ---- + +_HINGE_EULERDAMP_DISABLED_XML = """ + + + + + + + + + + + + + + + + + + + + +""" + +_HINGE_EULERDAMP_ENABLED_XML = """ + + +""" + + +class GradIntegratorTest(parameterized.TestCase): + """Tests that exercise the gradient path through the integrator. + + Unlike test_euler_step_grad (which uses loss on xpos and bypasses the + integrator), these tests use loss on qpos after step(), verifying that + gradients flow through: ctrl -> actuation -> acceleration -> solver adjoint + -> integrator -> qpos. + """ + + @absltest.skipIf( + wp.get_device().is_cuda and wp.get_device().arch < 70, + "tile kernels (cuSolverDx) require sm_70+", + ) + def test_euler_qpos_grad_no_eulerdamp(self): + """dL/dctrl through step() measured on qpos, eulerdamp disabled.""" + xml = _HINGE_EULERDAMP_DISABLED_XML + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + # AD gradient + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.step(m, d) + wp.launch( + _sum_qpos_kernel, + dim=(d.nworld, mjm.nq), + inputs=[d.qpos, loss], + ) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + # FD gradient + def eval_loss(ctrl_np): + _, _, _, d_fd = test_data.fixture(xml=xml, keyframe=0) + d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) + mjw.step(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch( + _sum_qpos_kernel, + dim=(d_fd.nworld, mjm.nq), + inputs=[d_fd.qpos, l], + ) + return l.numpy()[0] + + ctrl_np = mjd.ctrl.copy() + fd_grad = _fd_gradient(eval_loss, ctrl_np, eps=1e-3) + + self.assertTrue( + np.linalg.norm(ad_grad) > 1e-6, + f"AD gradient should be nonzero, got |grad|={np.linalg.norm(ad_grad):.3e}", + ) + np.testing.assert_allclose( + ad_grad, fd_grad, atol=_FD_TOL, rtol=_FD_TOL, + err_msg="AD vs FD mismatch for dL(qpos)/dctrl (eulerdamp disabled)", + ) + + @absltest.skipIf( + wp.get_device().is_cuda and wp.get_device().arch < 70, + "tile kernels (cuSolverDx) require sm_70+", + ) + def test_euler_qpos_grad_with_eulerdamp(self): + """dL/dctrl through step() measured on qpos, eulerdamp enabled.""" + xml = _HINGE_EULERDAMP_ENABLED_XML + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + # AD gradient + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.step(m, d) + wp.launch( + _sum_qpos_kernel, + dim=(d.nworld, mjm.nq), + inputs=[d.qpos, loss], + ) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + # FD gradient + def eval_loss(ctrl_np): + _, _, _, d_fd = test_data.fixture(xml=xml, keyframe=0) + d_fd.ctrl = wp.array(ctrl_np.reshape(1, -1), dtype=float) + mjw.step(m, d_fd) + l = wp.zeros(1, dtype=float) + wp.launch( + _sum_qpos_kernel, + dim=(d_fd.nworld, mjm.nq), + inputs=[d_fd.qpos, l], + ) + return l.numpy()[0] + + ctrl_np = mjd.ctrl.copy() + fd_grad = _fd_gradient(eval_loss, ctrl_np, eps=1e-3) + + self.assertTrue( + np.linalg.norm(ad_grad) > 1e-6, + f"AD gradient should be nonzero, got |grad|={np.linalg.norm(ad_grad):.3e}", + ) + np.testing.assert_allclose( + ad_grad, fd_grad, atol=_FD_TOL, rtol=_FD_TOL, + err_msg="AD vs FD mismatch for dL(qpos)/dctrl (eulerdamp enabled)", + ) + + @absltest.skipIf( + wp.get_device().is_cuda and wp.get_device().arch < 70, + "tile kernels (cuSolverDx) require sm_70+", + ) + @absltest.skipIf( + wp.get_device().is_cuda and wp.get_device().arch < 70, + "tile kernels (cuSolverDx) require sm_70+", + ) + def test_multistep_qpos_grad_nonzero(self): + """dL/dctrl through 2 steps produces nonzero gradient.""" + xml = _HINGE_EULERDAMP_DISABLED_XML + mjm, mjd, m, d = test_data.fixture(xml=xml, keyframe=0) + enable_grad(d) + + loss = wp.zeros(1, dtype=float, requires_grad=True) + tape = wp.Tape() + with tape: + mjw.step(m, d) + mjw.step(m, d) + wp.launch( + _sum_qpos_kernel, + dim=(d.nworld, mjm.nq), + inputs=[d.qpos, loss], + ) + tape.backward(loss=loss) + ad_grad = d.ctrl.grad.numpy()[0, : mjm.nu].copy() + tape.zero() + + # Multi-step AD vs FD accuracy is limited by shared-array accumulation + # across steps (a known Warp tape limitation). Here we just verify the + # gradient is nonzero — single-step FD accuracy is tested above. + self.assertTrue( + np.linalg.norm(ad_grad) > 1e-6, + f"Multi-step AD gradient should be nonzero, got |grad|={np.linalg.norm(ad_grad):.3e}", + ) + if __name__ == "__main__": absltest.main() diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index aca9c94fe..c1f29b6f6 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -669,6 +669,167 @@ def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> return int(valid_sizes[np.searchsorted(valid_sizes, njmax)]) +def _body_pair_nnz(mjm: mujoco.MjModel, body1: int, body2: int) -> int: + """Returns the number of unique DOFs in the kinematic tree union of two bodies.""" + body1 = mjm.body_weldid[body1] + body2 = mjm.body_weldid[body2] + da1 = mjm.body_dofadr[body1] + mjm.body_dofnum[body1] - 1 + da2 = mjm.body_dofadr[body2] + mjm.body_dofnum[body2] - 1 + nnz = 0 + while da1 >= 0 or da2 >= 0: + da = max(da1, da2) + if da1 == da: + da1 = mjm.dof_parentid[da1] + if da2 == da: + da2 = mjm.dof_parentid[da2] + nnz += 1 + return nnz + + +def _default_njmax_nnz(mjm: mujoco.MjModel, nconmax: int, njmax: int) -> int: + """Returns a heuristic estimate for the number of non-zeros in the sparse constraint Jacobian. + + Assumes all equality, friction, and limit constraints are active and computes + their non-zeros. For contacts, assumes njmax contact rows at the maximum + body-pair non-zeros from all enabled collision pairs. + + Args: + mjm: The model containing kinematic and dynamic information (host). + nconmax: Maximum number of contacts per world. + njmax: Maximum number of constraint rows per world. + + Returns: + Estimated number of non-zeros in the constraint Jacobian. + """ + total_nnz = 0 + + def _eq_bodies(i): + """Returns body pair for equality constraint i.""" + obj1id, obj2id = mjm.eq_obj1id[i], mjm.eq_obj2id[i] + if mjm.eq_objtype[i] == mujoco.mjtObj.mjOBJ_SITE: + return mjm.site_bodyid[obj1id], mjm.site_bodyid[obj2id] + return obj1id, obj2id + + # equality constraints (assume all active) + for i in range(mjm.neq): + eq_type = mjm.eq_type[i] + + if eq_type == mujoco.mjtEq.mjEQ_CONNECT: + total_nnz += 3 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_WELD: + total_nnz += 6 * _body_pair_nnz(mjm, *_eq_bodies(i)) + + elif eq_type == mujoco.mjtEq.mjEQ_JOINT: + total_nnz += 2 if mjm.eq_obj2id[i] >= 0 else 1 + + elif eq_type == mujoco.mjtEq.mjEQ_TENDON: + obj1id = mjm.eq_obj1id[i] + obj2id = mjm.eq_obj2id[i] + rownnz1 = mjm.ten_J_rownnz[obj1id] if obj1id < mjm.ntendon else 0 + if obj2id >= 0 and obj2id < mjm.ntendon: + rowadr1 = mjm.ten_J_rowadr[obj1id] + rowadr2 = mjm.ten_J_rowadr[obj2id] + rownnz2 = mjm.ten_J_rownnz[obj2id] + cols = set() + for j in range(rownnz1): + cols.add(mjm.ten_J_colind[rowadr1 + j]) + for j in range(rownnz2): + cols.add(mjm.ten_J_colind[rowadr2 + j]) + total_nnz += len(cols) + else: + total_nnz += rownnz1 + + elif eq_type == mujoco.mjtEq.mjEQ_FLEX: + obj1id = mjm.eq_obj1id[i] + if obj1id < mjm.nflex: + edge_start = mjm.flex_edgeadr[obj1id] + edge_count = mjm.flex_edgenum[obj1id] + for e in range(edge_count): + total_nnz += mjm.flexedge_J_rownnz[edge_start + e] + + # friction constraints + total_nnz += (mjm.dof_frictionloss > 0).sum() + for i in range(mjm.ntendon): + if mjm.tendon_frictionloss[i] > 0: + total_nnz += mjm.ten_J_rownnz[i] + + # limit constraints (assume all active) + for i in range(mjm.njnt): + if mjm.jnt_limited[i]: + jnt_type = mjm.jnt_type[i] + if jnt_type == mujoco.mjtJoint.mjJNT_BALL: + total_nnz += 3 + elif jnt_type in (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE): + total_nnz += 1 + for i in range(mjm.ntendon): + if mjm.tendon_limited[i]: + total_nnz += mjm.ten_J_rownnz[i] + + # contact constraints: njmax rows at max body-pair non-zeros + max_contact_nnz = 0 + + # contact pairs + for i in range(mjm.npair): + g1, g2 = mjm.pair_geom1[i], mjm.pair_geom2[i] + b1, b2 = mjm.geom_bodyid[g1], mjm.geom_bodyid[g2] + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, b1, b2)) + + # filter geom-geom pairs (unique body pairs, filtered) + body_pair_seen = set() + for i in range(mjm.ngeom): + bi = mjm.geom_bodyid[i] + cti, cai = mjm.geom_contype[i], mjm.geom_conaffinity[i] + for j in range(i + 1, mjm.ngeom): + bj = mjm.geom_bodyid[j] + if bi == bj: + continue + if mjm.body_weldid[bi] == 0 and mjm.body_weldid[bj] == 0: + continue + bp = (min(bi, bj), max(bi, bj)) + if bp in body_pair_seen: + continue + ctj, caj = mjm.geom_contype[j], mjm.geom_conaffinity[j] + if not ((cti & caj) or (ctj & cai)): + continue + body_pair_seen.add(bp) + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, bi, bj)) + + # flex vertex contacts + for fi in range(mjm.nflex): + fct = mjm.flex_contype[fi] + fca = mjm.flex_conaffinity[fi] + + vert_start = mjm.flex_vertadr[fi] + vert_count = mjm.flex_vertnum[fi] + flex_bodies = {mjm.flex_vertbodyid[vert_start + v] for v in range(vert_count)} + + geom_bodies = set() + for g in range(mjm.ngeom): + ct, ca = mjm.geom_contype[g], mjm.geom_conaffinity[g] + if (fct & ca) or (ct & fca): + geom_bodies.add(mjm.geom_bodyid[g]) + + for fb in flex_bodies: + for gb in geom_bodies: + if fb != gb: + max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, fb, gb)) + + # flex self-collision + if mjm.flex_selfcollide[fi]: + flex_body_list = sorted(flex_bodies) + for idx1 in range(len(flex_body_list)): + for idx2 in range(idx1 + 1, len(flex_body_list)): + max_contact_nnz = max( + max_contact_nnz, + _body_pair_nnz(mjm, flex_body_list[idx1], flex_body_list[idx2]), + ) + + total_nnz += njmax * max_contact_nnz + + return int(min(max(total_nnz, 1), njmax * mjm.nv)) + + def _resolve_batch_size(na: int | None, n: int | None, nworld: int, default: int) -> int: if na is not None: return na @@ -747,9 +908,11 @@ def make_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax - # TODO(team): heuristic for constraint Jacobian number of non-zeros - if njmax_nnz is None or not is_sparse(mjm): - njmax_nnz = njmax * mjm.nv + if njmax_nnz is None: + if is_sparse(mjm): + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)}) contact.efc_address = wp.array(np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int), dtype=int) @@ -935,9 +1098,11 @@ def put_data( sizes["naconmax"] = naconmax sizes["njmax"] = njmax - # TODO(team): heuristic for constraint Jacobian number of non-zeros - if njmax_nnz is None or not is_sparse(mjm): - njmax_nnz = njmax * mjm.nv + if njmax_nnz is None: + if is_sparse(mjm): + njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax) + else: + njmax_nnz = njmax * mjm.nv # ensure static geom positions are computed # TODO: remove once MjData creation semantics are fixed @@ -987,23 +1152,27 @@ def put_data( efc = types.Constraint(**efc_kwargs) if is_sparse(mjm): - # TODO(team): process efc_J sparsity structure for nv row shift - efc.J_rownnz = wp.array(np.full((nworld, njmax), mjm.nv, dtype=int), dtype=int) - efc.J_rowadr = wp.array( - np.tile(np.arange(0, njmax * mjm.nv, mjm.nv) if mjm.nv else np.zeros(njmax, dtype=int), (nworld, 1)), dtype=int - ) - efc.J_colind = wp.array(np.tile(np.arange(mjm.nv), (nworld, njmax)).reshape((nworld, 1, -1))[:, :, :njmax_nnz], dtype=int) - - mj_efc_J = np.zeros((mjd.nefc, mjm.nv)) + J_rownnz = np.zeros(njmax, dtype=np.int32) + J_rowadr = np.zeros(njmax, dtype=np.int32) + J_colind = np.zeros(njmax_nnz, dtype=np.int32) + J = np.zeros(njmax_nnz, dtype=np.float64) if mjd.nefc: if mujoco.mj_isSparse(mjm): - mujoco.mju_sparse2dense(mj_efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind) + J_rownnz[: mjd.nefc] = mjd.efc_J_rownnz[: mjd.nefc] + J_rowadr[: mjd.nefc] = mjd.efc_J_rowadr[: mjd.nefc] + nnz = int(mjd.efc_J_rownnz[: mjd.nefc].sum()) + J_colind[:nnz] = mjd.efc_J_colind[:nnz] + J[:nnz] = mjd.efc_J[:nnz] else: - mj_efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - efc_J = np.zeros((njmax, mjm.nv), dtype=float) - efc_J[: mjd.nefc, : mjm.nv] = mj_efc_J - efc_J_flat = np.tile(efc_J.reshape(-1), (nworld, 1, 1)).reshape((nworld, 1, -1))[:, :, :njmax_nnz] - efc.J = wp.array(efc_J_flat, dtype=float) + dense_J = mjd.efc_J.reshape((-1, mjm.nv))[: mjd.nefc] + mujoco.mju_dense2sparse( + J[: mjd.nefc * mjm.nv], dense_J, J_rownnz[: mjd.nefc], J_rowadr[: mjd.nefc], J_colind[: mjd.nefc * mjm.nv] + ) + + efc.J_rownnz = wp.array(np.tile(J_rownnz, (nworld, 1)), dtype=int) + efc.J_rowadr = wp.array(np.tile(J_rowadr, (nworld, 1)), dtype=int) + efc.J_colind = wp.array(np.tile(J_colind, (nworld, 1)).reshape((nworld, 1, -1)), dtype=int) + efc.J = wp.array(np.tile(J, (nworld, 1)).reshape((nworld, 1, -1)), dtype=float) else: efc.J_rownnz = wp.zeros((nworld, 0), dtype=int) efc.J_rowadr = wp.zeros((nworld, 0), dtype=int) @@ -2567,64 +2736,28 @@ def create_render_context( hfield_bounds_size_arr = wp.array(hfield_bounds_size, dtype=wp.vec3) # Flex BVHs - flex_bvh_id = wp.uint64(0) - flex_group_root = wp.zeros(nworld, dtype=int) - flex_mesh = None - flex_face_point = None - flex_elemdataadr = None - flex_shell = None - flex_shelldataadr = None - flex_faceadr = None - flex_nface = 0 - flex_radius = None - flex_vertflexid = None - flex_workadr = None - flex_worknum = None - flex_nwork = 0 - - if mjm.nflex > 0: - ( - fmesh, - face_point, - flex_group_roots, - flex_shell_data, - flex_faceadr_data, - flex_nface, - ) = bvh.build_flex_bvh(mjm, mjd, nworld) - - flex_mesh = fmesh - flex_bvh_id = fmesh.id - flex_face_point = face_point - flex_group_root = flex_group_roots - flex_elemdataadr = wp.array(mjm.flex_elemdataadr, dtype=int) - flex_shell = flex_shell_data - flex_shelldataadr = wp.array(mjm.flex_shelldataadr, dtype=int) - flex_faceadr = wp.array(flex_faceadr_data, dtype=int) - flex_radius = wp.array(mjm.flex_radius, dtype=float) - - # Compute flex_vertflexid: maps each flex vertex to its flex index - flex_vertflexid_data = np.zeros(mjm.nflexvert, dtype=np.int32) - for flexid in range(mjm.nflex): - vert_start = mjm.flex_vertadr[flexid] - vert_end = vert_start + mjm.flex_vertnum[flexid] - flex_vertflexid_data[vert_start:vert_end] = flexid - flex_vertflexid = wp.array(flex_vertflexid_data, dtype=int) - - # precompute work item layout for unified refit kernel - nflex = mjm.nflex - workadr = np.zeros(nflex, dtype=np.int32) - worknum = np.zeros(nflex, dtype=np.int32) - cumsum = 0 - for f in range(nflex): - workadr[f] = cumsum - if mjm.flex_dim[f] == 2: - worknum[f] = mjm.flex_elemnum[f] + mjm.flex_shellnum[f] - else: - worknum[f] = mjm.flex_shellnum[f] - cumsum += worknum[f] - flex_workadr = wp.array(workadr, dtype=int) - flex_worknum = wp.array(worknum, dtype=int) - flex_nwork = int(cumsum) + nflex = mjm.nflex + flex_registry = {} + + # Scene BVH flex primitives: 1D → one capsule per edge, 2D/3D → one box per flex + flex_geom_flexid = [] + flex_geom_edgeid = [] + flex_bvh_id = np.full(nflex, 0, dtype=wp.uint64) + flex_group_root = np.zeros((nflex, nworld), dtype=int) + + for f in range(nflex): + if mjm.flex_dim[f] == 1: + edge_adr = mjm.flex_edgeadr[f] + flex_geom_flexid.extend([f] * mjm.flex_edgenum[f]) + flex_geom_edgeid.extend([edge_adr + e for e in range(mjm.flex_edgenum[f])]) + flex_group_root[f] = np.zeros(nworld, dtype=int) + else: + flex_geom_flexid.append(f) + flex_geom_edgeid.append(-1) + fmesh, group_root = bvh.build_flex_bvh(mjm, mjd, nworld, f) + flex_registry[f] = fmesh + flex_bvh_id[f] = fmesh.id + flex_group_root[f] = group_root.numpy() textures_registry = [] for i in range(mjm.ntex): @@ -2743,26 +2876,20 @@ def create_render_context( hfield_registry=hfield_registry, hfield_bvh_id=hfield_bvh_id_arr, hfield_bounds_size=hfield_bounds_size_arr, - flex_mesh=flex_mesh, + flex_mesh_registry=flex_registry, flex_rgba=wp.array(mjm.flex_rgba, dtype=wp.vec4), - flex_bvh_id=flex_bvh_id, - flex_face_point=flex_face_point, - flex_faceadr=flex_faceadr, - flex_nface=flex_nface, - flex_nwork=flex_nwork, - flex_group_root=flex_group_root, - flex_elemdataadr=flex_elemdataadr, - flex_shell=flex_shell, - flex_shelldataadr=flex_shelldataadr, - flex_radius=flex_radius, - flex_workadr=flex_workadr, - flex_worknum=flex_worknum, + flex_bvh_id=wp.array(flex_bvh_id, dtype=wp.uint64), + flex_group_root=wp.array(flex_group_root, dtype=int), flex_render_smooth=flex_render_smooth, + bvh_nflexgeom=len(flex_geom_flexid), + flex_dim_np=mjm.flex_dim, + flex_geom_flexid=wp.array(flex_geom_flexid, dtype=int), + flex_geom_edgeid=wp.array(flex_geom_edgeid, dtype=int), bvh=None, bvh_id=None, - lower=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), - upper=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3), - group=wp.zeros(nworld * bvh_ngeom, dtype=int), + lower=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3), + upper=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3), + group=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=int), group_root=wp.zeros(nworld, dtype=int), ray=ray, rgb_data=wp.zeros((nworld, ri), dtype=wp.uint32), diff --git a/mujoco_warp/_src/math.py b/mujoco_warp/_src/math.py index a4384854a..f212914ea 100644 --- a/mujoco_warp/_src/math.py +++ b/mujoco_warp/_src/math.py @@ -83,6 +83,35 @@ def quat_to_mat(quat: wp.quat) -> wp.mat33: ) +@wp.func +def quat_z2vec(vec: wp.vec3) -> wp.quat: + """Compute quaternion performing rotation from z-axis to given vector.""" + quat = wp.quat(0.0, 0.0, 0.0, 1.0) + + # normalize vector; if too small, no rotation + norm = wp.length(vec) + if norm < types.MJ_MINVAL: + return quat + vec = vec / norm + + axis = wp.vec3(-vec[1], vec[0], 0.0) + a = wp.length(axis) + + # almost parallel + if a < types.MJ_MINVAL: + # opposite: 180 deg rotation around x axis + if vec[2] < 0.0: + quat = wp.quat(1.0, 0.0, 0.0, 0.0) + return quat + + # make quaternion from angle and axis + axis = axis / a + angle = wp.atan2(a, vec[2]) + quat = axis_angle_to_quat(axis, angle) + + return quat + + @wp.func def quat_inv(quat: wp.quat) -> wp.quat: return wp.quat(quat[0], -quat[1], -quat[2], -quat[3]) diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index f5c098e4e..f8f534af1 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -574,6 +574,7 @@ def _flex_elasticity( flex_edgeadr: wp.array(dtype=int), flex_elemadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), + flex_elemdataadr: wp.array(dtype=int), flex_elemedgeadr: wp.array(dtype=int), flex_vertbodyid: wp.array(dtype=int), flex_elem: wp.array(dtype=int), @@ -599,6 +600,7 @@ def _flex_elasticity( f = i break + local_elemid = elemid - flex_elemadr[f] dim = flex_dim[f] nvert = dim + 1 nedge = nvert * (nvert - 1) / 2 @@ -612,10 +614,11 @@ def _flex_elasticity( else: kD = 0.0 + elem_data_adr = flex_elemdataadr[f] + local_elemid * (dim + 1) gradient = wp.matrix(0.0, shape=(6, 6)) for e in range(nedge): - vert0 = flex_elem[(dim + 1) * elemid + edges[e, 0]] - vert1 = flex_elem[(dim + 1) * elemid + edges[e, 1]] + vert0 = flex_elem[elem_data_adr + edges[e, 0]] + vert1 = flex_elem[elem_data_adr + edges[e, 1]] xpos0 = flexvert_xpos_in[worldid, vert0] xpos1 = flexvert_xpos_in[worldid, vert1] for i in range(3): @@ -624,7 +627,7 @@ def _flex_elasticity( elongation = wp.spatial_vectorf(0.0) for e in range(nedge): - idx = flex_elemedge[elemid * nedge + e] + idx = flex_elemedge[flex_elemedgeadr[f] + local_elemid * nedge + e] vel = flexedge_velocity_in[worldid, flex_edgeadr[f] + idx] deformed = flexedge_length_in[worldid, flex_edgeadr[f] + idx] reference = flexedge_length0[flex_edgeadr[f] + idx] @@ -647,7 +650,7 @@ def _flex_elasticity( force[edges[ed2, i], x] -= elongation[ed1] * gradient[ed2, 3 * i + x] * metric[ed1, ed2] for v in range(nvert): - vert = flex_elem[(dim + 1) * elemid + v] + vert = flex_elem[elem_data_adr + v] bodyid = flex_vertbodyid[flex_vertadr[f] + vert] for x in range(3): wp.atomic_add(qfrc_spring_out, worldid, body_dofadr[bodyid] + x, force[v, x]) @@ -784,6 +787,7 @@ def passive(m: Model, d: Data): m.flex_edgeadr, m.flex_elemadr, m.flex_elemnum, + m.flex_elemdataadr, m.flex_elemedgeadr, m.flex_vertbodyid, m.flex_elem, diff --git a/mujoco_warp/_src/ray.py b/mujoco_warp/_src/ray.py index 5bc2a9a28..44c56f962 100644 --- a/mujoco_warp/_src/ray.py +++ b/mujoco_warp/_src/ray.py @@ -752,7 +752,8 @@ def ray_mesh_with_bvh_anyhit( @wp.func def ray_flex_with_bvh( # In: - bvh_id: wp.uint64, + flex_bvh_id: wp.array(dtype=wp.uint64), + flexid: int, group_root: int, pnt: wp.vec3, vec: wp.vec3, @@ -769,7 +770,7 @@ def ray_flex_with_bvh( n = wp.vec3(0.0, 0.0, 0.0) f = int(-1) - hit = wp.mesh_query_ray(bvh_id, pnt, vec, max_t, t, u, v, sign, n, f, group_root) + hit = wp.mesh_query_ray(flex_bvh_id[flexid], pnt, vec, max_t, t, u, v, sign, n, f, group_root) if hit: return t, n, u, v, f @@ -777,6 +778,23 @@ def ray_flex_with_bvh( return -1.0, wp.vec3(0.0, 0.0, 0.0), 0.0, 0.0, -1 +@wp.func +def ray_flex_with_bvh_anyhit( + # In: + flex_bvh_id: wp.array(dtype=wp.uint64), + flexid: int, + group_root: int, + pnt: wp.vec3, + vec: wp.vec3, + max_t: float, +) -> bool: + """Returns True if there is any hit for ray flex intersections. + + Requires wp.Mesh be constructed and their ids to be passed. Flex are already in world space. + """ + return wp.mesh_query_ray_anyhit(flex_bvh_id[flexid], pnt, vec, max_t, group_root) + + @wp.func def ray_geom(pos: wp.vec3, mat: wp.mat33, size: wp.vec3, pnt: wp.vec3, vec: wp.vec3, geomtype: int) -> Tuple[float, wp.vec3]: """Returns distance along ray to intersection with geom and normal at intersection point. diff --git a/mujoco_warp/_src/render.py b/mujoco_warp/_src/render.py index dbaf001b0..47a67cbd5 100644 --- a/mujoco_warp/_src/render.py +++ b/mujoco_warp/_src/render.py @@ -23,6 +23,7 @@ from mujoco_warp._src.ray import ray_cylinder from mujoco_warp._src.ray import ray_ellipsoid from mujoco_warp._src.ray import ray_flex_with_bvh +from mujoco_warp._src.ray import ray_flex_with_bvh_anyhit from mujoco_warp._src.ray import ray_mesh_with_bvh from mujoco_warp._src.ray import ray_mesh_with_bvh_anyhit from mujoco_warp._src.ray import ray_plane @@ -90,17 +91,26 @@ def cast_ray( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: bvh_id: wp.uint64, group_root: int, - world_id: int, + worldid: int, bvh_ngeom: int, + flex_bvh_ngeom: int, enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), ray_origin_world: wp.vec3, ray_dir_world: wp.vec3, ) -> Tuple[int, float, wp.vec3, float, float, int, int]: @@ -114,91 +124,127 @@ def cast_ray( query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root) bounds_nr = int(0) + ngeom = bvh_ngeom + flex_bvh_ngeom while wp.bvh_query_next(query, bounds_nr, dist): gi_global = bounds_nr - gi_bvh_local = gi_global - (world_id * bvh_ngeom) - gi = enabled_geom_ids[gi_bvh_local] + local_id = gi_global - (worldid * ngeom) + d = float(-1.0) hit_mesh_id = int(-1) u = float(0.0) v = float(0.0) f = int(-1) n = wp.vec3(0.0, 0.0, 0.0) + hit_geom_id = int(-1) + + if local_id < bvh_ngeom: + gi = enabled_geom_ids[local_id] + gtype = geom_type[gi] + else: + gi = local_id - bvh_ngeom + gtype = GeomType.FLEX + + hit_geom_id = gi # TODO: Investigate branch elimination with static loop unrolling - if geom_type[gi] == GeomType.PLANE: + if gtype == GeomType.PLANE: d, n = ray_plane( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.HFIELD: + if gtype == GeomType.HFIELD: d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh( hfield_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, dist, ) - if geom_type[gi] == GeomType.SPHERE: + if gtype == GeomType.SPHERE: d, n = ray_sphere( - geom_xpos_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0], + geom_xpos_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi][0] * geom_size[worldid % geom_size.shape[0], gi][0], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.ELLIPSOID: + if gtype == GeomType.ELLIPSOID: d, n = ray_ellipsoid( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CAPSULE: + if gtype == GeomType.CAPSULE: d, n = ray_capsule( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CYLINDER: + if gtype == GeomType.CYLINDER: d, n = ray_cylinder( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.BOX: + if gtype == GeomType.BOX: d, all, n = ray_box( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.MESH: + if gtype == GeomType.MESH: d, n, u, v, f, hit_mesh_id = ray_mesh_with_bvh( mesh_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, dist, ) + if gtype == GeomType.FLEX: + hit_geom_id = -2 + flexid = flex_geom_flexid[gi] + edge_id = flex_geom_edgeid[gi] + + if edge_id >= 0: + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flexid] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + pos = 0.5 * (v0 + v1) + vec = v1 - v0 + + length = wp.length(vec) + edgeq = math.quat_z2vec(vec) + mat = math.quat_to_mat(edgeq) + size = wp.vec3(flex_radius[flexid], 0.5 * length, 0.0) + + d, n = ray_capsule(pos, mat, size, ray_origin_world, ray_dir_world) + hit_mesh_id = flexid + else: + flex_gr = flex_group_root[worldid, flexid] + d, n, u, v, f = ray_flex_with_bvh(flex_bvh_id, flexid, flex_gr, ray_origin_world, ray_dir_world, dist) + if d >= 0.0: + hit_mesh_id = flexid if d >= 0.0 and d < dist: dist = d normal = n - geom_id = gi + geom_id = hit_geom_id bary_u = u bary_v = v face_idx = f @@ -213,17 +259,26 @@ def cast_ray_first_hit( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: bvh_id: wp.uint64, group_root: int, - world_id: int, + worldid: int, bvh_ngeom: int, + bvh_nflexgeom: int, enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), ray_origin_world: wp.vec3, ray_dir_world: wp.vec3, max_dist: float, @@ -231,81 +286,119 @@ def cast_ray_first_hit( """A simpler version of casting rays that only checks for the first hit.""" query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root) bounds_nr = int(0) + ngeom = bvh_ngeom + bvh_nflexgeom while wp.bvh_query_next(query, bounds_nr, max_dist): gi_global = bounds_nr - gi_bvh_local = gi_global - (world_id * bvh_ngeom) - gi = enabled_geom_ids[gi_bvh_local] + local_id = gi_global - (worldid * ngeom) + + d = float(-1.0) + n = wp.vec3(0.0, 0.0, 0.0) + + if local_id < bvh_ngeom: + gi = enabled_geom_ids[local_id] + gtype = geom_type[gi] + else: + gi = local_id - bvh_ngeom + gtype = GeomType.FLEX # TODO: Investigate branch elimination with static loop unrolling - if geom_type[gi] == GeomType.PLANE: + if gtype == GeomType.PLANE: d, n = ray_plane( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.HFIELD: + if gtype == GeomType.HFIELD: d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh( hfield_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, max_dist, ) - if geom_type[gi] == GeomType.SPHERE: + if gtype == GeomType.SPHERE: d, n = ray_sphere( - geom_xpos_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0], + geom_xpos_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi][0] * geom_size[worldid % geom_size.shape[0], gi][0], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.ELLIPSOID: + if gtype == GeomType.ELLIPSOID: d, n = ray_ellipsoid( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CAPSULE: + if gtype == GeomType.CAPSULE: d, n = ray_capsule( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.CYLINDER: + if gtype == GeomType.CYLINDER: d, n = ray_cylinder( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.BOX: + if gtype == GeomType.BOX: d, all, n = ray_box( - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], - geom_size[world_id % geom_size.shape[0], gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], + geom_size[worldid % geom_size.shape[0], gi], ray_origin_world, ray_dir_world, ) - if geom_type[gi] == GeomType.MESH: + if gtype == GeomType.MESH: hit = ray_mesh_with_bvh_anyhit( mesh_bvh_id, geom_dataid[gi], - geom_xpos_in[world_id, gi], - geom_xmat_in[world_id, gi], + geom_xpos_in[worldid, gi], + geom_xmat_in[worldid, gi], ray_origin_world, ray_dir_world, max_dist, ) d = 0.0 if hit else -1.0 + if gtype == GeomType.FLEX: + flexid = flex_geom_flexid[gi] + edge_id = flex_geom_edgeid[gi] + + if edge_id >= 0: + edge = flex_edge[edge_id] + vert_adr = flex_vertadr[flexid] + v0 = flexvert_xpos_in[worldid, vert_adr + edge[0]] + v1 = flexvert_xpos_in[worldid, vert_adr + edge[1]] + pos = 0.5 * (v0 + v1) + vec = v1 - v0 + + length = wp.length(vec) + edgeq = math.quat_z2vec(vec) + mat = math.quat_to_mat(edgeq) + size = wp.vec3(flex_radius[flexid], 0.5 * length, 0.0) + + d, n = ray_capsule(pos, mat, size, ray_origin_world, ray_dir_world) + else: + hit = ray_flex_with_bvh_anyhit( + flex_bvh_id, + flexid, + flex_group_root[worldid, flexid], + ray_origin_world, + ray_dir_world, + max_dist, + ) + d = 0.0 if hit else -1.0 if d >= 0.0 and d < max_dist: return True @@ -319,18 +412,27 @@ def compute_lighting( geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: use_shadows: bool, bvh_id: wp.uint64, group_root: int, bvh_ngeom: int, + bvh_nflexgeom: int, enabled_geom_ids: wp.array(dtype=int), - world_id: int, + worldid: int, mesh_bvh_id: wp.array(dtype=wp.uint64), hfield_bvh_id: wp.array(dtype=wp.uint64), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), lightactive: bool, lighttype: int, lightcastshadow: bool, @@ -381,15 +483,24 @@ def compute_lighting( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, bvh_id, group_root, - world_id, + worldid, bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, shadow_origin, L, max_t, @@ -431,6 +542,9 @@ def _render_megakernel( light_type: wp.array2d(dtype=int), light_castshadow: wp.array2d(dtype=bool), light_active: wp.array2d(dtype=bool), + flex_vertadr: wp.array(dtype=int), + flex_edge: wp.array(dtype=wp.vec2i), + flex_radius: wp.array(dtype=float), mesh_faceadr: wp.array(dtype=int), mat_texid: wp.array3d(dtype=int), mat_texrepeat: wp.array2d(dtype=wp.vec2), @@ -442,10 +556,12 @@ def _render_megakernel( cam_xmat_in: wp.array2d(dtype=wp.mat33), light_xpos_in: wp.array2d(dtype=wp.vec3), light_xdir_in: wp.array2d(dtype=wp.vec3), + flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: nrender: int, use_shadows: bool, bvh_ngeom: int, + bvh_nflexgeom: int, cam_res: wp.array(dtype=wp.vec2i), cam_id_map: wp.array(dtype=int), ray: wp.array(dtype=wp.vec3), @@ -457,8 +573,8 @@ def _render_megakernel( render_seg: wp.array(dtype=bool), bvh_id: wp.uint64, group_root: wp.array(dtype=int), - flex_bvh_id: wp.uint64, - flex_group_root: wp.array(dtype=int), + flex_bvh_id: wp.array(dtype=wp.uint64), + flex_group_root: wp.array2d(dtype=int), enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), mesh_facetexcoord: wp.array(dtype=wp.vec3i), @@ -466,26 +582,28 @@ def _render_megakernel( mesh_texcoord_offsets: wp.array(dtype=int), hfield_bvh_id: wp.array(dtype=wp.uint64), flex_rgba: wp.array(dtype=wp.vec4), + flex_geom_flexid: wp.array(dtype=int), + flex_geom_edgeid: wp.array(dtype=int), textures: wp.array(dtype=wp.Texture2D), # Out: rgb_out: wp.array2d(dtype=wp.uint32), depth_out: wp.array2d(dtype=float), seg_out: wp.array2d(dtype=int), ): - world_idx, ray_idx = wp.tid() + worldid, rayid = wp.tid() - # Map global ray_idx -> (cam_idx, ray_idx_local) using cumulative sizes + # Map global rayid -> (cam_idx, rayid_local) using cumulative sizes cam_idx = int(-1) - ray_idx_local = int(-1) + rayid_local = int(-1) accum = int(0) for i in range(nrender): num_i = cam_res[i][0] * cam_res[i][1] - if ray_idx < accum + num_i: + if rayid < accum + num_i: cam_idx = i - ray_idx_local = ray_idx - accum + rayid_local = rayid - accum break accum += num_i - if cam_idx == -1 or ray_idx_local < 0: + if cam_idx == -1 or rayid_local < 0: return if not render_rgb[cam_idx] and not render_depth[cam_idx] and not render_seg[cam_idx]: @@ -495,17 +613,17 @@ def _render_megakernel( mujoco_cam_id = cam_id_map[cam_idx] if wp.static(rc.use_precomputed_rays): - ray_dir_local_cam = ray[ray_idx] + ray_dir_local_cam = ray[rayid] else: img_w = cam_res[cam_idx][0] img_h = cam_res[cam_idx][1] - px = ray_idx_local % img_w - py = ray_idx_local // img_w + px = rayid_local % img_w + py = rayid_local // img_w ray_dir_local_cam = compute_ray( cam_projection[mujoco_cam_id], - cam_fovy[world_idx % cam_fovy.shape[0], mujoco_cam_id], + cam_fovy[worldid % cam_fovy.shape[0], mujoco_cam_id], cam_sensorsize[mujoco_cam_id], - cam_intrinsic[world_idx % cam_intrinsic.shape[0], mujoco_cam_id], + cam_intrinsic[worldid % cam_intrinsic.shape[0], mujoco_cam_id], img_w, img_h, px, @@ -513,41 +631,37 @@ def _render_megakernel( wp.static(rc.znear), ) - ray_dir_world = cam_xmat_in[world_idx, mujoco_cam_id] @ ray_dir_local_cam - ray_origin_world = cam_xpos_in[world_idx, mujoco_cam_id] + ray_dir_world = cam_xmat_in[worldid, mujoco_cam_id] @ ray_dir_local_cam + ray_origin_world = cam_xpos_in[worldid, mujoco_cam_id] geom_id, dist, normal, u, v, f, mesh_id = cast_ray( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, bvh_id, - group_root[world_idx], - world_idx, + group_root[worldid], + worldid, bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, ray_origin_world, ray_dir_world, ) - if wp.static(m.nflex > 0): - d, n, u, v, f = ray_flex_with_bvh( - flex_bvh_id, - flex_group_root[world_idx], - ray_origin_world, - ray_dir_world, - dist, - ) - if d >= 0.0 and d < dist: - dist = d - normal = n - geom_id = -2 - if render_seg[cam_idx] and geom_id != -1: - seg_out[world_idx, seg_adr[cam_idx] + ray_idx_local] = geom_id + seg_out[worldid, seg_adr[cam_idx] + rayid_local] = geom_id # Early Out if geom_id == -1: @@ -558,7 +672,7 @@ def _render_megakernel( # In camera-local coordinates, the optical axis is -Z. The Z-component of the # normalized ray direction is negative, so -ray_dir_local_cam[2] gives cos(θ) # between the ray and the optical axis. - depth_out[world_idx, depth_adr[cam_idx] + ray_idx_local] = dist * (-ray_dir_local_cam[2]) + depth_out[worldid, depth_adr[cam_idx] + rayid_local] = dist * (-ray_dir_local_cam[2]) if not render_rgb[cam_idx]: return @@ -567,31 +681,30 @@ def _render_megakernel( hit_point = ray_origin_world + ray_dir_world * dist if geom_id == -2: - # TODO: Currently flex textures are not supported, and only the first rgba value - # is used until further flex support is added. - color = flex_rgba[0] - elif geom_matid[world_idx % geom_matid.shape[0], geom_id] == -1: - color = geom_rgba[world_idx % geom_rgba.shape[0], geom_id] + # We encode flex_id in mesh_id for flex ray hits during cast_ray + color = flex_rgba[mesh_id] + elif geom_matid[worldid % geom_matid.shape[0], geom_id] == -1: + color = geom_rgba[worldid % geom_rgba.shape[0], geom_id] else: - color = mat_rgba[world_idx % mat_rgba.shape[0], geom_matid[world_idx % geom_matid.shape[0], geom_id]] + color = mat_rgba[worldid % mat_rgba.shape[0], geom_matid[worldid % geom_matid.shape[0], geom_id]] base_color = wp.vec3(color[0], color[1], color[2]) hit_color = base_color if wp.static(rc.use_textures): if geom_id != -2: - mat_id = geom_matid[world_idx % geom_matid.shape[0], geom_id] + mat_id = geom_matid[worldid % geom_matid.shape[0], geom_id] if mat_id >= 0: - tex_id = mat_texid[world_idx % mat_texid.shape[0], mat_id, 1] + tex_id = mat_texid[worldid % mat_texid.shape[0], mat_id, 1] if tex_id >= 0: tex_color = sample_texture( geom_type, mesh_faceadr, geom_id, - mat_texrepeat[world_idx % mat_texrepeat.shape[0], mat_id], + mat_texrepeat[worldid % mat_texrepeat.shape[0], mat_id], textures[tex_id], - geom_xpos_in[world_idx, geom_id], - geom_xmat_in[world_idx, geom_id], + geom_xpos_in[worldid, geom_id], + geom_xmat_in[worldid, geom_id], mesh_facetexcoord, mesh_texcoord, mesh_texcoord_offsets, @@ -616,21 +729,30 @@ def _render_megakernel( geom_type, geom_dataid, geom_size, + flex_vertadr, + flex_edge, + flex_radius, geom_xpos_in, geom_xmat_in, + flexvert_xpos_in, use_shadows, bvh_id, - group_root[world_idx], + group_root[worldid], bvh_ngeom, + bvh_nflexgeom, enabled_geom_ids, - world_idx, + worldid, mesh_bvh_id, hfield_bvh_id, - light_active[world_idx % light_active.shape[0], l], - light_type[world_idx % light_type.shape[0], l], - light_castshadow[world_idx % light_castshadow.shape[0], l], - light_xpos_in[world_idx, l], - light_xdir_in[world_idx, l], + flex_geom_flexid, + flex_geom_edgeid, + flex_bvh_id, + flex_group_root, + light_active[worldid % light_active.shape[0], l], + light_type[worldid % light_type.shape[0], l], + light_castshadow[worldid % light_castshadow.shape[0], l], + light_xpos_in[worldid, l], + light_xdir_in[worldid, l], normal, hit_point, ) @@ -639,7 +761,7 @@ def _render_megakernel( hit_color = wp.min(result, wp.vec3(1.0, 1.0, 1.0)) hit_color = wp.max(hit_color, wp.vec3(0.0, 0.0, 0.0)) - rgb_out[world_idx, rgb_adr[cam_idx] + ray_idx_local] = pack_rgba_to_uint32( + rgb_out[worldid, rgb_adr[cam_idx] + rayid_local] = pack_rgba_to_uint32( hit_color[0] * 255.0, hit_color[1] * 255.0, hit_color[2] * 255.0, @@ -662,6 +784,9 @@ def _render_megakernel( m.light_type, m.light_castshadow, m.light_active, + m.flex_vertadr, + m.flex_edge, + m.flex_radius, m.mesh_faceadr, m.mat_texid, m.mat_texrepeat, @@ -672,9 +797,11 @@ def _render_megakernel( d.cam_xmat, d.light_xpos, d.light_xdir, + d.flexvert_xpos, rc.nrender, rc.use_shadows, rc.bvh_ngeom, + rc.bvh_nflexgeom, rc.cam_res, rc.cam_id_map, rc.ray, @@ -695,6 +822,8 @@ def _render_megakernel( rc.mesh_texcoord_offsets, rc.hfield_bvh_id, rc.flex_rgba, + rc.flex_geom_flexid, + rc.flex_geom_edgeid, rc.textures, ], outputs=[ diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 0254ec21f..7ffe73878 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -126,94 +126,113 @@ def _kinematics_branch( jntadr = body_jntadr[bodyid] jntnum = body_jntnum[bodyid] + # Check for freejoint — handled separately because it reads position and + # quaternion directly from qpos rather than composing with the parent + # transform. We use an integer flag instead of ``continue`` because + # Warp's AD replay for ``continue`` inside a dynamic for-loop emits a + # goto that skips all adjoint code for that iteration, zeroing gradients. + is_free = int(0) if jntnum == 1: jnt_type_ = jnt_type[jntadr] if jnt_type_ == JointType.FREE: - qadr = jnt_qposadr[jntadr] - xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) - xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) - xquat = wp.normalize(xquat) - - xpos_out[worldid, bodyid] = xpos - xquat_out[worldid, bodyid] = xquat - xanchor_out[worldid, jntadr] = xpos - xaxis_out[worldid, jntadr] = jnt_axis[worldid % jnt_axis.shape[0], jntadr] - continue + is_free = int(1) - # regular or no joints - # apply fixed translation and rotation relative to parent - jnt_pos_id = worldid % jnt_pos.shape[0] - pid = body_parentid[bodyid] + if is_free == int(1): + qadr = jnt_qposadr[jntadr] + xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) + xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) + xquat = wp.normalize(xquat) - # mocap bodies have world body as parent - mocapid = body_mocapid[bodyid] - if mocapid >= 0: - xpos = mocap_pos_in[worldid, mocapid] - xquat = mocap_quat_in[worldid, mocapid] + xanchor_out[worldid, jntadr] = xpos + xaxis_out[worldid, jntadr] = jnt_axis[worldid % jnt_axis.shape[0], jntadr] else: - xpos = body_pos[worldid % body_pos.shape[0], bodyid] - xquat = body_quat[worldid % body_quat.shape[0], bodyid] - - if pid >= 0: - xpos = math.rot_vec_quat(xpos, xquat_out[worldid, pid]) + xpos_out[worldid, pid] - xquat = math.mul_quat(xquat_out[worldid, pid], xquat) + # regular or no joints + # apply fixed translation and rotation relative to parent + jnt_pos_id = worldid % jnt_pos.shape[0] + pid = body_parentid[bodyid] + + # mocap bodies have world body as parent + mocapid = body_mocapid[bodyid] + if mocapid >= 0: + xpos = mocap_pos_in[worldid, mocapid] + xquat = mocap_quat_in[worldid, mocapid] + else: + xpos = body_pos[worldid % body_pos.shape[0], bodyid] + xquat = body_quat[worldid % body_quat.shape[0], bodyid] + + if pid >= 0: + xpos = math.rot_vec_quat(xpos, xquat_out[worldid, pid]) + xpos_out[worldid, pid] + xquat = math.mul_quat(xquat_out[worldid, pid], xquat) + + # Unrolled joint processing — avoids nested dynamic-range loop which + # produces incorrect gradients in Warp's AD. + if jntnum >= 1: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 2: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 1, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 3: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 2, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) + if jntnum >= 4: + xpos, xquat = _process_joint( + xpos, + xquat, + jntadr + 3, + jnt_pos_id, + worldid, + qpos0, + jnt_type, + jnt_qposadr, + jnt_pos, + jnt_axis, + qpos, + xanchor_out, + xaxis_out, + ) - # Unrolled joint processing — avoids nested dynamic-range loop which - # produces incorrect gradients in Warp's AD. - if jntnum >= 1: - xpos, xquat = _process_joint( - xpos, xquat, jntadr, jnt_pos_id, worldid, qpos0, jnt_type, jnt_qposadr, jnt_pos, jnt_axis, qpos, xanchor_out, xaxis_out - ) - if jntnum >= 2: - xpos, xquat = _process_joint( - xpos, - xquat, - jntadr + 1, - jnt_pos_id, - worldid, - qpos0, - jnt_type, - jnt_qposadr, - jnt_pos, - jnt_axis, - qpos, - xanchor_out, - xaxis_out, - ) - if jntnum >= 3: - xpos, xquat = _process_joint( - xpos, - xquat, - jntadr + 2, - jnt_pos_id, - worldid, - qpos0, - jnt_type, - jnt_qposadr, - jnt_pos, - jnt_axis, - qpos, - xanchor_out, - xaxis_out, - ) - if jntnum >= 4: - xpos, xquat = _process_joint( - xpos, - xquat, - jntadr + 3, - jnt_pos_id, - worldid, - qpos0, - jnt_type, - jnt_qposadr, - jnt_pos, - jnt_axis, - qpos, - xanchor_out, - xaxis_out, - ) + xquat = wp.normalize(xquat) - xquat = wp.normalize(xquat) xpos_out[worldid, bodyid] = xpos xquat_out[worldid, bodyid] = xquat @@ -2151,20 +2170,18 @@ def _comvel_branch( jntid = body_jntadr[bodyid] jntnum = body_jntnum[bodyid] - if jntnum == 0: - cvel_out[worldid, bodyid] = cvel - continue - - # unrolled joint processing — avoids nested dynamic-range loop which - # produces incorrect gradients in warp's AD + # Use if/else instead of ``continue`` — Warp's AD replay for + # ``continue`` inside a dynamic for-loop skips adjoint code. if jntnum >= 1: + # unrolled joint processing — avoids nested dynamic-range loop which + # produces incorrect gradients in warp's AD cvel, dofid = _process_joint_vel(cvel, dofid, jntid, worldid, jnt_type, qvel, cdof, cdof_dot_out) - if jntnum >= 2: - cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 1, worldid, jnt_type, qvel, cdof, cdof_dot_out) - if jntnum >= 3: - cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 2, worldid, jnt_type, qvel, cdof, cdof_dot_out) - if jntnum >= 4: - cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 3, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 2: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 1, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 3: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 2, worldid, jnt_type, qvel, cdof, cdof_dot_out) + if jntnum >= 4: + cvel, dofid = _process_joint_vel(cvel, dofid, jntid + 3, worldid, jnt_type, qvel, cdof, cdof_dot_out) cvel_out[worldid, bodyid] = cvel diff --git a/mujoco_warp/_src/solver.py b/mujoco_warp/_src/solver.py index 58a255c55..6359c4394 100644 --- a/mujoco_warp/_src/solver.py +++ b/mujoco_warp/_src/solver.py @@ -2066,8 +2066,8 @@ def kernel( gauss_cost += (efc_Ma_in[worldid, ii] - qfrc_smooth_in[worldid, ii]) * ( qacc_in[worldid, ii] - qacc_smooth_in[worldid, ii] ) - wp.atomic_add(ctx_gauss_out, worldid, gauss_cost) - wp.atomic_add(ctx_cost_out, worldid, gauss_cost) + wp.atomic_add(ctx_gauss_out, worldid, 0.5 * gauss_cost) + wp.atomic_add(ctx_cost_out, worldid, 0.5 * gauss_cost) return kernel diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index d9a70490b..8a6f40fd0 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -18,18 +18,52 @@ import warp as wp from mujoco_warp._src.math import motion_cross +from mujoco_warp._src.types import MJ_MINVAL from mujoco_warp._src.types import ConeType from mujoco_warp._src.types import Data +from mujoco_warp._src.types import DynType from mujoco_warp._src.types import JointType from mujoco_warp._src.types import Model from mujoco_warp._src.types import State from mujoco_warp._src.types import vec5 +from mujoco_warp._src.types import vec10f from mujoco_warp._src.warp_util import cache_kernel from mujoco_warp._src.warp_util import event_scope wp.set_module_options({"enable_backward": False}) +# TODO(team): kernel analyzer array slice? +@wp.func +def next_act( + # Model: + opt_timestep: float, # kernel_analyzer: ignore + actuator_dyntype: int, # kernel_analyzer: ignore + actuator_dynprm: vec10f, # kernel_analyzer: ignore + actuator_actrange: wp.vec2, # kernel_analyzer: ignore + # Data In: + act_in: float, # kernel_analyzer: ignore + act_dot_in: float, # kernel_analyzer: ignore + # In: + act_dot_scale: float, + clamp: bool, +) -> float: + # advance actuation + if actuator_dyntype == DynType.FILTEREXACT: + tau = wp.max(MJ_MINVAL, actuator_dynprm[0]) + act = act_in + act_dot_scale * act_dot_in * tau * (1.0 - wp.exp(-opt_timestep / tau)) + elif actuator_dyntype == DynType.USER: + return act_in + else: + act = act_in + act_dot_scale * act_dot_in * opt_timestep + + # clamp to actrange + if clamp: + act = wp.clamp(act, actuator_actrange[0], actuator_actrange[1]) + + return act + + @cache_kernel def mul_m_sparse(check_skip: bool): @wp.kernel(module="unique") diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 21ac957bb..1eadb776c 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -344,6 +344,7 @@ class GeomType(enum.IntEnum): BOX: box MESH: mesh SDF: sdf + FLEX: flex """ PLANE = mujoco.mjtGeom.mjGEOM_PLANE @@ -355,6 +356,7 @@ class GeomType(enum.IntEnum): BOX = mujoco.mjtGeom.mjGEOM_BOX MESH = mujoco.mjtGeom.mjGEOM_MESH SDF = mujoco.mjtGeom.mjGEOM_SDF + FLEX = mujoco.mjtGeom.mjGEOM_FLEX # unsupported: NGEOMTYPES, ARROW*, LINE, SKIN, LABEL, NONE @@ -980,7 +982,8 @@ class Model: flex_edgenum: number of edges (nflex,) flex_elemadr: first element address (nflex,) flex_elemnum: number of elements (nflex,) - flex_elemedgeadr: first element address (nflex,) + flex_elemdataadr: first element vertex id address (nflex,) + flex_elemedgeadr: first element edge id address (nflex,) flex_shellnum: number of shells (nflex,) flex_shelldataadr: first shell data address (nflex,) flex_vertbodyid: vertex body ids (nflexvert,) @@ -1366,6 +1369,7 @@ class Model: flex_edgenum: array("nflex", int) flex_elemadr: array("nflex", int) flex_elemnum: array("nflex", int) + flex_elemdataadr: array("nflex", int) flex_elemedgeadr: array("nflex", int) flex_shellnum: array("nflex", int) flex_shelldataadr: array("nflex", int) @@ -1773,6 +1777,9 @@ class Data: njmax_nnz: number of non-zeros in constraint Jacobian nacon: number of detected contacts (across all worlds) (1,) ncollision: collision count from broadphase (1,) + solver_h: solver retained Hessian for backward pass + solver_hfactor: solver retained factored Hessian for backward pass + solver_Jaref: solver retained Jacobian reference for backward pass """ solver_niter: array("nworld", int) @@ -1898,21 +1905,12 @@ class RenderContext: hfield_registry: hfield BVH id to warp mesh mapping hfield_bvh_id: hfield BVH ids hfield_bounds_size: hfield bounds half-extents - flex_mesh: flex mesh + flex_mesh_registry: per-flex mesh BVH registry (prevents garbage collection) flex_rgba: flex rgba - flex_bvh_id: flex BVH id - flex_face_point: flex face points - flex_faceadr: flex face addresses - flex_nface: number of flex faces - flex_nwork: total flex work items for refit - flex_group_root: flex group roots - flex_elemdataadr: flex element data addresses - flex_shell: flex shell data - flex_shelldataadr: flex shell data addresses - flex_radius: flex radius - flex_workadr: flex work item addresses for refit - flex_worknum: flex work item counts for refit + flex_bvh_id: per-flex BVH ids + flex_group_root: per-flex group roots (nworld x n_flex_bvh) flex_render_smooth: whether to render flex meshes smoothly + flex_dim: flex dimension per flex (1D/2D/3D) bvh: scene BVH bvh_id: scene BVH id lower: lower bounds @@ -1922,10 +1920,8 @@ class RenderContext: ray: rays rgb_data: RGB data rgb_adr: RGB addresses - rgb_size: per-camera RGB buffer sizes depth_data: depth data depth_adr: depth addresses - depth_size: per-camera depth buffer sizes render_rgb: per-camera RGB render flags render_depth: per-camera depth render flags seg_data: segmentation data (per-pixel geom IDs) @@ -1955,21 +1951,15 @@ class RenderContext: hfield_registry: dict hfield_bvh_id: array("nhfield", wp.uint64) hfield_bounds_size: array("nhfield", wp.vec3) - flex_mesh: wp.Mesh + flex_mesh_registry: dict flex_rgba: array("nflex", wp.vec4) - flex_bvh_id: wp.uint64 - flex_face_point: array("*", wp.vec3) - flex_faceadr: array("nflex", int) - flex_nface: int - flex_nwork: int - flex_group_root: array("nworld", int) - flex_elemdataadr: array("nflex", int) - flex_shell: array("*", int) - flex_shelldataadr: array("nflex", int) - flex_radius: array("nflex", float) - flex_workadr: array("nflex", int) - flex_worknum: array("nflex", int) + flex_bvh_id: array("*", wp.uint64) + flex_group_root: array("nworld", "*", int) flex_render_smooth: bool + bvh_nflexgeom: int + flex_dim_np: array("nflex", int) + flex_geom_flexid: array("*", int) + flex_geom_edgeid: array("*", int) bvh: wp.Bvh bvh_id: wp.uint64 lower: array("*", wp.vec3) diff --git a/mujoco_warp/test_data/flex/floppy.xml b/mujoco_warp/test_data/flex/floppy.xml index dfdb973dc..fa5fa6d63 100644 --- a/mujoco_warp/test_data/flex/floppy.xml +++ b/mujoco_warp/test_data/flex/floppy.xml @@ -27,6 +27,7 @@ + diff --git a/mujoco_warp/test_data/flex/multiflex.xml b/mujoco_warp/test_data/flex/multiflex.xml new file mode 100644 index 000000000..bd7423d6c --- /dev/null +++ b/mujoco_warp/test_data/flex/multiflex.xml @@ -0,0 +1,42 @@ + + diff --git a/mujoco_warp/test_data/flex/rope.xml b/mujoco_warp/test_data/flex/rope.xml new file mode 100644 index 000000000..8f07a1611 --- /dev/null +++ b/mujoco_warp/test_data/flex/rope.xml @@ -0,0 +1,31 @@ + + diff --git a/mujoco_warp/testspeed.py b/mujoco_warp/testspeed.py index bae97bf99..67d98a12c 100644 --- a/mujoco_warp/testspeed.py +++ b/mujoco_warp/testspeed.py @@ -18,7 +18,7 @@ Usage: mjwarp-testspeed [flags] Example: - mjwarp-testspeed benchmark/humanoid/humanoid.xml --nworld 4096 -o "opt.solver=cg" + mjwarp-testspeed benchmarks/humanoid/humanoid.xml --nworld 4096 -o "opt.solver=cg" """ import dataclasses diff --git a/mujoco_warp/viewer.py b/mujoco_warp/viewer.py index 28797ac84..5659ad6b8 100644 --- a/mujoco_warp/viewer.py +++ b/mujoco_warp/viewer.py @@ -18,7 +18,7 @@ Usage: mjwarp-viewer [flags] Example: - mjwarp-viewer benchmark/humanoid/humanoid.xml -o "opt.solver=cg" + mjwarp-viewer benchmarks/humanoid/humanoid.xml -o "opt.solver=cg" """ import copy