Skip to content

Fully variable nonbonded lambda interpolation #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion tests/test_parameter_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from jax.config import config; config.update("jax_enable_x64", True)

import copy
import functools
import numpy as np
import jax.numpy as jnp

from common import GradientTest
from timemachine.lib import potentials
Expand Down Expand Up @@ -83,4 +84,104 @@ def test_nonbonded(self):
test_interpolated_potential,
rtol,
precision=precision
)


def test_nonbonded_advanced(self):

# This test checks that we can supply arbitrary transformations of lambda to
# the nonbonded potential, and that the resulting derivatives (both du/dp and du/dl)
# are correct.

np.random.seed(4321)
D = 3

cutoff = 1.0
size = 36

water_coords = self.get_water_coords(D, sort=False)
coords = water_coords[:size]
padding = 0.2
diag = np.amax(coords, axis=0) - np.amin(coords, axis=0) + padding
box = np.eye(3)
np.fill_diagonal(box, diag)

N = coords.shape[0]

lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32)
lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32)


def transform_q(lamb):
return lamb*lamb

def transform_s(lamb):
return jnp.sin(lamb*np.pi/2)

def transform_e(lamb):
return jnp.cos(lamb*np.pi/2)

def transform_w(lamb):
return (1-lamb*lamb)

# E = 0 # DEBUG!
qlj_src, ref_potential, test_potential = prepare_water_system(
coords,
lambda_plane_idxs,
lambda_offset_idxs,
p_scale=1.0,
cutoff=cutoff
)

qlj_dst, _, _ = prepare_water_system(
coords,
lambda_plane_idxs,
lambda_offset_idxs,
p_scale=1.0,
cutoff=cutoff
)

def interpolate_params(lamb, qlj_src, qlj_dst):
new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0]
new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1]
new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2]
return jnp.stack([new_q, new_s, new_e], axis=1)

def u_reference(x, params, box, lamb):
d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb))
d4 = jnp.expand_dims(d4, axis=-1)
x = jnp.concatenate((x, d4), axis=1)

qlj_src = params[:len(params)//2]
qlj_dst = params[len(params)//2:]
qlj = interpolate_params(lamb, qlj_src, qlj_dst)
return ref_potential(x, qlj, box, lamb)

for precision, rtol in [(np.float64, 1e-8), (np.float32, 1e-4)]:

for lamb in [0.0, 0.2, 1.0]:
Copy link
Owner Author

@proteneer proteneer Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxentile usage example and test here


qlj = np.concatenate([qlj_src, qlj_dst])

print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape)

args = copy.deepcopy(test_potential.args)
args.append("lambda*lambda") # transform q
args.append("sin(lambda*PI/2)") # transform sigma
args.append("cos(lambda*PI/2)") # transform epsilon
args.append("1-lambda*lambda") # transform w
Comment on lines +169 to +172
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This seems highly flexible, and already covers the primary use case I can imagine. (After optimizing protocols represented parametrically as protocol(lam, weights) = dot(weights, basis_expand(lam)), we can export to a string where the optimized weights appear as literals.

Where should we document the allowable syntax for these expressions? A few specific questions:

  • Are there any other global constants that can be referenced in these expressions aside from PI?
  • Is it possible to define and reuse variables here? (Something like 0.1 * x + 0.2 * x*x + 0.3 * x*x*x ; x=(1 - lambda*lambda);)
  • Is it possible to reference per-atom attributes in these expressions?
  • Is there a practical limit on the length of these expressions?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other global constants that can be referenced in these expressions aside from PI?

Right now no, but you can always just hard code in the constants. It woudln't hard to add in other constants though.

Is it possible to define and reuse variables here? (Something like 0.1 * x + 0.2 * xx + 0.3 * xxx ; x=(1 - lambdalambda);)

Yes absolutely, but right now the C++ code is written as return CUSTOM_EXPRESSION; to support proper branching (eg. if statements beyond simple ternary operators). Currently it's similar to lambda expressions, but we can definitely relax this without difficulty.

Is it possible to reference per-atom attributes in these expressions?

Currently no, but if these are forcefield independent attributes, we may be able to support them without too much difficulty. We probably won't be able to do derivatives for the per-atom attributes though, since the forward-mode AD/CSD is only efficient for R^1->R^N.

Is there a practical limit on the length of these expressions?

Anything that doesn't break the compiler :)


test_interpolated_potential = potentials.NonbondedInterpolated(
*args,
)

self.compare_forces(
coords,
qlj,
box,
lamb,
u_reference,
test_interpolated_potential,
rtol,
precision=precision
)
2 changes: 1 addition & 1 deletion timemachine/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/eigen)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/${CUB_SRC_DIR})

set_property(TARGET ${LIBRARY_NAME} PROPERTY CUDA_STANDARD 14)
target_link_libraries(${LIBRARY_NAME} PRIVATE -lcurand -lcudart -lcudadevrt)
target_link_libraries(${LIBRARY_NAME} PRIVATE -lcurand -lcuda -lcudart -lcudadevrt -lnvrtc)
set_target_properties(${LIBRARY_NAME} PROPERTIES PREFIX "")

install(TARGETS ${LIBRARY_NAME} DESTINATION "lib")
2 changes: 2 additions & 0 deletions timemachine/cpp/src/gpu_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ curandStatus_t templateCurandNormal(





#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
Expand Down
Loading