Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 49 additions & 55 deletions mujoco_warp/_src/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,7 +2453,7 @@ def update_gradient_JTCJ_sparse(
nblocks_perblock: int,
dim_block: int,
# Out:
h_out: wp.array3d(dtype=float),
ctx_h_out: wp.array3d(dtype=float),
):
conid_start, elementid = wp.tid()

Expand All @@ -2468,20 +2468,37 @@ def update_gradient_JTCJ_sparse(

worldid = contact_worldid_in[conid]
if ctx_done_in[worldid]:
return
continue

condim = contact_dim_in[conid]

if condim == 1:
return
continue

# check contact status
if contact_dist_in[conid] - contact_includemargin_in[conid] >= 0.0:
return
continue

efcid0 = contact_efc_address_in[conid, 0]
if efc_state_in[worldid, efcid0] != types.ConstraintState.CONE:
return
continue

# All dims share the same sparsity pattern. Scan colind once to find
# the sparse positions of dof1id and dof2id. Skip if either is absent.
rownnz = efc_J_rownnz_in[worldid, efcid0]
rowadr0 = efc_J_rowadr_in[worldid, efcid0]
pos1 = int(-1)
pos2 = int(-1)
for k in range(rownnz):
col = efc_J_colind_in[worldid, 0, rowadr0 + k]
if col == dof1id:
pos1 = k
if col == dof2id:
pos2 = k
if pos1 >= 0 and pos2 >= 0:
break
if pos1 < 0 or pos2 < 0:
continue

fri = contact_friction_in[conid]
mu = fri[0] * opt_impratio_invsqrt[worldid % opt_impratio_invsqrt.shape[0]]
Expand All @@ -2490,7 +2507,7 @@ def update_gradient_JTCJ_sparse(
dm = math.safe_div(efc_D_in[worldid, efcid0], mu2 * (1.0 + mu2))

if dm == 0.0:
return
continue

n = ctx_Jaref_in[worldid, efcid0] * mu
u = types.vec6(n, 0.0, 0.0, 0.0, 0.0, 0.0)
Expand All @@ -2509,89 +2526,66 @@ def update_gradient_JTCJ_sparse(
t = wp.max(t, types.MJ_MINVAL)
ttt = wp.max(t * t * t, types.MJ_MINVAL)

# Precompute common subexpressions.
mu_over_t = math.safe_div(mu, t)
mu_n_over_ttt = mu * math.safe_div(n, ttt)
mu2_minus_mu_n_over_t = mu2 - mu * math.safe_div(n, t)

h = float(0.0)

for dim1id in range(condim):
if dim1id == 0:
efcid1 = efcid0
rowadr1 = rowadr0
dm_fri1 = dm * mu
else:
efcid1 = contact_efc_address_in[conid, dim1id]
rowadr1 = efc_J_rowadr_in[worldid, efcid1]
dm_fri1 = dm * fri[dim1id - 1]

# TODO(team): improve performance for sparse code path
rownnz1 = efc_J_rownnz_in[worldid, efcid1]
rowadr1 = efc_J_rowadr_in[worldid, efcid1]

efc_J11 = float(0.0)
efc_J12 = float(0.0)
for i1 in range(rownnz1):
sparseid1 = rowadr1 + i1
colind1 = efc_J_colind_in[worldid, 0, sparseid1]
if dof1id == colind1:
efc_J11 = efc_J_in[worldid, 0, sparseid1]
if dof2id == colind1:
efc_J12 = efc_J_in[worldid, 0, sparseid1]
if efc_J11 != 0.0 and efc_J12 != 0.0:
break
# Direct J reads using cached sparse positions.
efc_J11 = efc_J_in[worldid, 0, rowadr1 + pos1]
efc_J12 = efc_J_in[worldid, 0, rowadr1 + pos2]

ui = u[dim1id]

for dim2id in range(0, dim1id + 1):
if dim2id == 0:
efcid2 = efcid0
rowadr2 = rowadr0
dm_fri12 = dm_fri1 * mu
else:
efcid2 = contact_efc_address_in[conid, dim2id]
rowadr2 = efc_J_rowadr_in[worldid, efcid2]
dm_fri12 = dm_fri1 * fri[dim2id - 1]

rownnz2 = efc_J_rownnz_in[worldid, efcid2]
rowadr2 = efc_J_rowadr_in[worldid, efcid2]

efc_J21 = float(0.0)
efc_J22 = float(0.0)
for i2 in range(rownnz2):
sparseid2 = rowadr2 + i2
colind2 = efc_J_colind_in[worldid, 0, sparseid2]
if dof1id == colind2:
efc_J21 = efc_J_in[worldid, 0, sparseid2]
if dof2id == colind2:
efc_J22 = efc_J_in[worldid, 0, sparseid2]
if efc_J21 != 0.0 and efc_J22 != 0.0:
break
# Direct J reads using cached sparse positions.
efc_J21 = efc_J_in[worldid, 0, rowadr2 + pos1]
efc_J22 = efc_J_in[worldid, 0, rowadr2 + pos2]

uj = u[dim2id]

# set first row/column: (1, -mu/t * u)
if dim1id == 0 and dim2id == 0:
hcone = 1.0
elif dim1id == 0:
hcone = -math.safe_div(mu, t) * uj
hcone = -mu_over_t * uj
elif dim2id == 0:
hcone = -math.safe_div(mu, t) * ui
hcone = -mu_over_t * ui
else:
hcone = mu * math.safe_div(n, ttt) * ui * uj
hcone = mu_n_over_ttt * ui * uj

# add to diagonal: mu^2 - mu * n / t
if dim1id == dim2id:
hcone += mu2 - mu * math.safe_div(n, t)

# pre and post multiply by diag(mu, friction) scale by dm
if dim1id == 0:
fri1 = mu
else:
fri1 = fri[dim1id - 1]
hcone += mu2_minus_mu_n_over_t

if dim2id == 0:
fri2 = mu
else:
fri2 = fri[dim2id - 1]

hcone *= dm * fri1 * fri2
hcone *= dm_fri12

if hcone != 0.0:
h += hcone * efc_J11 * efc_J22

if dim1id != dim2id:
h += hcone * efc_J12 * efc_J21

h_out[worldid, dof1id, dof2id] += h
ctx_h_out[worldid, dof1id, dof2id] += h


@wp.kernel
Expand Down Expand Up @@ -3003,7 +2997,7 @@ def _set_h_qM_dense(
if SPARSE_CONSTRAINT_JACOBIAN:
wp.launch(
update_gradient_JTCJ_sparse,
dim=(d.naconmax, m.dof_tri_row.size),
dim=(dim_block, m.dof_tri_row.size),
inputs=[
m.opt.impratio_invsqrt,
m.dof_tri_row,
Expand Down
Loading