diff --git a/spex/radial/simple/__init__.py b/spex/radial/simple/__init__.py index 8f17ccf..094fd34 100644 --- a/spex/radial/simple/__init__.py +++ b/spex/radial/simple/__init__.py @@ -1,4 +1,5 @@ from .bernstein import Bernstein +from .gaussian import Gaussian from .simple import Simple -__all__ = [Bernstein, Simple] +__all__ = [Gaussian, Bernstein, Simple] diff --git a/spex/radial/simple/gaussian.py b/spex/radial/simple/gaussian.py new file mode 100644 index 0000000..e8c07c4 --- /dev/null +++ b/spex/radial/simple/gaussian.py @@ -0,0 +1,71 @@ +import numpy as np +import torch + +import scipy + +from .simple import Simple + + +class Gaussian(Simple): + """ + Gaussian radial basis functions. + + The Gaussian basis functions are defined as ``exp(-gamma * (r - k * l / (K - 1)) ** 2)``, + where ``gamma = sqrt(2 * K) * (K - 1)``, l is the cutoff distance and ``k`` is the degree + of the basis. + + The basis is optionally be transformed with a learned linear layer (``trainable=True``), + optionally with a separate transformation per degree (``per_degree=True``). + The target number of features is specified by ``num_features``. + + Attributes: + cutoff (Tensor): Cutoff distance. + max_angular (int): Maximum spherical harmonic order. + n_per_l (list): Number of features per degree. + + """ + + # implementation heavily inspired by e3x. thanks! + + def __init__(self, *args, **kwargs): + """Initialise the Gaussian basis. + + Args: + cutoff (float): Cutoff distance. + num_radial (int): Number of radial basis functions. + max_angular (int): Maximum spherical harmonic order. + trainable (bool, optional): Whether a learned linear transformation is + applied. + per_degree (bool, optional): Whether to have a separate learned transform + per degree. + num_features (int, optional): Target number of features for learned + transformation. Defaults to ``num_radial``. + + """ + super().__init__(*args, **kwargs) + + K = torch.tensor(self.num_radial) + gamma = torch.sqrt(2 * K) * (K - 1) + + self.register_buffer("K", K) + self.register_buffer("gamma", gamma) + + def expand(self, r): + """Compute the Bernstein polynomial basis. + + Args: + r (Tensor): Input distances of shape ``[pair]``. + + Returns: + Expansion of shape ``[pair, num_radial]``. + """ + + r = r.unsqueeze(-1) + gamma = self.gamma + K = self.K + k = torch.arange(K, dtype=torch.float32) + y = torch.exp(-gamma/self.cutoff * (r - (k * self.cutoff) / (K - 1)) ** 2) + + return y + + diff --git a/tests/test_gaussian.py b/tests/test_gaussian.py new file mode 100644 index 0000000..27be98b --- /dev/null +++ b/tests/test_gaussian.py @@ -0,0 +1,54 @@ +import numpy as np +import torch + +from unittest import TestCase + + +class TestGaussian(TestCase): + """Basic test suite for the Gaussian class.""" + + def setUp(self): + self.num_radial = 128 + self.cutoff = 5.0 + self.max_angular = 3 + self.num_features = None + self.trainable = False + self.per_degree = False + + self.r = np.random.random(25) + + def test_jit(self): + """Test if Gaussian class works with TorchScript.""" + from spex.radial.simple import Gaussian + + radial = Gaussian( + cutoff=self.cutoff, + num_radial=self.num_radial, + max_angular=self.max_angular, + num_features=self.num_features, + trainable=self.trainable, + per_degree=self.per_degree, + ) + radial = torch.jit.script(radial) + radial(torch.tensor(self.r, dtype=torch.float32)) + + def test_hardcoded(self): + """Test Gaussian class with hardcoded parameters.""" + from spex.radial.simple import Gaussian + + radial = Gaussian( + cutoff=self.cutoff, + num_radial=self.num_radial, + max_angular=self.max_angular, + num_features=self.num_features, + ) + + assert radial.cutoff == self.cutoff + assert radial.num_radial == self.num_radial + assert radial.max_angular == self.max_angular + + # Validate function output against reference implementation + torch_r = torch.tensor(self.r, dtype=torch.float32) + torch_output = radial.expand(torch_r).detach().numpy() + + np.testing.assert_allclose(torch_output.shape, (self.r.shape[0], self.num_radial)) \ No newline at end of file