diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 8192cc007..4c0f03392 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -24,6 +24,7 @@ from .collision_primitive import geom_collision_pair from .collision_primitive import write_contact from .math import make_frame +from .math import safe_normalize from .math import upper_trid_index from .types import MJ_MAX_EPAFACES from .types import MJ_MAX_EPAHORIZON @@ -333,9 +334,9 @@ def eval_ccd_write_contact( geomtype2, ) - for i in range(ncontact): - points[i] = 0.5 * (witness1[i] + witness2[i]) - normal = witness1[0] - witness2[0] + normal, is_safe = safe_normalize(witness1[0] - witness2[0]) + if not is_safe: + return 0 frame = make_frame(normal) # flip if collision sensor @@ -348,7 +349,7 @@ def eval_ccd_write_contact( naconmax_in, i, dist, - points[i], + 0.5 * (witness1[i] + witness2[i]), frame, margin, gap, @@ -678,7 +679,11 @@ def ccd_kernel( hfield_contact_pos[count, 1] = pos[1] hfield_contact_pos[count, 2] = pos[2] - frame = make_frame(w1 - w2) + normal, is_safe = safe_normalize(w1 - w2) + if not is_safe: + continue + + frame = make_frame(normal) normal = wp.vec3(frame[0, 0], frame[0, 1], frame[0, 2]) hfield_contact_normal[count, 0] = normal[0] hfield_contact_normal[count, 1] = normal[1] diff --git a/mujoco_warp/_src/math.py b/mujoco_warp/_src/math.py index 54e3e4749..b865ea2df 100644 --- a/mujoco_warp/_src/math.py +++ b/mujoco_warp/_src/math.py @@ -184,7 +184,7 @@ def orthonormal(normal: wp.vec3) -> wp.vec3: dir = wp.vec3(-normal[1] * normal[0], 1.0 - normal[1] * normal[1], -normal[1] * normal[2]) else: dir = wp.vec3(-normal[2] * normal[0], -normal[2] * normal[1], 1.0 - normal[2] * normal[2]) - dir, _ = gjk_normalize(dir) + dir, _ = safe_normalize(dir) return dir @@ -194,12 +194,12 @@ def orthonormal_to_z(normal: wp.vec3) -> wp.vec3: dir = wp.vec3(1.0 - normal[0] * normal[0], -normal[0] * normal[1], -normal[0] * normal[2]) else: dir = wp.vec3(-normal[1] * normal[0], 1.0 - normal[1] * normal[1], -normal[1] * normal[2]) - dir, _ = gjk_normalize(dir) + dir, _ = safe_normalize(dir) return dir @wp.func -def gjk_normalize(a: wp.vec3): +def safe_normalize(a: wp.vec3): norm = wp.length(a) if norm > 1e-8 and norm < 1e12: return a / norm, True @@ -207,13 +207,13 @@ def gjk_normalize(a: wp.vec3): @wp.func -def make_frame(a: wp.vec3): - a = wp.normalize(a) - b, c = orthogonals(a) +def make_frame(n: wp.vec3): + """Returns frame given a normal vector.""" + b, c = orthogonals(n) # fmt: off return wp.mat33( - a.x, a.y, a.z, + n.x, n.y, n.z, b.x, b.y, b.z, c.x, c.y, c.z )