Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle undetermined energies in BAR calculations #1098

Merged
merged 42 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5cc41b2
Simplify: avoid recomputing u_kln
mcwitt Jul 24, 2023
f4eb94d
Clean: use builtin alias for lru_cache(None)
mcwitt Jul 24, 2023
b651d53
Compute 2-state delta f using MBAR
mcwitt Jul 25, 2023
7e8a9b2
Add failing test
mcwitt Jul 25, 2023
c51e808
Replace NaN with inf in u_kln
mcwitt Jul 25, 2023
0b72da5
Clean: use works_from_ukln
mcwitt Jul 25, 2023
da1d363
Fix test
mcwitt Jul 25, 2023
06066fb
Update test to use df_from_u_kln
mcwitt Jul 26, 2023
352c27d
Add tests for uniform distributions with partial and zero overlap
mcwitt Jul 26, 2023
7b9c853
Clean: use df_and_err_from_u_kln
mcwitt Jul 26, 2023
f3f2442
Strengthen tests: add comparison with exact dlogZ
mcwitt Jul 26, 2023
135dad6
Clean: add file-level nogpu mark
mcwitt Jul 26, 2023
d3120b8
Strengthen partial overlap test to compare with exact result
mcwitt Jul 26, 2023
0ed9719
Fix typo, add docstring note about u_kln convention
mcwitt Jul 26, 2023
75da78a
Refactor to avoid forwarding kwargs
mcwitt Jul 26, 2023
8da93ba
Fix missed update
mcwitt Jul 26, 2023
4ab37d8
Fix sign errors
mcwitt Jul 26, 2023
1c54793
Strengthen test: also check that error estimate is consistent
mcwitt Jul 26, 2023
7495063
Add failing test
mcwitt Jul 26, 2023
54e4b5c
Switch to -log(overlap) as cost function for bisection
mcwitt Jul 26, 2023
7d41f55
Fix docstring
mcwitt Jul 26, 2023
5f93cf6
Remove relative tolerance setting
mcwitt Jul 26, 2023
ec4cc2b
Clean, remove test case with zero overlap
mcwitt Jul 26, 2023
35fe889
Merge branch 'master' into fix/handle-energy-overflow-bisection
mcwitt Jul 26, 2023
dc2c1b2
Add assertion for self-consistent iteration method
mcwitt Jul 27, 2023
358edc3
Remove timeout logic from bootstrap_bar
mcwitt Jul 27, 2023
fd95420
Reduce number of bootstrap samples
mcwitt Jul 27, 2023
11e30ff
Increase relative tolerance, reduce max iterations for MBAR
mcwitt Jul 27, 2023
be4d5c5
Use pymbar version released on pypi
mcwitt Jul 27, 2023
92a9f70
Catch pymbar exception on incomplete convergence
mcwitt Jul 27, 2023
4792f09
Update n_boostrap for consistency
mcwitt Jul 28, 2023
2e0c9c1
Clean: axis=-1 -> axis=2
mcwitt Jul 28, 2023
f4a4b54
Fix and clean test
mcwitt Jul 28, 2023
f7fe229
Add assertions for pymbar behavior
mcwitt Jul 28, 2023
0ddf9a0
Remove mentions of timeout in docstring
mcwitt Jul 28, 2023
f0f3414
Fix name of function in docstring, tweak formatting
mcwitt Jul 28, 2023
f7a1ab4
Add option to specify max solver iterations for bootstrapping
mcwitt Jul 28, 2023
6950d82
Fix typo
mcwitt Jul 28, 2023
28c6d9f
Remove obsolete assertion failure message
mcwitt Jul 28, 2023
e52fd30
Assert finite results with nan and inf inputs
mcwitt Jul 28, 2023
e01c2a3
Merge branch 'master' into fix/handle-energy-overflow-bisection
mcwitt Jul 28, 2023
4a900c6
Merge branch 'master' into fix/handle-energy-overflow-bisection
mcwitt Jul 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy==1.23.5
scipy==1.10.1
pymbar==3.0.5
pymbar==3.1.0
networkx==2.8.8
matplotlib==3.7.1
rdkit==2022.3.5
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def build_extension(self, ext):
"jaxlib>0.4.1",
"networkx",
"numpy",
"pymbar>3.0.4,<4",
"pymbar>=3.0.6,<4",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should this be 3.1.0 or higher? Not sure why this is different from requirements

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3.0.6 definitely won't work because of choderalab/pymbar#425, which was merged in 3.0.6. In general I prefer to use version constraints in setup.py only to exclude known-incompatible versions.

