Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hamiltonian replica exchange #1128

Merged
merged 82 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
b55826a
Fix mypy warning
mcwitt Aug 23, 2023
9e30ad8
Add type annotation
mcwitt Aug 11, 2023
8b1cb63
Allow n_eq_steps=0
mcwitt Aug 16, 2023
d8bb58d
Add hrex generic implementation, rbfe specialization
mcwitt Aug 23, 2023
cc44b70
Add plots
mcwitt Aug 24, 2023
2252981
Add unit tests
mcwitt Aug 10, 2023
caaada8
Add basic integration test
mcwitt Aug 8, 2023
d83acb2
Use hrex rbfe in run_leg functions
mcwitt Aug 18, 2023
d38b083
Revert "Use hrex rbfe in run_leg functions"
mcwitt Aug 28, 2023
59b35c3
Cleanup existing docstrings
mcwitt Aug 28, 2023
799865a
Add docstrings
mcwitt Aug 28, 2023
39f7b1e
Move documentation
mcwitt Aug 28, 2023
493fe97
Consistently document references
mcwitt Aug 28, 2023
18338d1
Transpose plot for consistency
mcwitt Aug 28, 2023
fbe76f6
Clean for readability
mcwitt Aug 28, 2023
f9058f6
Fix plot legend
mcwitt Aug 28, 2023
543cac9
Rename: Choice -> MixtureOfMoves, Sequential -> SequenceOfMoves
mcwitt Aug 28, 2023
06db523
Remove mixture with identity move
mcwitt Aug 28, 2023
b3f1b3f
Fix: remove mixture with identity move
mcwitt Aug 28, 2023
61f68a6
Add (0, 0) to neighbor pairs to ensure aperiodicity
mcwitt Aug 28, 2023
bb49dd0
Tweak test thresholds, clean
mcwitt Aug 28, 2023
2e74b00
Fix: remove acceptance rate stats for (0, 0) pair
mcwitt Aug 28, 2023
eb7b6d0
Improve axis labels, naming consistency
mcwitt Aug 28, 2023
20cd12e
Fix aspect ratio, tweak subplot spacing
mcwitt Aug 28, 2023
341c286
Add transition matrix to diagnostics, plot
mcwitt Aug 29, 2023
e433699
Add relaxation time estimate to diagnostics
mcwitt Aug 29, 2023
0418a7c
Set square aspect ratio
mcwitt Aug 29, 2023
951cf26
Fix: set DEBUG=False
mcwitt Aug 29, 2023
61d1a4c
Annotate replica-state distribution heatmap
mcwitt Aug 29, 2023
770647e
Move nontrivial computations to module functions, add docstrings
mcwitt Aug 29, 2023
b4bc816
Split up large compound expression
mcwitt Aug 29, 2023
a76dc49
Remove redundant plot
mcwitt Aug 29, 2023
372fbb9
Use eigvals instead of eigvalsh
mcwitt Aug 29, 2023
be0ef78
Add one-liner docstrings to plotting functions
mcwitt Aug 29, 2023
142ecc2
Remove redundant np.array calls
mcwitt Aug 30, 2023
8a7fe73
Update timemachine/fe/free_energy.py
mcwitt Aug 30, 2023
aa3d9da
Update timemachine/md/hrex.py
mcwitt Aug 30, 2023
3b831d2
Standardize description of md_params parameter in docstrings
mcwitt Aug 30, 2023
df187e3
Add arxiv url
mcwitt Aug 30, 2023
ccbcb61
Change test from xfail to skip
mcwitt Aug 30, 2023
001202d
Avoid holding potentials for all states in memory simultaneously
mcwitt Aug 30, 2023
6ff38e4
Avoid reconstructing potentials
mcwitt Aug 30, 2023
c66710c
Add debug flag, transition matrix plot
mcwitt Sep 1, 2023
c6cda93
Precompute log_q values instead of U
mcwitt Sep 1, 2023
0282e13
Simplify
mcwitt Sep 1, 2023
9639782
Fix: computation of log_q using stale replicas
mcwitt Sep 1, 2023
5998f91
Fix heatmap annotations
mcwitt Sep 1, 2023
70d41f6
Strengthen test
mcwitt Sep 2, 2023
bd65bb4
Use raw strings for docstrings with math
mcwitt Sep 5, 2023
9753e8a
Fixes for plot readability with >10 states
mcwitt Sep 5, 2023
eae1151
Remove unnecessary iter conversion
mcwitt Sep 5, 2023
123795f
Misc fixes to plots
mcwitt Sep 5, 2023
2878505
Rename Hrex -> HREX
mcwitt Sep 5, 2023
22c1d8d
Remove leading underscore from type variables
mcwitt Sep 5, 2023
46707c9
Remove unused dataclass decorator
mcwitt Sep 5, 2023
57e8e3c
Add frozen=True to HREX
mcwitt Sep 5, 2023
3122a02
Misc fixes to plots and comments
mcwitt Sep 5, 2023
9e4fbfb
Invert order of swap attempts and MD sampling
mcwitt Sep 5, 2023
43755ba
Fix transition matrix normalization
mcwitt Sep 5, 2023
b041a7b
Minor plot improvements
mcwitt Sep 5, 2023
8406c21
Expand docstring
mcwitt Sep 5, 2023
7994a93
Add comment on symmetry
mcwitt Sep 5, 2023
d09e2fd
Fix replica state distribution convergence plot
mcwitt Sep 5, 2023
01755c6
Assert that columns of transition matrix sum to 1
mcwitt Sep 5, 2023
514d967
Use eigvalsh, add note on asymptotic symmetry assumption
mcwitt Sep 5, 2023
f67536c
Symmetrize transition matrix before eigvalsh to reduce variance
mcwitt Sep 6, 2023
814a8c8
Expand docstring to mention ordering of states
mcwitt Sep 6, 2023
6a31ac4
Add identity move to mixture only in the case of 2 states
mcwitt Sep 6, 2023
966d565
Clean up equilibration
mcwitt Sep 6, 2023
33f0e72
Add rbfe benchmark with and without hrex
mcwitt Sep 6, 2023
839cded
Revise heuristic for number of swap attempts
mcwitt Sep 6, 2023
b6d6635
Print diagnostics every N frames
mcwitt Sep 6, 2023
e8c505a
Modify test parameters to reduce runtime, increase hrex frequency
mcwitt Sep 6, 2023
448e867
Update usage in test
mcwitt Sep 6, 2023
90a0441
Standardize on 5 frames per iteration for now
mcwitt Sep 6, 2023
e9562aa
Add todo on context creation overhead
mcwitt Sep 6, 2023
bd05573
Tweak test parameters
mcwitt Sep 6, 2023
d8b91f4
Add stub for Context
mcwitt Sep 6, 2023
2db9fd7
Warn when setting global PRNG state, add todo
mcwitt Sep 6, 2023
d0f5a03
Add test for reproducibility of hrex rbfe
mcwitt Sep 6, 2023
82a499f
Reduce number of frames for benchmark
mcwitt Sep 6, 2023
e475136
Merge branch 'master' into hrex
mcwitt Sep 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 279 additions & 0 deletions tests/hrex/test_hrex_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
from dataclasses import dataclass, replace
from functools import partial
from typing import Callable, List, Protocol, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pytest
import scipy
from numpy.typing import NDArray
from scipy.special import logsumexp

