Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 100 additions & 74 deletions mujoco_warp/_src/collision_convex.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def ccd_hfield_kernel_builder(
geomtype2: int,
gjk_iterations: int,
epa_iterations: int,
geomgeomid: int,
):
"""Kernel builder for heightfield CCD collisions (no multiccd args)."""

Expand Down Expand Up @@ -205,6 +206,7 @@ def ccd_hfield_kernel(
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
naconmax_in: int,
naccdmax_in: int,
ncollision_in: wp.array(dtype=int),
# In:
collision_pair_in: wp.array(dtype=wp.vec2i),
Expand All @@ -218,6 +220,7 @@ def ccd_hfield_kernel(
epa_pr_in: wp.array2d(dtype=wp.vec3),
epa_norm2_in: wp.array2d(dtype=float),
epa_horizon_in: wp.array2d(dtype=int),
nccd_in: wp.array(dtype=int),
# Data out:
contact_dist_out: wp.array(dtype=float),
contact_pos_out: wp.array(dtype=wp.vec3),
Expand All @@ -234,18 +237,18 @@ def ccd_hfield_kernel(
contact_geomcollisionid_out: wp.array(dtype=int),
nacon_out: wp.array(dtype=int),
):
tid = wp.tid()
if tid >= ncollision_in[0]:
collisionid = wp.tid()
if collisionid >= ncollision_in[0]:
return

geoms = collision_pair_in[tid]
geoms = collision_pair_in[collisionid]
g1 = geoms[0]
g2 = geoms[1]

if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2:
return

worldid = collision_worldid_in[tid]
worldid = collision_worldid_in[collisionid]

# height field filter
no_hf_collision, xmin, xmax, ymin, ymax, zmin, zmax = _hfield_filter(
Expand All @@ -269,6 +272,11 @@ def ccd_hfield_kernel(
if no_hf_collision:
return

ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1)
if ccdid >= naccdmax_in:
wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid)
return

_, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params(
geom_condim,
geom_priority,
Expand All @@ -287,7 +295,7 @@ def ccd_hfield_kernel(
pair_friction,
collision_pair_in,
collision_pairid_in,
tid,
collisionid,
worldid,
)

Expand Down Expand Up @@ -368,16 +376,16 @@ def ccd_hfield_kernel(
geom2.margin = margin

# EPA memory
epa_vert1 = epa_vert1_in[tid]
epa_vert2 = epa_vert2_in[tid]
epa_vert_index1 = epa_vert_index1_in[tid]
epa_vert_index2 = epa_vert_index2_in[tid]
epa_face = epa_face_in[tid]
epa_pr = epa_pr_in[tid]
epa_norm2 = epa_norm2_in[tid]
epa_horizon = epa_horizon_in[tid]
epa_vert1 = epa_vert1_in[ccdid]
epa_vert2 = epa_vert2_in[ccdid]
epa_vert_index1 = epa_vert_index1_in[ccdid]
epa_vert_index2 = epa_vert_index2_in[ccdid]
epa_face = epa_face_in[ccdid]
epa_pr = epa_pr_in[ccdid]
epa_norm2 = epa_norm2_in[ccdid]
epa_horizon = epa_horizon_in[ccdid]

collision_pairid = collision_pairid_in[tid]
collision_pairid = collision_pairid_in[collisionid]

# process all prisms in subgrid
count = int(0)
Expand Down Expand Up @@ -692,6 +700,7 @@ def ccd_kernel_builder(
gjk_iterations: int,
epa_iterations: int,
use_multiccd: bool,
geomgeomid: int,
):
"""Kernel builder for non-heightfield CCD collisions (no hfield args)."""

Expand Down Expand Up @@ -725,7 +734,7 @@ def eval_ccd_write_contact(
geom2: Geom,
geoms: wp.vec2i,
worldid: int,
tid: int,
ccdid: int,
margin: float,
gap: float,
condim: int,
Expand Down Expand Up @@ -773,14 +782,14 @@ def eval_ccd_write_contact(
geomtype2,
x1,
x2,
epa_vert1_in[tid],
epa_vert2_in[tid],
epa_vert_index1_in[tid],
epa_vert_index2_in[tid],
epa_face_in[tid],
epa_pr_in[tid],
epa_norm2_in[tid],
epa_horizon_in[tid],
epa_vert1_in[ccdid],
epa_vert2_in[ccdid],
epa_vert_index1_in[ccdid],
epa_vert_index2_in[ccdid],
epa_face_in[ccdid],
epa_pr_in[ccdid],
epa_norm2_in[ccdid],
epa_horizon_in[ccdid],
)

if dist >= 0.0 and pairid[1] == -1:
Expand All @@ -802,22 +811,22 @@ def eval_ccd_write_contact(

if multiccd_idx > -1:
ncollision, witness1, witness2 = multicontact(
multiccd_polygon_in[tid],
multiccd_clipped_in[tid],
multiccd_pnormal_in[tid],
multiccd_pdist_in[tid],
multiccd_idx1_in[tid],
multiccd_idx2_in[tid],
multiccd_n1_in[tid],
multiccd_n2_in[tid],
multiccd_endvert_in[tid],
multiccd_face1_in[tid],
multiccd_face2_in[tid],
epa_vert1_in[tid],
epa_vert2_in[tid],
epa_vert_index1_in[tid],
epa_vert_index2_in[tid],
epa_face_in[tid, multiccd_idx],
multiccd_polygon_in[ccdid],
multiccd_clipped_in[ccdid],
multiccd_pnormal_in[ccdid],
multiccd_pdist_in[ccdid],
multiccd_idx1_in[ccdid],
multiccd_idx2_in[ccdid],
multiccd_n1_in[ccdid],
multiccd_n2_in[ccdid],
multiccd_endvert_in[ccdid],
multiccd_face1_in[ccdid],
multiccd_face2_in[ccdid],
epa_vert1_in[ccdid],
epa_vert2_in[ccdid],
epa_vert_index1_in[ccdid],
epa_vert_index2_in[ccdid],
epa_face_in[ccdid, multiccd_idx],
w1,
w2,
geom1,
Expand Down Expand Up @@ -914,6 +923,7 @@ def ccd_kernel(
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
naconmax_in: int,
naccdmax_in: int,
ncollision_in: wp.array(dtype=int),
# In:
collision_pair_in: wp.array(dtype=wp.vec2i),
Expand All @@ -938,6 +948,7 @@ def ccd_kernel(
multiccd_endvert_in: wp.array2d(dtype=wp.vec3),
multiccd_face1_in: wp.array2d(dtype=wp.vec3),
multiccd_face2_in: wp.array2d(dtype=wp.vec3),
nccd_in: wp.array(dtype=int),
# Data out:
contact_dist_out: wp.array(dtype=float),
contact_pos_out: wp.array(dtype=wp.vec3),
Expand All @@ -954,18 +965,23 @@ def ccd_kernel(
contact_geomcollisionid_out: wp.array(dtype=int),
nacon_out: wp.array(dtype=int),
):
tid = wp.tid()
if tid >= ncollision_in[0]:
collisionid = wp.tid()
if collisionid >= ncollision_in[0]:
return

geoms = collision_pair_in[tid]
geoms = collision_pair_in[collisionid]
g1 = geoms[0]
g2 = geoms[1]

if geom_type[g1] != geomtype1 or geom_type[g2] != geomtype2:
return

worldid = collision_worldid_in[tid]
ccdid = wp.atomic_add(nccd_in, wp.static(geomgeomid), 1)
if ccdid >= naccdmax_in:
wp.printf("CCD overflow - please increase naccdmax to %u\n", ccdid)
return

worldid = collision_worldid_in[collisionid]

_, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params(
geom_condim,
Expand All @@ -985,7 +1001,7 @@ def ccd_kernel(
pair_friction,
collision_pair_in,
collision_pairid_in,
tid,
collisionid,
worldid,
)

Expand Down Expand Up @@ -1039,7 +1055,7 @@ def ccd_kernel(
geom2,
geoms,
worldid,
tid,
ccdid,
margin,
gap,
condim,
Expand All @@ -1049,7 +1065,7 @@ def ccd_kernel(
solimp,
geom1.pos,
geom2.pos,
collision_pairid_in[tid],
collision_pairid_in[collisionid],
contact_dist_out,
contact_pos_out,
contact_frame_out,
Expand Down Expand Up @@ -1087,20 +1103,21 @@ def convex_narrowphase(m: Model, d: Data, ctx: CollisionContext, collision_table
"""

def _pair_count(p1: int, p2: int) -> int:
return m.geom_pair_type_count[upper_trid_index(len(GeomType), p1, p2)]
idx = upper_trid_index(len(GeomType), p1, p2)
return m.geom_pair_type_count[idx], idx

ncollision = sum(_pair_count(g[0].value, g[1].value) for g in collision_table)
ncollision = sum(_pair_count(g[0].value, g[1].value)[0] for g in collision_table)
# no convex collisions, early return

if ncollision == 0:
return

# compute nmaxpolygon and nmaxmeshdeg given the geom pairs for the model
nboxbox = _pair_count(GeomType.BOX.value, GeomType.BOX.value)
nboxbox = _pair_count(GeomType.BOX.value, GeomType.BOX.value)[0]
if (GeomType.BOX, GeomType.BOX) not in collision_table:
nboxbox = 0
nboxmesh = _pair_count(GeomType.BOX.value, GeomType.MESH.value)
nmeshmesh = _pair_count(GeomType.MESH.value, GeomType.MESH.value)
nboxmesh = _pair_count(GeomType.BOX.value, GeomType.MESH.value)[0]
nmeshmesh = _pair_count(GeomType.MESH.value, GeomType.MESH.value)[0]

epa_iterations = 16 if nboxbox == ncollision else m.opt.ccd_iterations

Expand All @@ -1117,22 +1134,25 @@ def _pair_count(p1: int, p2: int) -> int:
nmaxpolygon = max(m.nmaxpolygon, minval)
nmaxmeshdeg = max(m.nmaxmeshdeg, 3)

# ccd collider count
nccd = wp.zeros(len(GeomType) * (len(GeomType) + 1) // 2, dtype=int)

# epa_vert1: vertices in EPA polytope in geom 1 space
epa_vert1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3)
epa_vert1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3)
# epa_vert2: vertices in EPA polytope in geom 2 space
epa_vert2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=wp.vec3)
epa_vert2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=wp.vec3)
# epa_vert_index1: vertex indices in EPA polytope for geom 1
epa_vert_index1 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int)
epa_vert_index1 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int)
# epa_vert_index2: vertex indices in EPA polytope for geom 2 (naconmax, 5 + CCDiter)
epa_vert_index2 = wp.empty(shape=(d.naconmax, 5 + epa_iterations), dtype=int)
epa_vert_index2 = wp.empty(shape=(d.naccdmax, 5 + epa_iterations), dtype=int)
# epa_face: faces of polytope represented by three indices
epa_face = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int)
epa_face = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=int)
# epa_pr: projection of origin on polytope faces
epa_pr = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3)
epa_pr = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3)
# epa_norm2: epa_pr * epa_pr
epa_norm2 = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float)
epa_norm2 = wp.empty(shape=(d.naccdmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=float)
# epa_horizon: index pair (i j) of edges on horizon
epa_horizon = wp.empty(shape=(d.naconmax, MJ_MAX_EPAHORIZON), dtype=int)
epa_horizon = wp.empty(shape=(d.naccdmax, MJ_MAX_EPAHORIZON), dtype=int)

# Contact outputs
contact_outputs = [
Expand All @@ -1156,9 +1176,10 @@ def _pair_count(p1: int, p2: int) -> int:
for geom_pair in collision_table:
g1 = geom_pair[0].value
g2 = geom_pair[1].value
if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and _pair_count(g1, g2):
count, geomgeomid = _pair_count(g1, g2)
if (g1 == GeomType.HFIELD or g2 == GeomType.HFIELD) and count:
wp.launch(
ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations),
ccd_hfield_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, geomgeomid),
dim=d.naconmax,
inputs=[
m.opt.ccd_tolerance,
Expand Down Expand Up @@ -1203,6 +1224,7 @@ def _pair_count(p1: int, p2: int) -> int:
d.geom_xpos,
d.geom_xmat,
d.naconmax,
d.naccdmax,
d.ncollision,
ctx.collision_pair,
ctx.collision_pairid,
Expand All @@ -1215,41 +1237,43 @@ def _pair_count(p1: int, p2: int) -> int:
epa_pr,
epa_norm2,
epa_horizon,
nccd,
],
outputs=contact_outputs,
)

# Allocate multiccd arrays only for non-heightfield collisions
# multiccd_polygon: clipped contact surface
multiccd_polygon = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3)
multiccd_polygon = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3)
# multiccd_clipped: clipped contact surface (intermediate)
multiccd_clipped = wp.empty(shape=(d.naconmax, 2 * nmaxpolygon), dtype=wp.vec3)
multiccd_clipped = wp.empty(shape=(d.naccdmax, 2 * nmaxpolygon), dtype=wp.vec3)
# multiccd_pnormal: plane normal of clipping polygon
multiccd_pnormal = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3)
multiccd_pnormal = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3)
# multiccd_pdist: plane distance of clipping polygon
multiccd_pdist = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=float)
multiccd_pdist = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=float)
# multiccd_idx1: list of normal index candidates for Geom 1
multiccd_idx1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int)
multiccd_idx1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int)
# multiccd_idx2: list of normal index candidates for Geom 2
multiccd_idx2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=int)
multiccd_idx2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=int)
# multiccd_n1: list of normal candidates for Geom 1
multiccd_n1 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3)
multiccd_n1 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3)
# multiccd_n2: list of normal candidates for Geom 1
multiccd_n2 = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3)
multiccd_n2 = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3)
# multiccd_endvert: list of edge vertices candidates
multiccd_endvert = wp.empty(shape=(d.naconmax, nmaxmeshdeg), dtype=wp.vec3)
multiccd_endvert = wp.empty(shape=(d.naccdmax, nmaxmeshdeg), dtype=wp.vec3)
# multiccd_face1: contact face
multiccd_face1 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3)
multiccd_face1 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3)
# multiccd_face2: contact face
multiccd_face2 = wp.empty(shape=(d.naconmax, nmaxpolygon), dtype=wp.vec3)
multiccd_face2 = wp.empty(shape=(d.naccdmax, nmaxpolygon), dtype=wp.vec3)

# Launch non-heightfield collision kernels (no hfield args, 78 args total)
for geom_pair in collision_table:
g1 = geom_pair[0].value
g2 = geom_pair[1].value
if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and _pair_count(g1, g2):
count, geomgeomid = _pair_count(g1, g2)
if g1 != GeomType.HFIELD and g2 != GeomType.HFIELD and count:
wp.launch(
ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd),
ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd, geomgeomid),
dim=d.naconmax,
inputs=[
m.opt.ccd_tolerance,
Expand Down Expand Up @@ -1288,6 +1312,7 @@ def _pair_count(p1: int, p2: int) -> int:
d.geom_xpos,
d.geom_xmat,
d.naconmax,
d.naccdmax,
d.ncollision,
ctx.collision_pair,
ctx.collision_pairid,
Expand All @@ -1311,6 +1336,7 @@ def _pair_count(p1: int, p2: int) -> int:
multiccd_endvert,
multiccd_face1,
multiccd_face2,
nccd,
],
outputs=contact_outputs,
)
Loading
Loading