When upgrading, I opted to use the latest release <4.

"rdkit",
"scipy",
"matplotlib",
Expand Down
192 changes: 157 additions & 35 deletions tests/test_bar.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,192 @@
from functools import partial
from typing import Tuple

import numpy as np
import pymbar
import pytest
from pymbar.testsystems import ExponentialTestCase, gaussian_work_example
from numpy.typing import NDArray
from pymbar.testsystems import ExponentialTestCase

from timemachine.fe.bar import (
bar_with_bootstrapped_uncertainty,
bootstrap_bar,
compute_fwd_and_reverse_df_over_time,
df_and_err_from_u_kln,
df_from_u_kln,
pair_overlap_from_ukln,
ukln_to_ukn,
works_from_ukln,
)

pytestmark = [pytest.mark.nogpu]


def make_gaussian_ukln_example(
params_a: Tuple[float, float], params_b: Tuple[float, float], seed: int = 0, n_samples: int = 2000
) -> Tuple[NDArray, float]:
"""Generate 2-state u_kln matrix for a pair of normal distributions."""

def u(mu, sigma, x):
return (x - mu) ** 2 / (2 * sigma ** 2)

mu_a, sigma_a = params_a
mu_b, sigma_b = params_b

u_a = partial(u, mu_a, sigma_a)
u_b = partial(u, mu_b, sigma_b)

rng = np.random.default_rng(seed)

x_a = rng.normal(mu_a, sigma_a, (n_samples,))
x_b = rng.normal(mu_b, sigma_b, (n_samples,))

u_kln = np.array([[u_a(x_a), u_a(x_b)], [u_b(x_a), u_b(x_b)]])

dlogZ = np.log(sigma_a) - np.log(sigma_b)

return u_kln, dlogZ


def make_partial_overlap_uniform_ukln_example(dlogZ: float, n_samples: int = 100) -> NDArray:
"""Generate 2-state u_kln matrix for uniform distributions with partial overlap"""

def u_a(x):
"""Unif[0.0, 1.0], with log(Z) = 0"""
in_bounds = (x > 0) * (x < 1)
return np.where(in_bounds, 0, +np.inf)

def u_b(x):
"""Unif[0.5, 1.5], with log(Z) = dlogZ"""
x_ = x - 0.5
return u_a(x_) + dlogZ

rng = np.random.default_rng(2023)

x_a = rng.uniform(0, 1, (n_samples,))
x_b = rng.uniform(0.5, 1.5, (n_samples,))

assert np.isfinite(u_a(x_a)).all()
assert np.isfinite(u_b(x_b)).all()

u_kln = np.array([[u_a(x_a), u_a(x_b)], [u_b(x_a), u_b(x_b)]])

@pytest.mark.nogpu
def test_bootstrap_bar():
return u_kln


@pytest.mark.parametrize("sigma", [0.1, 1.0, 10.0])
def test_bootstrap_bar(sigma):
np.random.seed(0)
n_bootstrap = 1000
n_bootstrap = 100

for sigma_F in [0.1, 1, 10]:
# default rbfe instance size, varying difficulty
w_F, w_R = gaussian_work_example(2000, 2000, sigma_F=sigma_F, seed=0)
# default rbfe instance size, varying difficulty
u_kln, dlogZ = make_gaussian_ukln_example((0.0, 1.0), (1.0, sigma))

# estimate 3 times
df_ref, ddf_ref = pymbar.BAR(w_F, w_R)
df_0, bootstrap_samples = bootstrap_bar(w_F, w_R, n_bootstrap=n_bootstrap)
df_1, bootstrap_sigma = bar_with_bootstrapped_uncertainty(w_F, w_R)
# estimate 3 times
df_ref, df_err_ref = df_and_err_from_u_kln(u_kln)
df_0, bootstrap_samples = bootstrap_bar(u_kln, n_bootstrap=n_bootstrap)
df_1, bootstrap_sigma = bar_with_bootstrapped_uncertainty(u_kln)