from timemachine.fe.plots import (
plot_hrex_replica_state_distribution,
plot_hrex_replica_state_distribution_convergence,
plot_hrex_replica_state_distribution_heatmap,
plot_hrex_swap_acceptance_rates_convergence,
plot_hrex_transition_matrix,
)
from timemachine.md.hrex import HREXDiagnostics, ReplicaIdx, StateIdx, run_hrex
from timemachine.md.moves import MetropolisHastingsMove

DEBUG = False


class Distribution(Protocol):
def sample(self, n_samples: int) -> NDArray:
...

def log_q(self, x: float) -> float:
...


@dataclass
class Uniform:
x_1: float
x_2: float

def sample(self, n_samples: int) -> NDArray:
return np.random.uniform(self.x_1, self.x_2, size=(n_samples,))

def log_q(self, x: float) -> float:
return 0.0 if self.x_1 < x <= self.x_2 else -np.inf


@dataclass
class GaussianMixture:
locs: NDArray
scales: NDArray
log_weights: NDArray

def __post_init__(self):
assert len(self.locs) == len(self.scales) == len(self.log_weights)

def sample(self, n_samples: int) -> NDArray:
(n_components,) = self.locs.shape
probs = np.exp(self.log_weights - logsumexp(self.log_weights))
components = np.random.choice(n_components, p=probs, size=(n_samples,))
xs = np.random.normal(self.locs, self.scales, size=(n_samples, n_components))
return xs[np.arange(n_samples), components]

