Skip to content

Commit

Permalink
Add jitify code
Browse files Browse the repository at this point in the history
Clean-up

WIP

Fix du_dps

Clean-up

Improve default kwargs

Allow lambda for w to be interpolated

Add test for lambda_w interpolation

AVoid hard coded source paths

WIP

Update paths

WIP

WIP

Update src path

WIP

Trigger tests
  • Loading branch information
proteneer committed May 3, 2021
1 parent 905736c commit e228bc3
Show file tree
Hide file tree
Showing 11 changed files with 4,836 additions and 162 deletions.
103 changes: 102 additions & 1 deletion tests/test_parameter_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from jax.config import config; config.update("jax_enable_x64", True)

import copy
import functools
import numpy as np
import jax.numpy as jnp

from common import GradientTest
from timemachine.lib import potentials
Expand Down Expand Up @@ -83,4 +84,104 @@ def test_nonbonded(self):
test_interpolated_potential,
rtol,
precision=precision
)


def test_nonbonded_advanced(self):

# This test checks that we can supply arbitrary transformations of lambda to
# the nonbonded potential, and that the resulting derivatives (both du/dp and du/dl)
# are correct.

np.random.seed(4321)
D = 3

cutoff = 1.0
size = 36

water_coords = self.get_water_coords(D, sort=False)
coords = water_coords[:size]
padding = 0.2
diag = np.amax(coords, axis=0) - np.amin(coords, axis=0) + padding
box = np.eye(3)
np.fill_diagonal(box, diag)

N = coords.shape[0]

lambda_plane_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32)
lambda_offset_idxs = np.random.randint(low=0, high=2, size=N, dtype=np.int32)

for precision, rtol in [(np.float64, 1e-8), (np.float32, 1e-4)]:

# E = 0 # DEBUG!
qlj_src, ref_potential, test_potential = prepare_water_system(
coords,
lambda_plane_idxs,
lambda_offset_idxs,
p_scale=1.0,
cutoff=cutoff
)

qlj_dst, _, _ = prepare_water_system(
coords,
lambda_plane_idxs,
lambda_offset_idxs,
p_scale=1.0,
cutoff=cutoff
)

def transform_q(lamb):
return lamb*lamb

def transform_s(lamb):
return jnp.sin(lamb*np.pi/2)

def transform_e(lamb):
return jnp.cos(lamb*np.pi/2)

def transform_w(lamb):
return (1-lamb*lamb)

def interpolate_params(lamb, qlj_src, qlj_dst):
new_q = (1-transform_q(lamb))*qlj_src[:, 0] + transform_q(lamb)*qlj_dst[:, 0]
new_s = (1-transform_s(lamb))*qlj_src[:, 1] + transform_s(lamb)*qlj_dst[:, 1]
new_e = (1-transform_e(lamb))*qlj_src[:, 2] + transform_e(lamb)*qlj_dst[:, 2]
return jnp.stack([new_q, new_s, new_e], axis=1)

def u_reference(x, params, box, lamb):
d4 = cutoff*(lambda_plane_idxs + lambda_offset_idxs*transform_w(lamb))
d4 = jnp.expand_dims(d4, axis=-1)
x = jnp.concatenate((x, d4), axis=1)

qlj_src = params[:len(params)//2]
qlj_dst = params[len(params)//2:]
qlj = interpolate_params(lamb, qlj_src, qlj_dst)
return ref_potential(x, qlj, box, lamb)


for lamb in [0.0, 0.2, 1.0]:

qlj = np.concatenate([qlj_src, qlj_dst])

print("lambda", lamb, "cutoff", cutoff, "precision", precision, "xshape", coords.shape)

args = copy.deepcopy(test_potential.args)
args.append("lambda*lambda") # transform q
args.append("sin(lambda*PI/2)") # transform sigma
args.append("cos(lambda*PI/2)") # transform epsilon
args.append("1-lambda*lambda") # transform w

test_interpolated_potential = potentials.NonbondedInterpolated(
*args,
)

self.compare_forces(
coords,
qlj,
box,
lamb,
u_reference,
test_interpolated_potential,
rtol,
precision=precision
)
2 changes: 1 addition & 1 deletion timemachine/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/eigen)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/${CUB_SRC_DIR})

set_property(TARGET ${LIBRARY_NAME} PROPERTY CUDA_STANDARD 14)
target_link_libraries(${LIBRARY_NAME} PRIVATE -lcurand -lcudart -lcudadevrt)
target_link_libraries(${LIBRARY_NAME} PRIVATE -lcurand -lcuda -lcudart -lcudadevrt -lnvrtc)
set_target_properties(${LIBRARY_NAME} PROPERTIES PREFIX "")

install(TARGETS ${LIBRARY_NAME} DESTINATION "lib")
12 changes: 12 additions & 0 deletions timemachine/cpp/src/gpu_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ curandStatus_t templateCurandNormal(





#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
Expand All @@ -36,6 +38,16 @@ inline void curandAssert(curandStatus_t code, const char *file, int line, bool a
}
}

#define NVRTC_SAFE_CALL(x) \
do { \
nvrtcResult result = x; \
if (result != NVRTC_SUCCESS) { \
std::cerr << "\nerror: " #x " failed with error " \
<< nvrtcGetErrorString(result) << '\n'; \
exit(1); \
} \
} while(0)

// safe is for use of gpuErrchk
template<typename T>
T* gpuErrchkCudaMallocAndCopy(const T *host_array, int count) {
Expand Down
Loading

0 comments on commit e228bc3

Please sign in to comment.