-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow computing all-pairs potential on subset of atoms, add consistency check #660
Conversation
24fb419
to
4bb9dba
Compare
739d651
to
834604f
Compare
605ae9d
to
180ff2e
Compare
d70204e
to
0e090b9
Compare
180ff2e
to
e063057
Compare
cfc2a80
to
a79be9a
Compare
tests/nonbonded/test_consistency.py
Outdated
|
||
if interpolated: | ||
# TODO: why does interpolation break bitwise equivalence? | ||
np.testing.assert_allclose(du_dp_test, du_dp_ref, rtol=1e-10, atol=1e-10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not completely sure yet why the du_dps aren't bitwise equivalent (only in the interpolated case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved offline in discussion with @proteneer. Bitwise equivalence for the interpolated case is not possible with the current structure of the code, ultimately due to the fact that the distributive property of multiplication breaks bitwise equivalence (even when summation is done in fixed point), i.e. c * fixed_sum(a, b) != fixed_sum(c * a, c * b)
(see example in gist).
Added a note in 15a17b7.
518c76a
to
cf9a360
Compare
752e49c
to
91cfe42
Compare
@@ -376,26 +407,31 @@ void NonbondedAllPairs<RealType, Interpolated>::execute_device( | |||
|
|||
// coords are N,3 | |||
if (d_du_dx) { | |||
k_inv_permute_accum<<<dimGrid, tpb, 0, stream>>>(N, d_perm_, d_sorted_du_dx_, d_du_dx); | |||
k_scatter_accum<<<dim3(ceil_divide(K_, tpb), 3, 1), tpb, 0, stream>>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand the scatter gather
paradigm, scatter and accumulate are antonyms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, maybe this is confusing.. my intent was to generalize the naming when going from the existing K=N case to the more general K<=N as:
- permute -> gather
- inverse permute -> scatter
This is intended to be independent of whether we accumulate or assign to the result array. Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am more familiar with the fork and join
paradigm and it seemed like fork and scatter are equivalent and gather and join are equivalent. In which case this seemed backwards? I might just not understand the paradigm, was unexpected to see a scatter (or what I interpreted as a broadcast/fork) to reduce.
Doesn't matter as long as it is a consistent convention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, got it. I think this PR is consistently applying the convention described here, but I'm open to other suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The picture I had for "scatter" was https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html .
(the jax / xla documentation for scatter / gather is a bit less intuitive for me)
This is intended to be independent of whether we accumulate or assign to the result array. Does that make sense?
Will need to double-check if these are completely independent (case that allows repeated target indices needs some reduction operation ("accumulate sum"), case that disallows repeated indices can allow direct assignment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The picture I had for "scatter" was https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html .
Err, posted that before I saw @mcwitt 's link #660 (comment) , didn't mean to override that. (I think the only difference is the wiki definition and the current function assume no repeated scatter idxs, while the pytorch_scatter reduces over any repeated scatter idxs. For the current use, only unique idxs make sense.)
I might just not understand the paradigm, was unexpected to see a scatter (or what I interpreted as a broadcast/fork) to reduce.
I had misread the y[idxs[i]] += x[i]
as a reduction, but @mcwitt clarified offline that it's always a single addition to whatever was in the output array before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Better reflects the generality of these operations; there is no restriction that idxs, input array have the same length
Also consolidate and interpolated and non-interpolated cases
15a17b7
to
34e9375
Compare
Related: #472
atom_idxs
toNonbondedAllPairs
; this is used to select a subset of atoms for computing the all-pairs potentialNonbonded(exclusions, scales)
potential with the sum ofNonbondedAllPairs(host)
NonbondedAllPairs(ligand)
NonbondedInteractionGroup(host, ligand)
NonbondedPairList(exclusions, scales)
Notes for review
Rename
are purely string substitutions and should contain no other changes -- these were mainly to reflect that we're no longer just doing permutations, since the potential can be computed on a subset of atoms