def log_q(self, x: float) -> float:
x_ = np.atleast_1d(np.asarray(x))
log_q = -((x_[:, None] - self.locs) ** 2) / (2 * self.scales ** 2)
return logsumexp(log_q + self.log_weights, axis=1)


def gaussian(loc: float, scale: float, log_weight: float = 0.0) -> Distribution:
return GaussianMixture(np.array([loc]), np.array([scale]), np.array([log_weight]))


@dataclass
class LocalMove(MetropolisHastingsMove[float]):
def __init__(self, proposal: Callable[[float], Distribution], target: Distribution):
super().__init__()
self.proposal = proposal
self.target = target

def propose_with_log_q_diff(self, x: float) -> Tuple[float, float]:
x_p = self.proposal(x).sample(1).item()
log_q_diff = self.target.log_q(x_p) - self.target.log_q(x)
return x_p, log_q_diff


def run_hrex_with_local_proposal(
states: Sequence[Distribution],
initial_replicas: Sequence[float],
proposal: Callable[[float], Distribution],
n_samples=10_000,
n_samples_per_iter=10,
):
assert len(states) == len(initial_replicas)

state_idxs = [StateIdx(i) for i, _ in enumerate(states)]
neighbor_pairs = list(zip(state_idxs, state_idxs[1:]))

# Add (0, 0) to the list of neighbor pairs considered for swap moves to ensure that performing a fixed number of
# neighbor swaps is aperiodic in cases where swap acceptance rates approach 100%
neighbor_pairs = [(StateIdx(0), StateIdx(0))] + neighbor_pairs

def sample_replica(replica: float, state_idx: StateIdx, n_samples: int) -> List[float]:
"""Sample replica using local moves in the specified state"""
move = LocalMove(proposal, states[state_idx])
samples = move.sample_chain(replica, n_samples)
return samples

def get_log_q_fn(replicas: Sequence[float]) -> Callable[[ReplicaIdx, StateIdx], float]:
log_q_matrix = np.array(
[[states[state_idx].log_q(replicas[replica_idx]) for state_idx in state_idxs] for replica_idx in state_idxs]
)

def log_q(replica_idx: ReplicaIdx, state_idx: StateIdx) -> float:
return log_q_matrix[replica_idx, state_idx]

return log_q

def replica_from_samples(xs: List[float]) -> float:
return xs[-1]

samples_by_state_by_iter, diagnostics = run_hrex(
initial_replicas,
sample_replica,
replica_from_samples,
neighbor_pairs,
get_log_q_fn,
n_samples=n_samples,
n_samples_per_iter=n_samples_per_iter,
)

# Remove stats for (0, 0) pair
diagnostics = replace(
diagnostics,
fraction_accepted_by_pair_by_iter=[
fraction_accepted_by_pair[1:] for fraction_accepted_by_pair in diagnostics.fraction_accepted_by_pair_by_iter
],
)

if DEBUG:
plot_hrex_diagnostics(diagnostics)
plt.show()

return samples_by_state_by_iter, diagnostics


@pytest.mark.parametrize("seed", range(5))
def test_hrex_different_distributions_same_free_energy(seed):
np.random.seed(seed)

locs = [0.0, 0.5, 1.0]
states = [gaussian(loc, 0.3) for loc in locs]
initial_replicas = locs

proposal_radius = 0.1
proposal = lambda x: gaussian(x, proposal_radius)

samples_by_state_by_iter, diagnostics = run_hrex_with_local_proposal(states, initial_replicas, proposal)

samples_by_state = np.concatenate(samples_by_state_by_iter, axis=1)

# KS test assumes independent samples
# Use a rough estimate of autocorrelation time to subsample correlated MCMC samples
tau = round(1 / proposal_radius ** 2)

