diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 049216029c..73b22d5306 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -10,11 +10,13 @@ from rdkit import Chem from rdkit.Chem import AllChem, rdmolops -from timemachine.constants import ONE_4PI_EPS0 +from timemachine.constants import DEFAULT_FF, ONE_4PI_EPS0 +from timemachine.datasets import fetch_freesolv from timemachine.ff import Forcefield from timemachine.ff.charges import AM1CCC_CHARGES from timemachine.ff.handlers import bonded, nonbonded from timemachine.ff.handlers.deserialize import deserialize_handlers +from timemachine.ff.handlers.utils import get_spurious_param_idxs, get_symmetry_classes def test_harmonic_bond(): @@ -569,6 +571,37 @@ def test_am1ccc_throws_error_on_phosphorus(): assert "unsupported element" in str(e) +def test_am1ccc_symmetric_patterns(): + """Assert that AM1CCCHandler instances: + * have a symmetric_pattern_mask property with the expected shape and contents + * can be used with jax.grad + * raise ValueError when parameters associated with symmetric patterns != 0""" + ff = Forcefield.load_from_file(DEFAULT_FF) + q_handle = ff.q_handle + smirks, params = q_handle.smirks, q_handle.params + mol = fetch_freesolv()[123] + + sym_pattern_mask = q_handle.symmetric_pattern_mask + assert len(sym_pattern_mask) == len(q_handle.smirks) + assert sum(sym_pattern_mask) == 25 + + def f(params): + return np.sum(q_handle.static_parameterize(params, smirks, mol) ** 2) + + _ = f(params) + _ = jax.grad(f)(params) + + # expect problems when parameters associated with symmetric patterns are 1.0 + bad_params = np.array(params) + bad_params[sym_pattern_mask] = 1.0 + + with pytest.raises(ValueError): + _ = f(bad_params) + + with pytest.raises(ValueError): + _ = jax.grad(f)(bad_params) + + def test_am1_differences(): ff_raw = open("timemachine/ff/params/smirnoff_1_1_0_ccc.py").read() @@ -899,3 +932,38 @@ def test_lennard_jones_handler(): # if a parameter is > 99 then its adjoint should be zero (converse isn't necessarily true since) mask = np.argwhere(params > 90) assert np.all(adjoints[mask] == 0.0) + + +def test_symmetry_classes(): + """Assert get_symmetry_classes returns arrays of expected length num_atoms""" + mols = fetch_freesolv()[:10] + for mol in mols: + sym_classes = get_symmetry_classes(mol) + assert len(sym_classes) == mol.GetNumAtoms() + + +def test_spurious_param_idxs(): + """Assert that get_spurious_param_idxs + * returns arrays with length sometimes > 0 + * only ever returns indices of symmetric bond types""" + + mols = fetch_freesolv() + mol_dict = {mol.GetProp("_Name"): mol for mol in mols} + + # expect three of these to trigger + names = ["mobley_1743409", "mobley_1755375", "mobley_1760914", "mobley_1770205", "mobley_1781152"] + select_mols = [mol_dict[name] for name in names] + + ff = Forcefield.load_from_file(DEFAULT_FF) + q_handle = ff.q_handle + sym_idxs = set(np.where(q_handle.symmetric_pattern_mask)[0]) + + n_params_detected_per_mol = [] + for mol in select_mols: + spurious_idxs = get_spurious_param_idxs(mol, q_handle) + assert set(spurious_idxs).issubset(sym_idxs) + + n_params_detected_per_mol.append(len(spurious_idxs)) + + n_mols_with_any_detected_params = (np.array(n_params_detected_per_mol) > 0).sum() + assert n_mols_with_any_detected_params == 3 diff --git a/timemachine/ff/handlers/nonbonded.py b/timemachine/ff/handlers/nonbonded.py index b924e57805..129ec5b92f 100644 --- a/timemachine/ff/handlers/nonbonded.py +++ b/timemachine/ff/handlers/nonbonded.py @@ -5,13 +5,12 @@ import jax.numpy as jnp import networkx as nx import numpy as np -from rdkit import Chem from timemachine import constants from timemachine.ff.handlers.bcc_aromaticity import AromaticityModel from timemachine.ff.handlers.bcc_aromaticity import match_smirks as oe_match_smirks from timemachine.ff.handlers.serialize import SerializableMixIn -from timemachine.ff.handlers.utils import canonicalize_bond +from timemachine.ff.handlers.utils import canonicalize_bond, check_bond_smarts_symmetric, convert_to_oe from timemachine.ff.handlers.utils import match_smirks as rd_match_smirks from timemachine.graph_utils import convert_to_nx @@ -26,23 +25,6 @@ ELF10_MODELS = (AM1ELF10, AM1BCCELF10) -def convert_to_oe(mol): - """Convert an ROMol into an OEMol""" - - # imported here for optional dependency - from openeye import oechem - - mb = Chem.MolToMolBlock(mol) - ims = oechem.oemolistream() - ims.SetFormat(oechem.OEFormat_SDF) - ims.openstring(mb) - - for buf_mol in ims.GetOEMols(): - oemol = oechem.OEMol(buf_mol) - ims.close() - return oemol - - def oe_generate_conformations(oemol, sample_hydrogens=True): """Generate conformations for the input molecule. The molecule is modified in place. @@ -286,7 +268,7 @@ def apply_bond_charge_corrections(initial_charges, bond_idxs, deltas): net_charge = jnp.sum(initial_charges) final_net_charge = jnp.sum(final_charges) - net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-5) + net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-4) assert net_charge_is_unchanged @@ -524,7 +506,7 @@ def parameterize(self, mol): return self.partial_parameterize(self.params, mol) @staticmethod - def static_parameterize(params, smirks, mol): + def static_parameterize(params, smirks, mol, validate=True): """ Parameters ---------- @@ -534,8 +516,12 @@ def static_parameterize(params, smirks, mol): SMIRKS patterns matching bonds, to be parsed using OpenEye Toolkits mol: Chem.ROMol molecule to be parameterized. - + validate: bool + check params, smirks """ + if validate: + AM1CCCHandler.static_validate(params, smirks) + am1_charges = compute_or_load_am1_charges(mol) bond_idxs, type_idxs = compute_or_load_bond_smirks_matches(mol, smirks) @@ -545,3 +531,23 @@ def static_parameterize(params, smirks, mol): assert q_params.shape[0] == mol.GetNumAtoms() # check that return shape is consistent with input mol return q_params + + @staticmethod + def static_validate(params, smirks): + """Raise ValueError if any symmetric bond patterns have non-zero parameter""" + for pattern, param in zip(smirks, params): + pattern_symmetric = check_bond_smarts_symmetric(pattern) + param_nonzero = param != 0 + if pattern_symmetric and param_nonzero: + raise ValueError(f"a symmetric bond pattern {pattern} has a non-zero parameter {param}") + + @property + def symmetric_pattern_mask(self): + """Boolean array where mask[i] == True if self.smirks[i] symmetric + + Notes + ----- + * During training, avoid modifying any parameters where this mask is True + """ + mask = np.array([check_bond_smarts_symmetric(pattern) for pattern in self.smirks]) + return mask diff --git a/timemachine/ff/handlers/utils.py b/timemachine/ff/handlers/utils.py index 5ae6e1aa7c..017e9a16dc 100644 --- a/timemachine/ff/handlers/utils.py +++ b/timemachine/ff/handlers/utils.py @@ -1,6 +1,30 @@ +import re + +import jax +import numpy as np +from jax import grad +from jax import numpy as jnp +from numpy.typing import NDArray from rdkit import Chem +def convert_to_oe(mol): + """Convert an ROMol into an OEMol""" + + # imported here for optional dependency + from openeye import oechem + + mb = Chem.MolToMolBlock(mol) + ims = oechem.oemolistream() + ims.SetFormat(oechem.OEFormat_SDF) + ims.openstring(mb) + + for buf_mol in ims.GetOEMols(): + oemol = oechem.OEMol(buf_mol) + ims.close() + return oemol + + def canonicalize_bond(arr): """ Canonicalize a bonded interaction. If arr[0] < arr[-1] then arr is @@ -67,3 +91,99 @@ def match_smirks(mol, smirks): matches.append(tuple(mas)) return matches + + +def check_bond_smarts_symmetric(bond_smarts: str) -> bool: + """Match [:1]*[:2] + and return whether atom1 and atom2 are identical strings + + Notes + ----- + * The AM1CCC model contains symmetric patterns that must be assigned 0 parameters + (Otherwise, undefined behavior when symmetric bond matches in an arbitrary direction) + * Only checks string equivalence! + for example + check_bond_smarts_symmetric("[#6,#7:1]~[#7,#6:2]") + will be a false negative + * Does not handle all possible bond smarts + for example + "[#6:1]~[#6:2]~[#1]" + or + "[#6:1](~[#8])(~[#16:2])" + will not be matched, will return False by default. + However, for the bond smarts subset used by the AM1CCC model, this covers most cases + """ + + pattern = re.compile(r"\[(?P.*)\:1\].\[(?P.*)\:2\]") + match = pattern.match(bond_smarts) + + if type(match) is re.Match: + complete = match.span() == (0, len(bond_smarts)) + symmetric = match.group("atom1") == match.group("atom2") + return complete and symmetric + else: + # TODO: possibly warn in this branch? + # (false negatives possible -- but are also possible in the other branch...) + return False + + +def get_symmetry_classes(rdmol: Chem.Mol) -> NDArray: + """[atom.GetSymmetryClass() for atom in mol], + just renumbered for convenience""" + + # imported here for optional dependency + from openeye import oechem + + oemol = convert_to_oe(rdmol) + oechem.OEPerceiveSymmetry(oemol) + symmetry_classes = np.array([atom.GetSymmetryClass() for atom in oemol.GetAtoms()]) + n_classes = len(set(symmetry_classes)) + + # make indexy / contiguous from 0 to n_classes + idx_map = {old_idx: new_idx for (new_idx, old_idx) in enumerate(set(symmetry_classes))} + symmetry_classes = np.array([idx_map[old_idx] for old_idx in symmetry_classes]) + assert set(symmetry_classes) == set(range(n_classes)) + + return symmetry_classes + + +def get_spurious_param_idxs(mol, handle) -> NDArray: + """Find all indices i such that adjusting handle.params[i] can + result in distinct parameters being assigned to indistinguishable atoms in mol. + + Optimizing the parameters associated with these indices should be avoided. + """ + + symmetry_classes = get_symmetry_classes(mol) + smirks = handle.smirks + + def assign_params(ff_params): + return handle.static_parameterize(ff_params, smirks, mol, validate=False) + + def compute_spuriosity(ff_params): + # apply parameters + sys_params = assign_params(ff_params) + + # compute the mean per symmetry class + class_sums = jax.ops.segment_sum(sys_params, symmetry_classes) + class_means = class_sums / np.bincount(symmetry_classes) + + # expect no atom can be adjusted independently of others in its symmetry class + expected_constant_within_class = class_means[symmetry_classes] + assert expected_constant_within_class.shape == sys_params.shape + deviation_from_class_means = sys_params - expected_constant_within_class + spuriosity = jnp.sum(deviation_from_class_means ** 2) + + return spuriosity + + # TODO: may also want to try several points in the parameter space, + # randomly or systematically flipping signs... + trial_params = np.ones(len(handle.params)) # TODO: generalize + assert trial_params.shape == handle.params.shape + + # get idxs where component of gradient w.r.t. trial_params is != 0 + thresh = 1e-4 + g = grad(compute_spuriosity)(trial_params) + spurious_idxs = np.where(np.abs(g) > thresh)[0] + + return spurious_idxs