Skip to content

Commit bfa3cdb

Browse files
committed
Add pair_mask parameter to potential calculations in Calculator and CoulombPotential
1 parent e9cc2a2 commit bfa3cdb

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

src/torchpme/calculators/calculator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,21 @@ def _compute_rspace(
5151
charges: torch.Tensor,
5252
neighbor_indices: torch.Tensor,
5353
neighbor_distances: torch.Tensor,
54+
node_mask: torch.Tensor | None = None,
55+
pair_mask: torch.Tensor | None = None,
5456
) -> torch.Tensor:
5557
# Compute the pair potential terms V(r_ij) for each pair of atoms (i,j)
5658
# contained in the neighbor list
5759
with profiler.record_function("compute bare potential"):
5860
if self.potential.smearing is None:
5961
if self.potential.exclusion_radius is None:
60-
potentials_bare = self.potential.from_dist(neighbor_distances)
61-
else:
62-
potentials_bare = self.potential.from_dist(neighbor_distances) * (
63-
1 - self.potential.f_cutoff(neighbor_distances)
62+
potentials_bare = self.potential.from_dist(
63+
neighbor_distances, pair_mask=pair_mask
6464
)
65+
else:
66+
potentials_bare = self.potential.from_dist(
67+
neighbor_distances, pair_mask=pair_mask
68+
) * (1 - self.potential.f_cutoff(neighbor_distances))
6569
else:
6670
potentials_bare = self.potential.sr_from_dist(neighbor_distances)
6771

@@ -109,6 +113,8 @@ def forward(
109113
neighbor_indices: torch.Tensor,
110114
neighbor_distances: torch.Tensor,
111115
periodic: Optional[torch.Tensor] = None,
116+
node_mask: torch.Tensor | None = None,
117+
pair_mask: torch.Tensor | None = None,
112118
):
113119
r"""
114120
Compute the potential "energy".
@@ -161,6 +167,8 @@ def forward(
161167
charges=charges,
162168
neighbor_indices=neighbor_indices,
163169
neighbor_distances=neighbor_distances,
170+
node_mask=node_mask,
171+
pair_mask=pair_mask,
164172
)
165173

166174
if self.potential.smearing is None:

src/torchpme/potentials/coulomb.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,21 @@ def __init__(
3939
):
4040
super().__init__(smearing, exclusion_radius, exclusion_degree)
4141

42-
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
42+
def from_dist(
43+
self, dist: torch.Tensor, pair_mask: torch.Tensor | None = None
44+
) -> torch.Tensor:
4345
"""
4446
Full :math:`1/r` potential as a function of :math:`r`.
4547
4648
:param dist: torch.tensor containing the distances at which the potential is to
4749
be evaluated.
4850
"""
49-
return 1.0 / dist
51+
result = 1.0 / dist.clamp(min=1e-12)
52+
53+
if pair_mask is not None:
54+
result = result * pair_mask # elementwise multiply, keeps shape fixed
55+
56+
return result
5057

5158
def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
5259
"""

0 commit comments

Comments
 (0)