(n_samples,) = samples_by_state[0].shape

ks_pvalues = [
scipy.stats.ks_2samp(samples[tau::tau], state.sample(n_samples)).pvalue
for samples, state in zip(samples_by_state, states)
]

np.testing.assert_array_less(0.01, ks_pvalues)

final_swap_acceptance_rates = diagnostics.cumulative_swap_acceptance_rates[-1]
assert np.all(final_swap_acceptance_rates > 0.2)

# Swap acceptance rates should be approximately equal between pairs
assert np.all(np.abs(final_swap_acceptance_rates - final_swap_acceptance_rates.mean()) < 0.02)

# Fraction of time spent in each state for each replica should be close to uniform
n_iters = diagnostics.cumulative_replica_state_counts.shape[0]
final_replica_state_density = diagnostics.cumulative_replica_state_counts[-1] / n_iters
assert np.all(np.abs(final_replica_state_density - np.mean(final_replica_state_density)) < 0.2)


@pytest.mark.parametrize("seed", range(5))
def test_hrex_same_distributions_different_free_energies(seed):
np.random.seed(seed)

states = [gaussian(0.0, 0.3, log_weight) for log_weight in [-1.0, 0.0, 1.0]]
initial_replicas = [0.0] * len(states)

proposal_radius = 0.1
proposal = lambda x: gaussian(x, proposal_radius)

samples_by_state_by_iter, diagnostics = run_hrex_with_local_proposal(states, initial_replicas, proposal)

samples_by_state = np.concatenate(samples_by_state_by_iter, axis=1)

# KS test assumes independent samples
# Use a rough estimate of autocorrelation time to subsample correlated MCMC samples
tau = round(1 / proposal_radius ** 2)

(n_samples,) = samples_by_state[0].shape

ks_pvalues = [
scipy.stats.ks_2samp(samples[tau::tau], state.sample(n_samples)).pvalue
for samples, state in zip(samples_by_state, states)
]

np.testing.assert_array_less(0.01, ks_pvalues)
assert np.all(diagnostics.cumulative_swap_acceptance_rates == 1.0) # difference in log(q) for swaps is always zero

n_iters = diagnostics.cumulative_replica_state_counts.shape[0]
final_replica_state_density = diagnostics.cumulative_replica_state_counts[-1] / n_iters

# Fraction of time spent in each state for each replica should be close to uniform
assert np.all(np.abs(final_replica_state_density - np.mean(final_replica_state_density)) < 0.05)


@pytest.mark.parametrize("seed", [0, 1, 2, 3, 5])
def test_hrex_gaussian_mixture(seed):
"""Use HREX to sample from a mixture of two gaussians with ~zero overlap."""

np.random.seed(seed)

states = [
GaussianMixture(np.array([0.0, 1.0]), scales=np.array([0.1, 0.1]), log_weights=np.array([0.0, 0.0])),
gaussian(0.5, 0.5),
]

# start replicas at x=0
initial_replicas = [0.0, 0.0]

proposal_radius = 0.1
proposal = lambda x: gaussian(x, proposal_radius)

samples_by_state_by_iter, diagnostics = run_hrex_with_local_proposal(states, initial_replicas, proposal)

samples_by_state = np.concatenate(samples_by_state_by_iter, axis=1)
hrex_samples = samples_by_state[0] # samples from gaussian mixture

(n_samples,) = hrex_samples.shape

local_samples_ = LocalMove(proposal, states[0]).sample_chain(initial_replicas[0], n_samples)
local_samples = np.array(local_samples_)

# HREX should sample the energy well at x=1
assert np.any(hrex_samples > 1.0)

target_samples = states[0].sample(n_samples)

# KS test assumes independent samples
# Use a rough estimate of autocorrelation time to subsample correlated MCMC samples
tau = round(1 / proposal_radius ** 2)

def compute_ks_pvalue(samples):
return scipy.stats.ks_2samp(samples[tau::tau], target_samples).pvalue

assert compute_ks_pvalue(local_samples) == pytest.approx(0.0, abs=1e-10) # local sampling alone is insufficient
assert compute_ks_pvalue(hrex_samples) > 0.01

final_swap_acceptance_rates = diagnostics.cumulative_swap_acceptance_rates[-1]
assert final_swap_acceptance_rates[0] > 0.2

