Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion spex/radial/simple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bernstein import Bernstein
from .gaussian import Gaussian
from .simple import Simple

__all__ = [Bernstein, Simple]
__all__ = [Gaussian, Bernstein, Simple]
71 changes: 71 additions & 0 deletions spex/radial/simple/gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import torch

import scipy

from .simple import Simple


class Gaussian(Simple):
"""
Gaussian radial basis functions.

The Gaussian basis functions are defined as ``exp(-gamma * (r - k * l / (K - 1)) ** 2)``,
where ``gamma = sqrt(2 * K) * (K - 1)``, l is the cutoff distance and ``k`` is the degree
of the basis.

The basis is optionally be transformed with a learned linear layer (``trainable=True``),
optionally with a separate transformation per degree (``per_degree=True``).
The target number of features is specified by ``num_features``.

Attributes:
cutoff (Tensor): Cutoff distance.
max_angular (int): Maximum spherical harmonic order.
n_per_l (list): Number of features per degree.

"""

# implementation heavily inspired by e3x. thanks!

def __init__(self, *args, **kwargs):
"""Initialise the Gaussian basis.

Args:
cutoff (float): Cutoff distance.
num_radial (int): Number of radial basis functions.
max_angular (int): Maximum spherical harmonic order.
trainable (bool, optional): Whether a learned linear transformation is
applied.
per_degree (bool, optional): Whether to have a separate learned transform
per degree.
num_features (int, optional): Target number of features for learned
transformation. Defaults to ``num_radial``.

"""
super().__init__(*args, **kwargs)

K = torch.tensor(self.num_radial)
gamma = torch.sqrt(2 * K) * (K - 1)

self.register_buffer("K", K)
self.register_buffer("gamma", gamma)

def expand(self, r):
"""Compute the Bernstein polynomial basis.

Args:
r (Tensor): Input distances of shape ``[pair]``.

Returns:
Expansion of shape ``[pair, num_radial]``.
"""

r = r.unsqueeze(-1)
gamma = self.gamma
K = self.K
k = torch.arange(K, dtype=torch.float32)
y = torch.exp(-gamma/self.cutoff * (r - (k * self.cutoff) / (K - 1)) ** 2)

return y


54 changes: 54 additions & 0 deletions tests/test_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import torch

from unittest import TestCase


class TestGaussian(TestCase):
"""Basic test suite for the Gaussian class."""

def setUp(self):
self.num_radial = 128
self.cutoff = 5.0
self.max_angular = 3
self.num_features = None
self.trainable = False
self.per_degree = False

self.r = np.random.random(25)

def test_jit(self):
"""Test if Gaussian class works with TorchScript."""
from spex.radial.simple import Gaussian

radial = Gaussian(
cutoff=self.cutoff,
num_radial=self.num_radial,
max_angular=self.max_angular,
num_features=self.num_features,
trainable=self.trainable,
per_degree=self.per_degree,
)
radial = torch.jit.script(radial)
radial(torch.tensor(self.r, dtype=torch.float32))

def test_hardcoded(self):
"""Test Gaussian class with hardcoded parameters."""
from spex.radial.simple import Gaussian

radial = Gaussian(
cutoff=self.cutoff,
num_radial=self.num_radial,
max_angular=self.max_angular,
num_features=self.num_features,
)

assert radial.cutoff == self.cutoff
assert radial.num_radial == self.num_radial
assert radial.max_angular == self.max_angular

# Validate function output against reference implementation
torch_r = torch.tensor(self.r, dtype=torch.float32)
torch_output = radial.expand(torch_r).detach().numpy()

np.testing.assert_allclose(torch_output.shape, (self.r.shape[0], self.num_radial))