Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle undetermined energies in BAR calculations #1098

Merged
merged 42 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5cc41b2
Simplify: avoid recomputing u_kln
mcwitt Jul 24, 2023
f4eb94d
Clean: use builtin alias for lru_cache(None)
mcwitt Jul 24, 2023
b651d53
Compute 2-state delta f using MBAR
mcwitt Jul 25, 2023
7e8a9b2
Add failing test
mcwitt Jul 25, 2023
c51e808
Replace NaN with inf in u_kln
mcwitt Jul 25, 2023
0b72da5
Clean: use works_from_ukln
mcwitt Jul 25, 2023
da1d363
Fix test
mcwitt Jul 25, 2023
06066fb
Update test to use df_from_u_kln
mcwitt Jul 26, 2023
352c27d
Add tests for uniform distributions with partial and zero overlap
mcwitt Jul 26, 2023
7b9c853
Clean: use df_and_err_from_u_kln
mcwitt Jul 26, 2023
f3f2442
Strengthen tests: add comparison with exact dlogZ
mcwitt Jul 26, 2023
135dad6
Clean: add file-level nogpu mark
mcwitt Jul 26, 2023
d3120b8
Strengthen partial overlap test to compare with exact result
mcwitt Jul 26, 2023
0ed9719
Fix typo, add docstring note about u_kln convention
mcwitt Jul 26, 2023
75da78a
Refactor to avoid forwarding kwargs
mcwitt Jul 26, 2023
8da93ba
Fix missed update
mcwitt Jul 26, 2023
4ab37d8
Fix sign errors
mcwitt Jul 26, 2023
1c54793
Strengthen test: also check that error estimate is consistent
mcwitt Jul 26, 2023
7495063
Add failing test
mcwitt Jul 26, 2023
54e4b5c
Switch to -log(overlap) as cost function for bisection
mcwitt Jul 26, 2023
7d41f55
Fix docstring
mcwitt Jul 26, 2023
5f93cf6
Remove relative tolerance setting
mcwitt Jul 26, 2023
ec4cc2b
Clean, remove test case with zero overlap
mcwitt Jul 26, 2023
35fe889
Merge branch 'master' into fix/handle-energy-overflow-bisection
mcwitt Jul 26, 2023
dc2c1b2
Add assertion for self-consistent iteration method
mcwitt Jul 27, 2023
358edc3
Remove timeout logic from bootstrap_bar
mcwitt Jul 27, 2023
fd95420
Reduce number of bootstrap samples
mcwitt Jul 27, 2023
11e30ff
Increase relative tolerance, reduce max iterations for MBAR
mcwitt Jul 27, 2023
be4d5c5
Use pymbar version released on pypi
mcwitt Jul 27, 2023
92a9f70
Catch pymbar exception on incomplete convergence
mcwitt Jul 27, 2023
4792f09
Update n_boostrap for consistency
mcwitt Jul 28, 2023
2e0c9c1
Clean: axis=-1 -> axis=2
mcwitt Jul 28, 2023
f4a4b54
Fix and clean test
mcwitt Jul 28, 2023
f7fe229
Add assertions for pymbar behavior
mcwitt Jul 28, 2023
0ddf9a0
Remove mentions of timeout in docstring
mcwitt Jul 28, 2023
f0f3414
Fix name of function in docstring, tweak formatting
mcwitt Jul 28, 2023
f7a1ab4
Add option to specify max solver iterations for bootstrapping
mcwitt Jul 28, 2023
6950d82
Fix typo
mcwitt Jul 28, 2023
28c6d9f
Remove obsolete assertion failure message
mcwitt Jul 28, 2023
e52fd30
Assert finite results with nan and inf inputs
mcwitt Jul 28, 2023
e01c2a3
Merge branch 'master' into fix/handle-energy-overflow-bisection
mcwitt Jul 28, 2023
4a900c6
Merge branch 'master' into fix/handle-energy-overflow-bisection
mcwitt Jul 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 129 additions & 32 deletions tests/test_bar.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,164 @@
from functools import partial
from typing import Tuple

import numpy as np
import pymbar
import pytest
from pymbar.testsystems import ExponentialTestCase, gaussian_work_example
from numpy.typing import NDArray
from pymbar.testsystems import ExponentialTestCase