# assert estimates identical, uncertainties comparable
print(f"stddev(w_F) = {sigma_F}, bootstrap uncertainty = {bootstrap_sigma}, pymbar.BAR uncertainty = {ddf_ref}")
assert df_0 == df_ref
assert df_1 == df_ref
assert len(bootstrap_samples) == n_bootstrap, "timed out on default problem size!"
np.testing.assert_approx_equal(bootstrap_sigma, ddf_ref, significant=1)
# assert estimates identical, uncertainties comparable
print(f"bootstrap uncertainty = {bootstrap_sigma}, pymbar.MBAR uncertainty = {df_err_ref}")
assert df_0 == df_ref
assert df_1 == df_ref
assert len(bootstrap_samples) == n_bootstrap
np.testing.assert_approx_equal(bootstrap_sigma, df_err_ref, significant=1)

# assert bootstrap estimate is consistent with exact result
assert df_1 == pytest.approx(dlogZ, abs=2.0 * bootstrap_sigma)

@pytest.mark.nogpu
def test_pair_overlap_from_ukln():
def gaussian_overlap(p1, p2):
def make_gaussian(params):
mu, sigma = params

def u(x):
return (x - mu) ** 2 / (2 * sigma ** 2)
@pytest.mark.parametrize("sigma", [0.1, 1.0, 10.0])
def test_df_from_u_kln_consistent_with_df_and_err_from_u_kln(sigma):
u_kln, _ = make_gaussian_ukln_example((0.0, 1.0), (1.0, sigma))
df_ref, _ = df_and_err_from_u_kln(u_kln)
df = df_from_u_kln(u_kln)
assert df == df_ref


@pytest.mark.parametrize("sigma", [0.1, 1.0, 10.0])
def test_df_and_err_from_u_kln_approximates_exact_result(sigma):
u_kln, dlogZ = make_gaussian_ukln_example((0.0, 1.0), (1.0, sigma))
df, df_err = df_and_err_from_u_kln(u_kln)
assert df == pytest.approx(dlogZ, abs=2.0 * df_err)


@pytest.mark.parametrize("sigma", [0.3, 1.0, 10.0])
def test_df_and_err_from_u_kln_consistent_with_pymbar_bar(sigma):
"""Compare the estimator used for 2-state delta fs (currently MBAR) with pymbar.BAR as reference."""
u_kln, _ = make_gaussian_ukln_example((0.0, 1.0), (1.0, sigma))
w_F, w_R = works_from_ukln(u_kln)

df_ref, df_err_ref = pymbar.BAR(w_F, w_R)
df, df_err = df_and_err_from_u_kln(u_kln)

assert df == pytest.approx(df_ref, rel=0.05, abs=0.01)
np.testing.assert_approx_equal(df_err, df_err_ref, significant=1)


rng = np.random.default_rng(2022)
x = rng.normal(mu, sigma, 100)
def test_df_and_err_from_u_kln_partial_overlap():
dlogZ = 5.0
u_kln = make_partial_overlap_uniform_ukln_example(dlogZ)

return u, x
w_F, w_R = works_from_ukln(u_kln)

u1, x1 = make_gaussian(p1)
u2, x2 = make_gaussian(p2)
# this example has some infinite work values
assert np.any(np.isinf(w_F))
assert np.any(np.isinf(w_R))

u_kln = np.array([[u1(x1), u1(x2)], [u2(x1), u2(x2)]])
# but no NaNs
assert not np.any(np.isnan(w_F))
assert not np.any(np.isnan(w_R))

return pair_overlap_from_ukln(u_kln)
# pymbar.BAR warns and returns zero for df and uncertainty with default method
assert pymbar.BAR(w_F, w_R) == (0.0, 0.0)

# pymbar.BAR returns NaNs with self-consistent iteration method
df_sci, df_err_sci = pymbar.BAR(w_F, w_R, method="self-consistent-iteration")
assert np.isnan(df_sci)
assert np.isnan(df_err_sci)

df, df_err = df_and_err_from_u_kln(u_kln)
assert df == pytest.approx(dlogZ, abs=2.0 * df_err)
assert np.isfinite(df_err) and df_err > 0.0


def test_df_from_u_kln_does_not_raise_on_incomplete_convergence():
u_kln = make_partial_overlap_uniform_ukln_example(5.0)

# pymbar raises an exception on incomplete convergence when computing covariances
u_kn, N_k = ukln_to_ukn(u_kln)
mbar = pymbar.MBAR(u_kn, N_k, maximum_iterations=1)
with pytest.raises(pymbar.utils.ParameterError):
_ = mbar.getFreeEnergyDifferences()

