Skip to content

Commit

Permalink
Updates tolerances of plotting forward and reverse dg (#1114)
Browse files Browse the repository at this point in the history
* Change to MBAR in #1098
produces much more significant differences than BAR in cases of poor convergence
  • Loading branch information
badisa authored Aug 9, 2023
1 parent f195729 commit d7af228
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
9 changes: 5 additions & 4 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def test_forward_and_reverse_ddg_plot_validation():
plot_forward_and_reverse_ddg(dummy_solv_ukln, dummy_complex_ukln[0])


def test_forward_and_reverse_dg_plot():
@pytest.mark.parametrize("ukln_shape", [(47, 2, 2, 2000), (5, 2, 2, 10)])
def test_forward_and_reverse_dg_plot(ukln_shape):
rng = np.random.default_rng(2023)
ukln_shape = (47, 2, 2, 2000)
dummy_ukln = rng.random(size=ukln_shape)
dummy_ukln = rng.random(size=ukln_shape) * 1000

plot_forward_and_reverse_dg(dummy_ukln)
frames_per_step = min(ukln_shape[-1], 100)
plot_forward_and_reverse_dg(dummy_ukln, frames_per_step=frames_per_step)


def test_forward_and_reverse_dg_plot_validation():
Expand Down
14 changes: 8 additions & 6 deletions timemachine/fe/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,20 @@ def plot_fwd_reverse_predictions(
assert len(fwd) == len(fwd_err)
assert len(rev) == len(rev_err)

# Assert that first and last values are very close
assert np.allclose(fwd[-1], rev[-1])
assert np.allclose(fwd_err[-1], rev_err[-1])

# Assert that first and last values are close
assert np.allclose(fwd[-1], rev[-1], atol=1), f"{fwd[-1]} not close to {rev[-1]}"
if np.isfinite(fwd_err).all() and np.isfinite(rev_err.all()):
assert np.allclose(fwd_err[-1], rev_err[-1], atol=1)
fwd_mask = np.isfinite(fwd_err)
rev_mask = np.isfinite(rev_err)
xs = np.linspace(1.0 / len(fwd), 1.0, len(fwd))

plt.figure(figsize=(6, 6))
plt.title(f"{energy_type} Convergence Over Time")
plt.plot(xs, fwd, label=f"Forward {energy_type}", marker="o")
plt.fill_between(xs, fwd - fwd_err, fwd + fwd_err, alpha=0.25)
plt.fill_between(xs[fwd_mask], fwd[fwd_mask] - fwd_err[fwd_mask], fwd[fwd_mask] + fwd_err[fwd_mask], alpha=0.25)
plt.plot(xs, rev, label=f"Reverse {energy_type}", marker="o")
plt.fill_between(xs, rev - rev_err, rev + rev_err, alpha=0.25)
plt.fill_between(xs[rev_mask], rev[rev_mask] - rev_err[rev_mask], rev[rev_mask] + rev_err[rev_mask], alpha=0.25)
plt.axhline(fwd[-1], linestyle="--")
plt.xlabel("Fraction of simulation time")
plt.ylabel(f"{energy_type} (kcal/mol)")
Expand Down

0 comments on commit d7af228

Please sign in to comment.