Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions newton/_src/sensors/sensor_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,6 @@ def __init__(
c_shapes = [self.ObjectType.TOTAL]
c_bodies = []

contact_pairs = (
set(map(tuple, model.shape_contact_pairs.list()))
if getattr(model, "shape_contact_pairs", None) is not None
else None
)

TOTAL = self.ObjectType.TOTAL
wc = model.world_count
shape_ws = model.shape_world_start.list()
Expand All @@ -332,13 +326,65 @@ def bucket(indices, ws):
c_body_b, c_body_g = bucket(c_bodies, body_ws)
c_shape_b, c_shape_g = bucket(c_shapes, shape_ws)

_homogeneous = wc > 1 and (
len(set(body_ws[i + 1] - body_ws[i] for i in range(wc))) == 1
and len(set(shape_ws[i + 1] - shape_ws[i] for i in range(wc))) == 1
and all(len(s_body_b[w]) == len(s_body_b[0]) for w in range(1, wc))
and all(len(s_shape_b[w]) == len(s_shape_b[0]) for w in range(1, wc))
and all(len(c_body_b[w]) == len(c_body_b[0]) for w in range(1, wc))
and all(len(c_shape_b[w]) == len(c_shape_b[0]) for w in range(1, wc))
)

if getattr(model, "shape_contact_pairs", None) is not None:
if _homogeneous and wc > 1:
import numpy as np
cp_np = model.shape_contact_pairs.numpy()
lo, hi = shape_ws[0], shape_ws[1]
mask = (cp_np[:, 0] >= lo) & (cp_np[:, 0] < hi) & (cp_np[:, 1] >= lo) & (cp_np[:, 1] < hi)
contact_pairs = set(map(tuple, cp_np[mask].tolist()))
else:
contact_pairs = set(map(tuple, model.shape_contact_pairs.numpy()))
else:
contact_pairs = None

per_world_results = []
for w in range(wc):
wb = ([TOTAL] if TOTAL in c_bodies else []) + c_body_g + c_body_b[w]
wsh = ([TOTAL] if TOTAL in c_shapes else []) + c_shape_g + c_shape_b[w]
per_world_results.append(
self._assemble_sensor_mappings(s_body_b[w], s_shape_b[w], wb, wsh, model.body_shapes, contact_pairs)
if _homogeneous:
wb0 = ([TOTAL] if TOTAL in c_bodies else []) + c_body_g + c_body_b[0]
wsh0 = ([TOTAL] if TOTAL in c_shapes else []) + c_shape_g + c_shape_b[0]
r0 = self._assemble_sensor_mappings(
s_body_b[0], s_shape_b[0], wb0, wsh0, model.body_shapes, contact_pairs
)
sp0, raw0, n_readings0, counter_idx0, so_kinds0, cp_kinds0 = r0

for w in range(wc):
b_off = body_ws[w] - body_ws[0]
s_off = shape_ws[w] - shape_ws[0]

sp_w = [(a + s_off if a >= 0 else a, b + s_off if b >= 0 else b) for a, b in sp0]

so_w = []
for idx, kind in so_kinds0:
off = b_off if kind == self.ObjectType.BODY else s_off
so_w.append((idx + off, kind))

cp_w = []
for idx, kind in cp_kinds0:
if kind == TOTAL:
cp_w.append((idx, kind))
else:
off = b_off if kind == self.ObjectType.BODY else s_off
cp_w.append((idx + off, kind))

per_world_results.append((sp_w, raw0, n_readings0, counter_idx0, so_w, cp_w))
else:
for w in range(wc):
wb = ([TOTAL] if TOTAL in c_bodies else []) + c_body_g + c_body_b[w]
wsh = ([TOTAL] if TOTAL in c_shapes else []) + c_shape_g + c_shape_b[w]
per_world_results.append(
self._assemble_sensor_mappings(
s_body_b[w], s_shape_b[w], wb, wsh, model.body_shapes, contact_pairs
)
)

max_r = max((r[2] for r in per_world_results), default=0)
self.sensing_objs = [r[4] for r in per_world_results]
Expand Down
Loading