from timemachine.fe.bar import (
bar_with_bootstrapped_uncertainty,
bootstrap_bar,
compute_fwd_and_reverse_df_over_time,
df_and_err_from_u_kln,
df_from_u_kln,
pair_overlap_from_ukln,
works_from_ukln,
)


def make_gaussian_ukln_example(
params_a: Tuple[float, float], params_b: Tuple[float, float], seed: int = 0, n_samples: int = 2000
) -> Tuple[NDArray, float]:
"""Generate 2-state u_kln matrix for a pair of normal distributions."""

def u(mu, sigma, x):
return (x - mu) ** 2 / (2 * sigma ** 2)

mu_a, sigma_a = params_a
mu_b, sigma_b = params_b

u_a = partial(u, mu_a, sigma_a)
u_b = partial(u, mu_b, sigma_b)

rng = np.random.default_rng(seed)

x_a = rng.normal(mu_a, sigma_a, (n_samples,))
x_b = rng.normal(mu_b, sigma_b, (n_samples,))

u_kln = np.array([[u_a(x_a), u_a(x_b)], [u_b(x_a), u_b(x_b)]])

dlogZ = np.log(sigma_a) - np.log(sigma_b)

return u_kln, dlogZ


@pytest.fixture
def partial_overlap_uniform_ukln_example():
def u_a(x):
"""Unif[0.0, 1.0], with log_Z = 0"""
in_bounds = (x > 0) * (x < 1)
return np.where(in_bounds, 0, +np.inf)

def u_b(x):
"""Unif[0.5, 1.5], with log_Z = -5"""
x_ = x - 0.5
return u_a(x_) + 5.0

rng = np.random.default_rng(2023)

x_a = rng.uniform(0, 1, (1000,))
x_b = rng.uniform(0.5, 1.5, (1000,))

assert np.isfinite(u_a(x_a)).all()
assert np.isfinite(u_b(x_b)).all()

u_kln = np.array([[u_a(x_a), u_a(x_b)], [u_b(x_a), u_b(x_b)]])
mcwitt marked this conversation as resolved.
Show resolved Hide resolved

return u_kln


@pytest.mark.nogpu
def test_bootstrap_bar():
@pytest.mark.parametrize("sigma", [0.1, 1.0, 10.0])
def test_bootstrap_bar(sigma):
np.random.seed(0)
n_bootstrap = 1000
n_bootstrap = 100

for sigma_F in [0.1, 1, 10]:
# default rbfe instance size, varying difficulty
w_F, w_R = gaussian_work_example(2000, 2000, sigma_F=sigma_F, seed=0)
# default rbfe instance size, varying difficulty
u_kln, dlogZ = make_gaussian_ukln_example((0.0, 1.0), (1.0, sigma))

# estimate 3 times
df_ref, ddf_ref = pymbar.BAR(w_F, w_R)
df_0, bootstrap_samples = bootstrap_bar(w_F, w_R, n_bootstrap=n_bootstrap)
df_1, bootstrap_sigma = bar_with_bootstrapped_uncertainty(w_F, w_R)
# estimate 3 times
df_ref, df_err_ref = df_and_err_from_u_kln(u_kln)
df_0, bootstrap_samples = bootstrap_bar(u_kln, n_bootstrap=n_bootstrap)
df_1, bootstrap_sigma = bar_with_bootstrapped_uncertainty(u_kln)

# assert estimates identical, uncertainties comparable
print(f"stddev(w_F) = {sigma_F}, bootstrap uncertainty = {bootstrap_sigma}, pymbar.BAR uncertainty = {ddf_ref}")
assert df_0 == df_ref
assert df_1 == df_ref
assert len(bootstrap_samples) == n_bootstrap, "timed out on default problem size!"
np.testing.assert_approx_equal(bootstrap_sigma, ddf_ref, significant=1)
# assert estimates identical, uncertainties comparable
print(f"bootstrap uncertainty = {bootstrap_sigma}, pymbar.MBAR uncertainty = {df_err_ref}")
assert df_0 == df_ref
assert df_1 == df_ref
assert len(bootstrap_samples) == n_bootstrap, "timed out on default problem size!"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is this still relevant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept assertion but removed misleading message in 28c6d9f

