Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow computing all-pairs potential on subset of atoms, add consistency check #660

Merged
merged 11 commits into from
Mar 16, 2022

Conversation

mcwitt
Copy link
Collaborator

@mcwitt mcwitt commented Mar 2, 2022

Related: #472

  • Adds a constructor argument atom_idxs to NonbondedAllPairs; this is used to select a subset of atoms for computing the all-pairs potential
  • Adds consistency check comparing the result of the full Nonbonded(exclusions, scales) potential with the sum of
    • NonbondedAllPairs(host)
    • NonbondedAllPairs(ligand)
    • NonbondedInteractionGroup(host, ligand)
    • NonbondedPairList(exclusions, scales)

Notes for review

  • Commits starting with Rename are purely string substitutions and should contain no other changes -- these were mainly to reflect that we're no longer just doing permutations, since the potential can be computed on a subset of atoms
  • Most of the implementation changes to allow computing on a subset have been squashed into 71cc0e5

@mcwitt mcwitt force-pushed the all-pairs-on-subset branch from 24fb419 to 4bb9dba Compare March 2, 2022 02:23
@mcwitt mcwitt mentioned this pull request Mar 2, 2022
@mcwitt mcwitt force-pushed the all-pairs-on-subset branch 2 times, most recently from 739d651 to 834604f Compare March 3, 2022 17:01
@mcwitt mcwitt changed the base branch from master to fanout-summed-potential March 3, 2022 17:01
@mcwitt mcwitt force-pushed the all-pairs-on-subset branch 2 times, most recently from 605ae9d to 180ff2e Compare March 3, 2022 21:36
@mcwitt mcwitt force-pushed the fanout-summed-potential branch 2 times, most recently from d70204e to 0e090b9 Compare March 4, 2022 17:05
@mcwitt mcwitt force-pushed the all-pairs-on-subset branch from 180ff2e to e063057 Compare March 4, 2022 18:01
@mcwitt mcwitt changed the base branch from fanout-summed-potential to master March 4, 2022 18:01
@mcwitt mcwitt force-pushed the all-pairs-on-subset branch 2 times, most recently from cfc2a80 to a79be9a Compare March 4, 2022 19:31

if interpolated:
# TODO: why does interpolation break bitwise equivalence?
np.testing.assert_allclose(du_dp_test, du_dp_ref, rtol=1e-10, atol=1e-10)
Copy link
Collaborator Author

@mcwitt mcwitt Mar 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely sure yet why the du_dps aren't bitwise equivalent (only in the interpolated case)

Copy link
Collaborator Author

@mcwitt mcwitt Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved offline in discussion with @proteneer. Bitwise equivalence for the interpolated case is not possible with the current structure of the code, ultimately due to the fact that the distributive property of multiplication breaks bitwise equivalence (even when summation is done in fixed point), i.e. c * fixed_sum(a, b) != fixed_sum(c * a, c * b) (see example in gist).

Added a note in 15a17b7.

@mcwitt mcwitt marked this pull request as ready for review March 4, 2022 19:46
@mcwitt mcwitt marked this pull request as draft March 4, 2022 20:00
@mcwitt mcwitt force-pushed the all-pairs-on-subset branch from 518c76a to cf9a360 Compare March 4, 2022 20:37
@mcwitt mcwitt force-pushed the all-pairs-on-subset branch 5 times, most recently from 752e49c to 91cfe42 Compare March 4, 2022 23:09
@mcwitt mcwitt marked this pull request as ready for review March 4, 2022 23:52
@mcwitt mcwitt requested review from maxentile, proteneer and badisa March 5, 2022 01:22
@@ -376,26 +407,31 @@ void NonbondedAllPairs<RealType, Interpolated>::execute_device(

// coords are N,3
if (d_du_dx) {
k_inv_permute_accum<<<dimGrid, tpb, 0, stream>>>(N, d_perm_, d_sorted_du_dx_, d_du_dx);
k_scatter_accum<<<dim3(ceil_divide(K_, tpb), 3, 1), tpb, 0, stream>>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand the scatter gather paradigm, scatter and accumulate are antonyms?

Copy link
Collaborator Author

@mcwitt mcwitt Mar 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, maybe this is confusing.. my intent was to generalize the naming when going from the existing K=N case to the more general K<=N as:

  • permute -> gather
  • inverse permute -> scatter

This is intended to be independent of whether we accumulate or assign to the result array. Does that make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am more familiar with the fork and join paradigm and it seemed like fork and scatter are equivalent and gather and join are equivalent. In which case this seemed backwards? I might just not understand the paradigm, was unexpected to see a scatter (or what I interpreted as a broadcast/fork) to reduce.

Doesn't matter as long as it is a consistent convention

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. I think this PR is consistently applying the convention described here, but I'm open to other suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The picture I had for "scatter" was https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html .

(the jax / xla documentation for scatter / gather is a bit less intuitive for me)

This is intended to be independent of whether we accumulate or assign to the result array. Does that make sense?

Will need to double-check if these are completely independent (case that allows repeated target indices needs some reduction operation ("accumulate sum"), case that disallows repeated indices can allow direct assignment).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The picture I had for "scatter" was https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html .

Err, posted that before I saw @mcwitt 's link #660 (comment) , didn't mean to override that. (I think the only difference is the wiki definition and the current function assume no repeated scatter idxs, while the pytorch_scatter reduces over any repeated scatter idxs. For the current use, only unique idxs make sense.)

I might just not understand the paradigm, was unexpected to see a scatter (or what I interpreted as a broadcast/fork) to reduce.

I had misread the y[idxs[i]] += x[i] as a reduction, but @mcwitt clarified offline that it's always a single addition to whatever was in the output array before.

Copy link
Collaborator

@maxentile maxentile left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Copy link
Owner

@proteneer proteneer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@mcwitt mcwitt force-pushed the all-pairs-on-subset branch from 15a17b7 to 34e9375 Compare March 16, 2022 14:57
@mcwitt mcwitt enabled auto-merge (rebase) March 16, 2022 15:00
@mcwitt mcwitt merged commit 7caf80d into master Mar 16, 2022
@mcwitt mcwitt deleted the all-pairs-on-subset branch March 16, 2022 15:54
@proteneer proteneer added the cr_cppcuda C++ and CUDA label Apr 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cr_cppcuda C++ and CUDA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants