diff --git a/newton/_src/sensors/sensor_contact.py b/newton/_src/sensors/sensor_contact.py index 942e004967..4c22b0a670 100644 --- a/newton/_src/sensors/sensor_contact.py +++ b/newton/_src/sensors/sensor_contact.py @@ -304,12 +304,6 @@ def __init__( c_shapes = [self.ObjectType.TOTAL] c_bodies = [] - contact_pairs = ( - set(map(tuple, model.shape_contact_pairs.list())) - if getattr(model, "shape_contact_pairs", None) is not None - else None - ) - TOTAL = self.ObjectType.TOTAL wc = model.world_count shape_ws = model.shape_world_start.list() @@ -332,13 +326,65 @@ def bucket(indices, ws): c_body_b, c_body_g = bucket(c_bodies, body_ws) c_shape_b, c_shape_g = bucket(c_shapes, shape_ws) + _homogeneous = wc > 1 and ( + len(set(body_ws[i + 1] - body_ws[i] for i in range(wc))) == 1 + and len(set(shape_ws[i + 1] - shape_ws[i] for i in range(wc))) == 1 + and all(len(s_body_b[w]) == len(s_body_b[0]) for w in range(1, wc)) + and all(len(s_shape_b[w]) == len(s_shape_b[0]) for w in range(1, wc)) + and all(len(c_body_b[w]) == len(c_body_b[0]) for w in range(1, wc)) + and all(len(c_shape_b[w]) == len(c_shape_b[0]) for w in range(1, wc)) + ) + + if getattr(model, "shape_contact_pairs", None) is not None: + if _homogeneous and wc > 1: + import numpy as np + cp_np = model.shape_contact_pairs.numpy() + lo, hi = shape_ws[0], shape_ws[1] + mask = (cp_np[:, 0] >= lo) & (cp_np[:, 0] < hi) & (cp_np[:, 1] >= lo) & (cp_np[:, 1] < hi) + contact_pairs = set(map(tuple, cp_np[mask].tolist())) + else: + contact_pairs = set(map(tuple, model.shape_contact_pairs.numpy())) + else: + contact_pairs = None + per_world_results = [] - for w in range(wc): - wb = ([TOTAL] if TOTAL in c_bodies else []) + c_body_g + c_body_b[w] - wsh = ([TOTAL] if TOTAL in c_shapes else []) + c_shape_g + c_shape_b[w] - per_world_results.append( - self._assemble_sensor_mappings(s_body_b[w], s_shape_b[w], wb, wsh, model.body_shapes, contact_pairs) + if _homogeneous: + wb0 = ([TOTAL] if TOTAL in c_bodies else []) + c_body_g + c_body_b[0] + wsh0 = ([TOTAL] if TOTAL in c_shapes else []) + c_shape_g + c_shape_b[0] + r0 = self._assemble_sensor_mappings( + s_body_b[0], s_shape_b[0], wb0, wsh0, model.body_shapes, contact_pairs ) + sp0, raw0, n_readings0, counter_idx0, so_kinds0, cp_kinds0 = r0 + + for w in range(wc): + b_off = body_ws[w] - body_ws[0] + s_off = shape_ws[w] - shape_ws[0] + + sp_w = [(a + s_off if a >= 0 else a, b + s_off if b >= 0 else b) for a, b in sp0] + + so_w = [] + for idx, kind in so_kinds0: + off = b_off if kind == self.ObjectType.BODY else s_off + so_w.append((idx + off, kind)) + + cp_w = [] + for idx, kind in cp_kinds0: + if kind == TOTAL: + cp_w.append((idx, kind)) + else: + off = b_off if kind == self.ObjectType.BODY else s_off + cp_w.append((idx + off, kind)) + + per_world_results.append((sp_w, raw0, n_readings0, counter_idx0, so_w, cp_w)) + else: + for w in range(wc): + wb = ([TOTAL] if TOTAL in c_bodies else []) + c_body_g + c_body_b[w] + wsh = ([TOTAL] if TOTAL in c_shapes else []) + c_shape_g + c_shape_b[w] + per_world_results.append( + self._assemble_sensor_mappings( + s_body_b[w], s_shape_b[w], wb, wsh, model.body_shapes, contact_pairs + ) + ) max_r = max((r[2] for r in per_world_results), default=0) self.sensing_objs = [r[4] for r in per_world_results]