diff --git a/feflow/protocols/nonequilibrium_cycling.py b/feflow/protocols/nonequilibrium_cycling.py index 6a9f8ef..a906df0 100644 --- a/feflow/protocols/nonequilibrium_cycling.py +++ b/feflow/protocols/nonequilibrium_cycling.py @@ -23,15 +23,13 @@ from openff.units import unit from openff.units.openmm import to_openmm, from_openmm +from ..utils.data import serialize, deserialize + # Specific instance of logger for this module # logger = logging.getLogger(__name__) -class SimulationUnit(ProtocolUnit): - """ - Monolithic unit for simulation. It runs NEQ switching simulation from chemical systems and stores the - work computed in numpy-formatted files, to be analyzed by another unit. - """ +class SetupUnit(ProtocolUnit): @staticmethod def _check_states_compatibility(state_a, state_b): """ @@ -96,84 +94,28 @@ def _detect_phase(state_a, state_b): return detected_phase - @staticmethod - def extract_positions(context, hybrid_topology_factory, atom_selection_exp="not water"): - """ - Extract positions from initial and final systems based from the hybrid topology. - - Parameters - ---------- - context: openmm.Context - Current simulation context where from extract positions. - hybrid_topology_factory: perses.annihilation.relative.HybridTopologyFactory - Hybrid topology factory where to extract positions and mapping information - atom_selection_exp: str, optional - Atom selection expression using mdtraj syntax. Defaults to "not water" - - Returns - ------- - - Notes - ----- - It achieves this by taking the positions and indices from the initial and final states of - the transformation, and computing the overlap of these with the indices of the complete - hybrid topology, filtered by some mdtraj selection expression. - - 1. Get positions from context - 2. Get topology from HTF (already mdtraj topology) - 3. Merge that information into mdtraj.Trajectory - 4. Filter positions for initial/final according to selection string - """ - # TODO: Maybe we want this as a helper/utils function in perses. We also need tests for this. - import mdtraj as md - import numpy as np - - # Get positions from current openmm context - positions = context.getState(getPositions=True).getPositions(asNumpy=True) - - # Get topology from HTF - indices for initial and final topologies in hybrid topology - initial_indices = np.asarray(hybrid_topology_factory.initial_atom_indices) - final_indices = np.asarray(hybrid_topology_factory.final_atom_indices) - hybrid_topology = hybrid_topology_factory.hybrid_topology - selection = atom_selection_exp - md_trajectory = md.Trajectory(xyz=positions, topology=hybrid_topology) - selection_indices = md_trajectory.topology.select(selection) - - # Now we have to find the intersection/overlap between selected indices in the hybrid - # topology and the initial/final positions, respectively - initial_selected_indices = np.intersect1d(initial_indices, selection_indices) - final_selected_indices = np.intersect1d(final_indices, selection_indices) - initial_selected_positions = md_trajectory.xyz[0, initial_selected_indices, :] - final_selected_positions = md_trajectory.xyz[0, final_selected_indices, :] - - return initial_selected_positions, final_selected_positions - def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): """ - Execute the simulation part of the Nonequilibrium switching protocol using GUFE objects. + Execute the setup part of the nonequilibrium switching protocol. Parameters ---------- ctx: gufe.protocols.protocolunit.Context The gufe context for the unit. - state_a : gufe.ChemicalSystem The initial chemical system. - state_b : gufe.ChemicalSystem The objective chemical system. - mapping : dict[str, gufe.mapping.ComponentMapping] A dict featuring mappings between the two chemical systems. - settings : gufe.settings.model.Settings The full settings for the protocol. Returns ------- dict : dict[str, str] - Dictionary with paths to work arrays, both forward and reverse, and trajectory coordinates for systems - A and B. + Dictionary with paths to work arrays, both forward and reverse, and + trajectory coordinates for systems A and B. """ # needed imports import numpy as np @@ -185,18 +127,11 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): from openfe.protocols.openmm_rfe import _rfe_utils from feflow.utils.hybrid_topology import HybridTopologyFactoryModded as HybridTopologyFactory - # Setting up logging to file in shared filesystem - file_logger = logging.getLogger("neq-cycling") - output_log_path = ctx.shared / "perses-neq-cycling.log" - file_handler = logging.FileHandler(output_log_path, mode="w") - file_handler.setLevel(logging.DEBUG) # TODO: Set to INFO in production - log_formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S') - file_handler.setFormatter(log_formatter) - file_logger.addHandler(file_handler) - # Check compatibility between states (same receptor and solvent) self._check_states_compatibility(state_a, state_b) + phase = self._detect_phase(state_a, state_b) # infer phase from systems and components + # Get components from systems if found (None otherwise) -- NOTE: Uses hardcoded keys! receptor_a = state_a.components.get("protein") # receptor_b = state_b.components.get("protein") # Should not be needed @@ -320,10 +255,6 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): ) ####### END OF SETUP ######### - traj_save_frequency = settings.traj_save_frequency - work_save_frequency = settings.work_save_frequency # Note: this is divisor of traj save freq. - selection_expression = settings.atom_selection_expression - system = hybrid_factory.hybrid_system positions = hybrid_factory.hybrid_positions @@ -350,8 +281,140 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): # Minimize openmm.LocalEnergyMinimizer.minimize(context) - # Equilibrate - context.setVelocitiesToTemperature(temperature) + # SERIALIZE SYSTEM, STATE, INTEGRATOR + + system_ = context.getSystem() + state_ = context.getState(getPositions=True) + integrator_ = context.getIntegrator() + + system_outfile = ctx.shared / 'system.xml.bz2' + state_outfile = ctx.shared / 'state.xml.bz2' + integrator_outfile = ctx.shared / 'integrator.xml.bz2' + + serialize(system_, system_outfile) + serialize(state_, state_outfile) + serialize(integrator_, integrator_outfile) + + finally: + # Explicit cleanup for GPU resources + del context, integrator + + return {'system': system_outfile, + 'state': state_outfile, + 'integrator': integrator_outfile, + 'phase': phase, + 'initial_atom_indices': hybrid_factory.initial_atom_indices, + 'final_atom_indices': hybrid_factory.final_atom_indices, + } + + +class SimulationUnit(ProtocolUnit): + """ + Monolithic unit for simulation. It runs NEQ switching simulation from chemical systems and stores the + work computed in numpy-formatted files, to be analyzed by another unit. + """ + + @staticmethod + def extract_positions(context, initial_atom_indices, final_atom_indices): + """ + Extract positions from initial and final systems based from the hybrid topology. + + Parameters + ---------- + context: openmm.Context + Current simulation context where from extract positions. + hybrid_topology_factory: perses.annihilation.relative.HybridTopologyFactory + Hybrid topology factory where to extract positions and mapping information + + Returns + ------- + + Notes + ----- + It achieves this by taking the positions and indices from the initial and final states of + the transformation, and computing the overlap of these with the indices of the complete + hybrid topology, filtered by some mdtraj selection expression. + + 1. Get positions from context + 2. Get topology from HTF (already mdtraj topology) + 3. Merge that information into mdtraj.Trajectory + 4. Filter positions for initial/final according to selection string + """ + import numpy as np + + # Get positions from current openmm context + positions = context.getState(getPositions=True).getPositions(asNumpy=True) + + # Get indices for initial and final topologies in hybrid topology + initial_indices = np.asarray(initial_atom_indices) + final_indices = np.asarray(final_atom_indices) + + initial_positions = positions[initial_indices, :] + final_positions = positions[final_indices, :] + + return initial_positions, final_positions + + def _execute(self, ctx, *, setup, settings, **inputs): + """ + Execute the simulation part of the Nonequilibrium switching protocol using GUFE objects. + + Parameters + ---------- + ctx : gufe.protocols.protocolunit.Context + The gufe context for the unit. + + setup : + settings : gufe.settings.model.Settings + The full settings for the protocol. + + Returns + ------- + dict : dict[str, str] + Dictionary with paths to work arrays, both forward and reverse, and trajectory coordinates for systems + A and B. + """ + import numpy as np + import openmm + import openmm.unit as openmm_unit + from openmmtools.integrators import PeriodicNonequilibriumIntegrator + + # Setting up logging to file in shared filesystem + file_logger = logging.getLogger("neq-cycling") + output_log_path = ctx.shared / "perses-neq-cycling.log" + file_handler = logging.FileHandler(output_log_path, mode="w") + file_handler.setLevel(logging.DEBUG) # TODO: Set to INFO in production + log_formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + file_handler.setFormatter(log_formatter) + file_logger.addHandler(file_handler) + + # Get state, system, and integrator from setup unit + system = deserialize(setup.outputs['system']) + state = deserialize(setup.outputs['state']) + integrator = deserialize(setup.outputs['integrator']) + PeriodicNonequilibriumIntegrator.restore_interface(integrator) + + # Get atom indices for either end of the hybrid topology + initial_atom_indices = setup.outputs['initial_atom_indices'] + final_atom_indices = setup.outputs['final_atom_indices'] + + # Set up context + platform = get_openmm_platform(settings.platform) + context = openmm.Context(system, integrator, platform) + context.setState(state) + + # Equilibrate + thermodynamic_settings = settings.thermo_settings + temperature = to_openmm(thermodynamic_settings.temperature) + context.setVelocitiesToTemperature(temperature) + + # Extract settings used below + neq_steps = settings.eq_steps + eq_steps = settings.neq_steps + traj_save_frequency = settings.traj_save_frequency + work_save_frequency = settings.work_save_frequency # Note: this is divisor of traj save freq. + selection_expression = settings.atom_selection_expression + + try: # Prepare objects to store data -- empty lists so far forward_eq_old, forward_eq_new, forward_neq_old, forward_neq_new = list(), list(), list(), list() @@ -359,7 +422,7 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): # Coarse number of steps -- each coarse consists of work_save_frequency steps coarse_eq_steps = int(eq_steps/work_save_frequency) # Note: eq_steps is multiple of work save steps - coarse_neq_steps = int(neq_steps / work_save_frequency) # Note: neq_steps is multiple of work save steps + coarse_neq_steps = int(neq_steps/work_save_frequency) # Note: neq_steps is multiple of work save steps # TODO: Also get the GPU information (plain try-except with nvidia-smi) @@ -373,14 +436,15 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): if step % traj_save_frequency == 0: file_logger.debug(f"coarse step: {step}: saving trajectory (freq {traj_save_frequency})") initial_positions, final_positions = self.extract_positions(context, - hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_atom_indices, + final_atom_indices) forward_eq_old.append(initial_positions) forward_eq_new.append(final_positions) # Make sure trajectories are stored at the end of the eq loop file_logger.debug(f"coarse step: {step}: saving trajectory (freq {traj_save_frequency})") - initial_positions, final_positions = self.extract_positions(context, hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_positions, final_positions = self.extract_positions(context, + initial_atom_indices, + final_atom_indices) forward_eq_old.append(initial_positions) forward_eq_new.append(final_positions) @@ -397,13 +461,14 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): forward_works.append(integrator.get_protocol_work(dimensionless=True)) if fwd_step % traj_save_frequency == 0: initial_positions, final_positions = self.extract_positions(context, - hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_atom_indices, + final_atom_indices) forward_neq_old.append(initial_positions) forward_neq_new.append(final_positions) # Make sure trajectories are stored at the end of the neq loop - initial_positions, final_positions = self.extract_positions(context, hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_positions, final_positions = self.extract_positions(context, + initial_atom_indices, + final_atom_indices) forward_neq_old.append(initial_positions) forward_neq_new.append(final_positions) @@ -416,13 +481,14 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): integrator.step(work_save_frequency) if step % traj_save_frequency == 0: initial_positions, final_positions = self.extract_positions(context, - hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_atom_indices, + final_atom_indices) reverse_eq_new.append(initial_positions) # TODO: Maybe better naming not old/new but initial/final reverse_eq_old.append(final_positions) # Make sure trajectories are stored at the end of the eq loop - initial_positions, final_positions = self.extract_positions(context, hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_positions, final_positions = self.extract_positions(context, + initial_atom_indices, + final_atom_indices) reverse_eq_old.append(initial_positions) reverse_eq_new.append(final_positions) @@ -437,13 +503,14 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): reverse_works.append(integrator.get_protocol_work(dimensionless=True)) if rev_step % traj_save_frequency == 0: initial_positions, final_positions = self.extract_positions(context, - hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_atom_indices, + final_atom_indices) reverse_neq_old.append(initial_positions) reverse_neq_new.append(final_positions) # Make sure trajectories are stored at the end of the neq loop - initial_positions, final_positions = self.extract_positions(context, hybrid_topology_factory=hybrid_factory, - atom_selection_exp=selection_expression) + initial_positions, final_positions = self.extract_positions(context, + initial_atom_indices, + final_atom_indices) forward_eq_old.append(initial_positions) forward_eq_new.append(final_positions) @@ -455,6 +522,7 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): file_logger.info(f"replicate_{self.name} Nonequilibrium cycle total walltime: {cycle_walltime}") # Computing performance in ns/day + timestep = to_openmm(settings.timestep) simulation_time = 2*(eq_steps + neq_steps)*timestep walltime_in_seconds = cycle_walltime.total_seconds() * openmm_unit.seconds estimated_performance = simulation_time.value_in_unit( @@ -462,7 +530,7 @@ def _execute(self, ctx, *, state_a, state_b, mapping, settings, **inputs): file_logger.info(f"replicate_{self.name} Estimated performance: {estimated_performance} ns/day") # Serialize works - phase = self._detect_phase(state_a, state_b) # infer phase from systems and components + phase = setup.outputs['phase'] forward_work_path = ctx.shared / f"forward_{phase}_{self.name}.npy" reverse_work_path = ctx.shared / f"reverse_{phase}_{self.name}.npy" with open(forward_work_path, 'wb') as out_file: @@ -672,6 +740,7 @@ class NonEquilibriumCyclingProtocol(Protocol): of the same type of components as components in stateB. """ + _simulation_unit = SimulationUnit result_cls = NonEquilibriumCyclingProtocolResult def __init__(self, settings: Settings): @@ -712,8 +781,10 @@ def _create( # or JSON-serializable objects num_replicates = self.settings.num_replicates + setup = SetupUnit(state_a=stateA, state_b=stateB, mapping=mapping, settings=self.settings, name="setup") + simulations = [ - SimulationUnit(state_a=stateA, state_b=stateB, mapping=mapping, settings=self.settings, name=f"{replicate}") + self._simulation_unit(setup=setup, settings=self.settings, name=f"{replicate}") for replicate in range(num_replicates)] end = ResultUnit(name="result", simulations=simulations) diff --git a/feflow/settings/nonequilibrium_cycling.py b/feflow/settings/nonequilibrium_cycling.py index e2c97d3..903dc4e 100644 --- a/feflow/settings/nonequilibrium_cycling.py +++ b/feflow/settings/nonequilibrium_cycling.py @@ -69,7 +69,7 @@ class Config: platform = 'CUDA' traj_save_frequency: int = 2000 work_save_frequency: int = 500 - atom_selection_expression: str = "not water" + atom_selection_expression: str = "not water" # no longer used # Number of replicates to run (1 cycle/replicate) num_replicates: int = 1 diff --git a/feflow/tests/test_feflow.py b/feflow/tests/test_feflow.py deleted file mode 100644 index 118bfd0..0000000 --- a/feflow/tests/test_feflow.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Unit and regression test for the feflow package. -""" - -# Import package, test suite, and other packages as needed -import sys - -import pytest - -import feflow - - -def test_feflow_imported(): - """Sample test, will always pass so long as import statement worked.""" - print("importing ", feflow.__name__) - assert "feflow" in sys.modules - - -# Assert that a certain exception is raised -def f(): - raise SystemExit(1) - - -def test_mytest(): - with pytest.raises(SystemExit): - f() diff --git a/feflow/utils/data.py b/feflow/utils/data.py new file mode 100644 index 0000000..42dcf96 --- /dev/null +++ b/feflow/utils/data.py @@ -0,0 +1,72 @@ +import os +import pathlib + + +def serialize(item, filename: pathlib.Path): + """ + Serialize an OpenMM System, State, or Integrator. + + Parameters + ---------- + item : System, State, or Integrator + The thing to be serialized + filename : str + The filename to serialize to + """ + from openmm import XmlSerializer + + # Create parent directory if it doesn't exist + filename_basedir = filename.parent + if not filename_basedir.exists(): + os.makedirs(filename_basedir) + + if filename.suffix == '.gz': + import gzip + with gzip.open(filename, mode='wb') as outfile: + serialized_thing = XmlSerializer.serialize(item) + outfile.write(serialized_thing.encode()) + if filename.suffix == '.bz2': + import bz2 + with bz2.open(filename, mode='wb') as outfile: + serialized_thing = XmlSerializer.serialize(item) + outfile.write(serialized_thing.encode()) + else: + with open(filename, mode='w') as outfile: + serialized_thing = XmlSerializer.serialize(item) + outfile.write(serialized_thing) + + +def deserialize(filename: pathlib.Path): + """ + Deserialize an OpenMM System, State, or Integrator. + + Parameters + ---------- + item : System, State, or Integrator + The thing to be serialized + filename : str + The filename to serialize to + """ + from openmm import XmlSerializer + + # Create parent directory if it doesn't exist + filename_basedir = filename.parent + if not filename_basedir.exists(): + os.makedirs(filename_basedir) + + if filename.suffix == '.gz': + import gzip + with gzip.open(filename, mode='rb') as infile: + serialized_thing = infile.read().decode() + item = XmlSerializer.deserialize(serialized_thing) + if filename.suffix == '.bz2': + import bz2 + with bz2.open(filename, mode='rb') as infile: + serialized_thing = infile.read().decode() + item = XmlSerializer.deserialize(serialized_thing) + else: + with open(filename, mode='r') as infile: + serialized_thing = infile.read() + item = XmlSerializer.deserialize(serialized_thing) + + return item