np.testing.assert_approx_equal(bootstrap_sigma, df_err_ref, significant=1)

# assert bootstrap estimate is consistent with exact result
assert df_1 == pytest.approx(dlogZ, abs=2.0 * bootstrap_sigma)


@pytest.mark.nogpu
def test_pair_overlap_from_ukln():
def gaussian_overlap(p1, p2):
def make_gaussian(params):
mu, sigma = params
@pytest.mark.parametrize("sigma", [0.1, 1.0, 10.0])
def test_df_and_err_from_u_kln_approximates_exact_result(sigma):
u_kln, dlogZ = make_gaussian_ukln_example((0.0, 1.0), (1.0, sigma))
df, df_err = df_and_err_from_u_kln(u_kln)
assert df == pytest.approx(dlogZ, abs=2.0 * df_err)


@pytest.mark.parametrize("sigma", [0.3, 1.0, 10.0])
def test_df_from_u_kln_compare_with_pymbar_bar(sigma):
"""Compare the estimator used for 2-state delta fs (currently MBAR) with pymbar.BAR as reference."""
u_kln, _ = make_gaussian_ukln_example((0.0, 1.0), (1.0, sigma))
w_F, w_R = works_from_ukln(u_kln)

def u(x):
return (x - mu) ** 2 / (2 * sigma ** 2)
df_ref, _ = pymbar.BAR(w_F, w_R)
df = df_from_u_kln(u_kln)

rng = np.random.default_rng(2022)
x = rng.normal(mu, sigma, 100)
assert df == pytest.approx(df_ref, rel=0.05, abs=0.01)

return u, x

u1, x1 = make_gaussian(p1)
u2, x2 = make_gaussian(p2)
def test_df_and_err_from_u_kln_partial_overlap(partial_overlap_uniform_ukln_example):
u_kln = partial_overlap_uniform_ukln_example

u_kln = np.array([[u1(x1), u1(x2)], [u2(x1), u2(x2)]])
w_F, w_R = works_from_ukln(u_kln)

return pair_overlap_from_ukln(u_kln)
# this example has some infinite work values
assert np.any(np.isinf(w_F)) or np.any(np.isinf(w_R))

# pymbar.BAR warns and returns zero for df and uncertainty if inf is present in either input
assert pymbar.BAR(w_F, w_R) == (0.0, 0.0)

df, df_err = df_and_err_from_u_kln(u_kln)
assert np.isfinite(df) and df != 0.0
mcwitt marked this conversation as resolved.
Show resolved Hide resolved
assert np.isfinite(df_err) and df_err > 0.0


@pytest.mark.parametrize("n", [1, 30, 1000])
def test_df_and_err_from_u_kln_zero_overlap(n):
# non-overlapping uniform distributions
ones = np.ones(n)
infs = np.inf * np.ones(n)
u_kln = np.array(
[
[ones, infs],
[infs, ones],
]
)

_, df_err = df_and_err_from_u_kln(u_kln)
assert np.isfinite(df_err) and (n == 1 or df_err > 0.0)


@pytest.mark.nogpu
def test_pair_overlap_from_ukln():
# identical distributions
np.testing.assert_allclose(gaussian_overlap((0, 1), (0, 1)), 1.0)
u_kln, _ = make_gaussian_ukln_example((0, 1), (0, 1))
assert pair_overlap_from_ukln(u_kln) == pytest.approx(1.0)

# non-overlapping
assert gaussian_overlap((0, 0.01), (1, 0.01)) < 1e-10
u_kln, _ = make_gaussian_ukln_example((0, 0.01), (1, 0.01))
assert pair_overlap_from_ukln(u_kln) < 1e-10

# overlapping
assert gaussian_overlap((0, 0.1), (0.5, 0.2)) > 0.1
u_kln, _ = make_gaussian_ukln_example((0, 0.1), (0.5, 0.2))
assert pair_overlap_from_ukln(u_kln) > 0.1


@pytest.mark.nogpu
Expand Down
32 changes: 32 additions & 0 deletions tests/test_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@

