Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
mcwitt committed Aug 28, 2023
1 parent 59b35c3 commit 799865a
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 0 deletions.
46 changes: 46 additions & 0 deletions timemachine/fe/free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,52 @@ def run_sims_hrex(
n_swap_attempts_per_iter: Optional[int] = None,
verbose: bool = True,
) -> Tuple[PairBarResult, List[StoredArrays], List[NDArray], HrexDiagnostics]:
r"""Sample from a sequence of states using nearest-neighbor Hamiltonian Replica EXchange (HREX).
This implementation uses a method described in [1] (in section III.B.2) to generate effectively uncorrelated
permutations by attempting many consecutive nearest-neighbor swap moves. By default, the number of swap moves is
determined as a function of the number of states (:math:`K`) as :math`N_{\text{swaps}} = K^4`, a heuristic also
described in [1].
References
----------
[1]: http://dx.doi.org/10.1063/1.3660669
Parameters
----------
initial_states: sequence of InitialState
States to sample
md_params: MDParams
Parameters used to simulate new states
n_frames_per_iter: int
Number of frames to sample using MD per iteration
temperature: float
Temperature in K
n_swap_attempts_per_iter: int or None, optional
Number of nearest-neighbor swaps to attempt per iteration. Defaults to len(initial_states) ** 4.
verbose: bool, optional
Whether to print diagnostic information
Returns
-------
PairBarResult
results of pair BAR free energy analysis
List[StoredArrays]
coord trajectories
List[NDArray]
box trajectories
HrexDiagnostics
HREX statistics (e.g. swap rates, replica-state distribution)
"""

if n_swap_attempts_per_iter is None:
n_swap_attempts_per_iter = get_swap_attempts_per_iter_heuristic(len(initial_states))
Expand Down
59 changes: 59 additions & 0 deletions timemachine/fe/rbfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,65 @@ def estimate_relative_free_energy_bisection_hrex(
n_frames_bisection: int = 100,
n_frames_per_iter: int = 10,
) -> SimulationResult:
"""
Estimate relative free energy between mol_a and mol_b using Hamiltonian Replica EXchange (HREX) sampling of a
sequence of intermediate states determined by bisection. Molecules should be aligned to each other and within the
host environment.
Parameters
----------
mol_a: Chem.Mol
initial molecule
mol_b: Chem.Mol
target molecule
core: list of 2-tuples
atom_mapping of atoms in mol_a into atoms in mol_b
ff: Forcefield
Forcefield to be used for the system
host_config: HostConfig or None
Configuration for the host system. If None, then the vacuum leg is run.
md_params: MDParams, optional
Parameters for the equilibration and production MD
prefix: str, optional
A prefix to append to figures
lambda_interval: (float, float) or None, optional
Minimum and maximum value of lambda for the transformation; typically (0, 1), but sometimes useful to choose
other values for testing.
n_windows: int or None, optional
Number of windows used for interpolating the lambda schedule with additional windows. Defaults to
`DEFAULT_NUM_WINDOWS` windows.
min_overlap: float or None, optional
If not None, terminate bisection early when the BAR overlap between all neighboring pairs of states exceeds this
value. When given, the final number of windows may be less than or equal to n_windows.
keep_idxs: list of int or None, optional
If None, return only the end-state frames. Otherwise if not None, use only for debugging, and this
will return the frames corresponding to the idxs of interest.
min_cutoff: float or None, optional
Throw error if any atom moves more than this distance (nm) after minimization
n_frames_bisection: int or None, optional
Number of frames to sample using MD during the initial bisection phase used to determine lambda spacing
n_frames_per_iter: int or None, optional
Number of frames to sample using MD per HREX iteration
Returns
-------
SimulationResult
Collected data from the simulation (see class for storage information). Returned frames and boxes
are defined by keep_idxs.
"""

if n_windows is None:
n_windows = DEFAULT_NUM_WINDOWS
Expand Down
39 changes: 39 additions & 0 deletions timemachine/md/hrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,45 @@ def run_hrex(
n_samples_per_iter: int,
n_swap_attempts_per_iter: Optional[int] = None,
) -> Tuple[List[List[_Samples]], HrexDiagnostics]:
"""Sample from a sequence of states using Hamiltonian Replica EXchange (HREX).
Parameters
----------
replicas: sequence of _Replica
Sequence of initial states of each replica
sample_replica: (_Replica, StateIdx, n_samples: int) -> _Samples
Local sampling function. Should return n_samples samples from the given replica and state
replica_from_samples: _Samples -> _Replica
Function that returns a replica state given a sequence of local samples. This is used to update the state of
individual replicas following local sampling.
neighbor_pairs: sequence of (StateIdx, StateIdx)
Pairs of states for which to attempt swap moves
get_log_q_fn: sequence of _Replica -> ((ReplicaIdx, StateIdx) -> float)
Function that returns a function from replica-state pairs to log unnormalized probability. Note that this is
equivalent to the simpler signature (_Replica, StateIdx) -> float; the "curried" form here is to allow for the
implementation to compute the full matrix as a batch operation when this is more efficient.
n_samples: int
Total number of local samples (e.g. MD frames)
n_samples_per_iter: int
Number of local samples (e.g. MD frames) per HREX iteration
n_swap_attempts_per_iter: int or None, optional
Number of neighbor swaps to attempt per iteration. Default is given by :py:func:`get_swap_attempts_per_iter_heuristic`.
Returns
-------
List[List[_Samples]]
samples grouped by state and iteration
HrexDiagnostics
HREX statistics (e.g. swap rates, replica-state distribution)
"""

n_replicas = len(replicas)

Expand Down
2 changes: 2 additions & 0 deletions timemachine/md/moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def acceptance_fraction(self) -> float:


class Identity(MonteCarloMove[_State]):
"""Move that leaves the state unchanged and is always accepted"""

def move(self, x: _State) -> _State:
self._n_proposed += 1
self._n_accepted += 1
Expand Down

0 comments on commit 799865a

Please sign in to comment.