Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect non-trainable parameters in AM1CCCHandler #737

Closed
wants to merge 14 commits into from
70 changes: 69 additions & 1 deletion tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from rdkit import Chem
from rdkit.Chem import AllChem, rdmolops

from timemachine.constants import ONE_4PI_EPS0
from timemachine.constants import DEFAULT_FF, ONE_4PI_EPS0
from timemachine.datasets import fetch_freesolv
from timemachine.ff import Forcefield
from timemachine.ff.charges import AM1CCC_CHARGES
from timemachine.ff.handlers import bonded, nonbonded
from timemachine.ff.handlers.deserialize import deserialize_handlers
from timemachine.ff.handlers.utils import get_spurious_param_idxs, get_symmetry_classes


def test_harmonic_bond():
Expand Down Expand Up @@ -569,6 +571,37 @@ def test_am1ccc_throws_error_on_phosphorus():
assert "unsupported element" in str(e)


def test_am1ccc_symmetric_patterns():
"""Assert that AM1CCCHandler instances:
* have a symmetric_pattern_mask property with the expected shape and contents
* can be used with jax.grad
* raise ValueError when parameters associated with symmetric patterns != 0"""
ff = Forcefield.load_from_file(DEFAULT_FF)
q_handle = ff.q_handle
smirks, params = q_handle.smirks, q_handle.params
mol = fetch_freesolv()[123]

sym_pattern_mask = q_handle.symmetric_pattern_mask
assert len(sym_pattern_mask) == len(q_handle.smirks)
assert sum(sym_pattern_mask) == 25

def f(params):
return np.sum(q_handle.static_parameterize(params, smirks, mol) ** 2)

_ = f(params)
_ = jax.grad(f)(params)

# expect problems when parameters associated with symmetric patterns are 1.0
bad_params = np.array(params)
bad_params[sym_pattern_mask] = 1.0

with pytest.raises(ValueError):
_ = f(bad_params)

with pytest.raises(ValueError):
_ = jax.grad(f)(bad_params)


def test_am1_differences():

ff_raw = open("timemachine/ff/params/smirnoff_1_1_0_ccc.py").read()
Expand Down Expand Up @@ -899,3 +932,38 @@ def test_lennard_jones_handler():
# if a parameter is > 99 then its adjoint should be zero (converse isn't necessarily true since)
mask = np.argwhere(params > 90)
assert np.all(adjoints[mask] == 0.0)


def test_symmetry_classes():
"""Assert get_symmetry_classes returns arrays of expected length num_atoms"""
mols = fetch_freesolv()[:10]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this assert that some compounds share symmetry classes such that len(set(sym_classes)) != len(sym_classes)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so -- the numbering won't be comparable across compounds

for mol in mols:
sym_classes = get_symmetry_classes(mol)
assert len(sym_classes) == mol.GetNumAtoms()


def test_spurious_param_idxs():
"""Assert that get_spurious_param_idxs
* returns arrays with length sometimes > 0
* only ever returns indices of symmetric bond types"""

mols = fetch_freesolv()
mol_dict = {mol.GetProp("_Name"): mol for mol in mols}

# expect three of these to trigger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Might be nice to separate out the 3 that trigger, making it easier to reason about the test at a later date

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call -- will fix!

names = ["mobley_1743409", "mobley_1755375", "mobley_1760914", "mobley_1770205", "mobley_1781152"]
select_mols = [mol_dict[name] for name in names]

ff = Forcefield.load_from_file(DEFAULT_FF)
q_handle = ff.q_handle
sym_idxs = set(np.where(q_handle.symmetric_pattern_mask)[0])

n_params_detected_per_mol = []
for mol in select_mols:
spurious_idxs = get_spurious_param_idxs(mol, q_handle)
assert set(spurious_idxs).issubset(sym_idxs)

n_params_detected_per_mol.append(len(spurious_idxs))

n_mols_with_any_detected_params = (np.array(n_params_detected_per_mol) > 0).sum()
assert n_mols_with_any_detected_params == 3
50 changes: 28 additions & 22 deletions timemachine/ff/handlers/nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import jax.numpy as jnp
import networkx as nx
import numpy as np
from rdkit import Chem

from timemachine import constants
from timemachine.ff.handlers.bcc_aromaticity import AromaticityModel
from timemachine.ff.handlers.bcc_aromaticity import match_smirks as oe_match_smirks
from timemachine.ff.handlers.serialize import SerializableMixIn
from timemachine.ff.handlers.utils import canonicalize_bond
from timemachine.ff.handlers.utils import canonicalize_bond, check_bond_smarts_symmetric, convert_to_oe
from timemachine.ff.handlers.utils import match_smirks as rd_match_smirks
from timemachine.graph_utils import convert_to_nx

Expand All @@ -26,23 +25,6 @@
ELF10_MODELS = (AM1ELF10, AM1BCCELF10)


def convert_to_oe(mol):
"""Convert an ROMol into an OEMol"""

# imported here for optional dependency
from openeye import oechem

mb = Chem.MolToMolBlock(mol)
ims = oechem.oemolistream()
ims.SetFormat(oechem.OEFormat_SDF)
ims.openstring(mb)

