Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
665825a
initial commit
phoeenniixx Nov 9, 2024
5e57d34
linting
phoeenniixx Nov 9, 2024
e498848
adding some tests and a little in debug in `sLSTM` structure
phoeenniixx Nov 9, 2024
38e4c9c
new baseclass implementation
phoeenniixx Dec 12, 2024
a72c8c6
Update __init__.py
phoeenniixx Dec 13, 2024
b3b3e55
little debug in `predict` method
phoeenniixx Dec 23, 2024
87f4ff4
trying the baseclass predict function and removing the test files
phoeenniixx Dec 24, 2024
a6b2da9
refactor `__init__.py`
phoeenniixx Jan 6, 2025
39e2b6f
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 6, 2025
f67509a
linting
phoeenniixx Jan 6, 2025
46a9e74
Update layer.py
phoeenniixx Jan 6, 2025
7e7d915
docs
phoeenniixx Jan 6, 2025
31cd4de
linting
phoeenniixx Jan 6, 2025
c72bff9
Update __init__.py
phoeenniixx Jan 6, 2025
62e97ae
Update __init__.py
phoeenniixx Jan 6, 2025
93f0913
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 13, 2025
66900bc
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 21, 2025
acb23e7
Adding tests
phoeenniixx Jan 21, 2025
0b85284
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Jan 24, 2025
b01754e
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Feb 10, 2025
5e666b4
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx Mar 7, 2025
0a149a7
Merge branch 'main' into pr/1709
fkiraly Jun 6, 2025
9b21892
Merge branch 'main' into xLSTMTime
phoeenniixx Jul 29, 2025
2eda66f
refactor code
phoeenniixx Jul 30, 2025
942e717
add pkg class
phoeenniixx Jul 30, 2025
5556d71
linting
phoeenniixx Jul 30, 2025
2dca593
add docstrings and debug
phoeenniixx Jul 31, 2025
8adcb31
Merge branch 'main' into xLSTMTime
phoeenniixx Jul 31, 2025
1bc559c
add GH credits
phoeenniixx Jul 31, 2025
fd4b2ba
Merge branch 'main' into xLSTMTime
phoeenniixx Jul 31, 2025
6a7cc23
update documentation
phoeenniixx Jul 31, 2025
60d1651
Merge remote-tracking branch 'origin/xLSTMTime' into xLSTMTime
phoeenniixx Jul 31, 2025
96ec23d
add TriangularCausalMask
phoeenniixx Jul 31, 2025
6a40b7a
refactor files
phoeenniixx Aug 5, 2025
40beee8
Merge branch 'main' into xLSTMTime
phoeenniixx Aug 5, 2025
1cfaf9c
refactor files
phoeenniixx Aug 6, 2025
ed189de
Merge remote-tracking branch 'origin/xLSTMTime' into xLSTMTime
phoeenniixx Aug 6, 2025
7addfad
update models.rst
phoeenniixx Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ and you should take into account. Here is an overview over the pros and cons of
:py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4
:py:class:`~pytorch_forecasting.model.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3
:py:class:`~pytorch_forecasting.models.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3
:py:class:`~pytorch_forecasting.models.xlstm.xLSTMTime`, "x", "x", "x", "", "", "", "", "x", "", 3

