Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/07-lode-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ def forward(
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_distances: Optional[torch.Tensor] = None,
periodic: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Update meshes
assert self.potential.smearing is not None # otherwise mypy complains
Expand Down
76 changes: 63 additions & 13 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Optional

import torch

Expand All @@ -9,15 +9,17 @@ def _validate_parameters(
positions: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
periodic: Union[torch.Tensor, None] = None,
periodic: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> None:
dtype = positions.dtype
device = positions.device

# check shape, dtype and device of positions
num_atoms = len(positions)
if list(positions.shape) != [len(positions), 3]:
num_atoms = positions.shape[-2]
if list(positions.shape) != [num_atoms, 3]:
raise ValueError(
"`positions` must be a tensor with shape [n_atoms, 3], got tensor "
f"with shape {list(positions.shape)}"
Expand All @@ -40,14 +42,6 @@ def _validate_parameters(
f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})"
)

if smearing is not None and torch.equal(
cell.det(), torch.tensor(0.0, dtype=cell.dtype, device=cell.device)
):
raise ValueError(
"provided `cell` has a determinant of 0 and therefore is not valid for "
"periodic calculation"
)

# check shape, dtype & device of `charges`
if charges.dim() != 2:
raise ValueError(
Expand Down Expand Up @@ -120,3 +114,59 @@ def _validate_parameters(
f"device of `periodic` ({periodic.device}) must be same as that of "
f"the `positions` class ({device})"
)

if pair_mask is not None:
if pair_mask.shape != neighbor_indices[:, 0].shape:
raise ValueError(
"`pair_mask` must have the same shape as the number of neighbors, "
f"got tensor with shape {list(pair_mask.shape)} while the number of "
f"neighbors is {neighbor_indices.shape[0]}"
)

if pair_mask.device != device:
raise ValueError(
f"device of `pair_mask` ({pair_mask.device}) must be same as that "
f"of the `positions` class ({device})"
)

if pair_mask.dtype != torch.bool:
raise TypeError(
f"type of `pair_mask` ({pair_mask.dtype}) must be torch.bool"
)

if node_mask is not None:
if node_mask.shape != (num_atoms,):
raise ValueError(
"`node_mask` must have shape [n_atoms], got tensor with shape "
f"{list(node_mask.shape)} where n_atoms is {num_atoms}"
)

if node_mask.device != device:
raise ValueError(
f"device of `node_mask` ({node_mask.device}) must be same as that "
f"of the `positions` class ({device})"
)

if node_mask.dtype != torch.bool:
raise TypeError(
f"type of `node_mask` ({node_mask.dtype}) must be torch.bool"
)

if kvectors is not None:
if kvectors.shape[1] != 3:
raise ValueError(
"`kvectors` must be a tensor of shape [n_kvecs, 3], got "
f"tensor with shape {list(kvectors.shape)}"
)

if kvectors.device != device:
raise ValueError(
f"device of `kvectors` ({kvectors.device}) must be same as that of "
f"the `positions` class ({device})"
)

if kvectors.dtype != dtype:
raise TypeError(
f"type of `kvectors` ({kvectors.dtype}) must be same as that of the "
f"`positions` class ({dtype})"
)
31 changes: 25 additions & 6 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,24 @@ def _compute_rspace(
charges: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
pair_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Compute the pair potential terms V(r_ij) for each pair of atoms (i,j)
# contained in the neighbor list
with profiler.record_function("compute bare potential"):
if self.potential.smearing is None:
if self.potential.exclusion_radius is None:
potentials_bare = self.potential.from_dist(neighbor_distances)
else:
potentials_bare = self.potential.from_dist(neighbor_distances) * (
1 - self.potential.f_cutoff(neighbor_distances)
potentials_bare = self.potential.from_dist(
neighbor_distances, pair_mask
)
else:
potentials_bare = self.potential.from_dist(
neighbor_distances, pair_mask
) * (1 - self.potential.f_cutoff(neighbor_distances, pair_mask))
else:
potentials_bare = self.potential.sr_from_dist(neighbor_distances)
potentials_bare = self.potential.sr_from_dist(
neighbor_distances, pair_mask
)

# Multiply the bare potential terms V(r_ij) with the corresponding charges
# of ``atom j'' to obtain q_j*V(r_ij). Since each atom j can be a neighbor of
Expand Down Expand Up @@ -109,6 +114,9 @@ def forward(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
periodic: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
):
r"""
Compute the potential "energy".
Expand Down Expand Up @@ -145,22 +153,31 @@ def forward(
:param periodic: optional torch.tensor of shape ``(3,)`` indicating which
directions are periodic (True) and which are not (False). If not
provided, full periodicity is assumed.
:param node_mask: Optional torch.tensor of shape ``(len(positions),)`` that
indicates which of the atoms are masked.
:param pair_mask: Optional torch.tensor containing a mask to be applied to the
result.
:param kvectors: Optional precomputed k-vectors to be used in the Fourier
space part of the calculation.
"""
_validate_parameters(
charges=charges,
cell=cell,
positions=positions,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
periodic=periodic,
pair_mask=pair_mask,
node_mask=node_mask,
kvectors=kvectors,
)

# Compute short-range (SR) part using a real space sum
potential_sr = self._compute_rspace(
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
pair_mask=pair_mask,
)

if self.potential.smearing is None:
Expand All @@ -171,6 +188,8 @@ def forward(
cell=cell,
positions=positions,
periodic=periodic,
kvectors=kvectors,
node_mask=node_mask,
)

return self.prefactor * (potential_sr + potential_lr)
1 change: 0 additions & 1 deletion src/torchpme/calculators/calculator_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def forward(
positions=positions,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_vectors.norm(dim=-1),
smearing=self.potential.smearing,
)

# Compute short-range (SR) part using a real space sum
Expand Down
24 changes: 15 additions & 9 deletions src/torchpme/calculators/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,23 @@ def _compute_kspace(
cell: torch.Tensor,
positions: torch.Tensor,
periodic: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Define k-space cutoff from required real-space resolution
k_cutoff = 2 * torch.pi / self.lr_wavelength
if kvectors is None:
k_cutoff = 2 * torch.pi / self.lr_wavelength

# Compute number of times each basis vector of the reciprocal space can be
# scaled until the cutoff is reached
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()
# Compute number of times each basis vector of the reciprocal space can be
# scaled until the cutoff is reached
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()

# Generate k-vectors and evaluate
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)
knorm_sq = torch.sum(kvectors**2, dim=1)
# Generate k-vectors and evaluate
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)

knorm_sq = torch.sum(kvectors**2, dim=-1)

# G(k) is the Fourier transform of the Coulomb potential
# generated by a Gaussian charge density
Expand Down Expand Up @@ -140,4 +144,6 @@ def _compute_kspace(
energy -= 2 * prefac * charge_tot * ivolume
# Compensate for double counting of pairs (i,j) and (j,i)
energy += self.potential.pbc_correction(periodic, positions, cell, charges)
if node_mask is not None:
energy = energy * node_mask.unsqueeze(-1)
return energy / 2
4 changes: 4 additions & 0 deletions src/torchpme/calculators/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ def _compute_kspace(
cell: torch.Tensor,
positions: torch.Tensor,
periodic: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: Kernel function `G` and initialization of `MeshInterpolator` only depend
# on `cell`. Caching may save up to 15% but issues with AD need to be resolved.

# Compute number of times each basis vector of the reciprocal space can be
# scaled until the cutoff is reached
if node_mask is not None or kvectors is not None:
raise ValueError("Batching not implemented for mesh-based calculators")
ns = get_ns_mesh(cell, self.mesh_spacing)

self.mesh_interpolator.update(cell, ns)
Expand Down
2 changes: 2 additions & 0 deletions src/torchpme/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .kspace_filter import KSpaceFilter, KSpaceKernel, P3MKSpaceFilter
from .kvectors import (
compute_batched_kvectors,
generate_kvectors_for_ewald,
generate_kvectors_for_mesh,
get_ns_mesh,
Expand All @@ -26,6 +27,7 @@
"distances",
"generate_kvectors_for_ewald",
"generate_kvectors_for_mesh",
"compute_batched_kvectors",
"get_ns_mesh",
"gamma",
"gammaincc_over_powerlaw",
Expand Down
31 changes: 31 additions & 0 deletions src/torchpme/lib/kvectors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch.nn.utils.rnn import pad_sequence


def get_ns_mesh(cell: torch.Tensor, mesh_spacing: float):
Expand Down Expand Up @@ -133,3 +134,33 @@ def generate_kvectors_for_ewald(
calculators like PME.
"""
return _generate_kvectors(cell=cell, ns=ns, for_ewald=True).reshape(-1, 3)


def compute_batched_kvectors(
lr_wavelength: float,
cells: torch.Tensor,
) -> torch.Tensor:
r"""
Generate k-vectors for multiple systems in batches.

:param lr_wavelength: Spatial resolution used for the long-range (reciprocal space)
part of the Ewald sum. More concretely, all Fourier space vectors with a
wavelength >= this value will be kept. If not set to a global value, it will be
set to half the smearing parameter to ensure convergence of the
long-range part to a relative precision of 1e-5.
:param cell: torch.tensor of shape ``(B, 3, 3)``, where ``cell[i]`` is the i-th
basis vector of the unit cell for system i in the batch of size B.

"""
all_kvectors = []
k_cutoff = 2 * torch.pi / lr_wavelength
for cell in cells:
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)
all_kvectors.append(kvectors)
# We do not return masks here; instead, we rely on the fact that for the Coulomb
# potential, the k = 0 vector is ignored in the calculations and can therefore be
# safely padded with zeros.
return pad_sequence(all_kvectors, batch_first=True)
18 changes: 12 additions & 6 deletions src/torchpme/potentials/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,24 @@ def __init__(
else:
self.register_buffer("weights", initial_weights)

def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
potentials = [pot.from_dist(dist) for pot in self.potentials]
def from_dist(
self, dist: torch.Tensor, pair_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
potentials = [pot.from_dist(dist, pair_mask) for pot in self.potentials]
potentials = torch.stack(potentials, dim=-1)
return torch.inner(self.weights, potentials)

def sr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
potentials = [pot.sr_from_dist(dist) for pot in self.potentials]
def sr_from_dist(
self, dist: torch.Tensor, pair_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
potentials = [pot.sr_from_dist(dist, pair_mask) for pot in self.potentials]
potentials = torch.stack(potentials, dim=-1)
return torch.inner(self.weights, potentials)

def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
potentials = [pot.lr_from_dist(dist) for pot in self.potentials]
def lr_from_dist(
self, dist: torch.Tensor, pair_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
potentials = [pot.lr_from_dist(dist, pair_mask) for pot in self.potentials]
potentials = torch.stack(potentials, dim=-1)
return torch.inner(self.weights, potentials)

Expand Down
Loading