for buf_mol in ims.GetOEMols():
oemol = oechem.OEMol(buf_mol)
ims.close()
return oemol


def oe_generate_conformations(oemol, sample_hydrogens=True):
"""Generate conformations for the input molecule.
The molecule is modified in place.
Expand Down Expand Up @@ -286,7 +268,7 @@ def apply_bond_charge_corrections(initial_charges, bond_idxs, deltas):

net_charge = jnp.sum(initial_charges)
final_net_charge = jnp.sum(final_charges)
net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-5)
net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-4)

assert net_charge_is_unchanged

Expand Down Expand Up @@ -524,7 +506,7 @@ def parameterize(self, mol):
return self.partial_parameterize(self.params, mol)

@staticmethod
def static_parameterize(params, smirks, mol):
def static_parameterize(params, smirks, mol, validate=True):
"""
Parameters
----------
Expand All @@ -534,8 +516,12 @@ def static_parameterize(params, smirks, mol):
SMIRKS patterns matching bonds, to be parsed using OpenEye Toolkits
mol: Chem.ROMol
molecule to be parameterized.

validate: bool
check params, smirks
"""
if validate:
jkausrelay marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

@badisa badisa May 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After talking offline about how this is intended to verify the FF and not the molecule's parameterization. What about moving this validation to the Forcefield class? It seems like longer term we will want to do more validation of the forcefield parameters as we train them, and potentially not just the non-bonded parameters.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it might be more appropriate for this to be in the Forcefield constructor at this branch:

if isinstance(handle, nonbonded.AM1CCCHandler):
assert self.q_handle is None
self.q_handle = handle

Based on discussion with @jkausrelay yesterday, the initial thought was to put this check in the core function that gets called by everything else (static_parameterize, which is called by partial_parameterize, which is called by parameterize, ...), so it would cover the most use cases (such as training, where we don't necessarily create an instance of the Forcefield class for every optimizer step).

I'm agnostic about where the check should go, as long as it surfaces an error/warning before applying invalid parameters to a molecule.

AM1CCCHandler.static_validate(params, smirks)

am1_charges = compute_or_load_am1_charges(mol)
bond_idxs, type_idxs = compute_or_load_bond_smirks_matches(mol, smirks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could compute_or_load_bond_smirks_matches return the list of symmetric bonds and then do the check for delta[symmetric_bond_idxs] != 0 here?

Basically check the result of the match instead of the string.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be preferable (won't have any false negatives or fragile string hacks)

One consideration is that it makes this check a function of params, smirks, mol, rather than a function of params, smirks only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would advocate for this, even if only as an additional check. The regex seems to have gaps and my regex-foo is not up for the task.


Expand All @@ -545,3 +531,23 @@ def static_parameterize(params, smirks, mol):
assert q_params.shape[0] == mol.GetNumAtoms() # check that return shape is consistent with input mol

return q_params

@staticmethod
def static_validate(params, smirks):
"""Raise ValueError if any symmetric bond patterns have non-zero parameter"""
for pattern, param in zip(smirks, params):
pattern_symmetric = check_bond_smarts_symmetric(pattern)
param_nonzero = param != 0
if pattern_symmetric and param_nonzero:
raise ValueError(f"a symmetric bond pattern {pattern} has a non-zero parameter {param}")

@property
def symmetric_pattern_mask(self):
"""Boolean array where mask[i] == True if self.smirks[i] symmetric

Notes
-----
* During training, avoid modifying any parameters where this mask is True
"""
mask = np.array([check_bond_smarts_symmetric(pattern) for pattern in self.smirks])
return mask
120 changes: 120 additions & 0 deletions timemachine/ff/handlers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
import re

import jax
import numpy as np
from jax import grad
from jax import numpy as jnp
from numpy.typing import NDArray
from rdkit import Chem


def convert_to_oe(mol):
"""Convert an ROMol into an OEMol"""

# imported here for optional dependency
from openeye import oechem

mb = Chem.MolToMolBlock(mol)
ims = oechem.oemolistream()
ims.SetFormat(oechem.OEFormat_SDF)
ims.openstring(mb)

for buf_mol in ims.GetOEMols():
oemol = oechem.OEMol(buf_mol)
ims.close()
return oemol


def canonicalize_bond(arr):
"""
Canonicalize a bonded interaction. If arr[0] < arr[-1] then arr is
Expand Down Expand Up @@ -67,3 +91,99 @@ def match_smirks(mol, smirks):
matches.append(tuple(mas))

return matches