.. [#deepvar] Accounting for correlations using a multivariate loss function which converts the network into a DeepVAR model.

Expand Down
26 changes: 24 additions & 2 deletions pytorch_forecasting/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
Architectural deep learning layers from `nn.Module`.
"""

from pytorch_forecasting.layers._attention import AttentionLayer, FullAttention
from pytorch_forecasting.layers._attention import (
AttentionLayer,
FullAttention,
TriangularCausalMask,
)
from pytorch_forecasting.layers._decomposition import SeriesDecomposition
from pytorch_forecasting.layers._embeddings import (
DataEmbedding_inverted,
EnEmbedding,
Expand All @@ -15,15 +20,32 @@
from pytorch_forecasting.layers._output._flatten_head import (
FlattenHead,
)
from pytorch_forecasting.layers._recurrent._mlstm import (
mLSTMCell,
mLSTMLayer,
mLSTMNetwork,
)
from pytorch_forecasting.layers._recurrent._slstm import (
sLSTMCell,
sLSTMLayer,
sLSTMNetwork,
)

__all__ = [
"FullAttention",
"TriangularCausalMask",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this line get removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any imports for TriangularCausalMask

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually found it in layers._attention._full_attention, it is not imported even in the __init__ of layers._attention. I will add it to both the locations. At first, I thought it didnt exist 😅

"AttentionLayer",
"TriangularCausalMask",
"DataEmbedding_inverted",
"EnEmbedding",
"PositionalEmbedding",
"Encoder",
"EncoderLayer",
"FlattenHead",
"mLSTMCell",
"mLSTMLayer",
"mLSTMNetwork",
"sLSTMCell",
"sLSTMLayer",
"sLSTMNetwork",
"SeriesDecomposition",
]
7 changes: 5 additions & 2 deletions pytorch_forecasting/layers/_attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""

from pytorch_forecasting.layers._attention._attention_layer import AttentionLayer
from pytorch_forecasting.layers._attention._full_attention import FullAttention
from pytorch_forecasting.layers._attention._full_attention import (
FullAttention,
TriangularCausalMask,
)

__all__ = ["AttentionLayer", "FullAttention"]
__all__ = ["AttentionLayer", "FullAttention", "TriangularCausalMask"]
21 changes: 21 additions & 0 deletions pytorch_forecasting/layers/_recurrent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Recurrent Layers for Pytorch-Forecasting"""

from pytorch_forecasting.layers._recurrent._mlstm import (
mLSTMCell,
mLSTMLayer,
mLSTMNetwork,
)
from pytorch_forecasting.layers._recurrent._slstm import (
sLSTMCell,
sLSTMLayer,
sLSTMNetwork,
)

__all__ = [
"mLSTMCell",
"mLSTMLayer",
"mLSTMNetwork",
"sLSTMCell",
"sLSTMLayer",
"sLSTMNetwork",
]
7 changes: 7 additions & 0 deletions pytorch_forecasting/layers/_recurrent/_mlstm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""mLSTM layer"""

from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell
from pytorch_forecasting.layers._recurrent._mlstm.layer import mLSTMLayer
from pytorch_forecasting.layers._recurrent._mlstm.network import mLSTMNetwork

__all__ = ["mLSTMCell", "mLSTMLayer", "mLSTMNetwork"]
156 changes: 156 additions & 0 deletions pytorch_forecasting/layers/_recurrent/_mlstm/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import math

import torch
import torch.nn as nn


class mLSTMCell(nn.Module):
"""Implements the Matrix Long Short-Term Memory (mLSTM) Cell.

Implements the mLSTM algorithm as described in the paper:
(https://arxiv.org/pdf/2407.10240).

Parameters
----------
input_size : int
Size of the input feature vector.
hidden_size : int
Number of hidden units in the LSTM cell.
dropout : float, optional
Dropout rate applied to inputs and hidden states, by default 0.2.
layer_norm : bool, optional
If True, apply Layer Normalization to gates and interactions, by default True.

Attributes
----------
Wq : nn.Linear
Linear layer for computing the query vector.
Wk : nn.Linear
Linear layer for computing the key vector.
Wv : nn.Linear
Linear layer for computing the value vector.
Wi : nn.Linear
Linear layer for the input gate.
Wf : nn.Linear
Linear layer for the forget gate.
Wo : nn.Linear
Linear layer for the output gate.
dropout : nn.Dropout
Dropout regularization layer.
ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm
Optional layer normalization layers for respective computations.
"""

def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.layer_norm = layer_norm

self.Wq = nn.Linear(input_size, hidden_size)
self.Wk = nn.Linear(input_size, hidden_size)
self.Wv = nn.Linear(input_size, hidden_size)

self.Wi = nn.Linear(input_size, hidden_size)
self.Wf = nn.Linear(input_size, hidden_size)
self.Wo = nn.Linear(input_size, hidden_size)

self.dropout = nn.Dropout(dropout)

if layer_norm:
self.ln_q = nn.LayerNorm(hidden_size)
self.ln_k = nn.LayerNorm(hidden_size)
self.ln_v = nn.LayerNorm(hidden_size)
self.ln_i = nn.LayerNorm(hidden_size)
self.ln_f = nn.LayerNorm(hidden_size)
self.ln_o = nn.LayerNorm(hidden_size)

self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()

def forward(self, x, h_prev, c_prev, n_prev):
"""Compute the next hidden, cell, and normalized states in the mLSTM cell.

Parameters
----------
x : torch.Tensor
The number of features in the input.
h_prev : torch.Tensor
Previous hidden state
c_prev : torch.Tensor
Previous cell state
n_prev : torch.Tensor
Previous normalized state

Returns
-------
tuple of torch.Tensor:
h : torch.Tensor
Current hidden state
c : torch.Tensor
Current cell state
n : torch.Tensor
Current normalized state
"""

batch_size = x.size(0)
assert (
x.dim() == 2
), f"Input should be 2D (batch_size, input_size), got {x.dim()}D"
assert h_prev.size() == (
batch_size,
self.hidden_size,
), f"h_prev shape mismatch: {h_prev.size()}"
assert c_prev.size() == (
batch_size,
self.hidden_size,
), f"c_prev shape mismatch: {c_prev.size()}"
assert n_prev.size() == (
batch_size,
self.hidden_size,
), f"n_prev shape mismatch: {n_prev.size()}"

x = self.dropout(x)
h_prev = self.dropout(h_prev)

q = self.Wq(x)
k = self.Wk(x) / math.sqrt(self.hidden_size)
v = self.Wv(x)

if self.layer_norm:
q = self.ln_q(q)
k = self.ln_k(k)
v = self.ln_v(v)

i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x))
f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x))
o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x))

k_expanded = k.unsqueeze(-1)
v_expanded = v.unsqueeze(-2)

kv_interaction = k_expanded @ v_expanded

kv_sum = kv_interaction.sum(dim=1)

c = f * c_prev + i * kv_sum
n = f * n_prev + i * k

epsilon = 1e-8
normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon)
h = o * self.tanh(c * normalized_n)

return h, c, n

def init_hidden(self, batch_size, device=None):
"""
Initialize hidden, cell, and normalization states.
"""
if device is None:
device = next(self.parameters()).device
shape = (batch_size, self.hidden_size)
return (
torch.zeros(shape, device=device),
torch.zeros(shape, device=device),
torch.zeros(shape, device=device),
)
Loading
Loading