# no exception if we don't compute uncertainty
_ = mbar.getFreeEnergyDifferences(compute_uncertainty=False)

# df_from_u_kln, df_and_err_from_u_kln wrappers do not raise exceptions
df = df_from_u_kln(u_kln, maximum_iterations=1)
assert np.isfinite(df)

df, ddf = df_and_err_from_u_kln(u_kln, maximum_iterations=1)
assert np.isfinite(df)
assert np.isnan(ddf) # returns NaN for uncertainty on incomplete convergence


def test_pair_overlap_from_ukln():
# identical distributions
np.testing.assert_allclose(gaussian_overlap((0, 1), (0, 1)), 1.0)
u_kln, _ = make_gaussian_ukln_example((0, 1), (0, 1))
assert pair_overlap_from_ukln(u_kln) == pytest.approx(1.0)

# non-overlapping
assert gaussian_overlap((0, 0.01), (1, 0.01)) < 1e-10
u_kln, _ = make_gaussian_ukln_example((0, 0.01), (1, 0.01))
assert pair_overlap_from_ukln(u_kln) < 1e-10

# overlapping
assert gaussian_overlap((0, 0.1), (0.5, 0.2)) > 0.1
u_kln, _ = make_gaussian_ukln_example((0, 0.1), (0.5, 0.2))
assert pair_overlap_from_ukln(u_kln) > 0.1


@pytest.mark.nogpu
@pytest.mark.parametrize("frames_per_step", [1, 5, 10])
def test_compute_fwd_and_reverse_df_over_time(frames_per_step):
seed = 2023
Expand Down
37 changes: 37 additions & 0 deletions tests/test_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch

import numpy as np
import pymbar
import pytest
from hypothesis import example, given, seed
from hypothesis.strategies import integers
Expand All @@ -11,13 +12,16 @@

from timemachine.constants import DEFAULT_TEMP
from timemachine.fe import free_energy, topology, utils
from timemachine.fe.bar import ukln_to_ukn
from timemachine.fe.free_energy import (
BarResult,
HostConfig,
IndeterminateEnergyWarning,
MDParams,
MinOverlapWarning,
PairBarResult,
batches,
estimate_free_energy_bar,
make_pair_bar_plots,
run_sims_with_greedy_bisection,
sample,
Expand Down Expand Up @@ -351,3 +355,36 @@ def result_with_overlap(overlap):
]
results, _, _ = run_sims_with_greedy_bisection_partial(min_overlap=0.4)
assert len(results) == 1 + 2


def test_estimate_free_energy_bar_with_energy_overflow():
"""Ensure that we handle NaNs in u_kln inputs (e.g. due to overflow in potential evaluation)."""
rng = np.random.default_rng(2023)
u_kln = rng.uniform(-1, 1, (2, 2, 100))

_ = estimate_free_energy_bar(np.array([u_kln]), DEFAULT_TEMP)

u_kln_with_nan = np.array(u_kln)
u_kln_with_nan[0, 1, 10] = np.nan

# pymbar.MBAR fails with LinAlgError
with pytest.raises(SystemError, match="LinAlgError"):
u_kn, N_k = ukln_to_ukn(u_kln_with_nan)
_ = pymbar.MBAR(u_kn, N_k)

# should return finite results with warning
with pytest.warns(IndeterminateEnergyWarning, match="NaN"):
result_with_nan = estimate_free_energy_bar(np.array([u_kln_with_nan]), DEFAULT_TEMP)

assert np.isfinite(result_with_nan.dG)
assert np.isfinite(result_with_nan.dG_err)

u_kln_with_inf = np.array(u_kln)
u_kln_with_inf[0, 1, 10] = np.inf

# should give the same result with inf
result_with_inf = estimate_free_energy_bar(np.array([u_kln_with_inf]), DEFAULT_TEMP)
assert result_with_nan.dG == result_with_inf.dG
assert result_with_nan.dG_err == result_with_inf.dG_err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assert finite

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in e52fd30

np.testing.assert_array_equal(result_with_nan.dG_err_by_component, result_with_inf.dG_err_by_component)
np.testing.assert_array_equal(result_with_nan.overlap, result_with_inf.overlap)
Loading