from timemachine.constants import DEFAULT_TEMP
from timemachine.fe import free_energy, topology, utils
from timemachine.fe.bar import mbar_from_u_kln
from timemachine.fe.free_energy import (
BarResult,
HostConfig,
IndeterminateEnergyWarning,
MDParams,
MinOverlapWarning,
PairBarResult,
batches,
estimate_free_energy_bar,
make_pair_bar_plots,
run_sims_with_greedy_bisection,
sample,
Expand Down Expand Up @@ -351,3 +354,32 @@ def result_with_overlap(overlap):
]
results, _, _ = run_sims_with_greedy_bisection_partial(min_overlap=0.4)
assert len(results) == 1 + 2


def test_estimate_free_energy_bar_with_energy_overflow():
"""Ensure that we handle NaNs in u_kln inputs (e.g. due to overflow in potential evaluation)."""
rng = np.random.default_rng(2023)
u_kln = rng.uniform(-1, 1, (2, 2, 100))

_ = estimate_free_energy_bar(np.array([u_kln]), DEFAULT_TEMP)

u_kln_with_nan = np.array(u_kln)
u_kln_with_nan[0, 1, 10] = np.nan

# pymbar.MBAR fails with LinAlgError
with pytest.raises(SystemError, match="LinAlgError"):
mbar_from_u_kln(u_kln_with_nan)

# should warn with NaN
with pytest.warns(IndeterminateEnergyWarning, match="NaN"):
result_with_nan = estimate_free_energy_bar(np.array([u_kln_with_nan]), DEFAULT_TEMP)

u_kln_with_inf = np.array(u_kln)
u_kln_with_inf[0, 1, 10] = np.inf

# should give the same result with inf
result_with_inf = estimate_free_energy_bar(np.array([u_kln_with_inf]), DEFAULT_TEMP)
assert result_with_nan.dG == result_with_inf.dG
assert result_with_nan.dG_err == result_with_inf.dG_err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assert finite

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in e52fd30

np.testing.assert_array_equal(result_with_nan.dG_err_by_component, result_with_inf.dG_err_by_component)
np.testing.assert_array_equal(result_with_nan.overlap, result_with_inf.overlap)
71 changes: 44 additions & 27 deletions timemachine/fe/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,37 @@ def dG_dw(w):
return dG_dw


