-
Notifications
You must be signed in to change notification settings - Fork 18
Fully variable nonbonded lambda interpolation #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from jax.config import config; config.update("jax_enable_x64", True) | ||
|
||
import copy | ||
import functools | ||
import numpy as np | ||
import jax.numpy as jnp | ||
|
||
from common import GradientTest | ||
from timemachine.lib import potentials | ||
|
@@ -83,4 +84,104 @@ def test_nonbonded(self): | |
test_interpolated_potential, | ||
rtol, | ||
precision=precision | ||
) | ||
|
||
|
||
def test_nonbonded_advanced(self): | ||
|
||
# This test checks that we can supply arbitrary transformations of lambda to | ||
# the nonbonded potential, and that the resulting derivatives (both du/dp and du/dl) | ||
# are correct. | ||
|
||
np.random.seed(4321) | ||
D = 3 | ||
|
||
cutoff = 1.0 | ||
size = 36 | ||
|
||
water_coords = self.get_water_coords(D, sort=False) | ||
coords = water_coords[:size] | ||
padding = 0.2 | ||
diag = np.amax(coords, axis=0) - np.amin(coords, axis=0) + padding | ||
box = np.eye(3) | ||
np.fill_diagonal(box, diag) | ||
|
||
N = coords.shape[0] | ||
|
||
lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) | ||
lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32) | ||
|
||
|
||
def transform_q(lamb): | ||
return lamb*lamb | ||
|
||
def transform_s(lamb): | ||
return jnp.sin(lamb*np.pi/2) | ||
|
||
def transform_e(lamb): | ||
return jnp.cos(lamb*np.pi/2) | ||
|
||
def transform_w(lamb): | ||
return (1-lamb*lamb) | ||
|
||
# E = 0 # DEBUG! | ||
qlj_src, ref_potential, test_potential = prepare_water_system( | ||
coords, | ||
lambda_plane_idxs, | ||
lambda_offset_idxs, | ||
p_scale=1.0, | ||
cutoff=cutoff | ||
) | ||
|
||
qlj_dst, _, _ = prepare_water_system( | ||
coords, | ||
lambda_plane_idxs, | ||
lambda_offset_idxs, | ||
p_scale=1.0, | ||
cutoff=cutoff | ||
) | ||
|
||
def interpolate_params(lamb, qlj_src, qlj_dst): | ||
new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0] | ||
new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1] | ||
new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2] | ||
return jnp.stack([new_q, new_s, new_e], axis=1) | ||
|
||
def u_reference(x, params, box, lamb): | ||
d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb)) | ||
d4 = jnp.expand_dims(d4, axis=-1) | ||
x = jnp.concatenate((x, d4), axis=1) | ||
|
||
qlj_src = params[:len(params)//2] | ||
qlj_dst = params[len(params)//2:] | ||
qlj = interpolate_params(lamb, qlj_src, qlj_dst) | ||
return ref_potential(x, qlj, box, lamb) | ||
|
||
for precision, rtol in [(np.float64, 1e-8), (np.float32, 1e-4)]: | ||
|
||
for lamb in [0.0, 0.2, 1.0]: | ||
|
||
qlj = np.concatenate([qlj_src, qlj_dst]) | ||
|
||
print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape) | ||
|
||
args = copy.deepcopy(test_potential.args) | ||
args.append("lambda*lambda") # transform q | ||
args.append("sin(lambda*PI/2)") # transform sigma | ||
args.append("cos(lambda*PI/2)") # transform epsilon | ||
args.append("1-lambda*lambda") # transform w | ||
Comment on lines
+169
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! This seems highly flexible, and already covers the primary use case I can imagine. (After optimizing protocols represented parametrically as Where should we document the allowable syntax for these expressions? A few specific questions:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right now no, but you can always just hard code in the constants. It woudln't hard to add in other constants though.
Yes absolutely, but right now the C++ code is written as
Currently no, but if these are forcefield independent attributes, we may be able to support them without too much difficulty. We probably won't be able to do derivatives for the per-atom attributes though, since the forward-mode AD/CSD is only efficient for
Anything that doesn't break the compiler :) |
||
|
||
test_interpolated_potential = potentials.NonbondedInterpolated( | ||
*args, | ||
) | ||
|
||
self.compare_forces( | ||
coords, | ||
qlj, | ||
box, | ||
lamb, | ||
u_reference, | ||
test_interpolated_potential, | ||
rtol, | ||
precision=precision | ||
) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maxentile usage example and test here