diff --git a/tests/test_nonbonded_interaction_group.py b/tests/test_nonbonded_interaction_group.py new file mode 100644 index 000000000..a786c2c20 --- /dev/null +++ b/tests/test_nonbonded_interaction_group.py @@ -0,0 +1,410 @@ +import jax + +jax.config.update("jax_enable_x64", True) + +import numpy as np +import pytest +from common import GradientTest, prepare_reference_nonbonded +from simtk.openmm import app + +from timemachine.fe.utils import to_md_units +from timemachine.ff.handlers import openmm_deserializer +from timemachine.lib import potentials +from timemachine.lib.potentials import NonbondedInteractionGroup, NonbondedInteractionGroupInterpolated +from timemachine.potentials import jax_utils, nonbonded + + +@pytest.fixture(autouse=True) +def set_random_seed(): + np.random.seed(2022) + yield + + +@pytest.fixture() +def rng(): + return np.random.default_rng(2022) + + +@pytest.fixture +def example_system(): + pdb_path = "tests/data/5dfr_solv_equil.pdb" + host_pdb = app.PDBFile(pdb_path) + ff = app.ForceField("amber99sbildn.xml", "tip3p.xml") + return ( + ff.createSystem(host_pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False), + host_pdb.positions, + host_pdb.topology.getPeriodicBoxVectors(), + ) + + +@pytest.fixture +def example_nonbonded_params(example_system): + host_system, _, _ = example_system + host_fns, _ = openmm_deserializer.deserialize_system(host_system, cutoff=1.0) + + nonbonded_fn = None + for f in host_fns: + if isinstance(f, potentials.Nonbonded): + nonbonded_fn = f + + assert nonbonded_fn is not None + return nonbonded_fn.params + + +@pytest.fixture +def example_conf(example_system): + _, host_conf, _ = example_system + return np.array([[to_md_units(x), to_md_units(y), to_md_units(z)] for x, y, z in host_conf]) + + +@pytest.fixture +def example_box(example_system): + _, _, box = example_system + return np.asarray(box / box.unit) + + +def test_nonbonded_interaction_group_invalid_indices(): + def make_potential(ligand_idxs, num_atoms): + lambda_plane_idxs = [0] * num_atoms + lambda_offset_idxs = [0] * num_atoms + return NonbondedInteractionGroup(ligand_idxs, lambda_plane_idxs, lambda_offset_idxs, 1.0, 1.0).unbound_impl( + np.float64 + ) + + with pytest.raises(RuntimeError) as e: + make_potential([], 1) + assert "row_atom_idxs must be nonempty" in str(e) + + with pytest.raises(RuntimeError) as e: + make_potential([1, 1], 3) + assert "atom indices must be unique" in str(e) + + +def test_nonbonded_interaction_group_zero_interactions(rng: np.random.Generator): + num_atoms = 33 + num_atoms_ligand = 15 + beta = 2.0 + lamb = 0.1 + cutoff = 1.1 + box = 10.0 * np.eye(3) + conf = rng.uniform(0, 1, size=(num_atoms, 3)) + ligand_idxs = rng.choice(num_atoms, size=(num_atoms_ligand,), replace=False).astype(np.int32) + + # shift ligand atoms in x by twice the cutoff + conf[ligand_idxs, 0] += 2 * cutoff + + params = rng.uniform(0, 1, size=(num_atoms, 3)) + + potential = NonbondedInteractionGroup( + ligand_idxs, + np.zeros(num_atoms, dtype=np.int32), + np.zeros(num_atoms, dtype=np.int32), + beta, + cutoff, + ) + + du_dx, du_dp, du_dl, u = potential.unbound_impl(np.float64).execute(conf, params, box, lamb) + + assert (du_dx == 0).all() + assert (du_dp == 0).all() + assert du_dl == 0 + assert u == 0 + + +@pytest.mark.parametrize("lamb", [0.0, 0.1]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms_ligand", [1, 15]) +@pytest.mark.parametrize("num_atoms", [33, 231]) +def test_nonbonded_interaction_group_correctness( + num_atoms, + num_atoms_ligand, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng, +): + "Compares with jax reference implementation." + + conf = example_conf[:num_atoms] + params = example_nonbonded_params[:num_atoms, :] + + lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + + ligand_idxs = rng.choice(num_atoms, size=(num_atoms_ligand,), replace=False).astype(np.int32) + host_idxs = np.setdiff1d(np.arange(num_atoms), ligand_idxs) + + def ref_ixngroups(conf, params, box, lamb): + + # compute 4d coordinates + w = jax_utils.compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) + conf_4d = jax_utils.augment_dim(conf, w) + box_4d = (1000 * jax.numpy.eye(4)).at[:3, :3].set(box) + + vdW, electrostatics, _ = nonbonded.nonbonded_v3_interaction_groups( + conf_4d, params, box_4d, ligand_idxs, host_idxs, beta, cutoff + ) + return jax.numpy.sum(vdW + electrostatics) + + test_ixngroups = NonbondedInteractionGroup( + ligand_idxs, + lambda_plane_idxs, + lambda_offset_idxs, + beta, + cutoff, + ) + + GradientTest().compare_forces( + conf, + params, + example_box, + lamb=lamb, + ref_potential=ref_ixngroups, + test_potential=test_ixngroups, + rtol=rtol, + atol=atol, + precision=precision, + ) + + +@pytest.mark.parametrize("lamb", [0.0, 0.1, 0.9, 1.0]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms_ligand", [1, 15]) +@pytest.mark.parametrize("num_atoms", [33]) +def test_nonbonded_interaction_group_interpolated_correctness( + num_atoms, + num_atoms_ligand, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng, +): + "Compares with jax reference implementation, with parameter interpolation." + + conf = example_conf[:num_atoms] + params_initial = example_nonbonded_params[:num_atoms, :] + params_final = params_initial + rng.normal(0, 0.01, size=params_initial.shape) + params = np.concatenate((params_initial, params_final)) + + lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + + ligand_idxs = rng.choice(num_atoms, size=(num_atoms_ligand,), replace=False).astype(np.int32) + host_idxs = np.setdiff1d(np.arange(num_atoms), ligand_idxs) + + @nonbonded.interpolated + def ref_ixngroups(conf, params, box, lamb): + + # compute 4d coordinates + w = jax_utils.compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) + conf_4d = jax_utils.augment_dim(conf, w) + box_4d = (1000 * jax.numpy.eye(4)).at[:3, :3].set(box) + + vdW, electrostatics, _ = nonbonded.nonbonded_v3_interaction_groups( + conf_4d, params, box_4d, ligand_idxs, host_idxs, beta, cutoff + ) + return jax.numpy.sum(vdW + electrostatics) + + test_ixngroups = NonbondedInteractionGroupInterpolated( + ligand_idxs, + lambda_plane_idxs, + lambda_offset_idxs, + beta, + cutoff, + ) + + GradientTest().compare_forces( + conf, + params, + example_box, + lamb=lamb, + ref_potential=ref_ixngroups, + test_potential=test_ixngroups, + rtol=rtol, + atol=atol, + precision=precision, + ) + + +@pytest.mark.parametrize("lamb", [0.0, 0.1]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms_ligand", [1, 15]) +@pytest.mark.parametrize("num_atoms", [33, 231, 1050]) +def test_nonbonded_interaction_group_consistency_allpairs_lambda_planes( + num_atoms, + num_atoms_ligand, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng: np.random.Generator, +): + """Compares with reference nonbonded_v3 potential, which computes + the sum of all pairwise interactions. This uses the identity + + U = U_A + U_B + U_AB + + where + - U is the all-pairs potential over all atoms + - U_A, U_B are all-pairs potentials for interacting groups A and + B, respectively + - U_AB is the "interaction group" potential, i.e. the sum of + pairwise interactions (a, b) where "a" is in A and "b" is in B + + U is computed using the reference potential over all atoms, and + U_A + U_B computed using the reference potential over all atoms + separated into 2 lambda planes according to which interacting + group they belong + """ + + conf = example_conf[:num_atoms] + params = example_nonbonded_params[:num_atoms, :] + + max_abs_offset_idx = 2 # i.e., lambda_offset_idxs in {-2, -1, 0, 1, 2} + lambda_offset_idxs = rng.integers(-max_abs_offset_idx, max_abs_offset_idx + 1, size=(num_atoms,), dtype=np.int32) + + def make_reference_nonbonded(lambda_plane_idxs): + return prepare_reference_nonbonded( + params=params, + exclusion_idxs=np.array([], dtype=np.int32), + scales=np.zeros((0, 2), dtype=np.float64), + lambda_plane_idxs=lambda_plane_idxs, + lambda_offset_idxs=lambda_offset_idxs, + beta=beta, + cutoff=cutoff, + ) + + ref_allpairs = make_reference_nonbonded(np.zeros(num_atoms, dtype=np.int32)) + + ligand_idxs = rng.choice(num_atoms, size=(num_atoms_ligand,), replace=False).astype(np.int32) + + # for reference U_A + U_B computation, ensure minimum distance + # between a host and ligand atom is at least one cutoff distance + # when lambda = 1 + lambda_plane_idxs = np.zeros(num_atoms, dtype=np.int32) + lambda_plane_idxs[ligand_idxs] = 2 * max_abs_offset_idx + 1 + + ref_allpairs_minus_ixngroups = make_reference_nonbonded(lambda_plane_idxs) + + def ref_ixngroups(*args): + return ref_allpairs(*args) - ref_allpairs_minus_ixngroups(*args) + + test_ixngroups = NonbondedInteractionGroup( + ligand_idxs, + np.zeros(num_atoms, dtype=np.int32), # lambda plane indices + lambda_offset_idxs, + beta, + cutoff, + ) + + GradientTest().compare_forces( + conf, + params, + example_box, + lamb=lamb, + ref_potential=ref_ixngroups, + test_potential=test_ixngroups, + rtol=rtol, + atol=atol, + precision=precision, + ) + + +@pytest.mark.parametrize("lamb", [0.0, 0.1]) +@pytest.mark.parametrize("beta", [2.0]) +@pytest.mark.parametrize("cutoff", [1.1]) +@pytest.mark.parametrize("precision,rtol,atol", [(np.float64, 1e-8, 1e-8), (np.float32, 1e-4, 5e-4)]) +@pytest.mark.parametrize("num_atoms_ligand", [1, 15]) +@pytest.mark.parametrize("num_atoms", [33, 231]) +def test_nonbonded_interaction_group_consistency_allpairs_constant_shift( + num_atoms, + num_atoms_ligand, + precision, + rtol, + atol, + cutoff, + beta, + lamb, + example_nonbonded_params, + example_conf, + example_box, + rng: np.random.Generator, +): + """Compares with reference nonbonded_v3 potential, which computes + the sum of all pairwise interactions. This uses the identity + + U(x') - U(x) = U_AB(x') - U_AB(x) + + where + - U is the all-pairs potential over all atoms + - U_A, U_B are all-pairs potentials for interacting groups A and + B, respectively + - U_AB is the "interaction group" potential, i.e. the sum of + pairwise interactions (a, b) where "a" is in A and "b" is in B + - the transformation x -> x' does not affect U_A or U_B (e.g. a + constant translation applied to each atom in one group) + """ + + conf = example_conf[:num_atoms] + params = example_nonbonded_params[:num_atoms, :] + + lambda_plane_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + lambda_offset_idxs = rng.integers(-2, 3, size=(num_atoms,), dtype=np.int32) + + def ref_allpairs(conf): + return prepare_reference_nonbonded( + params=params, + exclusion_idxs=np.array([], dtype=np.int32), + scales=np.zeros((0, 2), dtype=np.float64), + lambda_plane_idxs=lambda_plane_idxs, + lambda_offset_idxs=lambda_offset_idxs, + beta=beta, + cutoff=cutoff, + )(conf, params, example_box, lamb) + + ligand_idxs = rng.choice(num_atoms, size=(num_atoms_ligand,), replace=False).astype(np.int32) + + def test_ixngroups(conf): + _, _, _, u = ( + NonbondedInteractionGroup( + ligand_idxs, + lambda_plane_idxs, + lambda_offset_idxs, + beta, + cutoff, + ) + .unbound_impl(precision) + .execute(conf, params, example_box, lamb) + ) + return u + + conf_prime = np.array(conf) + conf_prime[ligand_idxs] += rng.normal(0, 0.01, size=(3,)) + + ref_delta = ref_allpairs(conf_prime) - ref_allpairs(conf) + test_delta = test_ixngroups(conf_prime) - test_ixngroups(conf) + + np.testing.assert_allclose(ref_delta, test_delta, rtol=rtol, atol=atol) diff --git a/tests/test_parameter_interpolation.py b/tests/test_parameter_interpolation.py index 099a5fd5d..4605b2f01 100644 --- a/tests/test_parameter_interpolation.py +++ b/tests/test_parameter_interpolation.py @@ -2,22 +2,13 @@ config.update("jax_enable_x64", True) import copy -import functools import jax.numpy as jnp import numpy as np from common import GradientTest, prepare_water_system from timemachine.lib import potentials - - -def interpolated_potential(conf, params, box, lamb, u_fn): - assert params.size % 2 == 0 - - CP = params.shape[0] // 2 - new_params = (1 - lamb) * params[:CP] + lamb * params[CP:] - - return u_fn(conf, new_params, box, lamb) +from timemachine.potentials import nonbonded class TestInterpolatedPotential(GradientTest): @@ -57,7 +48,7 @@ def test_nonbonded(self): print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape) - ref_interpolated_potential = functools.partial(interpolated_potential, u_fn=ref_potential) + ref_interpolated_potential = nonbonded.interpolated(ref_potential) test_interpolated_potential = potentials.NonbondedInterpolated(*test_potential.args) diff --git a/timemachine/cpp/CMakeLists.txt b/timemachine/cpp/CMakeLists.txt index 016b3301f..9a227fd29 100644 --- a/timemachine/cpp/CMakeLists.txt +++ b/timemachine/cpp/CMakeLists.txt @@ -42,8 +42,9 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS src/gpu_utils.cu src/vendored/hilbert.cpp src/nonbonded.cu - src/nonbonded_dense.cu - src/nonbonded_pairs.cu + src/nonbonded_all_pairs.cu + src/nonbonded_pair_list.cu + src/nonbonded_interaction_group.cu src/neighborlist.cu src/harmonic_bond.cu src/harmonic_angle.cu @@ -54,6 +55,7 @@ pybind11_add_module(${LIBRARY_NAME} SHARED NO_EXTRAS src/rmsd_align.cpp src/summed_potential.cu src/device_buffer.cu + src/kernels/k_nonbonded.cu src/kernels/nonbonded_common.cu ) diff --git a/timemachine/cpp/src/kernels/k_nonbonded.cu b/timemachine/cpp/src/kernels/k_nonbonded.cu new file mode 100644 index 000000000..719466555 --- /dev/null +++ b/timemachine/cpp/src/kernels/k_nonbonded.cu @@ -0,0 +1,91 @@ +#include "k_nonbonded.cuh" + +void __global__ k_coords_to_kv( + const int N, + const double *coords, + const double *box, + const unsigned int *bin_to_idx, + unsigned int *keys, + unsigned int *vals) { + + const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (atom_idx >= N) { + return; + } + + // these coords have to be centered + double bx = box[0 * 3 + 0]; + double by = box[1 * 3 + 1]; + double bz = box[2 * 3 + 2]; + + double binWidth = max(max(bx, by), bz) / 255.0; + + double x = coords[atom_idx * 3 + 0]; + double y = coords[atom_idx * 3 + 1]; + double z = coords[atom_idx * 3 + 2]; + + x -= bx * floor(x / bx); + y -= by * floor(y / by); + z -= bz * floor(z / bz); + + unsigned int bin_x = x / binWidth; + unsigned int bin_y = y / binWidth; + unsigned int bin_z = z / binWidth; + + keys[atom_idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z]; + // uncomment below if you want to preserve the atom ordering + // keys[atom_idx] = atom_idx; + vals[atom_idx] = atom_idx; +} + +// TODO: DRY with k_coords_to_kv +void __global__ k_coords_to_kv_gather( + const int N, + const unsigned int *atom_idxs, + const double *coords, + const double *box, + const unsigned int *bin_to_idx, + unsigned int *keys, + unsigned int *vals) { + + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx >= N) { + return; + } + + const int atom_idx = atom_idxs[idx]; + + // these coords have to be centered + double bx = box[0 * 3 + 0]; + double by = box[1 * 3 + 1]; + double bz = box[2 * 3 + 2]; + + double binWidth = max(max(bx, by), bz) / 255.0; + + double x = coords[atom_idx * 3 + 0]; + double y = coords[atom_idx * 3 + 1]; + double z = coords[atom_idx * 3 + 2]; + + x -= bx * floor(x / bx); + y -= by * floor(y / by); + z -= bz * floor(z / bz); + + unsigned int bin_x = x / binWidth; + unsigned int bin_y = y / binWidth; + unsigned int bin_z = z / binWidth; + + keys[idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z]; + // uncomment below if you want to preserve the atom ordering + // keys[idx] = atom_idx; + vals[idx] = atom_idx; +} + +void __global__ k_arange(int N, unsigned int *arr) { + const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (atom_idx >= N) { + return; + } + arr[atom_idx] = atom_idx; +} diff --git a/timemachine/cpp/src/kernels/k_nonbonded_dense.cuh b/timemachine/cpp/src/kernels/k_nonbonded.cuh similarity index 90% rename from timemachine/cpp/src/kernels/k_nonbonded_dense.cuh rename to timemachine/cpp/src/kernels/k_nonbonded.cuh index c2cab8c16..00e1e7f5a 100644 --- a/timemachine/cpp/src/kernels/k_nonbonded_dense.cuh +++ b/timemachine/cpp/src/kernels/k_nonbonded.cuh @@ -5,6 +5,8 @@ #include "nonbonded_common.cuh" #include "surreal.cuh" +void __global__ k_arange(int N, unsigned int *arr); + // generate kv values from coordinates to be radix sorted void __global__ k_coords_to_kv( const int N, @@ -12,38 +14,17 @@ void __global__ k_coords_to_kv( const double *box, const unsigned int *bin_to_idx, unsigned int *keys, - unsigned int *vals) { - - const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (atom_idx >= N) { - return; - } - - // these coords have to be centered - double bx = box[0 * 3 + 0]; - double by = box[1 * 3 + 1]; - double bz = box[2 * 3 + 2]; - - double binWidth = max(max(bx, by), bz) / 255.0; + unsigned int *vals); - double x = coords[atom_idx * 3 + 0]; - double y = coords[atom_idx * 3 + 1]; - double z = coords[atom_idx * 3 + 2]; - - x -= bx * floor(x / bx); - y -= by * floor(y / by); - z -= bz * floor(z / bz); - - unsigned int bin_x = x / binWidth; - unsigned int bin_y = y / binWidth; - unsigned int bin_z = z / binWidth; - - keys[atom_idx] = bin_to_idx[bin_x * 256 * 256 + bin_y * 256 + bin_z]; - // uncomment below if you want to preserve the atom ordering - // keys[atom_idx] = atom_idx; - vals[atom_idx] = atom_idx; -} +// variant of k_coords_to_kv allowing the selection of a subset of coordinates +void __global__ k_coords_to_kv_gather( + const int N, // number of atoms in selection + const unsigned int *atom_idxs, // [N] indices of atoms to select + const double *coords, + const double *box, + const unsigned int *bin_to_idx, + unsigned int *keys, + unsigned int *vals); template void __global__ k_check_rebuild_box(const int N, const double *new_box, const double *old_box, int *rebuild) { @@ -282,7 +263,8 @@ template < bool COMPUTE_DU_DP> // void __device__ __forceinline__ v_nonbonded_unified( void __device__ v_nonbonded_unified( - const int N, + const int NC, + const int NR, const double *__restrict__ coords, const double *__restrict__ params, // [N] const double *__restrict__ box, @@ -316,6 +298,12 @@ void __device__ v_nonbonded_unified( // int lambda_offset_i = atom_i_idx < N ? lambda_offset_idxs[atom_i_idx] : 0; // int lambda_plane_i = atom_i_idx < N ? lambda_plane_idxs[atom_i_idx] : 0; + const int N = NC + NR; + + if (NR != 0) { + atom_i_idx += NC; + } + RealType ci_x = atom_i_idx < N ? coords[atom_i_idx * 3 + 0] : 0; RealType ci_y = atom_i_idx < N ? coords[atom_i_idx * 3 + 1] : 0; RealType ci_z = atom_i_idx < N ? coords[atom_i_idx * 3 + 2] : 0; @@ -348,15 +336,15 @@ void __device__ v_nonbonded_unified( // int lambda_offset_j = atom_j_idx < N ? lambda_offset_idxs[atom_j_idx] : 0; // int lambda_plane_j = atom_j_idx < N ? lambda_plane_idxs[atom_j_idx] : 0; - RealType cj_x = atom_j_idx < N ? coords[atom_j_idx * 3 + 0] : 0; - RealType cj_y = atom_j_idx < N ? coords[atom_j_idx * 3 + 1] : 0; - RealType cj_z = atom_j_idx < N ? coords[atom_j_idx * 3 + 2] : 0; - RealType cj_w = atom_j_idx < N ? coords_w[atom_j_idx] : 0; + RealType cj_x = atom_j_idx < NC ? coords[atom_j_idx * 3 + 0] : 0; + RealType cj_y = atom_j_idx < NC ? coords[atom_j_idx * 3 + 1] : 0; + RealType cj_z = atom_j_idx < NC ? coords[atom_j_idx * 3 + 2] : 0; + RealType cj_w = atom_j_idx < NC ? coords_w[atom_j_idx] : 0; - RealType dq_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 0] : 0; - RealType dsig_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 1] : 0; - RealType deps_dl_j = atom_j_idx < N ? dp_dl[atom_j_idx * 3 + 2] : 0; - RealType dw_dl_j = atom_j_idx < N ? dw_dl[atom_j_idx] : 0; + RealType dq_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 0] : 0; + RealType dsig_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 1] : 0; + RealType deps_dl_j = atom_j_idx < NC ? dp_dl[atom_j_idx * 3 + 2] : 0; + RealType dw_dl_j = atom_j_idx < NC ? dw_dl[atom_j_idx] : 0; unsigned long long gj_x = 0; unsigned long long gj_y = 0; @@ -366,9 +354,9 @@ void __device__ v_nonbonded_unified( int lj_param_idx_sig_j = atom_j_idx * 3 + 1; int lj_param_idx_eps_j = atom_j_idx * 3 + 2; - RealType qj = atom_j_idx < N ? params[charge_param_idx_j] : 0; - RealType sig_j = atom_j_idx < N ? params[lj_param_idx_sig_j] : 0; - RealType eps_j = atom_j_idx < N ? params[lj_param_idx_eps_j] : 0; + RealType qj = atom_j_idx < NC ? params[charge_param_idx_j] : 0; + RealType sig_j = atom_j_idx < NC ? params[lj_param_idx_sig_j] : 0; + RealType eps_j = atom_j_idx < NC ? params[lj_param_idx_eps_j] : 0; unsigned long long g_qj = 0; unsigned long long g_sigj = 0; @@ -403,11 +391,17 @@ void __device__ v_nonbonded_unified( d2ij += delta_w * delta_w; } + const bool valid_ij = + atom_i_idx < N && + ((NR == 0) ? atom_i_idx < atom_j_idx && atom_j_idx < N // all-pairs case, only compute the upper tri + // 0 <= i < N, i < j < N + : atom_j_idx < NC); // ixn groups case, compute all pairwise ixns + // NC <= i < N, 0 <= j < NC + // (ytz): note that d2ij must be *strictly* less than cutoff_squared. This is because we set the // non-interacting atoms to exactly real_cutoff*real_cutoff. This ensures that atoms who's 4th dimension // is set to cutoff are non-interacting. - if (d2ij < cutoff_squared && atom_j_idx > atom_i_idx && atom_j_idx < N && atom_i_idx < N) { - + if (d2ij < cutoff_squared && valid_ij) { // electrostatics RealType u; RealType es_prefactor; @@ -547,7 +541,8 @@ void __device__ v_nonbonded_unified( template void __global__ k_nonbonded_unified( - const int N, + const int NC, + const int NR, const double *__restrict__ coords, const double *__restrict__ params, // [N] const double *__restrict__ box, @@ -567,6 +562,12 @@ void __global__ k_nonbonded_unified( int row_block_idx = ixn_tiles[tile_idx]; int atom_i_idx = row_block_idx * 32 + threadIdx.x; + const int N = NC + NR; + + if (NR != 0) { + atom_i_idx += NC; + } + RealType dq_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 0] : 0; RealType dsig_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 1] : 0; RealType deps_dl_i = atom_i_idx < N ? dp_dl[atom_i_idx * 3 + 2] : 0; @@ -587,7 +588,8 @@ void __global__ k_nonbonded_unified( if (tile_is_vanilla) { v_nonbonded_unified( - N, + NC, + NR, coords, params, box, @@ -604,7 +606,8 @@ void __global__ k_nonbonded_unified( u_buffer); } else { v_nonbonded_unified( - N, + NC, + NR, coords, params, box, diff --git a/timemachine/cpp/src/kernels/k_nonbonded_pairs.cuh b/timemachine/cpp/src/kernels/k_nonbonded_pair_list.cuh similarity index 99% rename from timemachine/cpp/src/kernels/k_nonbonded_pairs.cuh rename to timemachine/cpp/src/kernels/k_nonbonded_pair_list.cuh index 7d95f18a2..0c6c9290b 100644 --- a/timemachine/cpp/src/kernels/k_nonbonded_pairs.cuh +++ b/timemachine/cpp/src/kernels/k_nonbonded_pair_list.cuh @@ -14,7 +14,7 @@ void __device__ __forceinline__ accumulate(unsigned long long *__restrict acc, u } template -void __global__ k_nonbonded_pairs( +void __global__ k_nonbonded_pair_list( const int M, // number of pairs const double *__restrict__ coords, const double *__restrict__ params, diff --git a/timemachine/cpp/src/nonbonded.hpp b/timemachine/cpp/src/nonbonded.hpp index 663c13a52..5c8d06e71 100644 --- a/timemachine/cpp/src/nonbonded.hpp +++ b/timemachine/cpp/src/nonbonded.hpp @@ -1,7 +1,7 @@ #pragma once -#include "nonbonded_dense.hpp" -#include "nonbonded_pairs.hpp" +#include "nonbonded_all_pairs.hpp" +#include "nonbonded_pair_list.hpp" #include "potential.hpp" #include @@ -10,10 +10,10 @@ namespace timemachine { template class Nonbonded : public Potential { private: - NonbondedDense dense_; + NonbondedAllPairs dense_; static const bool Negated = true; - NonbondedPairs exclusions_; // implement exclusions as negated NonbondedPairs + NonbondedPairList exclusions_; // implement exclusions as negated NonbondedPairList public: // these are marked public but really only intended for testing. diff --git a/timemachine/cpp/src/nonbonded_dense.cu b/timemachine/cpp/src/nonbonded_all_pairs.cu similarity index 90% rename from timemachine/cpp/src/nonbonded_dense.cu rename to timemachine/cpp/src/nonbonded_all_pairs.cu index 5a66343c5..b618e06dc 100644 --- a/timemachine/cpp/src/nonbonded_dense.cu +++ b/timemachine/cpp/src/nonbonded_all_pairs.cu @@ -10,10 +10,10 @@ #include "fixed_point.hpp" #include "gpu_utils.cuh" -#include "nonbonded_dense.hpp" +#include "nonbonded_all_pairs.hpp" #include "vendored/hilbert.h" -#include "k_nonbonded_dense.cuh" +#include "k_nonbonded.cuh" #include #include @@ -22,7 +22,7 @@ namespace timemachine { template -NonbondedDense::NonbondedDense( +NonbondedAllPairs::NonbondedAllPairs( const std::vector &lambda_plane_idxs, // [N] const std::vector &lambda_offset_idxs, // [N] const double beta, @@ -63,10 +63,6 @@ NonbondedDense::NonbondedDense( compute_add_du_dp_interpolated_( kernel_cache_.program(kernel_src.c_str()).kernel("k_add_du_dp_interpolated").instantiate()) { - if (lambda_offset_idxs.size() != N_) { - throw std::runtime_error("lambda offset idxs need to have size N"); - } - if (lambda_offset_idxs.size() != lambda_plane_idxs.size()) { throw std::runtime_error("lambda offset idxs and plane idxs need to be equivalent"); } @@ -138,12 +134,12 @@ NonbondedDense::NonbondedDense( gpuErrchk(cudaMalloc(&d_sort_storage_, d_sort_storage_bytes_)); }; -template NonbondedDense::~NonbondedDense() { +template NonbondedAllPairs::~NonbondedAllPairs() { gpuErrchk(cudaFree(d_lambda_plane_idxs_)); gpuErrchk(cudaFree(d_lambda_offset_idxs_)); gpuErrchk(cudaFree(d_du_dp_buffer_)); - gpuErrchk(cudaFree(d_perm_)); // nullptr if we never built nblist + gpuErrchk(cudaFree(d_perm_)); gpuErrchk(cudaFree(d_bin_to_idx_)); gpuErrchk(cudaFree(d_sorted_x_)); @@ -173,16 +169,16 @@ template NonbondedDense -void NonbondedDense::set_nblist_padding(double val) { +void NonbondedAllPairs::set_nblist_padding(double val) { nblist_padding_ = val; } -template void NonbondedDense::disable_hilbert_sort() { +template void NonbondedAllPairs::disable_hilbert_sort() { disable_hilbert_ = true; } template -void NonbondedDense::hilbert_sort( +void NonbondedAllPairs::hilbert_sort( const double *d_coords, const double *d_box, cudaStream_t stream) { const int tpb = 32; @@ -208,16 +204,8 @@ void NonbondedDense::hilbert_sort( gpuErrchk(cudaPeekAtLastError()); } -void __global__ k_arange(int N, unsigned int *arr) { - const int atom_idx = blockIdx.x * blockDim.x + threadIdx.x; - if (atom_idx >= N) { - return; - } - arr[atom_idx] = atom_idx; -} - template -void NonbondedDense::execute_device( +void NonbondedAllPairs::execute_device( const int N, const int P, const double *d_x, @@ -247,15 +235,17 @@ void NonbondedDense::execute_device( // g. note that du/dl is not an exact per-particle du/dl - it is only used for reduction purposes. if (N != N_) { - std::cout << N << " " << N_ << std::endl; - throw std::runtime_error("NonbondedDense::execute_device() N != N_"); + throw std::runtime_error( + "NonbondedAllPairs::execute_device(): expected N == N_, got N=" + std::to_string(N) + + ", N_=" + std::to_string(N_)); } const int M = Interpolated ? 2 : 1; if (P != M * N_ * 3) { - std::cout << P << " " << N_ << std::endl; - throw std::runtime_error("NonbondedDense::execute_device() P != M*N_*3"); + throw std::runtime_error( + "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=" + std::to_string(P) + + ", M*N_*3=" + std::to_string(M * N_ * 3)); } // identify which tiles contain interpolated parameters @@ -266,8 +256,6 @@ void NonbondedDense::execute_device( dim3 dimGrid(B, 3, 1); // (ytz) see if we need to rebuild the neighborlist. - // (ytz + jfass): note that this logic needs to change if we use NPT later on since a resize in the box - // can introduce new interactions. k_check_rebuild_coords_and_box <<>>(N, d_x, d_nblist_x_, d_box, d_nblist_box_, nblist_padding_, d_rebuild_nblist_); gpuErrchk(cudaPeekAtLastError()); @@ -297,10 +285,14 @@ void NonbondedDense::execute_device( std::vector h_box(9); gpuErrchk(cudaMemcpyAsync(&h_box[0], d_box, 3 * 3 * sizeof(*d_box), cudaMemcpyDeviceToHost, stream)); + + // this stream needs to be synchronized so we can be sure that p_ixn_count_ is properly set. + cudaStreamSynchronize(stream); + // Verify that the cutoff and box size are valid together. If cutoff is greater than half the box // then a particle can interact with multiple periodic copies. const double db_cutoff = (cutoff_ + nblist_padding_) * 2; - cudaStreamSynchronize(stream); + // Verify that box is orthogonal and the width of the box in all dimensions is greater than twice the cutoff for (int i = 0; i < 9; i++) { if (i == 0 || i == 4 || i == 8) { @@ -334,7 +326,6 @@ void NonbondedDense::execute_device( gpuErrchk(cudaMemsetAsync(d_sorted_dp_dl_, 0, N * 3 * sizeof(*d_sorted_dp_dl_), stream)) } - // this stream needs to be synchronized so we can be sure that p_ixn_count_ is properly set. // reset buffers and sorted accumulators if (d_du_dx) { gpuErrchk(cudaMemsetAsync(d_sorted_du_dx_, 0, N * 3 * sizeof(*d_sorted_du_dx_), stream)) @@ -364,6 +355,7 @@ void NonbondedDense::execute_device( kernel_ptrs_[kernel_idx]<<>>( N, + 0, d_sorted_x_, d_sorted_p_, d_box, @@ -410,7 +402,7 @@ void NonbondedDense::execute_device( } template -void NonbondedDense::du_dp_fixed_to_float( +void NonbondedAllPairs::du_dp_fixed_to_float( const int N, const int P, const unsigned long long *du_dp, double *du_dp_float) { // In the interpolated case we have derivatives for the initial and final parameters @@ -426,9 +418,9 @@ void NonbondedDense::du_dp_fixed_to_float( } } -template class NonbondedDense; -template class NonbondedDense; -template class NonbondedDense; -template class NonbondedDense; +template class NonbondedAllPairs; +template class NonbondedAllPairs; +template class NonbondedAllPairs; +template class NonbondedAllPairs; } // namespace timemachine diff --git a/timemachine/cpp/src/nonbonded_dense.hpp b/timemachine/cpp/src/nonbonded_all_pairs.hpp similarity index 74% rename from timemachine/cpp/src/nonbonded_dense.hpp rename to timemachine/cpp/src/nonbonded_all_pairs.hpp index b0e82b00f..1ad6bbdec 100644 --- a/timemachine/cpp/src/nonbonded_dense.hpp +++ b/timemachine/cpp/src/nonbonded_all_pairs.hpp @@ -9,7 +9,8 @@ namespace timemachine { typedef void (*k_nonbonded_fn)( - const int N, + const int NC, + const int NR, const double *__restrict__ coords, const double *__restrict__ params, // [N] const double *__restrict__ box, @@ -25,7 +26,7 @@ typedef void (*k_nonbonded_fn)( unsigned long long *__restrict__ du_dl_buffer, unsigned long long *__restrict__ u_buffer); -template class NonbondedDense : public Potential { +template class NonbondedAllPairs : public Potential { private: std::array kernel_ptrs_; @@ -48,21 +49,28 @@ template class NonbondedDense : public Po unsigned int *d_perm_; // hilbert curve permutation - double *d_w_; // - double *d_dw_dl_; // - - double *d_sorted_x_; // - double *d_sorted_w_; // - double *d_sorted_dw_dl_; // - double *d_sorted_p_; // - double *d_unsorted_p_; // + double *d_w_; // 4D coordinates + double *d_dw_dl_; + + // "sorted" means + // - if hilbert sorting enabled, atoms are sorted according to the + // hilbert curve index + // - otherwise, atom ordering is preserved with respect to input + // + // "unsorted" means the atom ordering is preserved with respect to input + double *d_sorted_x_; // sorted coordinates + double *d_sorted_w_; // sorted 4D coordinates + double *d_sorted_dw_dl_; + double *d_sorted_p_; // sorted parameters + double *d_unsorted_p_; // unsorted parameters double *d_sorted_dp_dl_; double *d_unsorted_dp_dl_; - unsigned long long *d_sorted_du_dx_; // - unsigned long long *d_sorted_du_dp_; // - unsigned long long *d_du_dp_buffer_; // + unsigned long long *d_sorted_du_dx_; + unsigned long long *d_sorted_du_dp_; + unsigned long long *d_du_dp_buffer_; - unsigned int *d_bin_to_idx_; + // used for hilbert sorting + unsigned int *d_bin_to_idx_; // mapping from 256x256x256 grid to hilbert curve index unsigned int *d_sort_keys_in_; unsigned int *d_sort_keys_out_; unsigned int *d_sort_vals_in_; @@ -83,14 +91,14 @@ template class NonbondedDense : public Po void set_nblist_padding(double val); void disable_hilbert_sort(); - NonbondedDense( + NonbondedAllPairs( const std::vector &lambda_plane_idxs, // N const std::vector &lambda_offset_idxs, // N const double beta, const double cutoff, const std::string &kernel_src); - ~NonbondedDense(); + ~NonbondedAllPairs(); virtual void execute_device( const int N, diff --git a/timemachine/cpp/src/nonbonded_interaction_group.cu b/timemachine/cpp/src/nonbonded_interaction_group.cu new file mode 100644 index 000000000..022f0ff3d --- /dev/null +++ b/timemachine/cpp/src/nonbonded_interaction_group.cu @@ -0,0 +1,481 @@ +#include "vendored/jitify.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fixed_point.hpp" +#include "gpu_utils.cuh" +#include "nonbonded_interaction_group.hpp" +#include "vendored/hilbert.h" + +#include "k_nonbonded.cuh" + +#include +#include +#include + +namespace timemachine { + +std::vector set_to_vector(const std::set &s) { + std::vector v(s.begin(), s.end()); + return v; +} + +template +NonbondedInteractionGroup::NonbondedInteractionGroup( + const std::set &row_atom_idxs, + const std::vector &lambda_plane_idxs, // [N] + const std::vector &lambda_offset_idxs, // [N] + const double beta, + const double cutoff, + const std::string &kernel_src) + : N_(lambda_offset_idxs.size()), NR_(row_atom_idxs.size()), NC_(N_ - NR_), cutoff_(cutoff), nblist_(NC_, NR_), + beta_(beta), d_sort_storage_(nullptr), d_sort_storage_bytes_(0), nblist_padding_(0.1), disable_hilbert_(false), + kernel_ptrs_({// enumerate over every possible kernel combination + // U: Compute U + // X: Compute DU_DL + // L: Compute DU_DX + // P: Compute DU_DP + // U X L P + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified, + &k_nonbonded_unified}), + compute_w_coords_instance_(kernel_cache_.program(kernel_src.c_str()).kernel("k_compute_w_coords").instantiate()), + compute_permute_interpolated_( + kernel_cache_.program(kernel_src.c_str()).kernel("k_permute_interpolated").instantiate()), + compute_add_du_dp_interpolated_( + kernel_cache_.program(kernel_src.c_str()).kernel("k_add_du_dp_interpolated").instantiate()) { + + if (NR_ == 0) { + throw std::runtime_error("row_atom_idxs must be nonempty"); + } + + if (lambda_offset_idxs.size() != lambda_plane_idxs.size()) { + throw std::runtime_error("lambda offset idxs and plane idxs need to be equivalent"); + } + + // compute set of column atoms as set difference + std::vector all_atom_idxs(N_); + std::iota(all_atom_idxs.begin(), all_atom_idxs.end(), 0); + std::set col_atom_idxs; + std::set_difference( + all_atom_idxs.begin(), + all_atom_idxs.end(), + row_atom_idxs.begin(), + row_atom_idxs.end(), + std::inserter(col_atom_idxs, col_atom_idxs.end())); + + std::vector col_atom_idxs_v(set_to_vector(col_atom_idxs)); + gpuErrchk(cudaMalloc(&d_col_atom_idxs_, NC_ * sizeof(*d_col_atom_idxs_))); + gpuErrchk( + cudaMemcpy(d_col_atom_idxs_, &col_atom_idxs_v[0], NC_ * sizeof(*d_col_atom_idxs_), cudaMemcpyHostToDevice)); + + std::vector row_atom_idxs_v(set_to_vector(row_atom_idxs)); + gpuErrchk(cudaMalloc(&d_row_atom_idxs_, NR_ * sizeof(*d_row_atom_idxs_))); + gpuErrchk( + cudaMemcpy(d_row_atom_idxs_, &row_atom_idxs_v[0], NR_ * sizeof(*d_row_atom_idxs_), cudaMemcpyHostToDevice)); + + gpuErrchk(cudaMalloc(&d_lambda_plane_idxs_, N_ * sizeof(*d_lambda_plane_idxs_))); + gpuErrchk(cudaMemcpy( + d_lambda_plane_idxs_, &lambda_plane_idxs[0], N_ * sizeof(*d_lambda_plane_idxs_), cudaMemcpyHostToDevice)); + + gpuErrchk(cudaMalloc(&d_lambda_offset_idxs_, N_ * sizeof(*d_lambda_offset_idxs_))); + gpuErrchk(cudaMemcpy( + d_lambda_offset_idxs_, &lambda_offset_idxs[0], N_ * sizeof(*d_lambda_offset_idxs_), cudaMemcpyHostToDevice)); + + gpuErrchk(cudaMalloc(&d_perm_, N_ * sizeof(*d_perm_))); + + gpuErrchk(cudaMalloc(&d_sorted_x_, N_ * 3 * sizeof(*d_sorted_x_))); + + gpuErrchk(cudaMalloc(&d_w_, N_ * sizeof(*d_w_))); + gpuErrchk(cudaMalloc(&d_dw_dl_, N_ * sizeof(*d_dw_dl_))); + gpuErrchk(cudaMalloc(&d_sorted_w_, N_ * sizeof(*d_sorted_w_))); + gpuErrchk(cudaMalloc(&d_sorted_dw_dl_, N_ * sizeof(*d_sorted_dw_dl_))); + + gpuErrchk(cudaMalloc(&d_unsorted_p_, N_ * 3 * sizeof(*d_unsorted_p_))); // interpolated + gpuErrchk(cudaMalloc(&d_sorted_p_, N_ * 3 * sizeof(*d_sorted_p_))); // interpolated + gpuErrchk(cudaMalloc(&d_unsorted_dp_dl_, N_ * 3 * sizeof(*d_unsorted_dp_dl_))); // interpolated + gpuErrchk(cudaMalloc(&d_sorted_dp_dl_, N_ * 3 * sizeof(*d_sorted_dp_dl_))); // interpolated + gpuErrchk(cudaMalloc(&d_sorted_du_dx_, N_ * 3 * sizeof(*d_sorted_du_dx_))); + gpuErrchk(cudaMalloc(&d_sorted_du_dp_, N_ * 3 * sizeof(*d_sorted_du_dp_))); + gpuErrchk(cudaMalloc(&d_du_dp_buffer_, N_ * 3 * sizeof(*d_du_dp_buffer_))); + + gpuErrchk(cudaMallocHost(&p_ixn_count_, 1 * sizeof(*p_ixn_count_))); + + gpuErrchk(cudaMalloc(&d_nblist_x_, N_ * 3 * sizeof(*d_nblist_x_))); + gpuErrchk(cudaMemset(d_nblist_x_, 0, N_ * 3 * sizeof(*d_nblist_x_))); // set non-sensical positions + gpuErrchk(cudaMalloc(&d_nblist_box_, 3 * 3 * sizeof(*d_nblist_x_))); + gpuErrchk(cudaMemset(d_nblist_box_, 0, 3 * 3 * sizeof(*d_nblist_x_))); + gpuErrchk(cudaMalloc(&d_rebuild_nblist_, 1 * sizeof(*d_rebuild_nblist_))); + gpuErrchk(cudaMallocHost(&p_rebuild_nblist_, 1 * sizeof(*p_rebuild_nblist_))); + + gpuErrchk(cudaMalloc(&d_sort_keys_in_, N_ * sizeof(d_sort_keys_in_))); + gpuErrchk(cudaMalloc(&d_sort_keys_out_, N_ * sizeof(d_sort_keys_out_))); + gpuErrchk(cudaMalloc(&d_sort_vals_in_, N_ * sizeof(d_sort_vals_in_))); + + // initialize hilbert curve + std::vector bin_to_idx(256 * 256 * 256); + for (int i = 0; i < 256; i++) { + for (int j = 0; j < 256; j++) { + for (int k = 0; k < 256; k++) { + + bitmask_t hilbert_coords[3]; + hilbert_coords[0] = i; + hilbert_coords[1] = j; + hilbert_coords[2] = k; + + unsigned int bin = static_cast(hilbert_c2i(3, 8, hilbert_coords)); + bin_to_idx[i * 256 * 256 + j * 256 + k] = bin; + } + } + } + + gpuErrchk(cudaMalloc(&d_bin_to_idx_, 256 * 256 * 256 * sizeof(*d_bin_to_idx_))); + gpuErrchk( + cudaMemcpy(d_bin_to_idx_, &bin_to_idx[0], 256 * 256 * 256 * sizeof(*d_bin_to_idx_), cudaMemcpyHostToDevice)); + + // estimate size needed to do radix sorting, this can use uninitialized data. + cub::DeviceRadixSort::SortPairs( + d_sort_storage_, + d_sort_storage_bytes_, + d_sort_keys_in_, + d_sort_keys_out_, + d_sort_vals_in_, + d_perm_, + std::max(NC_, NR_)); + + gpuErrchk(cudaPeekAtLastError()); + gpuErrchk(cudaMalloc(&d_sort_storage_, d_sort_storage_bytes_)); +}; + +template +NonbondedInteractionGroup::~NonbondedInteractionGroup() { + gpuErrchk(cudaFree(d_col_atom_idxs_)); + gpuErrchk(cudaFree(d_row_atom_idxs_)); + + gpuErrchk(cudaFree(d_lambda_plane_idxs_)); + gpuErrchk(cudaFree(d_lambda_offset_idxs_)); + gpuErrchk(cudaFree(d_du_dp_buffer_)); + gpuErrchk(cudaFree(d_perm_)); + + gpuErrchk(cudaFree(d_bin_to_idx_)); + gpuErrchk(cudaFree(d_sorted_x_)); + + gpuErrchk(cudaFree(d_w_)); + gpuErrchk(cudaFree(d_dw_dl_)); + gpuErrchk(cudaFree(d_sorted_w_)); + gpuErrchk(cudaFree(d_sorted_dw_dl_)); + gpuErrchk(cudaFree(d_unsorted_p_)); + gpuErrchk(cudaFree(d_sorted_p_)); + gpuErrchk(cudaFree(d_unsorted_dp_dl_)); + gpuErrchk(cudaFree(d_sorted_dp_dl_)); + gpuErrchk(cudaFree(d_sorted_du_dx_)); + gpuErrchk(cudaFree(d_sorted_du_dp_)); + + gpuErrchk(cudaFree(d_sort_keys_in_)); + gpuErrchk(cudaFree(d_sort_keys_out_)); + gpuErrchk(cudaFree(d_sort_vals_in_)); + gpuErrchk(cudaFree(d_sort_storage_)); + + gpuErrchk(cudaFreeHost(p_ixn_count_)); + + gpuErrchk(cudaFree(d_nblist_x_)); + gpuErrchk(cudaFree(d_nblist_box_)); + gpuErrchk(cudaFree(d_rebuild_nblist_)); + gpuErrchk(cudaFreeHost(p_rebuild_nblist_)); +}; + +template +void NonbondedInteractionGroup::set_nblist_padding(double val) { + nblist_padding_ = val; +} + +template +void NonbondedInteractionGroup::disable_hilbert_sort() { + disable_hilbert_ = true; +} + +template +void NonbondedInteractionGroup::hilbert_sort( + const int N, + const unsigned int *d_atom_idxs, + const double *d_coords, + const double *d_box, + unsigned int *d_perm, + cudaStream_t stream) { + + const int tpb = 32; + const int B = ceil_divide(N, tpb); + + k_coords_to_kv_gather<<>>( + N, d_atom_idxs, d_coords, d_box, d_bin_to_idx_, d_sort_keys_in_, d_sort_vals_in_); + + gpuErrchk(cudaPeekAtLastError()); + + cub::DeviceRadixSort::SortPairs( + d_sort_storage_, + d_sort_storage_bytes_, + d_sort_keys_in_, + d_sort_keys_out_, + d_sort_vals_in_, + d_perm, + N, + 0, // begin bit + sizeof(*d_sort_keys_in_) * 8, // end bit + stream // cudaStream + ); + + gpuErrchk(cudaPeekAtLastError()); +} + +template +void NonbondedInteractionGroup::execute_device( + const int N, + const int P, + const double *d_x, + const double *d_p, // 2 * N * 3 + const double *d_box, // 3 * 3 + const double lambda, + unsigned long long *d_du_dx, + unsigned long long *d_du_dp, + unsigned long long *d_du_dl, + unsigned long long *d_u, + cudaStream_t stream) { + + // (ytz) the nonbonded algorithm proceeds as follows: + + // (done in constructor), construct a hilbert curve mapping each of the 256x256x256 cells into an index. + // a. decide if we need to rebuild the neighborlist, if so: + // - look up which cell each particle belongs to, and its linear index along the hilbert curve. + // - use radix pair sort keyed on the hilbert index with values equal to the atomic index + // - resulting sorted values is the permutation array. + // - permute lambda plane/offsets, coords + // b. else: + // - permute new coords + // c. permute parameters + // d. compute the nonbonded interactions using the neighborlist + // e. inverse permute the forces, du/dps into the original index. + // f. u and du/dl is buffered into a per-particle array, and then reduced. + // g. note that du/dl is not an exact per-particle du/dl - it is only used for reduction purposes. + + if (N != N_) { + throw std::runtime_error( + "NonbondedAllPairs::execute_device(): expected N == N_, got N=" + std::to_string(N) + + ", N_=" + std::to_string(N_)); + } + + const int M = Interpolated ? 2 : 1; + + if (P != M * N_ * 3) { + throw std::runtime_error( + "NonbondedAllPairs::execute_device(): expected P == M*N_*3, got P=" + std::to_string(P) + + ", M*N_*3=" + std::to_string(M * N_ * 3)); + } + + // identify which tiles contain interpolated parameters + + const int tpb = 32; + const int B = ceil_divide(N_, tpb); + + // (ytz) see if we need to rebuild the neighborlist. + k_check_rebuild_coords_and_box + <<>>(N_, d_x, d_nblist_x_, d_box, d_nblist_box_, nblist_padding_, d_rebuild_nblist_); + + gpuErrchk(cudaPeekAtLastError()); + + // we can optimize this away by doing the check on the GPU directly. + gpuErrchk(cudaMemcpyAsync( + p_rebuild_nblist_, d_rebuild_nblist_, 1 * sizeof(*p_rebuild_nblist_), cudaMemcpyDeviceToHost, stream)); + gpuErrchk(cudaStreamSynchronize(stream)); // slow! + + dim3 dimGrid(B, 3, 1); + + if (p_rebuild_nblist_[0] > 0) { + + // (ytz): update the permutation index before building neighborlist, as the neighborlist is tied + // to a particular sort order + if (!disable_hilbert_) { + this->hilbert_sort(NC_, d_col_atom_idxs_, d_x, d_box, d_perm_, stream); + this->hilbert_sort(NR_, d_row_atom_idxs_, d_x, d_box, d_perm_ + NC_, stream); + } else { + gpuErrchk(cudaMemcpyAsync( + d_perm_, d_col_atom_idxs_, NC_ * sizeof(*d_col_atom_idxs_), cudaMemcpyDeviceToDevice, stream)); + gpuErrchk(cudaMemcpyAsync( + d_perm_ + NC_, d_row_atom_idxs_, NR_ * sizeof(*d_row_atom_idxs_), cudaMemcpyDeviceToDevice, stream)); + } + + // compute new coordinates, new lambda_idxs, new_plane_idxs + k_permute<<>>(N_, d_perm_, d_x, d_sorted_x_); + gpuErrchk(cudaPeekAtLastError()); + + nblist_.build_nblist_device( + NC_, NR_, d_sorted_x_, d_sorted_x_ + 3 * NC_, d_box, cutoff_ + nblist_padding_, stream); + gpuErrchk(cudaMemcpyAsync( + p_ixn_count_, nblist_.get_ixn_count(), 1 * sizeof(*p_ixn_count_), cudaMemcpyDeviceToHost, stream)); + + std::vector h_box(9); + gpuErrchk(cudaMemcpyAsync(&h_box[0], d_box, 3 * 3 * sizeof(*d_box), cudaMemcpyDeviceToHost, stream)); + + // this stream needs to be synchronized so we can be sure that p_ixn_count_ is properly set. + cudaStreamSynchronize(stream); + + // Verify that the cutoff and box size are valid together. If cutoff is greater than half the box + // then a particle can interact with multiple periodic copies. + const double db_cutoff = (cutoff_ + nblist_padding_) * 2; + + // Verify that box is orthogonal and the width of the box in all dimensions is greater than twice the cutoff + for (int i = 0; i < 9; i++) { + if (i == 0 || i == 4 || i == 8) { + if (h_box[i] < db_cutoff) { + throw std::runtime_error( + "Cutoff with padding is more than half of the box width, neighborlist is no longer reliable"); + } + } else if (h_box[i] != 0.0) { + throw std::runtime_error("Provided non-ortholinear box, unable to compute nonbonded energy"); + } + } + + gpuErrchk(cudaMemsetAsync(d_rebuild_nblist_, 0, sizeof(*d_rebuild_nblist_), stream)); + gpuErrchk(cudaMemcpyAsync(d_nblist_x_, d_x, N * 3 * sizeof(*d_x), cudaMemcpyDeviceToDevice, stream)); + gpuErrchk(cudaMemcpyAsync(d_nblist_box_, d_box, 3 * 3 * sizeof(*d_box), cudaMemcpyDeviceToDevice, stream)); + } else { + k_permute<<>>(N, d_perm_, d_x, d_sorted_x_); + gpuErrchk(cudaPeekAtLastError()); + } + + // if the neighborlist is empty, we can return early + if (p_ixn_count_[0] == 0) { + return; + } + + // do parameter interpolation here + if (Interpolated) { + CUresult result = compute_permute_interpolated_.configure(dimGrid, tpb, 0, stream) + .launch(lambda, N, d_perm_, d_p, d_sorted_p_, d_sorted_dp_dl_); + if (result != 0) { + throw std::runtime_error("Driver call to k_permute_interpolated failed"); + } + } else { + k_permute<<>>(N, d_perm_, d_p, d_sorted_p_); + gpuErrchk(cudaPeekAtLastError()); + gpuErrchk(cudaMemsetAsync(d_sorted_dp_dl_, 0, N * 3 * sizeof(*d_sorted_dp_dl_), stream)) + } + + // reset buffers and sorted accumulators + if (d_du_dx) { + gpuErrchk(cudaMemsetAsync(d_sorted_du_dx_, 0, N * 3 * sizeof(*d_sorted_du_dx_), stream)) + } + if (d_du_dp) { + gpuErrchk(cudaMemsetAsync(d_sorted_du_dp_, 0, N * 3 * sizeof(*d_sorted_du_dp_), stream)) + } + + // update new w coordinates + // (tbd): cache lambda value for equilibrium calculations + CUresult result = compute_w_coords_instance_.configure(B, tpb, 0, stream) + .launch(N, lambda, cutoff_, d_lambda_plane_idxs_, d_lambda_offset_idxs_, d_w_, d_dw_dl_); + if (result != 0) { + throw std::runtime_error("Driver call to k_compute_w_coords"); + } + + gpuErrchk(cudaPeekAtLastError()); + k_permute_2x<<>>(N, d_perm_, d_w_, d_dw_dl_, d_sorted_w_, d_sorted_dw_dl_); + gpuErrchk(cudaPeekAtLastError()); + + // look up which kernel we need for this computation + int kernel_idx = 0; + kernel_idx |= d_du_dp ? 1 << 0 : 0; + kernel_idx |= d_du_dl ? 1 << 1 : 0; + kernel_idx |= d_du_dx ? 1 << 2 : 0; + kernel_idx |= d_u ? 1 << 3 : 0; + + kernel_ptrs_[kernel_idx]<<>>( + NC_, + NR_, + d_sorted_x_, + d_sorted_p_, + d_box, + d_sorted_dp_dl_, + d_sorted_w_, + d_sorted_dw_dl_, + beta_, + cutoff_, + nblist_.get_ixn_tiles(), + nblist_.get_ixn_atoms(), + d_sorted_du_dx_, + d_sorted_du_dp_, + d_du_dl, // switch to nullptr if we don't request du_dl + d_u // switch to nullptr if we don't request energies + ); + + gpuErrchk(cudaPeekAtLastError()); + + // coords are N,3 + if (d_du_dx) { + k_inv_permute_accum<<>>(N, d_perm_, d_sorted_du_dx_, d_du_dx); + gpuErrchk(cudaPeekAtLastError()); + } + + // params are N,3 + // this needs to be an accumulated permute + if (d_du_dp) { + k_inv_permute_assign<<>>(N, d_perm_, d_sorted_du_dp_, d_du_dp_buffer_); + gpuErrchk(cudaPeekAtLastError()); + } + + if (d_du_dp) { + if (Interpolated) { + CUresult result = compute_add_du_dp_interpolated_.configure(dimGrid, tpb, 0, stream) + .launch(lambda, N, d_du_dp_buffer_, d_du_dp); + if (result != 0) { + throw std::runtime_error("Driver call to k_add_du_dp_interpolated failed"); + } + } else { + k_add_ull_to_ull<<>>(N, d_du_dp_buffer_, d_du_dp); + } + gpuErrchk(cudaPeekAtLastError()); + } +} + +template +void NonbondedInteractionGroup::du_dp_fixed_to_float( + const int N, const int P, const unsigned long long *du_dp, double *du_dp_float) { + + // In the interpolated case we have derivatives for the initial and final parameters + const int num_tuples = Interpolated ? N * 2 : N; + + for (int i = 0; i < num_tuples; i++) { + const int idx_charge = i * 3 + 0; + const int idx_sig = i * 3 + 1; + const int idx_eps = i * 3 + 2; + du_dp_float[idx_charge] = FIXED_TO_FLOAT_DU_DP(du_dp[idx_charge]); + du_dp_float[idx_sig] = FIXED_TO_FLOAT_DU_DP(du_dp[idx_sig]); + du_dp_float[idx_eps] = FIXED_TO_FLOAT_DU_DP(du_dp[idx_eps]); + } +} + +template class NonbondedInteractionGroup; +template class NonbondedInteractionGroup; +template class NonbondedInteractionGroup; +template class NonbondedInteractionGroup; + +} // namespace timemachine diff --git a/timemachine/cpp/src/nonbonded_interaction_group.hpp b/timemachine/cpp/src/nonbonded_interaction_group.hpp new file mode 100644 index 000000000..b3ab817bb --- /dev/null +++ b/timemachine/cpp/src/nonbonded_interaction_group.hpp @@ -0,0 +1,134 @@ +#pragma once + +#include "neighborlist.hpp" +#include "potential.hpp" +#include "vendored/jitify.hpp" +#include +#include +#include + +namespace timemachine { + +typedef void (*k_nonbonded_fn)( + const int NC, + const int NR, + const double *__restrict__ coords, + const double *__restrict__ params, // [N] + const double *__restrict__ box, + const double *__restrict__ dl_dp, + const double *__restrict__ coords_w, // 4D coords + const double *__restrict__ dw_dl, // 4D derivatives + const double beta, + const double cutoff, + const int *__restrict__ ixn_tiles, + const unsigned int *__restrict__ ixn_atoms, + unsigned long long *__restrict__ du_dx, + unsigned long long *__restrict__ du_dp, + unsigned long long *__restrict__ du_dl_buffer, + unsigned long long *__restrict__ u_buffer); + +template class NonbondedInteractionGroup : public Potential { + +private: + const int N_; // N_ = NC_ + NR_ + const int NR_; // number of row atoms + const int NC_; // number of column atoms + + std::array kernel_ptrs_; + + unsigned int *d_col_atom_idxs_; + unsigned int *d_row_atom_idxs_; + + int *d_lambda_plane_idxs_; + int *d_lambda_offset_idxs_; + int *p_ixn_count_; // pinned memory + + double beta_; + double cutoff_; + Neighborlist nblist_; + + double nblist_padding_; + double *d_nblist_x_; // coords which were used to compute the nblist + double *d_nblist_box_; // box which was used to rebuild the nblist + int *d_rebuild_nblist_; // whether or not we have to rebuild the nblist + int *p_rebuild_nblist_; // pinned + + unsigned int *d_perm_; // hilbert curve permutation + + double *d_w_; // 4D coordinates + double *d_dw_dl_; + + // "sorted" means + // - if hilbert sorting enabled, atoms are sorted into contiguous + // blocks by interaction group, and each block is hilbert-sorted + // independently + // - otherwise, atoms are sorted into contiguous blocks by + // interaction group, with arbitrary ordering within each block + // + // "unsorted" means the atom ordering is preserved with respect to input + double *d_sorted_x_; // sorted coordinates + double *d_sorted_w_; // sorted 4D coordinates + double *d_sorted_dw_dl_; + double *d_sorted_p_; // sorted parameters + double *d_unsorted_p_; // unsorted parameters + double *d_sorted_dp_dl_; + double *d_unsorted_dp_dl_; + unsigned long long *d_sorted_du_dx_; + unsigned long long *d_sorted_du_dp_; + unsigned long long *d_du_dp_buffer_; + + // used for hilbert sorting + unsigned int *d_bin_to_idx_; // mapping from 256x256x256 grid to hilbert curve index + unsigned int *d_sort_keys_in_; + unsigned int *d_sort_keys_out_; + unsigned int *d_sort_vals_in_; + unsigned int *d_sort_storage_; + size_t d_sort_storage_bytes_; + + bool disable_hilbert_; + + void hilbert_sort( + const int N, + const unsigned int *d_atom_idxs, + const double *d_x, + const double *d_box, + unsigned int *d_perm, + cudaStream_t stream); + + jitify::JitCache kernel_cache_; + jitify::KernelInstantiation compute_w_coords_instance_; + jitify::KernelInstantiation compute_permute_interpolated_; + jitify::KernelInstantiation compute_add_du_dp_interpolated_; + +public: + // these are marked public but really only intended for testing. + void set_nblist_padding(double val); + void disable_hilbert_sort(); + + NonbondedInteractionGroup( + const std::set &row_atom_idxs, + const std::vector &lambda_plane_idxs, // N + const std::vector &lambda_offset_idxs, // N + const double beta, + const double cutoff, + const std::string &kernel_src); + + ~NonbondedInteractionGroup(); + + virtual void execute_device( + const int N, + const int P, + const double *d_x, + const double *d_p, + const double *d_box, + const double lambda, + unsigned long long *d_du_dx, + unsigned long long *d_du_dp, + unsigned long long *d_du_dl, + unsigned long long *d_u, + cudaStream_t stream) override; + + void du_dp_fixed_to_float(const int N, const int P, const unsigned long long *du_dp, double *du_dp_float) override; +}; + +} // namespace timemachine diff --git a/timemachine/cpp/src/nonbonded_pairs.cu b/timemachine/cpp/src/nonbonded_pair_list.cu similarity index 87% rename from timemachine/cpp/src/nonbonded_pairs.cu rename to timemachine/cpp/src/nonbonded_pair_list.cu index f6b7c41ef..f08cc7eca 100644 --- a/timemachine/cpp/src/nonbonded_pairs.cu +++ b/timemachine/cpp/src/nonbonded_pair_list.cu @@ -1,14 +1,14 @@ #include "gpu_utils.cuh" -#include "k_nonbonded_pairs.cuh" +#include "k_nonbonded_pair_list.cuh" #include "math_utils.cuh" -#include "nonbonded_pairs.hpp" +#include "nonbonded_pair_list.hpp" #include #include namespace timemachine { template -NonbondedPairs::NonbondedPairs( +NonbondedPairList::NonbondedPairList( const std::vector &pair_idxs, // [M, 2] const std::vector &scales, // [M, 2] const std::vector &lambda_plane_idxs, // [N] @@ -72,7 +72,7 @@ NonbondedPairs::NonbondedPairs( }; template -NonbondedPairs::~NonbondedPairs() { +NonbondedPairList::~NonbondedPairList() { gpuErrchk(cudaFree(d_pair_idxs_)); gpuErrchk(cudaFree(d_scales_)); gpuErrchk(cudaFree(d_lambda_plane_idxs_)); @@ -89,7 +89,7 @@ NonbondedPairs::~NonbondedPairs() { }; template -void NonbondedPairs::execute_device( +void NonbondedPairList::execute_device( const int N, const int P, const double *d_x, @@ -129,7 +129,7 @@ void NonbondedPairs::execute_device( int num_blocks_pairs = ceil_divide(M_, tpb); - k_nonbonded_pairs<<>>( + k_nonbonded_pair_list<<>>( M_, d_x, Interpolated ? d_p_interp_ : d_p, @@ -162,9 +162,9 @@ void NonbondedPairs::execute_device( } } -// TODO: this implementation is duplicated from NonbondedDense. Worth adding NonbondedBase? +// TODO: this implementation is duplicated from NonbondedAllPairs template -void NonbondedPairs::du_dp_fixed_to_float( +void NonbondedPairList::du_dp_fixed_to_float( const int N, const int P, const unsigned long long *du_dp, double *du_dp_float) { // In the interpolated case we have derivatives for the initial and final parameters @@ -180,16 +180,16 @@ void NonbondedPairs::du_dp_fixed_to_float( } } -template class NonbondedPairs; -template class NonbondedPairs; +template class NonbondedPairList; +template class NonbondedPairList; -template class NonbondedPairs; -template class NonbondedPairs; +template class NonbondedPairList; +template class NonbondedPairList; -template class NonbondedPairs; -template class NonbondedPairs; +template class NonbondedPairList; +template class NonbondedPairList; -template class NonbondedPairs; -template class NonbondedPairs; +template class NonbondedPairList; +template class NonbondedPairList; } // namespace timemachine diff --git a/timemachine/cpp/src/nonbonded_pairs.hpp b/timemachine/cpp/src/nonbonded_pair_list.hpp similarity index 94% rename from timemachine/cpp/src/nonbonded_pairs.hpp rename to timemachine/cpp/src/nonbonded_pair_list.hpp index dac1c6e6d..a2770bd36 100644 --- a/timemachine/cpp/src/nonbonded_pairs.hpp +++ b/timemachine/cpp/src/nonbonded_pair_list.hpp @@ -6,7 +6,7 @@ namespace timemachine { -template class NonbondedPairs : public Potential { +template class NonbondedPairList : public Potential { private: int *d_pair_idxs_; // [M, 2] @@ -36,7 +36,7 @@ template class NonbondedPai jitify::KernelInstantiation compute_add_du_dp_interpolated_; public: - NonbondedPairs( + NonbondedPairList( const std::vector &pair_idxs, // [M, 2] const std::vector &scales, // [M, 2] const std::vector &lambda_plane_idxs, // [N] @@ -45,7 +45,7 @@ template class NonbondedPai const double cutoff, const std::string &kernel_src); - ~NonbondedPairs(); + ~NonbondedPairList(); virtual void execute_device( const int N, diff --git a/timemachine/cpp/src/wrap_kernels.cpp b/timemachine/cpp/src/wrap_kernels.cpp index a0a5549ef..d695bb14c 100644 --- a/timemachine/cpp/src/wrap_kernels.cpp +++ b/timemachine/cpp/src/wrap_kernels.cpp @@ -14,8 +14,7 @@ #include "integrator.hpp" #include "neighborlist.hpp" #include "nonbonded.hpp" -#include "nonbonded_dense.hpp" -#include "nonbonded_pairs.hpp" +#include "nonbonded_interaction_group.hpp" #include "periodic_torsion.hpp" #include "potential.hpp" #include "rmsd_align.hpp" @@ -707,6 +706,74 @@ template void declare_nonbonded(py::modul py::arg("transform_lambda_w") = "lambda"); } +std::set unique_idxs(const std::vector &idxs) { + std::set unique_idxs(idxs.begin(), idxs.end()); + if (unique_idxs.size() < idxs.size()) { + throw std::runtime_error("atom indices must be unique"); + } + return unique_idxs; +} + +template +void declare_nonbonded_interaction_group(py::module &m, const char *typestr) { + using Class = timemachine::NonbondedInteractionGroup; + std::string pyclass_name = std::string("NonbondedInteractionGroup_") + typestr; + py::class_, timemachine::Potential>( + m, pyclass_name.c_str(), py::buffer_protocol(), py::dynamic_attr()) + .def("set_nblist_padding", &timemachine::NonbondedInteractionGroup::set_nblist_padding) + .def( + "disable_hilbert_sort", + &timemachine::NonbondedInteractionGroup::disable_hilbert_sort) + .def( + py::init([](const py::array_t &row_atom_idxs_i, + const py::array_t &lambda_plane_idxs_i, + const py::array_t &lambda_offset_idxs_i, + const double beta, + const double cutoff, + const std::string &transform_lambda_charge = "lambda", + const std::string &transform_lambda_sigma = "lambda", + const std::string &transform_lambda_epsilon = "lambda", + const std::string &transform_lambda_w = "lambda") { + std::vector row_atom_idxs(row_atom_idxs_i.size()); + std::memcpy(row_atom_idxs.data(), row_atom_idxs_i.data(), row_atom_idxs_i.size() * sizeof(int)); + std::set unique_row_atom_idxs(unique_idxs(row_atom_idxs)); + + std::vector lambda_plane_idxs(lambda_plane_idxs_i.size()); + std::memcpy( + lambda_plane_idxs.data(), lambda_plane_idxs_i.data(), lambda_plane_idxs_i.size() * sizeof(int)); + + std::vector lambda_offset_idxs(lambda_offset_idxs_i.size()); + std::memcpy( + lambda_offset_idxs.data(), lambda_offset_idxs_i.data(), lambda_offset_idxs_i.size() * sizeof(int)); + + std::string dir_path = dirname(__FILE__); + std::string kernel_dir = dir_path + "/kernels"; + std::string src_path = kernel_dir + "/k_lambda_transformer_jit.cuh"; + std::ifstream t(src_path); + std::string source_str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); + source_str = std::regex_replace(source_str, std::regex("KERNEL_DIR"), kernel_dir); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_CHARGE"), transform_lambda_charge); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_SIGMA"), transform_lambda_sigma); + source_str = + std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_EPSILON"), transform_lambda_epsilon); + source_str = std::regex_replace(source_str, std::regex("CUSTOM_EXPRESSION_W"), transform_lambda_w); + + return new timemachine::NonbondedInteractionGroup( + unique_row_atom_idxs, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff, source_str); + }), + py::arg("row_atom_idxs_i"), + py::arg("lambda_plane_idxs_i"), + py::arg("lambda_offset_idxs_i"), + py::arg("beta"), + py::arg("cutoff"), + py::arg("transform_lambda_charge") = "lambda", + py::arg("transform_lambda_sigma") = "lambda", + py::arg("transform_lambda_epsilon") = "lambda", + py::arg("transform_lambda_w") = "lambda"); +} + void declare_barostat(py::module &m) { using Class = timemachine::MonteCarloBarostat; @@ -812,5 +879,11 @@ PYBIND11_MODULE(custom_ops, m) { declare_nonbonded(m, "f64"); declare_nonbonded(m, "f32"); + declare_nonbonded_interaction_group(m, "f64_interpolated"); + declare_nonbonded_interaction_group(m, "f32_interpolated"); + + declare_nonbonded_interaction_group(m, "f64"); + declare_nonbonded_interaction_group(m, "f32"); + declare_context(m); } diff --git a/timemachine/lib/potentials.py b/timemachine/lib/potentials.py index 7d64b8ba1..26346a5d8 100644 --- a/timemachine/lib/potentials.py +++ b/timemachine/lib/potentials.py @@ -278,3 +278,20 @@ def unbound_impl(self, precision): custom_ctor = getattr(custom_ops, cls_name_base) return custom_ctor(*self.args) + + +class NonbondedInteractionGroup(CustomOpWrapper): + pass + + +class NonbondedInteractionGroupInterpolated(NonbondedInteractionGroup): + def unbound_impl(self, precision): + cls_name_base = "NonbondedInteractionGroup" + if precision == np.float64: + cls_name_base += "_f64_interpolated" + else: + cls_name_base += "_f32_interpolated" + + custom_ctor = getattr(custom_ops, cls_name_base) + + return custom_ctor(*self.args) diff --git a/timemachine/potentials/nonbonded.py b/timemachine/potentials/nonbonded.py index 0cf224eb5..14de0136d 100644 --- a/timemachine/potentials/nonbonded.py +++ b/timemachine/potentials/nonbonded.py @@ -1,3 +1,5 @@ +import functools + import jax.numpy as np from jax.ops import index, index_update from jax.scipy.special import erfc @@ -242,7 +244,7 @@ def apply_cutoff(x): # vdW by Lennard-Jones sig_ij = apply_cutoff(sig[inds_l] + sig[inds_r]) eps_ij = apply_cutoff(eps[inds_l] * eps[inds_r]) - vdW = lennard_jones(dij, sig_ij, eps_ij) + vdW = np.where(eps_ij != 0, lennard_jones(dij, sig_ij, eps_ij), 0) # Electrostatics by direct-space part of PME qij = apply_cutoff(charges[inds_l] * charges[inds_r]) @@ -251,6 +253,33 @@ def apply_cutoff(x): return vdW, electrostatics +def nonbonded_v3_interaction_groups(conf, params, box, inds_l, inds_r, beta: float, cutoff: Optional[float] = None): + """Nonbonded interactions between all pairs of atoms $(i, j)$ + where $i$ is in the first set and $j$ in the second. + + See nonbonded_v3 docstring for more details + """ + pairs = np.stack(np.meshgrid(inds_l, inds_r)).reshape(2, -1).T + vdW, electrostatics = nonbonded_v3_on_specific_pairs(conf, params, box, pairs[:, 0], pairs[:, 1], beta, cutoff) + return vdW, electrostatics, pairs + + +def interpolated(u_fn): + @functools.wraps(u_fn) + def wrapper(conf, params, box, lamb): + + # params is expected to be the concatenation of initial + # (lambda = 0) and final (lamda = 1) parameters, each of + # length num_atoms + assert params.size % 2 == 0 + num_atoms = params.shape[0] // 2 + + new_params = (1 - lamb) * params[:num_atoms] + lamb * params[num_atoms:] + return u_fn(conf, new_params, box, lamb) + + return wrapper + + def validate_coulomb_cutoff(cutoff=1.0, beta=2.0, threshold=1e-2): """check whether f(r) = erfc(beta * r) <= threshold at r = cutoff following https://github.com/proteneer/timemachine/pull/424#discussion_r629678467"""