Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
sphinx
sphinx_book_theme
sphinx_book_theme
numpydoc
8 changes: 7 additions & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = ["sphinx.ext.todo", "sphinx.ext.viewcode", "sphinx.ext.autodoc"]
extensions = [
"sphinx.ext.todo",
"sphinx.ext.viewcode",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"numpydoc",
]

templates_path = ["_templates"]
exclude_patterns = []
Expand Down
119 changes: 85 additions & 34 deletions nlgm/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,26 @@


class Encoder(nn.Module):
def __init__(self, hidden_dim=20, latent_dim=2):
"""
Encoder class for the geometric autoencoder.
"""
Encoder class for the geometric autoencoder.

Args:
hidden_dim (int): Number of hidden dimensions.
latent_dim (int): Number of latent dimensions.
"""
Parameters
----------
hidden_dim : int
Number of hidden dimensions.
latent_dim : int
Number of latent dimensions.

Methods
-------
forward

Attributes
----------
encoder
"""

def __init__(self, hidden_dim=20, latent_dim=2):
super(Encoder, self).__init__()

self.encoder = nn.Sequential(
Expand Down Expand Up @@ -45,25 +57,41 @@ def forward(self, x):
"""
Forward pass of the encoder.

Args:
x (torch.Tensor): Input tensor.
Parameters
----------
x : torch.Tensor
Input tensor.

Returns:
torch.Tensor: Encoded output tensor.
Returns
-------
tensor : torch.Tensor
Encoded output tensor.
"""
z = self.encoder(x)
return z


class Decoder(nn.Module):
def __init__(self, hidden_dim=20, latent_dim=2):
"""
Decoder class for the geometric autoencoder.
"""
Decoder class for the geometric autoencoder.

Args:
hidden_dim (int): Number of hidden dimensions.
latent_dim (int): Number of latent dimensions.
"""
Parameters
----------
hidden_dim : int
Number of hidden dimensions.
latent_dim : int
Number of latent dimensions.

Methods
-------
forward

Attributes
----------
decoder
"""

def __init__(self, hidden_dim=20, latent_dim=2):
super(Decoder, self).__init__()

self.decoder = nn.Sequential(
Expand All @@ -88,26 +116,45 @@ def forward(self, z):
"""
Forward pass of the decoder.

Args:
z (torch.Tensor): Encoded input tensor.
Parameters
----------
z : torch.Tensor
Encoded input tensor.

Returns:
torch.Tensor: Decoded output tensor.
Returns
-------
tensor : torch.Tensor
Decoded output tensor.
"""
x_recon = self.decoder(z)
return x_recon


class GeometricAutoencoder(nn.Module):
def __init__(self, signature, hidden_dim=20, latent_dim=2):
"""
Geometric Autoencoder class.
"""
Geometric Autoencoder class.

Parameters
----------
signature : list
List of signature dimensions.
hidden_dim : int
Number of hidden dimensions.
latent_dim : int
Number of latent dimensions.

Methods
-------
forward

Attributes
----------
geometry
encoder
decoder
"""

Args:
signature (list): List of signature dimensions.
hidden_dim (int): Number of hidden dimensions.
latent_dim (int): Number of latent dimensions.
"""
def __init__(self, signature, hidden_dim=20, latent_dim=2):
super(GeometricAutoencoder, self).__init__()
self.geometry = ProductManifold(signature)
self.encoder = Encoder(hidden_dim, latent_dim)
Expand All @@ -117,11 +164,15 @@ def forward(self, x):
"""
Forward pass of the geometric autoencoder.

Args:
x (torch.Tensor): Input tensor.
Parameters
----------
x : torch.Tensor
Input tensor.

Returns:
torch.Tensor: Decoded output tensor.
Returns
-------
tensor : torch.Tensor
Decoded output tensor.
"""
z = self.encoder(x)
z = self.geometry.exponential_map(z)
Expand Down