if DEBUG:
plt.figure()
hist = partial(plt.hist, density=True, bins=50, alpha=0.7)
hist(target_samples, label="target")
hist(hrex_samples, label="hrex")
hist(local_samples, label="local")
plt.legend()
plt.show()


def plot_hrex_diagnostics(diagnostics: HREXDiagnostics):
plot_hrex_swap_acceptance_rates_convergence(diagnostics.cumulative_swap_acceptance_rates)
plot_hrex_transition_matrix(diagnostics.transition_matrix)
plot_hrex_replica_state_distribution(diagnostics.cumulative_replica_state_counts)
plot_hrex_replica_state_distribution_heatmap(diagnostics.cumulative_replica_state_counts)
plot_hrex_replica_state_distribution_convergence(diagnostics.cumulative_replica_state_counts)
77 changes: 77 additions & 0 deletions tests/hrex/test_hrex_alchemical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from importlib import resources
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pytest

from timemachine.fe.free_energy import HostConfig, MDParams, SimulationResult
from timemachine.fe.plots import (
plot_hrex_replica_state_distribution_convergence,
plot_hrex_replica_state_distribution_heatmap,
plot_hrex_swap_acceptance_rates_convergence,
plot_hrex_transition_matrix,
)
from timemachine.fe.rbfe import estimate_relative_free_energy_bisection_hrex
from timemachine.ff import Forcefield
from timemachine.md import builders
from timemachine.testsystems.relative import get_hif2a_ligand_pair_single_topology

DEBUG = False


@pytest.mark.nightly(reason="Slow")
@pytest.mark.parametrize("host", [None, "complex", "solvent"])
def test_hrex_rbfe_hif2a(host: Optional[str]):
mcwitt marked this conversation as resolved.
Show resolved Hide resolved
mol_a, mol_b, core = get_hif2a_ligand_pair_single_topology()
forcefield = Forcefield.load_default()
md_params = MDParams(n_frames=1000, n_eq_steps=10_000, steps_per_frame=400, seed=2023)

host_config: Optional[HostConfig] = None

if host == "complex":
with resources.path("timemachine.testsystems.data", "hif2a_nowater_min.pdb") as protein_path:
host_sys, host_conf, box, _, num_water_atoms = builders.build_protein_system(
str(protein_path), forcefield.protein_ff, forcefield.water_ff
)
box += np.diag([0.1, 0.1, 0.1]) # remove any possible clashes
host_config = HostConfig(host_sys, host_conf, box, num_water_atoms)
elif host == "solvent":
host_sys, host_conf, box, _ = builders.build_water_system(4.0, forcefield.water_ff)
box += np.diag([0.1, 0.1, 0.1]) # remove any possible clashes
host_config = HostConfig(host_sys, host_conf, box, host_conf.shape[0])

result = estimate_relative_free_energy_bisection_hrex(
mol_a,
mol_b,
core,
forcefield,
host_config,
md_params,
lambda_interval=(0.0, 0.2),
n_windows=5,
n_frames_bisection=100,
n_frames_per_iter=10,
)

if DEBUG:
plot_hrex_rbfe_hif2a(result)

assert result.hrex_diagnostics

# Swap acceptance rates for all neighboring pairs should be >~ 20%
final_swap_acceptance_rates = result.hrex_diagnostics.cumulative_swap_acceptance_rates[-1]
assert np.all(final_swap_acceptance_rates > 0.2)

# All replicas should have visited each state at least once
final_replica_state_counts = result.hrex_diagnostics.cumulative_replica_state_counts[-1]
assert np.all(final_replica_state_counts > 0)


def plot_hrex_rbfe_hif2a(result: SimulationResult):
assert result.hrex_diagnostics
plot_hrex_swap_acceptance_rates_convergence(result.hrex_diagnostics.cumulative_swap_acceptance_rates)
plot_hrex_transition_matrix(result.hrex_diagnostics.transition_matrix)
plot_hrex_replica_state_distribution_convergence(result.hrex_diagnostics.cumulative_replica_state_counts)
plot_hrex_replica_state_distribution_heatmap(result.hrex_diagnostics.cumulative_replica_state_counts)
plt.show()
Loading