diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index 3959a99d7..b957d115a 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -22,11 +22,24 @@ from mujoco_warp._src.types import ContactType from mujoco_warp._src.types import vec5 from mujoco_warp._src.types import vec11 +from mujoco_warp._src.warp_util import cache_kernel from mujoco_warp._src.warp_util import event_scope wp.set_module_options({"enable_backward": False}) +def _reinterpret(arr, dtype, shape): + """Reinterpret array memory as a different dtype/shape (zero-copy view).""" + # This allows a better memory access pattern for certain usecases + return wp.array( + ptr=arr.ptr, + dtype=dtype, + shape=shape, + device=arr.device, + copy=False, + ) + + @wp.kernel def _zero_constraint_counts( # Data out: @@ -121,10 +134,8 @@ def _efc_equality_connect( nv: int, nsite: int, opt_timestep: wp.array(dtype=float), - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), @@ -132,6 +143,7 @@ def _efc_equality_connect( eq_solref: wp.array2d(dtype=wp.vec2), eq_solimp: wp.array2d(dtype=vec5), eq_data: wp.array2d(dtype=vec11), + dof_affects_body: wp.array2d(dtype=int), eq_connect_adr: wp.array(dtype=int), # Data in: qvel_in: wp.array2d(dtype=float), @@ -196,9 +208,8 @@ def _efc_equality_connect( Jqvel = wp.vec3f(0.0, 0.0, 0.0) for dofid in range(nv): # TODO: parallelize jacp1, _ = support.jac( - body_parentid, body_rootid, - dof_bodyid, + dof_affects_body, subtree_com_in, cdof_in, pos1, @@ -207,9 +218,8 @@ def _efc_equality_connect( worldid, ) jacp2, _ = support.jac( - body_parentid, body_rootid, - dof_bodyid, + dof_affects_body, subtree_com_in, cdof_in, pos2, @@ -725,10 +735,8 @@ def _efc_equality_weld( nv: int, nsite: int, opt_timestep: wp.array(dtype=float), - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), site_quat: wp.array2d(dtype=wp.quat), eq_obj1id: wp.array(dtype=int), @@ -737,6 +745,7 @@ def _efc_equality_weld( eq_solref: wp.array2d(dtype=wp.vec2), eq_solimp: wp.array2d(dtype=vec5), eq_data: wp.array2d(dtype=vec11), + dof_affects_body: wp.array2d(dtype=int), eq_wld_adr: wp.array(dtype=int), # Data in: qvel_in: wp.array2d(dtype=float), @@ -812,9 +821,8 @@ def _efc_equality_weld( for dofid in range(nv): # TODO: parallelize jacp1, jacr1 = support.jac( - body_parentid, body_rootid, - dof_bodyid, + dof_affects_body, subtree_com_in, cdof_in, pos1, @@ -823,9 +831,8 @@ def _efc_equality_weld( worldid, ) jacp2, jacr2 = support.jac( - body_parentid, body_rootid, - dof_bodyid, + dof_affects_body, subtree_com_in, cdof_in, pos2, @@ -1210,373 +1217,297 @@ def _efc_limit_tendon( ) -@wp.kernel -def _efc_contact_pyramidal( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_impratio_invsqrt: wp.array(dtype=float), - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - body_dofnum: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), - dof_parentid: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - njmax_in: int, - nacon_in: wp.array(dtype=int), - # In: - refsafe_in: int, - dist_in: wp.array(dtype=float), - condim_in: wp.array(dtype=int), - includemargin_in: wp.array(dtype=float), - worldid_in: wp.array(dtype=int), - geom_in: wp.array(dtype=wp.vec2i), - pos_in: wp.array(dtype=wp.vec3), - frame_in: wp.array(dtype=wp.mat33), - friction_in: wp.array(dtype=vec5), - solref_in: wp.array(dtype=wp.vec2), - solimp_in: wp.array(dtype=vec5), - type_in: wp.array(dtype=int), - # Data out: - nefc_out: wp.array(dtype=int), - contact_efc_address_out: wp.array2d(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), -): - conid, dimid = wp.tid() +@cache_kernel +def _efc_contact_init(cone_type: types.ConeType): + IS_ELLIPTIC = cone_type == types.ConeType.ELLIPTIC + + @wp.kernel(module="unique", enable_backward=False) + def kernel( + # Data in: + njmax_in: int, + nacon_in: wp.array(dtype=int), + # In: + dist_in: wp.array(dtype=float), + condim_in: wp.array(dtype=int), + includemargin_in: wp.array(dtype=float), + worldid_in: wp.array(dtype=int), + type_in: wp.array(dtype=int), + # Data out: + nefc_out: wp.array(dtype=int), + contact_efc_address_out: wp.array2d(dtype=int), + efc_conid_out: wp.array2d(dtype=int), + ): + conid = wp.tid() + + if conid >= nacon_in[0]: + return - if conid >= nacon_in[0]: - return + if not type_in[conid] & ContactType.CONSTRAINT: + return - if not type_in[conid] & ContactType.CONSTRAINT: - return + condim = condim_in[conid] - condim = condim_in[conid] + if wp.static(IS_ELLIPTIC): + nrows = condim + else: + nrows = 1 + if condim > 1: + nrows = 2 * (condim - 1) - if condim == 1 and dimid > 0: - return - elif condim > 1 and dimid >= 2 * (condim - 1): - return + includemargin = includemargin_in[conid] + pos = dist_in[conid] - includemargin + active = pos < 0 - includemargin = includemargin_in[conid] - pos = dist_in[conid] - includemargin - active = pos < 0 + if not active: + for dimid in range(nrows): + contact_efc_address_out[conid, dimid] = -1 + return - if active: worldid = worldid_in[conid] - efcid = wp.atomic_add(nefc_out, worldid, 1) - if efcid >= njmax_in: - contact_efc_address_out[conid, dimid] = -1 + base_efcid = wp.atomic_add(nefc_out, worldid, nrows) + + if base_efcid + nrows > njmax_in: + for dimid in range(nrows): + contact_efc_address_out[conid, dimid] = -1 return - timestep = opt_timestep[worldid % opt_timestep.shape[0]] - impratio_invsqrt = opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] - contact_efc_address_out[conid, dimid] = efcid + for dimid in range(nrows): + efcid = base_efcid + dimid + contact_efc_address_out[conid, dimid] = efcid + efc_conid_out[worldid, efcid] = conid + + return kernel + + +@cache_kernel +def _efc_contact_jac_tiled(tile_size: int, cone_type: types.ConeType): + TILE_SIZE = tile_size + IS_ELLIPTIC = cone_type == types.ConeType.ELLIPTIC + + @wp.kernel(module="unique", enable_backward=False) + def kernel( + # Model: + body_rootid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + dof_affects_body: wp.array2d(dtype=int), + # Data in: + ne_in: wp.array(dtype=int), + nf_in: wp.array(dtype=int), + nl_in: wp.array(dtype=int), + nefc_in: wp.array(dtype=int), + qvel_in: wp.array2d(dtype=float), + subtree_com_in: wp.array2d(dtype=wp.vec3), + cdof_in: wp.array2d(dtype=wp.spatial_vector), + contact_efc_address_in: wp.array2d(dtype=int), + efc_conid_in: wp.array2d(dtype=int), + njmax_in: int, + # In: + nv_padded: int, + condim_in: wp.array(dtype=int), + geom_in: wp.array(dtype=wp.vec2i), + pos_in: wp.array(dtype=wp.vec3), + frame_in: wp.array2d(dtype=wp.vec3), + friction_in: wp.array2d(dtype=float), + # Data out: + efc_J_out: wp.array3d(dtype=float), + efc_Jqvel_out: wp.array2d(dtype=float), + ): + worldid, dof_block_id, tid = wp.tid() + + dof_start = dof_block_id * TILE_SIZE + if dof_start >= nv_padded: + return - geom = geom_in[conid] - body1 = geom_bodyid[geom[0]] - body2 = geom_bodyid[geom[1]] + cdof_tile = wp.tile_load(cdof_in[worldid], shape=TILE_SIZE, offset=dof_start, bounds_check=True) + qvel_tile = wp.tile_load(qvel_in[worldid], shape=TILE_SIZE, offset=dof_start, bounds_check=True) - con_pos = pos_in[conid] - frame = frame_in[conid] + efcid_start = ne_in[worldid] + nf_in[worldid] + nl_in[worldid] + efcid_end = wp.min(nefc_in[worldid], njmax_in) - # pyramidal has common invweight across all edges - body_invweight0_id = worldid % body_invweight0.shape[0] - invweight = body_invweight0[body_invweight0_id, body1][0] + body_invweight0[body_invweight0_id, body2][0] + prev_conid = int(-1) + condim = int(0) - if condim > 1: - dimid2 = dimid / 2 + 1 + for efcid in range(efcid_start, efcid_end): + conid = efc_conid_in[worldid, efcid] - friction = friction_in[conid] - fri0 = friction[0] - frii = friction[dimid2 - 1] - invweight = invweight + fri0 * fri0 * invweight - invweight = invweight * 2.0 * fri0 * fri0 * impratio_invsqrt * impratio_invsqrt + if conid != prev_conid: + prev_conid = conid + condim = condim_in[conid] - Jqvel = float(0.0) + geom = geom_in[conid] + body1 = geom_bodyid[geom[0]] + body2 = geom_bodyid[geom[1]] - # skip fixed bodies - body1 = body_weldid[body1] - body2 = body_weldid[body2] - - da1 = body_dofadr[body1] + body_dofnum[body1] - 1 - da2 = body_dofadr[body2] + body_dofnum[body2] - 1 - da = wp.max(da1, da2) - - for dofid in range(nv - 1, -1, -1): - if dofid == da: - # TODO(team): contact_jacobian - jac1p, jac1r = support.jac( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - con_pos, - body1, - dofid, - worldid, - ) - jac2p, jac2r = support.jac( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - con_pos, - body2, - dofid, - worldid, - ) + con_pos = pos_in[conid] + offset1 = con_pos - subtree_com_in[worldid, body_rootid[body1]] + offset2 = con_pos - subtree_com_in[worldid, body_rootid[body2]] - J = float(0.0) - Ji = float(0.0) - if condim > 1: - dimid2 = dimid / 2 + 1 + affects1_tile = wp.tile_load(dof_affects_body[body1], shape=TILE_SIZE, offset=dof_start, bounds_check=False) + affects2_tile = wp.tile_load(dof_affects_body[body2], shape=TILE_SIZE, offset=dof_start, bounds_check=False) - for xyz in range(3): - jacp_dif = jac2p[xyz] - jac1p[xyz] - J += frame[0, xyz] * jacp_dif + jacp1_tile = wp.tile_map(support._compute_jacp, cdof_tile, offset1, affects1_tile) + jacp2_tile = wp.tile_map(support._compute_jacp, cdof_tile, offset2, affects2_tile) + jacp_dif_tile = wp.tile_map(wp.sub, jacp2_tile, jacp1_tile) - if condim > 1: - if dimid2 < 3: - Ji += frame[dimid2, xyz] * jacp_dif - else: - Ji += frame[dimid2 - 3, xyz] * (jac2r[xyz] - jac1r[xyz]) + jacr1_tile = wp.tile_map(support._compute_jacr, cdof_tile, affects1_tile) + jacr2_tile = wp.tile_map(support._compute_jacr, cdof_tile, affects2_tile) + jacr_dif_tile = wp.tile_map(wp.sub, jacr2_tile, jacr1_tile) - if condim > 1: - if dimid % 2 == 0: - J += Ji * frii - else: - J -= Ji * frii + base_efcid = contact_efc_address_in[conid, 0] - efc_J_out[worldid, efcid, dofid] = J - Jqvel += J * qvel_in[worldid, dofid] + if not wp.static(IS_ELLIPTIC): + frame_0 = frame_in[conid, 0] + Ji_0p_tile = wp.tile_map(wp.dot, jacp_dif_tile, frame_0) - # Advance tree pointers and recompute da for next iteration - if da1 == da: - da1 = dof_parentid[da1] - if da2 == da: - da2 = dof_parentid[da2] - da = wp.max(da1, da2) + if condim > 1: + Ji_0r_tile = wp.tile_map(wp.dot, jacr_dif_tile, frame_0) + frame_1 = frame_in[conid, 1] + Ji_1p_tile = wp.tile_map(wp.dot, jacp_dif_tile, frame_1) + Ji_1r_tile = wp.tile_map(wp.dot, jacr_dif_tile, frame_1) + frame_2 = frame_in[conid, 2] + Ji_2p_tile = wp.tile_map(wp.dot, jacp_dif_tile, frame_2) + Ji_2r_tile = wp.tile_map(wp.dot, jacr_dif_tile, frame_2) + + if wp.static(IS_ELLIPTIC): + dimid = efcid - base_efcid + if dimid < 3: + frame_idx = dimid + else: + frame_idx = dimid - 3 + + frame_row = frame_in[conid, frame_idx] + + if dimid < 3: + J_tile = wp.tile_map(wp.dot, jacp_dif_tile, frame_row) + else: + J_tile = wp.tile_map(wp.dot, jacr_dif_tile, frame_row) else: - efc_J_out[worldid, efcid, dofid] = 0.0 - - if condim == 1: - efc_type = ConstraintType.CONTACT_FRICTIONLESS - else: - efc_type = ConstraintType.CONTACT_PYRAMIDAL - - _update_efc_row( - worldid, - timestep, - refsafe_in, - efcid, - pos, - pos, - invweight, - solref_in[conid], - solimp_in[conid], - includemargin, - Jqvel, - 0.0, - efc_type, - conid, - efc_type_out, - efc_id_out, - efc_pos_out, - efc_margin_out, - efc_D_out, - efc_vel_out, - efc_aref_out, - efc_frictionloss_out, - ) - - -@wp.kernel -def _efc_contact_elliptic( - # Model: - nv: int, - opt_timestep: wp.array(dtype=float), - opt_impratio_invsqrt: wp.array(dtype=float), - body_parentid: wp.array(dtype=int), - body_rootid: wp.array(dtype=int), - body_weldid: wp.array(dtype=int), - body_dofnum: wp.array(dtype=int), - body_dofadr: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=wp.vec2), - dof_bodyid: wp.array(dtype=int), - dof_parentid: wp.array(dtype=int), - geom_bodyid: wp.array(dtype=int), - # Data in: - qvel_in: wp.array2d(dtype=float), - subtree_com_in: wp.array2d(dtype=wp.vec3), - cdof_in: wp.array2d(dtype=wp.spatial_vector), - njmax_in: int, - nacon_in: wp.array(dtype=int), - # In: - refsafe_in: int, - dist_in: wp.array(dtype=float), - condim_in: wp.array(dtype=int), - includemargin_in: wp.array(dtype=float), - worldid_in: wp.array(dtype=int), - geom_in: wp.array(dtype=wp.vec2i), - pos_in: wp.array(dtype=wp.vec3), - frame_in: wp.array(dtype=wp.mat33), - friction_in: wp.array(dtype=vec5), - solref_in: wp.array(dtype=wp.vec2), - solreffriction_in: wp.array(dtype=wp.vec2), - solimp_in: wp.array(dtype=vec5), - type_in: wp.array(dtype=int), - # Data out: - nefc_out: wp.array(dtype=int), - contact_efc_address_out: wp.array2d(dtype=int), - efc_type_out: wp.array2d(dtype=int), - efc_id_out: wp.array2d(dtype=int), - efc_J_out: wp.array3d(dtype=float), - efc_pos_out: wp.array2d(dtype=float), - efc_margin_out: wp.array2d(dtype=float), - efc_D_out: wp.array2d(dtype=float), - efc_vel_out: wp.array2d(dtype=float), - efc_aref_out: wp.array2d(dtype=float), - efc_frictionloss_out: wp.array2d(dtype=float), -): - conid, dimid = wp.tid() - - if conid >= nacon_in[0]: - return - - if not type_in[conid] & ContactType.CONSTRAINT: - return - - condim = condim_in[conid] - - if dimid > condim - 1: - return - - includemargin = includemargin_in[conid] - pos = dist_in[conid] - includemargin - active = pos < 0.0 - - if active: - worldid = worldid_in[conid] - - efcid = wp.atomic_add(nefc_out, worldid, 1) - if efcid >= njmax_in: - contact_efc_address_out[conid, dimid] = -1 + J_tile = Ji_0p_tile + if condim > 1: + dimid = efcid - base_efcid + dimid2 = dimid / 2 + 1 + frii = friction_in[conid, dimid2 - 1] + frii_sign = frii * (1.0 - 2.0 * float(dimid & 1)) + + if dimid2 == 1: + J_tile = wp.tile_map(wp.add, J_tile, wp.tile_map(wp.mul, Ji_1p_tile, frii_sign)) + elif dimid2 == 2: + J_tile = wp.tile_map(wp.add, J_tile, wp.tile_map(wp.mul, Ji_2p_tile, frii_sign)) + elif dimid2 == 3: + J_tile = wp.tile_map(wp.add, J_tile, wp.tile_map(wp.mul, Ji_0r_tile, frii_sign)) + elif dimid2 == 4: + J_tile = wp.tile_map(wp.add, J_tile, wp.tile_map(wp.mul, Ji_1r_tile, frii_sign)) + else: + J_tile = wp.tile_map(wp.add, J_tile, wp.tile_map(wp.mul, Ji_2r_tile, frii_sign)) + + wp.tile_store(efc_J_out[worldid, efcid], J_tile, offset=dof_start, bounds_check=True) + + Jqvel_tile = wp.tile_map(wp.mul, J_tile, qvel_tile) + Jqvel_tile = wp.tile_reduce(wp.add, Jqvel_tile) + if tid == 0: + wp.atomic_add(efc_Jqvel_out, worldid, efcid, Jqvel_tile[0]) + + return kernel + + +@cache_kernel +def _efc_contact_update(cone_type: types.ConeType): + IS_ELLIPTIC = cone_type == types.ConeType.ELLIPTIC + + @wp.kernel(module="unique", enable_backward=False) + def kernel( + # Model: + opt_timestep: wp.array(dtype=float), + opt_impratio_invsqrt: wp.array(dtype=float), + body_invweight0: wp.array2d(dtype=wp.vec2), + geom_bodyid: wp.array(dtype=int), + # Data in: + ne_in: wp.array(dtype=int), + nf_in: wp.array(dtype=int), + nl_in: wp.array(dtype=int), + nefc_in: wp.array(dtype=int), + contact_efc_address_in: wp.array2d(dtype=int), + efc_conid_in: wp.array2d(dtype=int), + efc_Jqvel_in: wp.array2d(dtype=float), + # In: + refsafe_in: int, + condim_in: wp.array(dtype=int), + includemargin_in: wp.array(dtype=float), + dist_in: wp.array(dtype=float), + geom_in: wp.array(dtype=wp.vec2i), + friction_in: wp.array2d(dtype=float), + solref_in: wp.array(dtype=wp.vec2), + solreffriction_in: wp.array(dtype=wp.vec2), + solimp_in: wp.array(dtype=vec5), + # Data out: + efc_type_out: wp.array2d(dtype=int), + efc_id_out: wp.array2d(dtype=int), + efc_pos_out: wp.array2d(dtype=float), + efc_margin_out: wp.array2d(dtype=float), + efc_D_out: wp.array2d(dtype=float), + efc_vel_out: wp.array2d(dtype=float), + efc_aref_out: wp.array2d(dtype=float), + efc_frictionloss_out: wp.array2d(dtype=float), + ): + worldid, contact_idx = wp.tid() + + efcid_start = ne_in[worldid] + nf_in[worldid] + nl_in[worldid] + efcid_end = nefc_in[worldid] + efcid = efcid_start + contact_idx + + if efcid >= efcid_end: return - timestep = opt_timestep[worldid % opt_timestep.shape[0]] - impratio_invsqrt = opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] - contact_efc_address_out[conid, dimid] = efcid - - con_pos = pos_in[conid] - frame = frame_in[conid] + conid = efc_conid_in[worldid, efcid] + condim = condim_in[conid] geom = geom_in[conid] body1 = geom_bodyid[geom[0]] body2 = geom_bodyid[geom[1]] - Jqvel = float(0.0) - - # skip fixed bodies - body1 = body_weldid[body1] - body2 = body_weldid[body2] - - da1 = body_dofadr[body1] + body_dofnum[body1] - 1 - da2 = body_dofadr[body2] + body_dofnum[body2] - 1 - da = wp.max(da1, da2) - - for dofid in range(nv - 1, -1, -1): - if dofid == da: - # TODO(team): contact jacobian - jac1p, jac1r = support.jac( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - con_pos, - body1, - dofid, - worldid, - ) - jac2p, jac2r = support.jac( - body_parentid, - body_rootid, - dof_bodyid, - subtree_com_in, - cdof_in, - con_pos, - body2, - dofid, - worldid, - ) - - J = float(0.0) - for xyz in range(3): - if dimid < 3: - jac_dif = jac2p[xyz] - jac1p[xyz] - J += frame[dimid, xyz] * jac_dif - else: - jac_dif = jac2r[xyz] - jac1r[xyz] - J += frame[dimid - 3, xyz] * jac_dif - - efc_J_out[worldid, efcid, dofid] = J - Jqvel += J * qvel_in[worldid, dofid] - - # Advance tree pointers and recompute da for next iteration - if da1 == da: - da1 = dof_parentid[da1] - if da2 == da: - da2 = dof_parentid[da2] - da = wp.max(da1, da2) - else: - efc_J_out[worldid, efcid, dofid] = 0.0 + Jqvel = efc_Jqvel_in[worldid, efcid] + timestep = opt_timestep[worldid % opt_timestep.shape[0]] + impratio_invsqrt = opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]] body_invweight0_id = worldid % body_invweight0.shape[0] invweight = body_invweight0[body_invweight0_id, body1][0] + body_invweight0[body_invweight0_id, body2][0] + includemargin = includemargin_in[conid] + pos = dist_in[conid] - includemargin + ref = solref_in[conid] pos_aref = pos - if dimid > 0: - solreffriction = solreffriction_in[conid] + if condim == 1: + efc_type = ConstraintType.CONTACT_FRICTIONLESS + else: + if wp.static(IS_ELLIPTIC): + efc_type = ConstraintType.CONTACT_ELLIPTIC - # non-normal directions use solreffriction (if non-zero) - if solreffriction[0] or solreffriction[1]: - ref = solreffriction + base_efcid = contact_efc_address_in[conid, 0] + dimid = efcid - base_efcid - invweight = invweight * impratio_invsqrt * impratio_invsqrt - friction = friction_in[conid] + if dimid > 0: + solreffriction = solreffriction_in[conid] - if dimid > 1: - fri0 = friction[0] - frii = friction[dimid - 1] - fri = fri0 * fri0 / (frii * frii) - invweight *= fri + if solreffriction[0] != 0.0 or solreffriction[1] != 0.0: + ref = solreffriction - pos_aref = 0.0 + invweight = invweight * impratio_invsqrt * impratio_invsqrt - if condim == 1: - efc_type = ConstraintType.CONTACT_FRICTIONLESS - else: - efc_type = ConstraintType.CONTACT_ELLIPTIC + if dimid > 1: + fri0 = friction_in[conid, 0] + frii = friction_in[conid, dimid - 1] + fri = fri0 * fri0 / (frii * frii) + invweight = invweight * fri + + pos_aref = 0.0 + else: + efc_type = ConstraintType.CONTACT_PYRAMIDAL + fri0 = friction_in[conid, 0] + invweight = invweight + fri0 * fri0 * invweight + invweight = invweight * 2.0 * fri0 * fri0 * impratio_invsqrt * impratio_invsqrt _update_efc_row( worldid, @@ -1603,6 +1534,8 @@ def _efc_contact_elliptic( efc_frictionloss_out, ) + return kernel + @event_scope def make_constraint(m: types.Model, d: types.Data): @@ -1624,10 +1557,8 @@ def make_constraint(m: types.Model, d: types.Data): m.nv, m.nsite, m.opt.timestep, - m.body_parentid, m.body_rootid, m.body_invweight0, - m.dof_bodyid, m.site_bodyid, m.eq_obj1id, m.eq_obj2id, @@ -1635,6 +1566,7 @@ def make_constraint(m: types.Model, d: types.Data): m.eq_solref, m.eq_solimp, m.eq_data, + m.dof_affects_body, m.eq_connect_adr, d.qvel, d.eq_active, @@ -1667,10 +1599,8 @@ def make_constraint(m: types.Model, d: types.Data): m.nv, m.nsite, m.opt.timestep, - m.body_parentid, m.body_rootid, m.body_invweight0, - m.dof_bodyid, m.site_bodyid, m.site_quat, m.eq_obj1id, @@ -1679,6 +1609,7 @@ def make_constraint(m: types.Model, d: types.Data): m.eq_solref, m.eq_solimp, m.eq_data, + m.dof_affects_body, m.eq_wld_adr, d.qvel, d.eq_active, @@ -1982,102 +1913,100 @@ def make_constraint(m: types.Model, d: types.Data): # contact if not (m.opt.disableflags & types.DisableBit.CONTACT): - if m.opt.cone == types.ConeType.PYRAMIDAL: - wp.launch( - _efc_contact_pyramidal, - dim=(d.naconmax, m.nmaxpyramid), + # Reinterpret frame and friction arrays for optimized memory access + contact_frame_2d = _reinterpret(d.contact.frame, wp.vec3, (d.naconmax, 3)) + contact_friction_2d = _reinterpret(d.contact.friction, float, (d.naconmax, 5)) + + wp.launch( + _efc_contact_init(m.opt.cone), + dim=(d.naconmax,), + inputs=[ + d.njmax, + d.nacon, + d.contact.dist, + d.contact.dim, + d.contact.includemargin, + d.contact.worldid, + d.contact.type, + ], + outputs=[ + d.nefc, + d.contact.efc_address, + d.efc.conid, + ], + ) + + if m.nv_pad > 0 and m.nv > 0: + tile_size = m.block_dim.contact_jac_tiled + n_dof_blocks = (m.nv_pad + tile_size - 1) // tile_size + + # Zero Jqvel since we use atomic_add to accumulate across DOF blocks + d.efc.Jqvel.zero_() + + wp.launch_tiled( + _efc_contact_jac_tiled(tile_size, m.opt.cone), + dim=(d.nworld, n_dof_blocks), inputs=[ - m.nv, - m.opt.timestep, - m.opt.impratio_invsqrt, - m.body_parentid, m.body_rootid, - m.body_weldid, - m.body_dofnum, - m.body_dofadr, - m.body_invweight0, - m.dof_bodyid, - m.dof_parentid, m.geom_bodyid, - d.qvel, - d.subtree_com, - d.cdof, - d.njmax, - d.nacon, - refsafe, - d.contact.dist, - d.contact.dim, - d.contact.includemargin, - d.contact.worldid, - d.contact.geom, - d.contact.pos, - d.contact.frame, - d.contact.friction, - d.contact.solref, - d.contact.solimp, - d.contact.type, - ], - outputs=[ + m.dof_affects_body, + d.ne, + d.nf, + d.nl, d.nefc, - d.contact.efc_address, - d.efc.type, - d.efc.id, - d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, - ], - ) - elif m.opt.cone == types.ConeType.ELLIPTIC: - wp.launch( - _efc_contact_elliptic, - dim=(d.naconmax, m.nmaxcondim), - inputs=[ - m.nv, - m.opt.timestep, - m.opt.impratio_invsqrt, - m.body_parentid, - m.body_rootid, - m.body_weldid, - m.body_dofnum, - m.body_dofadr, - m.body_invweight0, - m.dof_bodyid, - m.dof_parentid, - m.geom_bodyid, d.qvel, d.subtree_com, d.cdof, + d.contact.efc_address, + d.efc.conid, d.njmax, - d.nacon, - refsafe, - d.contact.dist, + m.nv_pad, d.contact.dim, - d.contact.includemargin, - d.contact.worldid, d.contact.geom, d.contact.pos, - d.contact.frame, - d.contact.friction, - d.contact.solref, - d.contact.solreffriction, - d.contact.solimp, - d.contact.type, + contact_frame_2d, + contact_friction_2d, ], outputs=[ - d.nefc, - d.contact.efc_address, - d.efc.type, - d.efc.id, d.efc.J, - d.efc.pos, - d.efc.margin, - d.efc.D, - d.efc.vel, - d.efc.aref, - d.efc.frictionloss, + d.efc.Jqvel, ], + block_dim=tile_size, ) + + wp.launch( + _efc_contact_update(m.opt.cone), + dim=(d.nworld, d.njmax), + inputs=[ + m.opt.timestep, + m.opt.impratio_invsqrt, + m.body_invweight0, + m.geom_bodyid, + d.ne, + d.nf, + d.nl, + d.nefc, + d.contact.efc_address, + d.efc.conid, + d.efc.Jqvel, + refsafe, + d.contact.dim, + d.contact.includemargin, + d.contact.dist, + d.contact.geom, + contact_friction_2d, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + ], + outputs=[ + d.efc.type, + d.efc.id, + d.efc.pos, + d.efc.margin, + d.efc.D, + d.efc.vel, + d.efc.aref, + d.efc.frictionloss, + ], + ) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 4f0901d0a..d7bc4beec 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -281,6 +281,20 @@ def _check_friction(name: str, id_: int, condim: int, friction, checks): m.jnt_limited_ball_adr = np.nonzero(mjm.jnt_limited & (mjm.jnt_type == mujoco.mjtJoint.mjJNT_BALL))[0] m.dof_tri_row, m.dof_tri_col = np.tril_indices(mjm.nv) + dof_affects_body = np.zeros((mjm.nbody, m.nv_pad), dtype=np.int32) + for bodyid in range(mjm.nbody): + for dofid in range(mjm.nv): + dof_bodyid_ = mjm.dof_bodyid[dofid] + in_tree = int(dof_bodyid_ == 0) + parentid = bodyid + while parentid != 0: + if parentid == dof_bodyid_: + in_tree = 1 + break + parentid = mjm.body_parentid[parentid] + dof_affects_body[bodyid, dofid] = in_tree + m.dof_affects_body = dof_affects_body + # precalculated geom pairs filterparent = not (mjm.opt.disableflags & types.DisableBit.FILTERPARENT) diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index 26de5bfe8..c6497438b 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -238,11 +238,10 @@ def _spring_damper_tendon_passive( def _gravity_force( # Model: opt_gravity: wp.array(dtype=wp.vec3), - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), body_mass: wp.array2d(dtype=float), body_gravcomp: wp.array2d(dtype=float), - dof_bodyid: wp.array(dtype=int), + dof_affects_body: wp.array2d(dtype=int), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -258,7 +257,7 @@ def _gravity_force( if gravcomp: force = -gravity * body_mass[worldid % body_mass.shape[0], bodyid] * gravcomp pos = xipos_in[worldid, bodyid] - jac, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos, bodyid, dofid, worldid) + jac, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos, bodyid, dofid, worldid) wp.atomic_add(qfrc_gravcomp_out[worldid], dofid, wp.dot(jac, force)) @@ -814,11 +813,10 @@ def passive(m: Model, d: Data): dim=(d.nworld, m.nbody - 1, m.nv), inputs=[ m.opt.gravity, - m.body_parentid, m.body_rootid, m.body_mass, m.body_gravcomp, - m.dof_bodyid, + m.dof_affects_body, d.xipos, d.subtree_com, d.cdof, diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 010283b1e..b961d14be 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -240,10 +240,8 @@ def _flex_vertices( def _flex_edges( # Model: nflex: int, - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), body_dofadr: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), flex_vertadr: wp.array(dtype=int), flex_edgeadr: wp.array(dtype=int), flex_edgenum: wp.array(dtype=int), @@ -251,6 +249,7 @@ def _flex_edges( flex_edge: wp.array(dtype=wp.vec2i), flexedge_J_rowadr: wp.array(dtype=int), flexedge_J_colind: wp.array(dtype=int), + dof_affects_body: wp.array2d(dtype=int), # Data in: qvel_in: wp.array2d(dtype=float), subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -306,33 +305,33 @@ def _flex_edges( # TODO(team): jacdif - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos1, b1, dofi0, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos2, b2, dofi0, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos1, b1, dofi0, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos2, b2, dofi0, worldid) jacdif = jacp2 - jacp1 Ji0 = wp.dot(jacdif, edge) - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos1, b1, dofi1, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos2, b2, dofi1, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos1, b1, dofi1, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos2, b2, dofi1, worldid) jacdif = jacp2 - jacp1 Ji1 = wp.dot(jacdif, edge) - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos1, b1, dofi2, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos2, b2, dofi2, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos1, b1, dofi2, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos2, b2, dofi2, worldid) jacdif = jacp2 - jacp1 Ji2 = wp.dot(jacdif, edge) - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos1, b1, dofj0, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos2, b2, dofj0, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos1, b1, dofj0, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos2, b2, dofj0, worldid) jacdif = jacp2 - jacp1 Jj0 = wp.dot(jacdif, edge) - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos1, b1, dofj1, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos2, b2, dofj1, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos1, b1, dofj1, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos2, b2, dofj1, worldid) jacdif = jacp2 - jacp1 Jj1 = wp.dot(jacdif, edge) - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos1, b1, dofj2, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos2, b2, dofj2, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos1, b1, dofj2, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pos2, b2, dofj2, worldid) jacdif = jacp2 - jacp1 Jj2 = wp.dot(jacdif, edge) @@ -413,10 +412,8 @@ def flex(m: Model, d: Data): dim=(d.nworld, m.nflexedge), inputs=[ m.nflex, - m.body_parentid, m.body_rootid, m.body_dofadr, - m.dof_bodyid, m.flex_vertadr, m.flex_edgeadr, m.flex_edgenum, @@ -424,6 +421,7 @@ def flex(m: Model, d: Data): m.flex_edge, m.flexedge_J_rowadr, m.flexedge_J_colind, + m.dof_affects_body, d.qvel, d.subtree_com, d.cdof, @@ -1533,7 +1531,6 @@ def rne_postconstraint(m: Model, d: Data): def _tendon_dot( # Model: nv: int, - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), jnt_type: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), @@ -1546,6 +1543,7 @@ def _tendon_dot( wrap_type: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), wrap_prm: wp.array(dtype=float), + dof_affects_body: wp.array2d(dtype=int), # Data in: site_xpos_in: wp.array2d(dtype=wp.vec3), subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -1632,12 +1630,12 @@ def _tendon_dot( # TODO(team): parallelize? for i in range(nv): jac1, _ = support.jac_dot( - body_parentid, body_rootid, jnt_type, jnt_dofadr, dof_bodyid, dof_jntid, + dof_affects_body, subtree_com_in, cdof_in, cvel_in, @@ -1648,12 +1646,12 @@ def _tendon_dot( worldid, ) jac2, _ = support.jac_dot( - body_parentid, body_rootid, jnt_type, jnt_dofadr, dof_bodyid, dof_jntid, + dof_affects_body, subtree_com_in, cdof_in, cvel_in, @@ -1670,9 +1668,8 @@ def _tendon_dot( # get endpoint Jacobians, subtract jac1, _ = support.jac( - body_parentid, body_rootid, - dof_bodyid, + dof_affects_body, subtree_com_in, cdof_in, wpnt0, @@ -1681,9 +1678,8 @@ def _tendon_dot( worldid, ) jac2, _ = support.jac( - body_parentid, body_rootid, - dof_bodyid, + dof_affects_body, subtree_com_in, cdof_in, wpnt1, @@ -1766,7 +1762,6 @@ def tendon_bias(m: Model, d: Data, qfrc: wp.array2d(dtype=float)): dim=(d.nworld, m.ntendon), inputs=[ m.nv, - m.body_parentid, m.body_rootid, m.jnt_type, m.jnt_dofadr, @@ -1779,6 +1774,7 @@ def tendon_bias(m: Model, d: Data, qfrc: wp.array2d(dtype=float)): m.wrap_type, m.wrap_objid, m.wrap_prm, + m.dof_affects_body, d.site_xpos, d.subtree_com, d.cdof, @@ -1915,7 +1911,6 @@ def com_vel(m: Model, d: Data): def _transmission( # Model: nv: int, - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), body_weldid: wp.array(dtype=int), body_dofnum: wp.array(dtype=int), @@ -1923,7 +1918,6 @@ def _transmission( jnt_type: wp.array(dtype=int), jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), dof_parentid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), site_quat: wp.array2d(dtype=wp.quat), @@ -1935,6 +1929,7 @@ def _transmission( actuator_trnid: wp.array(dtype=wp.vec2i), actuator_gear: wp.array2d(dtype=wp.spatial_vector), actuator_cranklength: wp.array(dtype=float), + dof_affects_body: wp.array2d(dtype=int), # Data in: qpos_in: wp.array2d(dtype=float), xquat_in: wp.array2d(dtype=wp.quat), @@ -2032,15 +2027,13 @@ def _transmission( # get Jacobians of axis(jacA) and vec(jac) # mj_jacPointAxis jacp, jacr = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_xpos_idslider, site_bodyid[idslider], i, worldid + body_rootid, dof_affects_body, subtree_com_in, cdof_in, site_xpos_idslider, site_bodyid[idslider], i, worldid ) jacS = jacp jacA = wp.cross(jacr, axis) # mj_jacSite - jac, _ = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_xpos_id, site_bodyid[id], i, worldid - ) + jac, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, site_xpos_id, site_bodyid[id], i, worldid) jac -= jacS # apply the chain rule @@ -2092,9 +2085,8 @@ def _transmission( # TODO(team): parallelize for i in range(nv): jacp, jacr = support.jac( - body_parentid, body_rootid, - dof_bodyid, + dof_affects_body, subtree_com_in, cdof_in, site_xpos_in[worldid, siteid], @@ -2164,12 +2156,12 @@ def _transmission( # TODO(team): parallelize for i in range(nv): jacp, jacr = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_xpos, site_bodyid[siteid], i, worldid + body_rootid, dof_affects_body, subtree_com_in, cdof_in, site_xpos, site_bodyid[siteid], i, worldid ) # jacref: global Jacobian of reference site jacpref, jacrref = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, ref_xpos, site_bodyid[refid], i, worldid + body_rootid, dof_affects_body, subtree_com_in, cdof_in, ref_xpos, site_bodyid[refid], i, worldid ) jacpdif = jacp - jacpref @@ -2201,11 +2193,10 @@ def _transmission( def _transmission_body_moment( # Model: opt_cone: int, - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), actuator_trnid: wp.array(dtype=wp.vec2i), + dof_affects_body: wp.array2d(dtype=int), actuator_trntype_body_adr: wp.array(dtype=int), # Data in: subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -2281,8 +2272,8 @@ def _transmission_body_moment( normal = wp.vec3(contact_frame[0, 0], contact_frame[0, 1], contact_frame[0, 2]) # get Jacobian difference - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, contact_pos, b1, dofid, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, contact_pos, b2, dofid, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, contact_pos, b1, dofid, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, contact_pos, b2, dofid, worldid) jacdif = jacp2 - jacp1 # project Jacobian along the normal of the contact frame @@ -2319,7 +2310,6 @@ def transmission(m: Model, d: Data): dim=[d.nworld, m.nu], inputs=[ m.nv, - m.body_parentid, m.body_rootid, m.body_weldid, m.body_dofnum, @@ -2327,7 +2317,6 @@ def transmission(m: Model, d: Data): m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, - m.dof_bodyid, m.dof_parentid, m.site_bodyid, m.site_quat, @@ -2339,6 +2328,7 @@ def transmission(m: Model, d: Data): m.actuator_trnid, m.actuator_gear, m.actuator_cranklength, + m.dof_affects_body, d.qpos, d.xquat, d.site_xpos, @@ -2359,11 +2349,10 @@ def transmission(m: Model, d: Data): dim=(m.nacttrnbody, d.naconmax, m.nv), inputs=[ m.opt.cone, - m.body_parentid, m.body_rootid, - m.dof_bodyid, m.geom_bodyid, m.actuator_trnid, + m.dof_affects_body, m.actuator_trntype_body_adr, d.subtree_com, d.cdof, @@ -2789,11 +2778,10 @@ def _joint_tendon( def _spatial_site_tendon( # Model: nv: int, - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), + dof_affects_body: wp.array2d(dtype=int), tendon_site_pair_adr: wp.array(dtype=int), wrap_site_pair_adr: wp.array(dtype=int), wrap_pulley_scale: wp.array(dtype=float), @@ -2831,8 +2819,8 @@ def _spatial_site_tendon( if body0 != body1: # TODO(team): parallelize for i in range(nv): - jacp1, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pnt0, body0, i, worldid) - jacp2, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pnt1, body1, i, worldid) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pnt0, body0, i, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, pnt1, body1, i, worldid) J = wp.dot(jacp2 - jacp1, vec) if J: @@ -2843,15 +2831,14 @@ def _spatial_site_tendon( def _spatial_geom_tendon( # Model: nv: int, - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), site_bodyid: wp.array(dtype=int), wrap_type: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), wrap_prm: wp.array(dtype=float), + dof_affects_body: wp.array2d(dtype=int), tendon_geom_adr: wp.array(dtype=int), wrap_geom_adr: wp.array(dtype=int), wrap_pulley_scale: wp.array(dtype=float), @@ -2934,25 +2921,17 @@ def _spatial_geom_tendon( J = float(0.0) # site-geom if dif_body_sitegeom: - jacp_site0, _ = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt0, bodyid_site0, i, worldid - ) + jacp_site0, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, site_pnt0, bodyid_site0, i, worldid) - jacp_geom0, _ = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, geom_pnt0, bodyid_geom, i, worldid - ) + jacp_geom0, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, geom_pnt0, bodyid_geom, i, worldid) J += wp.dot(jacp_geom0 - jacp_site0, vec_sitegeom) # geom-site if dif_body_geomsite: - jacp_geom1, _ = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, geom_pnt1, bodyid_geom, i, worldid - ) + jacp_geom1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, geom_pnt1, bodyid_geom, i, worldid) - jacp_site1, _ = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt1, bodyid_site1, i, worldid - ) + jacp_site1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, site_pnt1, bodyid_site1, i, worldid) J += wp.dot(jacp_site1 - jacp_geom1, vec_geomsite) @@ -2973,12 +2952,8 @@ def _spatial_geom_tendon( if bodyid_site0 != bodyid_site1: # TODO(team): parallelize for i in range(nv): - jacp1, _ = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt0, bodyid_site0, i, worldid - ) - jacp2, _ = support.jac( - body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, site_pnt1, bodyid_site1, i, worldid - ) + jacp1, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, site_pnt0, bodyid_site0, i, worldid) + jacp2, _ = support.jac(body_rootid, dof_affects_body, subtree_com_in, cdof_in, site_pnt1, bodyid_site1, i, worldid) J = wp.dot(jacp2 - jacp1, vec_sitesite) @@ -3181,11 +3156,10 @@ def tendon(m: Model, d: Data): dim=(d.nworld, m.wrap_site_pair_adr.size), inputs=[ m.nv, - m.body_parentid, m.body_rootid, - m.dof_bodyid, m.site_bodyid, m.wrap_objid, + m.dof_affects_body, m.tendon_site_pair_adr, m.wrap_site_pair_adr, m.wrap_pulley_scale, @@ -3202,15 +3176,14 @@ def tendon(m: Model, d: Data): dim=(d.nworld, m.wrap_geom_adr.size), inputs=[ m.nv, - m.body_parentid, m.body_rootid, - m.dof_bodyid, m.geom_bodyid, m.geom_size, m.site_bodyid, m.wrap_type, m.wrap_objid, m.wrap_prm, + m.dof_affects_body, m.tendon_geom_adr, m.wrap_geom_adr, m.wrap_pulley_scale, diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index edf6567b6..ffc4c64c1 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -363,12 +363,27 @@ def transform_force(frc: wp.spatial_vector, offset: wp.vec3) -> wp.spatial_vecto return transform_force(force, torque, offset) +@wp.func +def _compute_jacp(cdof_clip: wp.spatial_vector, offset: wp.vec3, affect: int) -> wp.vec3: + if affect == 0: + return wp.vec3(0.0, 0.0, 0.0) + cdof_lin = wp.spatial_bottom(cdof_clip) + cdof_ang = wp.spatial_top(cdof_clip) + return cdof_lin + wp.cross(cdof_ang, offset) + + +@wp.func +def _compute_jacr(cdof_clip: wp.spatial_vector, affect: int) -> wp.vec3: + if affect == 0: + return wp.vec3(0.0, 0.0, 0.0) + return wp.spatial_top(cdof_clip) + + @wp.func def jac( # Model: - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - dof_bodyid: wp.array(dtype=int), + dof_affects_body: wp.array2d(dtype=int), # Data in: subtree_com_in: wp.array2d(dtype=wp.vec3), cdof_in: wp.array2d(dtype=wp.spatial_vector), @@ -378,26 +393,11 @@ def jac( dofid: int, worldid: int, ) -> Tuple[wp.vec3, wp.vec3]: - dof_bodyid_ = dof_bodyid[dofid] - in_tree = int(dof_bodyid_ == 0) - parentid = bodyid - while parentid != 0: - if parentid == dof_bodyid_: - in_tree = 1 - break - parentid = body_parentid[parentid] - - if not in_tree: - return wp.vec3(0.0), wp.vec3(0.0) - offset = point - wp.vec3(subtree_com_in[worldid, body_rootid[bodyid]]) - cdof = cdof_in[worldid, dofid] - cdof_ang = wp.spatial_top(cdof) - cdof_lin = wp.spatial_bottom(cdof) - - jacp = cdof_lin + wp.cross(cdof_ang, offset) - jacr = cdof_ang + affect = dof_affects_body[bodyid, dofid] + jacp = _compute_jacp(cdof, offset, affect) + jacr = _compute_jacr(cdof, affect) return jacp, jacr @@ -405,12 +405,12 @@ def jac( @wp.func def jac_dot( # Model: - body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), jnt_type: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), dof_bodyid: wp.array(dtype=int), dof_jntid: wp.array(dtype=int), + dof_affects_body: wp.array2d(dtype=int), # Data in: subtree_com_in: wp.array2d(dtype=wp.vec3), cdof_in: wp.array2d(dtype=wp.spatial_vector), @@ -422,16 +422,7 @@ def jac_dot( dofid: int, worldid: int, ) -> Tuple[wp.vec3, wp.vec3]: - dof_bodyid_ = dof_bodyid[dofid] - in_tree = int(dof_bodyid_ == 0) - parentid = bodyid - while parentid != 0: - if parentid == dof_bodyid_: - in_tree = 1 - break - parentid = body_parentid[parentid] - - if not in_tree: + if not dof_affects_body[bodyid, dofid]: return wp.vec3(0.0), wp.vec3(0.0) com = subtree_com_in[worldid, body_rootid[bodyid]] diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 691c72320..79b741065 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -43,6 +43,8 @@ class BlockDim: # collision_driver segmented_sort: int = 128 + # constraint + contact_jac_tiled: int = 32 # forward euler_dense: int = 32 actuator_velocity: int = 32 @@ -1049,6 +1051,7 @@ class Model: jnt_limited_ball_adr: limited/ball jntadr dof_tri_row: dof lower triangle row (used in solver) dof_tri_col: dof lower triangle col (used in solver) + dof_affects_body: precomputed mask: does DOF affect body (nbody, nv) nxn_geom_pair: collision pair geom ids [-2, ngeom-1] nxn_geom_pair_filtered: valid collision pair geom ids [-1, ngeom - 1] @@ -1401,6 +1404,7 @@ class Model: jnt_limited_ball_adr: wp.array(dtype=int) dof_tri_row: wp.array(dtype=int) dof_tri_col: wp.array(dtype=int) + dof_affects_body: array("nbody", "nv_pad", int) nxn_geom_pair: wp.array(dtype=wp.vec2i) nxn_geom_pair_filtered: wp.array(dtype=wp.vec2i) nxn_pairid: wp.array(dtype=wp.vec2i) @@ -1512,7 +1516,9 @@ class Constraint: Attributes: type: constraint type (ConstraintType) (nworld, njmax) id: id of object of specific type (nworld, njmax) + conid: contact id for each efc row (nworld, njmax) J: constraint Jacobian (nworld, njmax_pad, nv_pad) + Jqvel: J @ qvel for contacts (nworld, njmax) pos: constraint position (equality, contact) (nworld, njmax) margin: inclusion margin (contact) (nworld, njmax) D: constraint mass (nworld, njmax_pad) @@ -1527,7 +1533,9 @@ class Constraint: type: array("nworld", "njmax", int) id: array("nworld", "njmax", int) + conid: array("nworld", "njmax", int) J: array("nworld", "njmax_pad", "nv_pad", float) + Jqvel: array("nworld", "njmax", float) pos: array("nworld", "njmax", float) margin: array("nworld", "njmax", float) D: array("nworld", "njmax_pad", float)