diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index f2228bce9..d1a64ebd9 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -969,23 +969,27 @@ def put_data( efc = types.Constraint(**efc_kwargs) if is_sparse(mjm): - # TODO(team): process efc_J sparsity structure for nv row shift - efc.J_rownnz = wp.array(np.full((nworld, njmax), mjm.nv, dtype=int), dtype=int) - efc.J_rowadr = wp.array( - np.tile(np.arange(0, njmax * mjm.nv, mjm.nv) if mjm.nv else np.zeros(njmax, dtype=int), (nworld, 1)), dtype=int - ) - efc.J_colind = wp.array(np.tile(np.arange(mjm.nv), (nworld, njmax)).reshape((nworld, 1, -1))[:, :, :njmax_nnz], dtype=int) - - mj_efc_J = np.zeros((mjd.nefc, mjm.nv)) + J_rownnz = np.zeros(njmax, dtype=np.int32) + J_rowadr = np.zeros(njmax, dtype=np.int32) + J_colind = np.zeros(njmax_nnz, dtype=np.int32) + J = np.zeros(njmax_nnz, dtype=np.float64) if mjd.nefc: if mujoco.mj_isSparse(mjm): - mujoco.mju_sparse2dense(mj_efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind) + J_rownnz[: mjd.nefc] = mjd.efc_J_rownnz[: mjd.nefc] + J_rowadr[: mjd.nefc] = mjd.efc_J_rowadr[: mjd.nefc] + nnz = int(mjd.efc_J_rownnz[: mjd.nefc].sum()) + J_colind[:nnz] = mjd.efc_J_colind[:nnz] + J[:nnz] = mjd.efc_J[:nnz] else: - mj_efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - efc_J = np.zeros((njmax, mjm.nv), dtype=float) - efc_J[: mjd.nefc, : mjm.nv] = mj_efc_J - efc_J_flat = np.tile(efc_J.reshape(-1), (nworld, 1, 1)).reshape((nworld, 1, -1))[:, :, :njmax_nnz] - efc.J = wp.array(efc_J_flat, dtype=float) + dense_J = mjd.efc_J.reshape((-1, mjm.nv))[: mjd.nefc] + mujoco.mju_dense2sparse( + J[: mjd.nefc * mjm.nv], dense_J, J_rownnz[: mjd.nefc], J_rowadr[: mjd.nefc], J_colind[: mjd.nefc * mjm.nv] + ) + + efc.J_rownnz = wp.array(np.tile(J_rownnz, (nworld, 1)), dtype=int) + efc.J_rowadr = wp.array(np.tile(J_rowadr, (nworld, 1)), dtype=int) + efc.J_colind = wp.array(np.tile(J_colind, (nworld, 1)).reshape((nworld, 1, -1)), dtype=int) + efc.J = wp.array(np.tile(J, (nworld, 1)).reshape((nworld, 1, -1)), dtype=float) else: efc.J_rownnz = wp.zeros((nworld, 0), dtype=int) efc.J_rowadr = wp.zeros((nworld, 0), dtype=int)