Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect non-trainable parameters in AM1CCCHandler #737

Closed
wants to merge 14 commits into from

Conversation

maxentile
Copy link
Collaborator

@maxentile maxentile commented May 7, 2022

Problem:

I observed that it is currently possible for AM1CCCHandler to assign distinct parameters to indistinguishable atoms.

Based on discussion with @proteneer , we found that this behavior can be triggered when a symmetric bond pattern in AM1CCCHandler such as

('[#6X4:1]-[#6X4:2]', 0.0),

is assigned a non-zero parameter. When a BCC like ('[#6X4:1]-[#6X4:2]', +1.0) can match a pair of atoms in either direction, an arbitrary direction is selected, leading to undefined behavior in general. When applied to a symmetric molecule like cyclohexane, this can even result in different charges being applied to indistinguishable atoms.

In the initial, untrained force field file, BCCs defined by symmetric bond SMARTS all have parameter values of 0.0, so the undefined match direction has no effect on the assigned charges. However, gradients of realistic training objectives w.r.t. these parameters can be non-zero, so numerical optimization can readily exploit this undefined behavior, which is undesirable.


Possible workarounds:

  • Manually identify symmetric patterns, and set their parameter values to nan in the forcefield .py file, so that gradients w.r.t. these parameters would be nan-poisoned. (However, this would also have the effect of nan-poisoning the assigned charges themselves...)
  • Detect symmetric SMARTS patterns in the forcefield by string manipulation
  • Detect symmetric SMARTS patterns in the force field by inspection of the resulting query mols
  • Detect symmetric patterns at runtime by seeing if they match in both directions on a given molecule
  • Detect where initial parameter == 0.0. (However, there are some safely asymmetric patterns, such as
    ('[#6X3$(*=[#6]):1]-[#1:2]', 0.0),
    , where the initial parameter is 0.0.)
  • Modify definition of AM1CCCHandler so that it's safe for arbitrary SMARTS patterns to have arbitrary parameter values (e.g. by matching both directions rather than an arbitrary direction)

Current PR:

This draft PR adds two heuristic ways to detect parameter indices that currently should not be trained in AM1CCCHandler:

  • check_bond_smarts_symmetric -- string manipulation heuristic -- regex match "[<atom1>:1]*[<atom2>:2]" and return whether atom1 and atom2 are identical strings... If applied to

    'AM1CCC': {'patterns': [('[#6X4:1]-[#16X1-1:2]', 0.9818644744993272),
    , will flag 25 patterns as symmetric.

  • get_spurious_param_idxs -- intended to be similar to the use case of gradient-based force field optimization (defines an objective function that can only be increased if spuriously distinct parameters are assigned to indistinguishable atoms, takes gradient w.r.t. params, reports indices of any non-zero components of this gradient). This was how I initially detected the problem, and in principle it could be useful for validating any model that assigns parameters to atoms.


TBD:

Not sure exactly where best to use this information yet.

  • Add a "trainable" vs. "non-trainable" mask to the AM1CCCHandler object? (This seems most natural to me, since training script already makes a selection of which parameters to train -- want to use this information to narrow the selection of trainable parameters.)
  • Throw an exception due to undefined behavior in AM1CCCHandler.static_parameterize when parameters associated with symmetric bond SMARTS are != 0?

@maxentile maxentile marked this pull request as ready for review May 9, 2022 20:49
@maxentile maxentile requested a review from jkausrelay May 10, 2022 13:04
tests/test_handlers.py Outdated Show resolved Hide resolved
@maxentile maxentile requested a review from badisa May 10, 2022 17:04

# TODO: may also want to try several points in the parameter space,
# randomly or systematically flipping signs...
trial_params = np.ones(len(handle.params)) # TODO: generalize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle.params + 1.0 ? So it will always be a perturbation to the original params.

"""
if validate:
AM1CCCHandler.static_validate(params, smirks)

am1_charges = compute_or_load_am1_charges(mol)
bond_idxs, type_idxs = compute_or_load_bond_smirks_matches(mol, smirks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could compute_or_load_bond_smirks_matches return the list of symmetric bonds and then do the check for delta[symmetric_bond_idxs] != 0 here?

Basically check the result of the match instead of the string.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be preferable (won't have any false negatives or fragile string hacks)

One consideration is that it makes this check a function of params, smirks, mol, rather than a function of params, smirks only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would advocate for this, even if only as an additional check. The regex seems to have gaps and my regex-foo is not up for the task.

will be a false negative
* Does not handle all possible bond smarts
for example
"[#6:1]~[#6:2]~[#1]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea how likely these patterns are to become a problem in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already a problem: there are two patterns in the current model that are not parsed correctly

[#6X3:1](~[#8X1,#16X1])(~[#16X1:2])
and
[#6X3:1](~[#8X1,#16X1])(~[#8X1:2])

Since these don't match the regex, they return False by default.

Don't have a sense of what SMARTS subset may be exercised by future edits to this force field.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using the following pattern?

pattern = re.compile(r"(\[(?P<atom1>.+)\:1\].+\[(?P<atom2>.+)\:2\]`)
...
complete = match.span() == (0, bond_smarts.index(":2]") + 3)

Would cover those two cases. Unless you also need to match the 16X1 to the 16X1 in the middle? Which I guess you could do the hacky thing of having a third group in the middle and check if the atom1 and atom2 are represented somewhere in the middle....


if type(match) is re.Match:
complete = match.span() == (0, len(bond_smarts))
symmetric = match.group("atom1") == match.group("atom2")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be advantageous to do the following?

Suggested change
symmetric = match.group("atom1") == match.group("atom2")
symmetric = list(sorted(match.group("atom1").split(","))) == list(sorted(match.group("atom2").split(",")))

Seems to work in the case of [#6,#7:1]~[#7,#6:2]

That said, nothing in our AM1BCC charges seems to have this lack of ordering, so probably unimportant

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can double-check this -- one issue with .split(",") is that comma-separated lists can appear inside of nested structures. For example, applying match.group("atom1").split(",") to this pattern

('[#6X3$(*=[#7,#15]):1]-[#16X1-1:2]', 2.320877731439586),

returns ['#6X3$(*=[#7', '#15])'] which looks undesirable -- would probably want to sort recursively

However, there are other ways you could produce syntactically distinct SMARTS with identical meaning. Not sure if there's a robust "canonicalize SMARTS" function...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increasingly disliking the regex approach

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, many possibilities for false negatives. Alternative in #739 may be preferable


def test_symmetry_classes():
"""Assert get_symmetry_classes returns arrays of expected length num_atoms"""
mols = fetch_freesolv()[:10]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this assert that some compounds share symmetry classes such that len(set(sym_classes)) != len(sym_classes)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so -- the numbering won't be comparable across compounds

mols = fetch_freesolv()
mol_dict = {mol.GetProp("_Name"): mol for mol in mols}

# expect three of these to trigger
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Might be nice to separate out the 3 that trigger, making it easier to reason about the test at a later date

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call -- will fix!

return symmetry_classes


def get_spurious_param_idxs(mol, handle) -> NDArray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is spurious the right term? This seems more specific to indistinguishable param idxs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Indistinguishable" sounds like an improvement.

(Although still not perfect, since I don't want to detect "redundant / indistinguishable parameters", but instead I want to detect "parameters whose variation can introduce spurious differences between atoms that should be indistinguishable")

"""
if validate:
Copy link
Collaborator

@badisa badisa May 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After talking offline about how this is intended to verify the FF and not the molecule's parameterization. What about moving this validation to the Forcefield class? It seems like longer term we will want to do more validation of the forcefield parameters as we train them, and potentially not just the non-bonded parameters.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it might be more appropriate for this to be in the Forcefield constructor at this branch:

if isinstance(handle, nonbonded.AM1CCCHandler):
assert self.q_handle is None
self.q_handle = handle

Based on discussion with @jkausrelay yesterday, the initial thought was to put this check in the core function that gets called by everything else (static_parameterize, which is called by partial_parameterize, which is called by parameterize, ...), so it would cover the most use cases (such as training, where we don't necessarily create an instance of the Forcefield class for every optimizer step).

I'm agnostic about where the check should go, as long as it surfaces an error/warning before applying invalid parameters to a molecule.

@maxentile
Copy link
Collaborator Author

maxentile commented May 10, 2022

Should clarify: the reason the get_spurious_idxs function is more complicated than just "check if any charges in a symmetry class are different" is that AM1CCC already assigns slightly different partial charges to atoms that should be indistinguishable (because the AM1 charges are currently not symmetrized).

For example, the carbons in benzene are assigned charges that differ by <0.1%, but are not exactly the same.

import numpy as np
from rdkit import Chem
from timemachine.ff import Forcefield
from timemachine.constants import DEFAULT_FF

# molecule with 2 atomic symmetry classes
benzene = 'C1=CC=CC=C1'
mol = Chem.MolFromSmiles(benzene)
mol = Chem.AddHs(mol)

# expect 2 distinct charges, {+q, -q}
ff = Forcefield.load_from_file(DEFAULT_FF)
assigned_charges = np.array(ff.q_handle.parameterize(mol))

# observe 3 distinct charges
print(set(assigned_charges))  # {1.5335002, -1.5333823, -1.5340897}

get_spurious_idxs is designed to ignore that asymmetry, and instead check whether you can make asymmetry worse by locally varying the adjustable parameters.

@maxentile maxentile mentioned this pull request May 10, 2022
@maxentile
Copy link
Collaborator Author

Thanks for the discussion of this approach.

Closing since superseded by #739 and #738

@maxentile maxentile closed this May 13, 2022
@maxentile maxentile deleted the symmetric-bccs branch June 15, 2022 13:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants