@@ -51,17 +51,21 @@ def _compute_rspace(
51
51
charges : torch .Tensor ,
52
52
neighbor_indices : torch .Tensor ,
53
53
neighbor_distances : torch .Tensor ,
54
+ node_mask : torch .Tensor | None = None ,
55
+ pair_mask : torch .Tensor | None = None ,
54
56
) -> torch .Tensor :
55
57
# Compute the pair potential terms V(r_ij) for each pair of atoms (i,j)
56
58
# contained in the neighbor list
57
59
with profiler .record_function ("compute bare potential" ):
58
60
if self .potential .smearing is None :
59
61
if self .potential .exclusion_radius is None :
60
- potentials_bare = self .potential .from_dist (neighbor_distances )
61
- else :
62
- potentials_bare = self .potential .from_dist (neighbor_distances ) * (
63
- 1 - self .potential .f_cutoff (neighbor_distances )
62
+ potentials_bare = self .potential .from_dist (
63
+ neighbor_distances , pair_mask = pair_mask
64
64
)
65
+ else :
66
+ potentials_bare = self .potential .from_dist (
67
+ neighbor_distances , pair_mask = pair_mask
68
+ ) * (1 - self .potential .f_cutoff (neighbor_distances ))
65
69
else :
66
70
potentials_bare = self .potential .sr_from_dist (neighbor_distances )
67
71
@@ -109,6 +113,8 @@ def forward(
109
113
neighbor_indices : torch .Tensor ,
110
114
neighbor_distances : torch .Tensor ,
111
115
periodic : Optional [torch .Tensor ] = None ,
116
+ node_mask : torch .Tensor | None = None ,
117
+ pair_mask : torch .Tensor | None = None ,
112
118
):
113
119
r"""
114
120
Compute the potential "energy".
@@ -161,6 +167,8 @@ def forward(
161
167
charges = charges ,
162
168
neighbor_indices = neighbor_indices ,
163
169
neighbor_distances = neighbor_distances ,
170
+ node_mask = node_mask ,
171
+ pair_mask = pair_mask ,
164
172
)
165
173
166
174
if self .potential .smearing is None :
0 commit comments