Skip to content

Commit

Permalink
Minor clean up to tests (#1099)
Browse files Browse the repository at this point in the history
* Cleans up test logic

* Use the cutoff/beta/padding specified by the potential
* Removes references to du_dl
* Sitches to using np.testing.assert_allclose to have similar style of
error
* Fix up typo in docstring
* Avoid duplicate work in test

* Removes benchmark test that is always skipped

* Prefer the `tests/test_benchmarks.py` for benchmarking
  • Loading branch information
badisa authored Jul 26, 2023
1 parent 5b046f9 commit c4369bd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 56 deletions.
10 changes: 4 additions & 6 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import jax
import numpy as np
import pytest
from hilbertcurve.hilbertcurve import HilbertCurve
from numpy.typing import NDArray
from rdkit import Chem
Expand Down Expand Up @@ -262,7 +261,6 @@ def compare_forces(
test_du_dx_2, test_du_dp_2, test_u_2 = test_potential.unbound_impl.execute_selective(
x, params, box, compute_du_dx, compute_du_dp, compute_u
)

np.testing.assert_array_equal(test_du_dx, test_du_dx_2)
np.testing.assert_array_equal(test_u, test_u_2)
np.testing.assert_array_equal(test_du_dp, test_du_dp_2)
Expand Down Expand Up @@ -447,7 +445,7 @@ def check_split_ixns(
# Should be the same as the new code with the orig ff
sum_grad_new, sum_u_new = compute_new_grad_u(ffs.ref, precision, coords0, box, lamb, num_water_atoms, host_bps)

assert sum_u_ref == pytest.approx(sum_u_new, rel=rtol, abs=atol)
np.testing.assert_allclose(sum_u_ref, sum_u_new, rtol=rtol, atol=atol)
np.testing.assert_allclose(sum_grad_ref, sum_grad_new, rtol=rtol, atol=atol)

# Compute the grads, potential with the intramolecular terms scaled
Expand All @@ -462,7 +460,7 @@ def check_split_ixns(
expected_u = sum_u_ref - LL_u_ref + LL_u_intra
expected_grad = sum_grad_ref - LL_grad_ref + LL_grad_intra

assert expected_u == pytest.approx(sum_u_intra, rel=rtol, abs=atol)
np.testing.assert_allclose(expected_u, sum_u_intra, rtol=rtol, atol=atol)
np.testing.assert_allclose(expected_grad, sum_grad_intra, rtol=rtol, atol=atol)

# Compute the grads, potential with the ligand-water terms scaled
Expand All @@ -487,7 +485,7 @@ def check_split_ixns(
expected_u = sum_u_ref - WL_u_ref + WL_u_solv
expected_grad = sum_grad_ref - WL_grad_ref + WL_grad_solv

assert expected_u == pytest.approx(sum_u_solv, rel=rtol, abs=atol)
np.testing.assert_allclose(expected_u, sum_u_solv, rtol=rtol, atol=atol)
np.testing.assert_allclose(expected_grad, sum_grad_solv, rtol=rtol, atol=atol)

# Compute the grads, potential with the protein-ligand terms scaled
Expand All @@ -512,5 +510,5 @@ def check_split_ixns(
expected_u = sum_u_ref - PL_u_ref + PL_u_prot
expected_grad = sum_grad_ref - PL_grad_ref + PL_grad_prot

assert expected_u == pytest.approx(sum_u_prot, rel=rtol, abs=atol)
np.testing.assert_allclose(expected_u, sum_u_prot, rtol=rtol, atol=atol)
np.testing.assert_allclose(expected_grad, sum_grad_prot, rtol=rtol, atol=atol)
5 changes: 2 additions & 3 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def test_set_and_get():

def test_fwd_mode():
"""
This test ensures that we can reverse-mode differentiate
observables that are dU_dlambdas of each state. We provide
adjoints with respect to each computed dU/dLambda.
This test verifies that stepping forward in time matches whether using the
reference or the GPU platform.
"""

np.random.seed(4321)
Expand Down
64 changes: 17 additions & 47 deletions tests/test_nonbonded.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# (ytz): check test and run benchmark with pytest:
# pytest -xsv tests/test_nonbonded.py::TestNonbonded::test_dhfr && nvprof pytest -xsv tests/test_nonbonded.py::TestNonbonded::test_benchmark
import copy
import itertools
import unittest
from dataclasses import replace
from typing import cast
Expand Down Expand Up @@ -30,8 +26,8 @@ def setUp(self):
self.nonbonded_fn = cast(potentials.Nonbonded, nonbonded_bp.potential)
self.nonbonded_params = nonbonded_bp.params
self.host_conf = host_coords
self.beta = 2.0
self.cutoff = 1.1
self.beta = self.nonbonded_fn.beta
self.cutoff = self.nonbonded_fn.cutoff

def test_nblist_hilbert(self):
"""
Expand All @@ -48,7 +44,7 @@ def test_nblist_hilbert(self):
ref_nonbonded_impl = replace(self.nonbonded_fn, disable_hilbert_sort=True).to_gpu(precision).unbound_impl
test_nonbonded_impl = self.nonbonded_fn.to_gpu(precision).unbound_impl

padding = 0.1
padding = self.nonbonded_fn.nblist_padding
deltas = np.random.rand(N, 3) - 0.5 # [-0.5, +0.5]
divisor = 0.5 * (2 * np.sqrt(3)) / padding
# if deltas are kept under +- p/(2*sqrt(3)) then no rebuild gets triggered
Expand Down Expand Up @@ -83,7 +79,7 @@ def test_nblist_hilbert(self):

def test_nblist_rebuild(self):
"""
This test makes sure that periodically rebuilding the neighborlist no impact on numerical results. The
This test makes sure that periodically rebuilding the neighborlist has no impact on numerical results. The
computed forces, energies, etc. should be bitwise identical.
"""

Expand Down Expand Up @@ -139,20 +135,21 @@ def test_correctness(self):

rng = np.random.default_rng(2022)

for atom_idxs in [None, np.array(rng.choice(N, N // 2, replace=False), dtype=np.int32)]:
test_conf = self.host_conf[:N]

# strip out parts of the system
test_exclusions = []
test_scales = []
for (i, j), (sa, sb) in zip(self.nonbonded_fn.exclusion_idxs, self.nonbonded_fn.scale_factors):
if i < N and j < N:
test_exclusions.append((i, j))
test_scales.append((sa, sb))

test_conf = self.host_conf[:N]
test_exclusions = np.array(test_exclusions, dtype=np.int32)
test_scales = np.array(test_scales, dtype=np.float64)
test_params = self.nonbonded_params[:N, :]

# strip out parts of the system
test_exclusions = []
test_scales = []
for (i, j), (sa, sb) in zip(self.nonbonded_fn.exclusion_idxs, self.nonbonded_fn.scale_factors):
if i < N and j < N:
test_exclusions.append((i, j))
test_scales.append((sa, sb))
test_exclusions = np.array(test_exclusions, dtype=np.int32)
test_scales = np.array(test_scales, dtype=np.float64)
test_params = self.nonbonded_params[:N, :]
for atom_idxs in [None, np.array(rng.choice(N, N // 2, replace=False), dtype=np.int32)]:

potential = potentials.Nonbonded(
N, test_exclusions, test_scales, self.beta, self.cutoff, atom_idxs=atom_idxs
Expand All @@ -163,33 +160,6 @@ def test_correctness(self):
test_conf, test_params, self.box, potential, potential.to_gpu(precision), rtol=rtol, atol=atol
)

@unittest.skip("benchmark-only")
def test_benchmark(self):
"""
This is mainly for benchmarking nonbonded computations on the initial state.
"""

precision = np.float32

nb_fn = copy.deepcopy(self.nonbonded_fn)

impl = nb_fn.to_gpu(precision).unbound_impl

for combo in itertools.product([False, True], repeat=4):

compute_du_dx, compute_du_dp, compute_u = combo

for trip in range(50):

test_du_dx, test_du_dp, test_u = impl.execute_selective(
self.host_conf,
[self.nonbonded_params],
self.box,
compute_du_dx,
compute_du_dp,
compute_u,
)


class TestNonbondedWater(GradientTest):
def test_nblist_box_resize(self):
Expand Down

0 comments on commit c4369bd

Please sign in to comment.