-
Notifications
You must be signed in to change notification settings - Fork 718
[ENH] xLSTMTime
implementation
#1709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
665825a
initial commit
phoeenniixx 5e57d34
linting
phoeenniixx e498848
adding some tests and a little in debug in `sLSTM` structure
phoeenniixx 38e4c9c
new baseclass implementation
phoeenniixx a72c8c6
Update __init__.py
phoeenniixx b3b3e55
little debug in `predict` method
phoeenniixx 87f4ff4
trying the baseclass predict function and removing the test files
phoeenniixx a6b2da9
refactor `__init__.py`
phoeenniixx 39e2b6f
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx f67509a
linting
phoeenniixx 46a9e74
Update layer.py
phoeenniixx 7e7d915
docs
phoeenniixx 31cd4de
linting
phoeenniixx c72bff9
Update __init__.py
phoeenniixx 62e97ae
Update __init__.py
phoeenniixx 93f0913
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx 66900bc
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx acb23e7
Adding tests
phoeenniixx 0b85284
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx b01754e
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx 5e666b4
Merge branch 'sktime:main' into xLSTMTime
phoeenniixx 0a149a7
Merge branch 'main' into pr/1709
fkiraly 9b21892
Merge branch 'main' into xLSTMTime
phoeenniixx 2eda66f
refactor code
phoeenniixx 942e717
add pkg class
phoeenniixx 5556d71
linting
phoeenniixx 2dca593
add docstrings and debug
phoeenniixx 8adcb31
Merge branch 'main' into xLSTMTime
phoeenniixx 1bc559c
add GH credits
phoeenniixx fd4b2ba
Merge branch 'main' into xLSTMTime
phoeenniixx 6a7cc23
update documentation
phoeenniixx 60d1651
Merge remote-tracking branch 'origin/xLSTMTime' into xLSTMTime
phoeenniixx 96ec23d
add TriangularCausalMask
phoeenniixx 6a40b7a
refactor files
phoeenniixx 40beee8
Merge branch 'main' into xLSTMTime
phoeenniixx 1cfaf9c
refactor files
phoeenniixx ed189de
Merge remote-tracking branch 'origin/xLSTMTime' into xLSTMTime
phoeenniixx 7addfad
update models.rst
phoeenniixx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Recurrent Layers for Pytorch-Forecasting""" | ||
|
||
from pytorch_forecasting.layers._recurrent._mlstm import ( | ||
mLSTMCell, | ||
mLSTMLayer, | ||
mLSTMNetwork, | ||
) | ||
from pytorch_forecasting.layers._recurrent._slstm import ( | ||
sLSTMCell, | ||
sLSTMLayer, | ||
sLSTMNetwork, | ||
) | ||
|
||
__all__ = [ | ||
"mLSTMCell", | ||
"mLSTMLayer", | ||
"mLSTMNetwork", | ||
"sLSTMCell", | ||
"sLSTMLayer", | ||
"sLSTMNetwork", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""mLSTM layer""" | ||
|
||
from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell | ||
from pytorch_forecasting.layers._recurrent._mlstm.layer import mLSTMLayer | ||
from pytorch_forecasting.layers._recurrent._mlstm.network import mLSTMNetwork | ||
|
||
__all__ = ["mLSTMCell", "mLSTMLayer", "mLSTMNetwork"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class mLSTMCell(nn.Module): | ||
"""Implements the Matrix Long Short-Term Memory (mLSTM) Cell. | ||
|
||
Implements the mLSTM algorithm as described in the paper: | ||
(https://arxiv.org/pdf/2407.10240). | ||
|
||
Parameters | ||
---------- | ||
input_size : int | ||
Size of the input feature vector. | ||
hidden_size : int | ||
Number of hidden units in the LSTM cell. | ||
dropout : float, optional | ||
Dropout rate applied to inputs and hidden states, by default 0.2. | ||
layer_norm : bool, optional | ||
If True, apply Layer Normalization to gates and interactions, by default True. | ||
|
||
Attributes | ||
---------- | ||
Wq : nn.Linear | ||
Linear layer for computing the query vector. | ||
Wk : nn.Linear | ||
Linear layer for computing the key vector. | ||
Wv : nn.Linear | ||
Linear layer for computing the value vector. | ||
Wi : nn.Linear | ||
Linear layer for the input gate. | ||
Wf : nn.Linear | ||
Linear layer for the forget gate. | ||
Wo : nn.Linear | ||
Linear layer for the output gate. | ||
dropout : nn.Dropout | ||
Dropout regularization layer. | ||
ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm | ||
Optional layer normalization layers for respective computations. | ||
""" | ||
|
||
def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True): | ||
super().__init__() | ||
self.input_size = input_size | ||
self.hidden_size = hidden_size | ||
self.layer_norm = layer_norm | ||
|
||
self.Wq = nn.Linear(input_size, hidden_size) | ||
self.Wk = nn.Linear(input_size, hidden_size) | ||
self.Wv = nn.Linear(input_size, hidden_size) | ||
|
||
self.Wi = nn.Linear(input_size, hidden_size) | ||
self.Wf = nn.Linear(input_size, hidden_size) | ||
self.Wo = nn.Linear(input_size, hidden_size) | ||
|
||
self.dropout = nn.Dropout(dropout) | ||
|
||
if layer_norm: | ||
self.ln_q = nn.LayerNorm(hidden_size) | ||
self.ln_k = nn.LayerNorm(hidden_size) | ||
self.ln_v = nn.LayerNorm(hidden_size) | ||
self.ln_i = nn.LayerNorm(hidden_size) | ||
self.ln_f = nn.LayerNorm(hidden_size) | ||
self.ln_o = nn.LayerNorm(hidden_size) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
self.tanh = nn.Tanh() | ||
|
||
def forward(self, x, h_prev, c_prev, n_prev): | ||
"""Compute the next hidden, cell, and normalized states in the mLSTM cell. | ||
|
||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
The number of features in the input. | ||
h_prev : torch.Tensor | ||
Previous hidden state | ||
c_prev : torch.Tensor | ||
Previous cell state | ||
n_prev : torch.Tensor | ||
Previous normalized state | ||
|
||
Returns | ||
------- | ||
tuple of torch.Tensor: | ||
h : torch.Tensor | ||
Current hidden state | ||
c : torch.Tensor | ||
Current cell state | ||
n : torch.Tensor | ||
Current normalized state | ||
""" | ||
|
||
batch_size = x.size(0) | ||
assert ( | ||
x.dim() == 2 | ||
), f"Input should be 2D (batch_size, input_size), got {x.dim()}D" | ||
assert h_prev.size() == ( | ||
batch_size, | ||
self.hidden_size, | ||
), f"h_prev shape mismatch: {h_prev.size()}" | ||
assert c_prev.size() == ( | ||
batch_size, | ||
self.hidden_size, | ||
), f"c_prev shape mismatch: {c_prev.size()}" | ||
assert n_prev.size() == ( | ||
batch_size, | ||
self.hidden_size, | ||
), f"n_prev shape mismatch: {n_prev.size()}" | ||
|
||
x = self.dropout(x) | ||
h_prev = self.dropout(h_prev) | ||
|
||
q = self.Wq(x) | ||
k = self.Wk(x) / math.sqrt(self.hidden_size) | ||
v = self.Wv(x) | ||
|
||
if self.layer_norm: | ||
q = self.ln_q(q) | ||
k = self.ln_k(k) | ||
v = self.ln_v(v) | ||
|
||
i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x)) | ||
f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x)) | ||
o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x)) | ||
|
||
k_expanded = k.unsqueeze(-1) | ||
v_expanded = v.unsqueeze(-2) | ||
|
||
kv_interaction = k_expanded @ v_expanded | ||
|
||
kv_sum = kv_interaction.sum(dim=1) | ||
|
||
c = f * c_prev + i * kv_sum | ||
n = f * n_prev + i * k | ||
|
||
epsilon = 1e-8 | ||
normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon) | ||
h = o * self.tanh(c * normalized_n) | ||
|
||
return h, c, n | ||
|
||
def init_hidden(self, batch_size, device=None): | ||
""" | ||
Initialize hidden, cell, and normalization states. | ||
""" | ||
if device is None: | ||
device = next(self.parameters()).device | ||
shape = (batch_size, self.hidden_size) | ||
return ( | ||
torch.zeros(shape, device=device), | ||
torch.zeros(shape, device=device), | ||
torch.zeros(shape, device=device), | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this line get removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see any imports for
TriangularCausalMask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually found it in
layers._attention._full_attention
, it is not imported even in the__init__
oflayers._attention
. I will add it to both the locations. At first, I thought it didnt exist 😅