def bootstrap_bar(w_F, w_R, n_bootstrap=1000, timeout=10) -> Tuple[float, NDArray]:
"""Subsample w_F, w_R with replacement and re-run BAR many times
def mbar_from_u_kln(u_kln: NDArray, **kwargs):
mcwitt marked this conversation as resolved.
Show resolved Hide resolved
"""Construct a pymbar.MBAR instance given a 2-state u_kln matrix."""
k, l, n = u_kln.shape
assert k == l == 2
u_kn = u_kln.reshape(k, -1)
assert u_kn.shape == (k, l * n)
N_k = n * np.ones(l)
return pymbar.MBAR(u_kln, N_k, **kwargs)


def df_and_err_from_u_kln(u_kln: NDArray, **kwargs) -> Tuple[float, float]:
"""Compute free energy difference and uncertainty given a 2-state u_kln matrix."""
mbar = mbar_from_u_kln(u_kln, **kwargs)
df, ddf = mbar.getFreeEnergyDifferences()
return df[1, 0], ddf[1, 0]


def df_from_u_kln(u_kln: NDArray, **kwargs) -> float:
"""Compute free energy difference given a 2-state u_kln matrix."""
mbar = mbar_from_u_kln(u_kln, **kwargs)
df = mbar.getFreeEnergyDifferences(compute_uncertainty=False)[0]
return df[1, 0]


def bootstrap_bar(u_kln: NDArray, n_bootstrap=1000, timeout=10) -> Tuple[float, NDArray]:
"""Given a 2-state u_kln matrix, subsample u_kln with replacement and re-run bar_from_u_kln many times
maxentile marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
w_F : array
forward works
w_R : array
reverse works
u_kln : array
2-state u_kln matrix
n_bootstrap : int
# bootstrap samples
timeout : int
maxentile marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -126,9 +148,11 @@ def bootstrap_bar(w_F, w_R, n_bootstrap=1000, timeout=10) -> Tuple[float, NDArra
* TODO[deboggle] -- upgrade from pymbar3 to pymbar4 and remove this
* TODO[performance] -- multiprocessing, if needed?
"""
full_bar_result = pymbar.BAR(w_F, w_R, compute_uncertainty=False)
mbar = mbar_from_u_kln(u_kln)

n_F, n_R = len(w_F), len(w_R)
full_bar_result = mbar.getFreeEnergyDifferences(compute_uncertainty=False)[0][1, 0]

_, _, n = u_kln.shape

bootstrap_samples = []

Expand All @@ -142,27 +166,22 @@ def bootstrap_bar(w_F, w_R, n_bootstrap=1000, timeout=10) -> Tuple[float, NDArra
if elapsed_time > timeout:
break

w_F_sample = rng.choice(w_F, size=(n_F,), replace=True)
w_R_sample = rng.choice(w_R, size=(n_R,), replace=True)
u_kln_sample = rng.choice(u_kln, size=(n,), replace=True, axis=-1)

bar_result = pymbar.BAR(
w_F=w_F_sample,
w_R=w_R_sample,
DeltaF=full_bar_result, # warm start
compute_uncertainty=False,
bar_result = df_from_u_kln(
u_kln_sample,
initial_f_k=mbar.f_k, # warm start
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: would it make sense to add bootstrap_maximum_iterations to signature of bootstrap_bar, then forward it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, it seems like useful flexibility to be able to specify max iterations here. Added in f7a1ab4

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure whether it would be useful to be able to specify max iterations separately for the point estimate and the bootstrap samples, but this can probably be added later if useful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure whether it would be useful to be able to specify max iterations separately for the point estimate and the bootstrap samples, but this can probably be added later if useful.

Separately makes sense to me (to control expense), but can be added later if needed.

relative_tolerance=1e-6, # reduce cost
mcwitt marked this conversation as resolved.
Show resolved Hide resolved
)

bootstrap_samples.append(bar_result)

return full_bar_result, np.array(bootstrap_samples)


def bar_with_bootstrapped_uncertainty(w_F, w_R, n_bootstrap=1000, timeout=10) -> Tuple[float, float]:
"""Drop-in replacement for pymbar.BAR(w_F, w_R) -> (df, ddf)
where first return is forwarded from pymbar.BAR but second return is computed by bootstrapping"""
def bar_with_bootstrapped_uncertainty(u_kln: NDArray, n_bootstrap=1000, timeout=10) -> Tuple[float, float]:
"""Given 2-state u_kln, returns free energy difference and uncertainty computed by bootstrapping."""

df, bootstrap_dfs = bootstrap_bar(w_F, w_R, n_bootstrap=n_bootstrap, timeout=timeout)
df, bootstrap_dfs = bootstrap_bar(u_kln, n_bootstrap=n_bootstrap, timeout=timeout)

# warn if bootstrap distribution deviates significantly from normality
normaltest_result = normaltest(bootstrap_dfs)
Expand All @@ -185,7 +204,7 @@ def works_from_ukln(u_kln: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:


def df_from_ukln_by_lambda(ukln_by_lambda: NDArray) -> Tuple[float, float]:
"""Extract dF and dF error compute by BAR over a series of lambda windows
"""Extract df and df error computed by BAR over a series of lambda windows

Parameters
----------
Expand All @@ -203,15 +222,13 @@ def df_from_ukln_by_lambda(ukln_by_lambda: NDArray) -> Tuple[float, float]:
win_errs = []
for lambda_idx in range(ukln_by_lambda.shape[0]):
window_ukln = ukln_by_lambda[lambda_idx]

w_fwd, w_rev = works_from_ukln(window_ukln)
dF, dF_err = pymbar.BAR(w_fwd, w_rev)
win_dfs.append(dF)
win_errs.append(dF_err)
df, df_err = df_and_err_from_u_kln(window_ukln)
win_dfs.append(df)
win_errs.append(df_err)
return np.sum(win_dfs), np.linalg.norm(win_errs) # type: ignore


def pair_overlap_from_ukln(u_kln: np.ndarray) -> float:
def pair_overlap_from_ukln(u_kln: NDArray) -> float:
"""Compute the off-diagonal entry of 2x2 MBAR overlap matrix,
and normalize to interval [0,1]

Expand Down
Loading