def check_bond_smarts_symmetric(bond_smarts: str) -> bool:
"""Match [<atom1>:1]*[<atom2>:2]
and return whether atom1 and atom2 are identical strings

Notes
-----
* The AM1CCC model contains symmetric patterns that must be assigned 0 parameters
(Otherwise, undefined behavior when symmetric bond matches in an arbitrary direction)
* Only checks string equivalence!
for example
check_bond_smarts_symmetric("[#6,#7:1]~[#7,#6:2]")
will be a false negative
* Does not handle all possible bond smarts
for example
"[#6:1]~[#6:2]~[#1]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea how likely these patterns are to become a problem in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already a problem: there are two patterns in the current model that are not parsed correctly

[#6X3:1](~[#8X1,#16X1])(~[#16X1:2])
and
[#6X3:1](~[#8X1,#16X1])(~[#8X1:2])

Since these don't match the regex, they return False by default.

Don't have a sense of what SMARTS subset may be exercised by future edits to this force field.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using the following pattern?

pattern = re.compile(r"(\[(?P<atom1>.+)\:1\].+\[(?P<atom2>.+)\:2\]`)
...
complete = match.span() == (0, bond_smarts.index(":2]") + 3)

Would cover those two cases. Unless you also need to match the 16X1 to the 16X1 in the middle? Which I guess you could do the hacky thing of having a third group in the middle and check if the atom1 and atom2 are represented somewhere in the middle....

or
"[#6:1](~[#8])(~[#16:2])"
will not be matched, will return False by default.
However, for the bond smarts subset used by the AM1CCC model, this covers most cases
"""

pattern = re.compile(r"\[(?P<atom1>.*)\:1\].\[(?P<atom2>.*)\:2\]")
match = pattern.match(bond_smarts)

if type(match) is re.Match:
complete = match.span() == (0, len(bond_smarts))
symmetric = match.group("atom1") == match.group("atom2")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be advantageous to do the following?

Suggested change
symmetric = match.group("atom1") == match.group("atom2")
symmetric = list(sorted(match.group("atom1").split(","))) == list(sorted(match.group("atom2").split(",")))

Seems to work in the case of [#6,#7:1]~[#7,#6:2]

That said, nothing in our AM1BCC charges seems to have this lack of ordering, so probably unimportant

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can double-check this -- one issue with .split(",") is that comma-separated lists can appear inside of nested structures. For example, applying match.group("atom1").split(",") to this pattern

('[#6X3$(*=[#7,#15]):1]-[#16X1-1:2]', 2.320877731439586),

returns ['#6X3$(*=[#7', '#15])'] which looks undesirable -- would probably want to sort recursively

However, there are other ways you could produce syntactically distinct SMARTS with identical meaning. Not sure if there's a robust "canonicalize SMARTS" function...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increasingly disliking the regex approach

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, many possibilities for false negatives. Alternative in #739 may be preferable

return complete and symmetric
else:
# TODO: possibly warn in this branch?
# (false negatives possible -- but are also possible in the other branch...)
return False


def get_symmetry_classes(rdmol: Chem.Mol) -> NDArray:
"""[atom.GetSymmetryClass() for atom in mol],
just renumbered for convenience"""

# imported here for optional dependency
from openeye import oechem

oemol = convert_to_oe(rdmol)
oechem.OEPerceiveSymmetry(oemol)
symmetry_classes = np.array([atom.GetSymmetryClass() for atom in oemol.GetAtoms()])
n_classes = len(set(symmetry_classes))

# make indexy / contiguous from 0 to n_classes
idx_map = {old_idx: new_idx for (new_idx, old_idx) in enumerate(set(symmetry_classes))}
badisa marked this conversation as resolved.
Show resolved Hide resolved
symmetry_classes = np.array([idx_map[old_idx] for old_idx in symmetry_classes])
assert set(symmetry_classes) == set(range(n_classes))

return symmetry_classes


def get_spurious_param_idxs(mol, handle) -> NDArray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is spurious the right term? This seems more specific to indistinguishable param idxs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Indistinguishable" sounds like an improvement.

(Although still not perfect, since I don't want to detect "redundant / indistinguishable parameters", but instead I want to detect "parameters whose variation can introduce spurious differences between atoms that should be indistinguishable")

"""Find all indices i such that adjusting handle.params[i] can
result in distinct parameters being assigned to indistinguishable atoms in mol.

Optimizing the parameters associated with these indices should be avoided.
"""

symmetry_classes = get_symmetry_classes(mol)
smirks = handle.smirks

def assign_params(ff_params):
return handle.static_parameterize(ff_params, smirks, mol, validate=False)

def compute_spuriosity(ff_params):
# apply parameters
sys_params = assign_params(ff_params)

# compute the mean per symmetry class
class_sums = jax.ops.segment_sum(sys_params, symmetry_classes)
class_means = class_sums / np.bincount(symmetry_classes)

# expect no atom can be adjusted independently of others in its symmetry class
expected_constant_within_class = class_means[symmetry_classes]
assert expected_constant_within_class.shape == sys_params.shape
deviation_from_class_means = sys_params - expected_constant_within_class
spuriosity = jnp.sum(deviation_from_class_means ** 2)

return spuriosity

# TODO: may also want to try several points in the parameter space,
# randomly or systematically flipping signs...
trial_params = np.ones(len(handle.params)) # TODO: generalize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle.params + 1.0 ? So it will always be a perturbation to the original params.

assert trial_params.shape == handle.params.shape

# get idxs where component of gradient w.r.t. trial_params is != 0
thresh = 1e-4
g = grad(compute_spuriosity)(trial_params)
spurious_idxs = np.where(np.abs(g) > thresh)[0]

return spurious_idxs