-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Detect non-trainable parameters in AM1CCCHandler #737
Changes from all commits
8f8c17e
074b679
8e74707
757b1f8
40495e4
fe95e72
a819cce
158305b
1ae638b
15a2fc2
b268c55
f895f39
6506133
0ab6796
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,11 +10,13 @@ | |
from rdkit import Chem | ||
from rdkit.Chem import AllChem, rdmolops | ||
|
||
from timemachine.constants import ONE_4PI_EPS0 | ||
from timemachine.constants import DEFAULT_FF, ONE_4PI_EPS0 | ||
from timemachine.datasets import fetch_freesolv | ||
from timemachine.ff import Forcefield | ||
from timemachine.ff.charges import AM1CCC_CHARGES | ||
from timemachine.ff.handlers import bonded, nonbonded | ||
from timemachine.ff.handlers.deserialize import deserialize_handlers | ||
from timemachine.ff.handlers.utils import get_spurious_param_idxs, get_symmetry_classes | ||
|
||
|
||
def test_harmonic_bond(): | ||
|
@@ -569,6 +571,37 @@ def test_am1ccc_throws_error_on_phosphorus(): | |
assert "unsupported element" in str(e) | ||
|
||
|
||
def test_am1ccc_symmetric_patterns(): | ||
"""Assert that AM1CCCHandler instances: | ||
* have a symmetric_pattern_mask property with the expected shape and contents | ||
* can be used with jax.grad | ||
* raise ValueError when parameters associated with symmetric patterns != 0""" | ||
ff = Forcefield.load_from_file(DEFAULT_FF) | ||
q_handle = ff.q_handle | ||
smirks, params = q_handle.smirks, q_handle.params | ||
mol = fetch_freesolv()[123] | ||
|
||
sym_pattern_mask = q_handle.symmetric_pattern_mask | ||
assert len(sym_pattern_mask) == len(q_handle.smirks) | ||
assert sum(sym_pattern_mask) == 25 | ||
|
||
def f(params): | ||
return np.sum(q_handle.static_parameterize(params, smirks, mol) ** 2) | ||
|
||
_ = f(params) | ||
_ = jax.grad(f)(params) | ||
|
||
# expect problems when parameters associated with symmetric patterns are 1.0 | ||
bad_params = np.array(params) | ||
bad_params[sym_pattern_mask] = 1.0 | ||
|
||
with pytest.raises(ValueError): | ||
_ = f(bad_params) | ||
|
||
with pytest.raises(ValueError): | ||
_ = jax.grad(f)(bad_params) | ||
|
||
|
||
def test_am1_differences(): | ||
|
||
ff_raw = open("timemachine/ff/params/smirnoff_1_1_0_ccc.py").read() | ||
|
@@ -899,3 +932,38 @@ def test_lennard_jones_handler(): | |
# if a parameter is > 99 then its adjoint should be zero (converse isn't necessarily true since) | ||
mask = np.argwhere(params > 90) | ||
assert np.all(adjoints[mask] == 0.0) | ||
|
||
|
||
def test_symmetry_classes(): | ||
"""Assert get_symmetry_classes returns arrays of expected length num_atoms""" | ||
mols = fetch_freesolv()[:10] | ||
for mol in mols: | ||
sym_classes = get_symmetry_classes(mol) | ||
assert len(sym_classes) == mol.GetNumAtoms() | ||
|
||
|
||
def test_spurious_param_idxs(): | ||
"""Assert that get_spurious_param_idxs | ||
* returns arrays with length sometimes > 0 | ||
* only ever returns indices of symmetric bond types""" | ||
|
||
mols = fetch_freesolv() | ||
mol_dict = {mol.GetProp("_Name"): mol for mol in mols} | ||
|
||
# expect three of these to trigger | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Might be nice to separate out the 3 that trigger, making it easier to reason about the test at a later date There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call -- will fix! |
||
names = ["mobley_1743409", "mobley_1755375", "mobley_1760914", "mobley_1770205", "mobley_1781152"] | ||
select_mols = [mol_dict[name] for name in names] | ||
|
||
ff = Forcefield.load_from_file(DEFAULT_FF) | ||
q_handle = ff.q_handle | ||
sym_idxs = set(np.where(q_handle.symmetric_pattern_mask)[0]) | ||
|
||
n_params_detected_per_mol = [] | ||
for mol in select_mols: | ||
spurious_idxs = get_spurious_param_idxs(mol, q_handle) | ||
assert set(spurious_idxs).issubset(sym_idxs) | ||
|
||
n_params_detected_per_mol.append(len(spurious_idxs)) | ||
|
||
n_mols_with_any_detected_params = (np.array(n_params_detected_per_mol) > 0).sum() | ||
assert n_mols_with_any_detected_params == 3 |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -5,13 +5,12 @@ | |||||||
import jax.numpy as jnp | ||||||||
import networkx as nx | ||||||||
import numpy as np | ||||||||
from rdkit import Chem | ||||||||
|
||||||||
from timemachine import constants | ||||||||
from timemachine.ff.handlers.bcc_aromaticity import AromaticityModel | ||||||||
from timemachine.ff.handlers.bcc_aromaticity import match_smirks as oe_match_smirks | ||||||||
from timemachine.ff.handlers.serialize import SerializableMixIn | ||||||||
from timemachine.ff.handlers.utils import canonicalize_bond | ||||||||
from timemachine.ff.handlers.utils import canonicalize_bond, check_bond_smarts_symmetric, convert_to_oe | ||||||||
from timemachine.ff.handlers.utils import match_smirks as rd_match_smirks | ||||||||
from timemachine.graph_utils import convert_to_nx | ||||||||
|
||||||||
|
@@ -26,23 +25,6 @@ | |||||||
ELF10_MODELS = (AM1ELF10, AM1BCCELF10) | ||||||||
|
||||||||
|
||||||||
def convert_to_oe(mol): | ||||||||
"""Convert an ROMol into an OEMol""" | ||||||||
|
||||||||
# imported here for optional dependency | ||||||||
from openeye import oechem | ||||||||
|
||||||||
mb = Chem.MolToMolBlock(mol) | ||||||||
ims = oechem.oemolistream() | ||||||||
ims.SetFormat(oechem.OEFormat_SDF) | ||||||||
ims.openstring(mb) | ||||||||
|
||||||||
for buf_mol in ims.GetOEMols(): | ||||||||
oemol = oechem.OEMol(buf_mol) | ||||||||
ims.close() | ||||||||
return oemol | ||||||||
|
||||||||
|
||||||||
def oe_generate_conformations(oemol, sample_hydrogens=True): | ||||||||
"""Generate conformations for the input molecule. | ||||||||
The molecule is modified in place. | ||||||||
|
@@ -286,7 +268,7 @@ def apply_bond_charge_corrections(initial_charges, bond_idxs, deltas): | |||||||
|
||||||||
net_charge = jnp.sum(initial_charges) | ||||||||
final_net_charge = jnp.sum(final_charges) | ||||||||
net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-5) | ||||||||
net_charge_is_unchanged = jnp.isclose(final_net_charge, net_charge, atol=1e-4) | ||||||||
|
||||||||
assert net_charge_is_unchanged | ||||||||
|
||||||||
|
@@ -524,7 +506,7 @@ def parameterize(self, mol): | |||||||
return self.partial_parameterize(self.params, mol) | ||||||||
|
||||||||
@staticmethod | ||||||||
def static_parameterize(params, smirks, mol): | ||||||||
def static_parameterize(params, smirks, mol, validate=True): | ||||||||
""" | ||||||||
Parameters | ||||||||
---------- | ||||||||
|
@@ -534,8 +516,12 @@ def static_parameterize(params, smirks, mol): | |||||||
SMIRKS patterns matching bonds, to be parsed using OpenEye Toolkits | ||||||||
mol: Chem.ROMol | ||||||||
molecule to be parameterized. | ||||||||
|
||||||||
validate: bool | ||||||||
check params, smirks | ||||||||
""" | ||||||||
if validate: | ||||||||
jkausrelay marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After talking offline about how this is intended to verify the FF and not the molecule's parameterization. What about moving this validation to the Forcefield class? It seems like longer term we will want to do more validation of the forcefield parameters as we train them, and potentially not just the non-bonded parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, it might be more appropriate for this to be in the Forcefield constructor at this branch: timemachine/timemachine/ff/__init__.py Lines 71 to 73 in 776873a
Based on discussion with @jkausrelay yesterday, the initial thought was to put this check in the core function that gets called by everything else ( I'm agnostic about where the check should go, as long as it surfaces an error/warning before applying invalid parameters to a molecule. |
||||||||
AM1CCCHandler.static_validate(params, smirks) | ||||||||
|
||||||||
am1_charges = compute_or_load_am1_charges(mol) | ||||||||
bond_idxs, type_idxs = compute_or_load_bond_smirks_matches(mol, smirks) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could Basically check the result of the match instead of the string. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may be preferable (won't have any false negatives or fragile string hacks) One consideration is that it makes this check a function of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would advocate for this, even if only as an additional check. The regex seems to have gaps and my regex-foo is not up for the task. |
||||||||
|
||||||||
|
@@ -545,3 +531,23 @@ def static_parameterize(params, smirks, mol): | |||||||
assert q_params.shape[0] == mol.GetNumAtoms() # check that return shape is consistent with input mol | ||||||||
|
||||||||
return q_params | ||||||||
|
||||||||
@staticmethod | ||||||||
def static_validate(params, smirks): | ||||||||
"""Raise ValueError if any symmetric bond patterns have non-zero parameter""" | ||||||||
for pattern, param in zip(smirks, params): | ||||||||
pattern_symmetric = check_bond_smarts_symmetric(pattern) | ||||||||
param_nonzero = param != 0 | ||||||||
if pattern_symmetric and param_nonzero: | ||||||||
raise ValueError(f"a symmetric bond pattern {pattern} has a non-zero parameter {param}") | ||||||||
|
||||||||
@property | ||||||||
def symmetric_pattern_mask(self): | ||||||||
"""Boolean array where mask[i] == True if self.smirks[i] symmetric | ||||||||
|
||||||||
Notes | ||||||||
----- | ||||||||
* During training, avoid modifying any parameters where this mask is True | ||||||||
""" | ||||||||
mask = np.array([check_bond_smarts_symmetric(pattern) for pattern in self.smirks]) | ||||||||
return mask |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,6 +1,30 @@ | ||||||||
import re | ||||||||
|
||||||||
import jax | ||||||||
import numpy as np | ||||||||
from jax import grad | ||||||||
from jax import numpy as jnp | ||||||||
from numpy.typing import NDArray | ||||||||
from rdkit import Chem | ||||||||
|
||||||||
|
||||||||
def convert_to_oe(mol): | ||||||||
"""Convert an ROMol into an OEMol""" | ||||||||
|
||||||||
# imported here for optional dependency | ||||||||
from openeye import oechem | ||||||||
|
||||||||
mb = Chem.MolToMolBlock(mol) | ||||||||
ims = oechem.oemolistream() | ||||||||
ims.SetFormat(oechem.OEFormat_SDF) | ||||||||
ims.openstring(mb) | ||||||||
|
||||||||
for buf_mol in ims.GetOEMols(): | ||||||||
oemol = oechem.OEMol(buf_mol) | ||||||||
ims.close() | ||||||||
return oemol | ||||||||
|
||||||||
|
||||||||
def canonicalize_bond(arr): | ||||||||
""" | ||||||||
Canonicalize a bonded interaction. If arr[0] < arr[-1] then arr is | ||||||||
|
@@ -67,3 +91,99 @@ def match_smirks(mol, smirks): | |||||||
matches.append(tuple(mas)) | ||||||||
|
||||||||
return matches | ||||||||
|
||||||||
|
||||||||
def check_bond_smarts_symmetric(bond_smarts: str) -> bool: | ||||||||
"""Match [<atom1>:1]*[<atom2>:2] | ||||||||
and return whether atom1 and atom2 are identical strings | ||||||||
|
||||||||
Notes | ||||||||
----- | ||||||||
* The AM1CCC model contains symmetric patterns that must be assigned 0 parameters | ||||||||
(Otherwise, undefined behavior when symmetric bond matches in an arbitrary direction) | ||||||||
* Only checks string equivalence! | ||||||||
for example | ||||||||
check_bond_smarts_symmetric("[#6,#7:1]~[#7,#6:2]") | ||||||||
will be a false negative | ||||||||
* Does not handle all possible bond smarts | ||||||||
for example | ||||||||
"[#6:1]~[#6:2]~[#1]" | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any idea how likely these patterns are to become a problem in the future? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already a problem: there are two patterns in the current model that are not parsed correctly
Since these don't match the regex, they return Don't have a sense of what SMARTS subset may be exercised by future edits to this force field. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about using the following pattern?
Would cover those two cases. Unless you also need to match the |
||||||||
or | ||||||||
"[#6:1](~[#8])(~[#16:2])" | ||||||||
will not be matched, will return False by default. | ||||||||
However, for the bond smarts subset used by the AM1CCC model, this covers most cases | ||||||||
""" | ||||||||
|
||||||||
pattern = re.compile(r"\[(?P<atom1>.*)\:1\].\[(?P<atom2>.*)\:2\]") | ||||||||
match = pattern.match(bond_smarts) | ||||||||
|
||||||||
if type(match) is re.Match: | ||||||||
complete = match.span() == (0, len(bond_smarts)) | ||||||||
symmetric = match.group("atom1") == match.group("atom2") | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be advantageous to do the following?
Suggested change
Seems to work in the case of That said, nothing in our AM1BCC charges seems to have this lack of ordering, so probably unimportant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can double-check this -- one issue with
returns However, there are other ways you could produce syntactically distinct SMARTS with identical meaning. Not sure if there's a robust "canonicalize SMARTS" function... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Increasingly disliking the regex approach There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, many possibilities for false negatives. Alternative in #739 may be preferable |
||||||||
return complete and symmetric | ||||||||
else: | ||||||||
# TODO: possibly warn in this branch? | ||||||||
# (false negatives possible -- but are also possible in the other branch...) | ||||||||
return False | ||||||||
|
||||||||
|
||||||||
def get_symmetry_classes(rdmol: Chem.Mol) -> NDArray: | ||||||||
"""[atom.GetSymmetryClass() for atom in mol], | ||||||||
just renumbered for convenience""" | ||||||||
|
||||||||
# imported here for optional dependency | ||||||||
from openeye import oechem | ||||||||
|
||||||||
oemol = convert_to_oe(rdmol) | ||||||||
oechem.OEPerceiveSymmetry(oemol) | ||||||||
symmetry_classes = np.array([atom.GetSymmetryClass() for atom in oemol.GetAtoms()]) | ||||||||
n_classes = len(set(symmetry_classes)) | ||||||||
|
||||||||
# make indexy / contiguous from 0 to n_classes | ||||||||
idx_map = {old_idx: new_idx for (new_idx, old_idx) in enumerate(set(symmetry_classes))} | ||||||||
badisa marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
symmetry_classes = np.array([idx_map[old_idx] for old_idx in symmetry_classes]) | ||||||||
assert set(symmetry_classes) == set(range(n_classes)) | ||||||||
|
||||||||
return symmetry_classes | ||||||||
|
||||||||
|
||||||||
def get_spurious_param_idxs(mol, handle) -> NDArray: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Indistinguishable" sounds like an improvement. (Although still not perfect, since I don't want to detect "redundant / indistinguishable parameters", but instead I want to detect "parameters whose variation can introduce spurious differences between atoms that should be indistinguishable") |
||||||||
"""Find all indices i such that adjusting handle.params[i] can | ||||||||
result in distinct parameters being assigned to indistinguishable atoms in mol. | ||||||||
|
||||||||
Optimizing the parameters associated with these indices should be avoided. | ||||||||
""" | ||||||||
|
||||||||
symmetry_classes = get_symmetry_classes(mol) | ||||||||
smirks = handle.smirks | ||||||||
|
||||||||
def assign_params(ff_params): | ||||||||
return handle.static_parameterize(ff_params, smirks, mol, validate=False) | ||||||||
|
||||||||
def compute_spuriosity(ff_params): | ||||||||
# apply parameters | ||||||||
sys_params = assign_params(ff_params) | ||||||||
|
||||||||
# compute the mean per symmetry class | ||||||||
class_sums = jax.ops.segment_sum(sys_params, symmetry_classes) | ||||||||
class_means = class_sums / np.bincount(symmetry_classes) | ||||||||
|
||||||||
# expect no atom can be adjusted independently of others in its symmetry class | ||||||||
expected_constant_within_class = class_means[symmetry_classes] | ||||||||
assert expected_constant_within_class.shape == sys_params.shape | ||||||||
deviation_from_class_means = sys_params - expected_constant_within_class | ||||||||
spuriosity = jnp.sum(deviation_from_class_means ** 2) | ||||||||
|
||||||||
return spuriosity | ||||||||
|
||||||||
# TODO: may also want to try several points in the parameter space, | ||||||||
# randomly or systematically flipping signs... | ||||||||
trial_params = np.ones(len(handle.params)) # TODO: generalize | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. handle.params + 1.0 ? So it will always be a perturbation to the original params. |
||||||||
assert trial_params.shape == handle.params.shape | ||||||||
|
||||||||
# get idxs where component of gradient w.r.t. trial_params is != 0 | ||||||||
thresh = 1e-4 | ||||||||
g = grad(compute_spuriosity)(trial_params) | ||||||||
spurious_idxs = np.where(np.abs(g) > thresh)[0] | ||||||||
|
||||||||
return spurious_idxs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this assert that some compounds share symmetry classes such that
len(set(sym_classes)) != len(sym_classes)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so -- the numbering won't be comparable across compounds