Skip to content

Commit

Permalink
Refactor (inds_i, inds_j) --> pairs, improve docs
Browse files Browse the repository at this point in the history
Addressing #578 (comment)
  • Loading branch information
maxentile committed Feb 18, 2022
1 parent 8f225c2 commit 213752b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 63 deletions.
18 changes: 10 additions & 8 deletions tests/test_jax_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,14 @@ def _nonbonded_v3_clone(
box_4d = None
box = box_4d

# TODO: len(inds_i) == n_interactions -- may want to break this
# TODO: len(pairs) == n_interactions -- may want to break this
# up into more manageable blocks if n_interactions is large
inds_i, inds_j = get_all_pairs_indices(N)
pairs = get_all_pairs_indices(N)

lj, coulomb = nonbonded_v3_on_specific_pairs(conf, params, box, inds_i, inds_j, beta, cutoff)
lj, coulomb = nonbonded_v3_on_specific_pairs(conf, params, box, pairs, beta, cutoff)

# keep only eps > 0
inds_i, inds_j = pairs.T
eps = params[:, 2]
lj = np.where(eps[inds_i] > 0, lj, 0)
lj = np.where(eps[inds_j] > 0, lj, 0)
Expand Down Expand Up @@ -304,11 +305,11 @@ def test_vmap():
n_total = n_ligand + n_environment
conf, params, box, lamb, _, _, beta, cutoff, _, _ = generate_random_inputs(n_total, 3)

inds_i, inds_j = get_group_group_indices(n_ligand, n_environment)
inds_j += n_ligand
n_interactions = len(inds_i)
pairs = get_group_group_indices(n_ligand, n_environment)
pairs[:, 1] += n_ligand
n_interactions = len(pairs)

fixed_kwargs = dict(params=params, box=box, inds_l=inds_i, inds_r=inds_j, beta=beta, cutoff=cutoff)
fixed_kwargs = dict(params=params, box=box, pairs=pairs, beta=beta, cutoff=cutoff)

# signature: conf -> ljs, coulombs, where ljs.shape == (n_interactions, )
u_pairs = partial(nonbonded_v3_on_specific_pairs, **fixed_kwargs)
Expand Down Expand Up @@ -352,9 +353,10 @@ def u_a(x, box, params):
i_s, j_s = np.indices((split, N - split))
indices_left = i_s.flatten()
indices_right = j_s.flatten() + split
pairs = np.array([indices_left, indices_right]).T

def u_b(x, box, params):
vdw, es = nonbonded_v3_on_specific_pairs(x, params, box, indices_left, indices_right, beta, cutoff)
vdw, es = nonbonded_v3_on_specific_pairs(x, params, box, pairs, beta, cutoff)

return np.sum(vdw + es)

Expand Down
32 changes: 16 additions & 16 deletions tests/test_jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def test_get_all_pairs_indices():
"""check i < j < n"""
ns = onp.random.randint(5, 50, 10)
for n in ns:
inds_i, inds_j = get_all_pairs_indices(n)
assert (inds_i < inds_j).all()
assert (inds_j < n).all()
pairs = get_all_pairs_indices(n)
assert (pairs[:, 0] < pairs[:, 1]).all()
assert (pairs < n).all()


def test_get_group_group_indices():
Expand All @@ -70,11 +70,11 @@ def test_get_group_group_indices():
ms = onp.random.randint(5, 50, 10)

for n, m in zip(ns, ms):
inds_i, inds_j = get_group_group_indices(n, m)
assert (inds_i < n).all()
assert (inds_j < m).all()
pairs = get_group_group_indices(n, m)
assert (pairs[:, 0] < n).all()
assert (pairs[:, 1] < m).all()

assert len(inds_i) == n * m
assert len(pairs) == n * m


def test_compute_lifting_parameter():
Expand Down Expand Up @@ -118,22 +118,22 @@ def test_batched_neighbor_inds():
boxes = np.array([np.eye(3)] * n_confs)

n_alchemical = 50
inds_l, inds_r = get_group_group_indices(n=n_alchemical, m=n_particles - n_alchemical)
inds_r += n_alchemical
n_possible_interactions = len(inds_l)
inds_l, inds_r = get_group_group_indices(n=n_alchemical, m=n_particles - n_alchemical).T
pairs = np.array([inds_l, inds_r + n_alchemical]).T
n_possible_interactions = len(pairs)

full_distances = vmap(distance_on_pairs)(confs[:, inds_l], confs[:, inds_r], boxes)
assert full_distances.shape == (n_confs, n_possible_interactions)

neighbor_inds_l, neighbor_inds_r = batched_neighbor_inds(confs, inds_l, inds_r, cutoff, boxes)
n_neighbor_pairs = neighbor_inds_l.shape[1]
assert neighbor_inds_r.shape == (n_confs, n_neighbor_pairs)
batch_pairs = batched_neighbor_inds(confs, pairs, cutoff, boxes)
n_neighbor_pairs = batch_pairs.shape[1]
assert batch_pairs.shape == (n_confs, n_neighbor_pairs, 2)
assert n_neighbor_pairs <= n_possible_interactions

def d(conf, inds_l, inds_r, box):
return distance_on_pairs(conf[inds_l], conf[inds_r], box)
def d(conf, pairs, box):
return distance_on_pairs(conf[pairs[:, 0]], conf[pairs[:, 1]], box)

neighbor_distances = vmap(d)(confs, neighbor_inds_l, neighbor_inds_r, boxes)
neighbor_distances = vmap(d)(confs, batch_pairs, boxes)

assert neighbor_distances.shape == (n_confs, n_neighbor_pairs)
assert np.sum(neighbor_distances < cutoff) == np.sum(full_distances < cutoff)
79 changes: 43 additions & 36 deletions timemachine/potentials/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,22 @@ def get_all_pairs_indices(n: int) -> Tuple[Array, Array]:
"""all indices i, j such that i < j < n"""
n_interactions = n * (n - 1) / 2

inds_i, inds_j = np.triu_indices(n, k=1)
pairs = np.triu_indices(n, k=1)

assert len(inds_i) == n_interactions
assert pairs.shape == (n_interactions, 2)

return inds_i, inds_j
return pairs


def get_group_group_indices(n: int, m: int) -> Tuple[Array, Array]:
"""all indices i, j such that i < n, j < m"""
n_interactions = n * m

_inds_i, _inds_j = np.indices((n, m))
inds_i, inds_j = _inds_i.flatten(), _inds_j.flatten()
pairs = np.indices((n, m)).T

assert len(inds_i) == n_interactions
assert pairs.shape == (n_interactions, 2)

return inds_i, inds_j
return pairs


def compute_lifting_parameter(lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff):
Expand Down Expand Up @@ -109,20 +108,30 @@ def distance_on_pairs(ri, rj, box=None):
return dij


def batched_neighbor_inds(confs, inds_l, inds_r, cutoff, boxes):
"""Given candidate interacting pairs (inds_l, inds_r),
inds_l.shape == n_interactions
exclude most pairs whose distances are >= cutoff (neighbor_inds_l, neighbor_inds_r)
neighbor_inds_l.shape == (len(confs), max_n_neighbors)
where the total number of neighbors returned for each conf in confs is the same
max_n_neighbors
def batched_neighbor_inds(confs, pairs, cutoff, boxes):
"""Given candidate interacting pairs, exclude most pairs whose distances are >= cutoff
This padding causes some amount of wasted effort, but keeps things nice and fixed-dimensional
for later XLA steps
Parameters
----------
confs: (n_snapshots, n_atoms, 3) float array
pairs: (n_candidate_pairs, 2) integer array
cutoff: float
boxes: (n_snapshots, 3, 3) float array
Returns
-------
batch_pairs : (len(confs), max_n_neighbors, 2) array
where max_n_neighbors pairs are returned for each conf in confs
Notes
-----
* Padding causes some amount of wasted effort, but keeps things nice and fixed-dimensional for later XLA steps
* TODO [naming]: rename to get_interacting_pair_indices_batch or similar
* TODO [usability]: reorder input arguments in less surprising way
"""
assert len(confs.shape) == 3
distances = vmap(distance_on_pairs)(confs[:, inds_l], confs[:, inds_r], boxes)
assert distances.shape == (len(confs), len(inds_l))
distances = vmap(distance_on_pairs)(confs[:, pairs[:, 0]], confs[:, pairs[:, 1]], boxes)
assert distances.shape == (len(confs), len(pairs))

neighbor_masks = distances < cutoff
# how many total neighbors?
Expand All @@ -134,13 +143,11 @@ def batched_neighbor_inds(confs, inds_l, inds_r, cutoff, boxes):

# sorting in order of [falses, ..., trues]
keep_inds = np.argsort(neighbor_masks, axis=1)[:, -max_n_neighbors:]
neighbor_inds_l = inds_l[keep_inds]
neighbor_inds_r = inds_r[keep_inds]
batch_pairs = pairs[keep_inds]

assert neighbor_inds_l.shape == (len(confs), max_n_neighbors)
assert neighbor_inds_l.shape == neighbor_inds_r.shape
assert batch_pairs.shape == (len(confs), max_n_neighbors, 2)

return neighbor_inds_l, neighbor_inds_r
return batch_pairs


def get_ligand_dependent_indices_batch(confs, boxes, ligand_indices, cutoff=1.2):
Expand All @@ -159,36 +166,36 @@ def get_ligand_dependent_indices_batch(confs, boxes, ligand_indices, cutoff=1.2)
Returns
-------
(batch_inds_l, batch_inds_r)
each of shape (len(confs), n_pairs),
where n_pairs = maximum number of interacting pairs in confs
* batch_pairs: (n_snapshots, n_pairs, 2) int array
where n_pairs = maximum number of interacting pairs in any conf
Notes
-----
* Index arrays are padded so each conf has the same number of interacting pairs -- a small fraction of the returned
pairs ij will have distance(i, j) > cutoff, so these may need to be filtered / masked again at later steps
* TODO [naming]: change to return a single [n_pairs, 2] array instead of pair of [n_pairs,] arrays?
* TODO [flexibility]: accept environment_indices instead of inferring them?
"""
n_snapshots, n_atoms, _ = confs.shape
environment_indices = np.array(list(set(onp.arange(n_atoms)) - set(onp.array(ligand_indices))))

# (ligand, environment) pairs within distance cutoff
_inds_l, _inds_r = get_group_group_indices(len(ligand_indices), len(environment_indices))
inds_l, inds_r = ligand_indices[_inds_l], environment_indices[_inds_r]
neighbor_inds_l, neighbor_inds_r = batched_neighbor_inds(confs, inds_l, inds_r, cutoff, boxes)
pairs = get_group_group_indices(len(ligand_indices), len(environment_indices))
batch_ligand_environment_pairs = batched_neighbor_inds(confs, pairs, cutoff, boxes)
n_ligand_environment_pairs = batch_ligand_environment_pairs.shape[1]

# (ligand, ligand) pairs
_l, _r = get_all_pairs_indices(len(ligand_indices))
ligand_inds_l, ligand_inds_r = ligand_indices[_l], ligand_indices[_r]
_pairs = get_all_pairs_indices(len(ligand_indices))
n_ligand_ligand_pairs = len(_pairs)
ligand_ligand_pairs = ligand_indices[_pairs]
batch_ligand_ligand_pairs = np.repeat(ligand_ligand_pairs[np.newaxis, :], n_snapshots, 0)

# concatenate
batch_inds_l = np.hstack([neighbor_inds_l, np.repeat(ligand_inds_l[np.newaxis, :], n_snapshots, 0)])
batch_inds_r = np.hstack([neighbor_inds_r, np.repeat(ligand_inds_r[np.newaxis, :], n_snapshots, 0)])
n_pairs = n_ligand_environment_pairs + n_ligand_ligand_pairs
batch_pairs = np.hstack([batch_ligand_environment_pairs, batch_ligand_ligand_pairs])

assert batch_inds_l.shape == batch_inds_r.shape
assert batch_pairs.shape == (n_snapshots, n_pairs, 2)

return batch_inds_l, batch_inds_r
return batch_pairs


def distance(x, box):
Expand Down
7 changes: 4 additions & 3 deletions timemachine/potentials/nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,16 @@ def nonbonded_v3(
return np.sum(eij_total / 2)


def nonbonded_v3_on_specific_pairs(conf, params, box, inds_l, inds_r, beta: float, cutoff: Optional[float] = None):
def nonbonded_v3_on_specific_pairs(conf, params, box, pairs, beta: float, cutoff: Optional[float] = None):
"""See nonbonded_v3 docstring for more details
Notes
-----
* Responsibility of caller to ensure pair indices (inds_l, inds_r) are complete.
In case of parameter interpolation, more pairs need to be added.
* Responsibility of caller to ensure pair indices are complete.
"""

inds_l, inds_r = pairs

# distances and cutoff
dij = distance_on_pairs(conf[inds_l], conf[inds_r], box)
if cutoff is None:
Expand Down

0 comments on commit 213752b

Please sign in to comment.