Skip to content

Commit

Permalink
Update nightly test for (x,) -> (x,v) return type (#736)
Browse files Browse the repository at this point in the history
* Address #734 (comment)
* Add -gpu tag to lint stage
  • Loading branch information
maxentile authored May 5, 2022
1 parent a929e24 commit fd14908
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ lint:
needs: ["docker_build"]
tags:
- timemachine
- gpu
rules:
- if: $CI_EXTERNAL_PULL_REQUEST_IID
- if: $NIGHTLY_TESTS
Expand Down
10 changes: 6 additions & 4 deletions tests/test_vacuum_importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def test_vacuum_importance_sampling():
# (ytz): hacky as hell, needs to be divisible by # of hyperthreaded cores
num_samples = 120000

weighted_samples, log_weights = enhanced.generate_log_weighted_samples(
weighted_xv_samples, log_weights = enhanced.generate_log_weighted_samples(
mol, temperature, state.U_easy, state.U_decharged, seed, num_batches=num_samples
)

enhanced_samples = enhanced.sample_from_log_weights(weighted_samples, log_weights, 100000)
enhanced_xv_samples = enhanced.sample_from_log_weights(weighted_xv_samples, log_weights, 100000)
enhanced_samples = np.array([x for (x, v) in enhanced_xv_samples])

@jax.jit
def get_torsion(x_l):
Expand All @@ -76,7 +77,7 @@ def get_torsion(x_l):
# check for symmetry about theta=0
assert np.mean((enhanced_torsions_lhs - enhanced_torsions_rhs[::-1]) ** 2) < 5e-2

weighted_samples, log_weights = enhanced.generate_log_weighted_samples(
weighted_xv_samples, log_weights = enhanced.generate_log_weighted_samples(
mol,
temperature,
state.U_decharged,
Expand All @@ -85,7 +86,8 @@ def get_torsion(x_l):
seed=seed,
)

vanilla_samples = enhanced.sample_from_log_weights(weighted_samples, log_weights, 100000)
vanilla_xv_samples = enhanced.sample_from_log_weights(weighted_xv_samples, log_weights, 100000)
vanilla_samples = np.array([x for (x, v) in vanilla_xv_samples])

vanilla_torsions = batch_torsion_fn(vanilla_samples).reshape(-1)
vanilla_samples_rhs, _ = np.histogram(vanilla_torsions, bins=50, range=(0, np.pi), density=True)
Expand Down

0 comments on commit fd14908

Please sign in to comment.