Skip to content

Commit 64c628b

Browse files
fix: add input validation to cluster_conformations preventing hangs and UB
- Reject negative/zero cutoff_nm (previously caused infinite loop/segfault) - Reject non-square or non-2D rmsd_matrix (previously caused OOB reads in Numba) - Reject NaN/Inf values in rmsd_matrix (previously silently misclassified) - Document rmsd_matrix_angstrom per-call allocation cost - Document _gromos_loop in-place counts mutation and tie-breaking behavior - Add 8 new tests covering all validation paths and empty matrix edge case
1 parent 13a6d2b commit 64c628b

2 files changed

Lines changed: 69 additions & 1 deletion

File tree

src/mdpp/analysis/clustering.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@ class RMSDMatrixResult:
2929

3030
@property
3131
def rmsd_matrix_angstrom(self) -> NDArray[np.floating]:
32-
"""Return the RMSD matrix in Angstrom."""
32+
"""Return the RMSD matrix in Angstrom.
33+
34+
Note:
35+
Each access allocates a new ``(n_frames, n_frames)`` array.
36+
Cache the result in a local variable if you need it more
37+
than once -- at 120k frames this is ~54 GB per call.
38+
"""
3339
return self.rmsd_matrix_nm * 10.0
3440

3541

@@ -148,6 +154,14 @@ def _gromos_loop(
148154
whole loop runs in ~10 seconds on a modern CPU, versus ~100
149155
minutes for the original fully-recomputing Python implementation.
150156
157+
When multiple unassigned frames have the same neighbour count,
158+
the one with the smallest frame index is chosen as the cluster
159+
centre (deterministic tie-breaking).
160+
161+
Warning:
162+
``counts`` is **mutated in place** -- the caller must not
163+
reuse the array after this function returns.
164+
151165
Returns:
152166
``(labels, n_clusters, medoids)`` as three numpy arrays.
153167
"""
@@ -237,6 +251,12 @@ def cluster_conformations(
237251
"""
238252
if method != "gromos":
239253
raise ValueError(f"Unsupported clustering method: {method!r}. Use 'gromos'.")
254+
if cutoff_nm <= 0.0:
255+
raise ValueError(f"cutoff_nm must be positive, got {cutoff_nm!r}")
256+
if rmsd_matrix.ndim != 2 or rmsd_matrix.shape[0] != rmsd_matrix.shape[1]:
257+
raise ValueError(f"rmsd_matrix must be a square 2-D array, got shape {rmsd_matrix.shape}")
258+
if rmsd_matrix.size > 0 and not np.isfinite(rmsd_matrix).all():
259+
raise ValueError("rmsd_matrix contains NaN or Inf values")
240260

241261
# ``np.ascontiguousarray`` is a no-op for already-contiguous inputs
242262
# (which is what ``compute_rmsd_matrix`` returns) and avoids a

tests/analysis/test_clustering.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,54 @@ def test_all_frames_isolated(self) -> None:
318318
assert result.n_clusters == n
319319
assert len(np.unique(result.labels)) == n
320320

321+
def test_negative_cutoff_raises(self) -> None:
322+
"""A negative cutoff must raise ValueError, not hang."""
323+
rmsd = np.zeros((3, 3), dtype=np.float32)
324+
with pytest.raises(ValueError, match="cutoff_nm must be positive"):
325+
cluster_conformations(rmsd, cutoff_nm=-0.1)
326+
327+
def test_zero_cutoff_raises(self) -> None:
328+
"""A zero cutoff must raise ValueError."""
329+
rmsd = np.zeros((3, 3), dtype=np.float32)
330+
with pytest.raises(ValueError, match="cutoff_nm must be positive"):
331+
cluster_conformations(rmsd, cutoff_nm=0.0)
332+
333+
def test_non_square_matrix_raises(self) -> None:
334+
"""A non-square matrix must raise ValueError."""
335+
rmsd = np.zeros((3, 5), dtype=np.float32)
336+
with pytest.raises(ValueError, match="square 2-D array"):
337+
cluster_conformations(rmsd, cutoff_nm=0.1)
338+
339+
def test_1d_array_raises(self) -> None:
340+
"""A 1-D array must raise ValueError."""
341+
rmsd = np.zeros(10, dtype=np.float32)
342+
with pytest.raises(ValueError, match="square 2-D array"):
343+
cluster_conformations(rmsd, cutoff_nm=0.1)
344+
345+
def test_nan_in_matrix_raises(self) -> None:
346+
"""NaN values in the RMSD matrix must raise ValueError."""
347+
rmsd = np.zeros((4, 4), dtype=np.float32)
348+
rmsd[1, 2] = np.nan
349+
rmsd[2, 1] = np.nan
350+
with pytest.raises(ValueError, match="NaN or Inf"):
351+
cluster_conformations(rmsd, cutoff_nm=0.1)
352+
353+
def test_inf_in_matrix_raises(self) -> None:
354+
"""Inf values in the RMSD matrix must raise ValueError."""
355+
rmsd = np.zeros((4, 4), dtype=np.float32)
356+
rmsd[0, 3] = np.inf
357+
rmsd[3, 0] = np.inf
358+
with pytest.raises(ValueError, match="NaN or Inf"):
359+
cluster_conformations(rmsd, cutoff_nm=0.1)
360+
361+
def test_empty_matrix(self) -> None:
362+
"""A 0x0 matrix should produce zero clusters."""
363+
rmsd = np.zeros((0, 0), dtype=np.float32)
364+
result = cluster_conformations(rmsd, cutoff_nm=0.1)
365+
assert result.n_clusters == 0
366+
assert len(result.labels) == 0
367+
assert len(result.medoid_frames) == 0
368+
321369

322370
# ---------------------------------------------------------------------------
323371
# Wrapper dtype / memory tests: verify compute_rmsd_matrix does no redundant copy

0 commit comments

Comments
 (0)