diff --git a/docs/source/models.rst b/docs/source/models.rst index 53b39cae5..cd9048b0a 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -30,7 +30,8 @@ and you should take into account. Here is an overview over the pros and cons of :py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3 :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4 - :py:class:`~pytorch_forecasting.model.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3 + :py:class:`~pytorch_forecasting.models.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3 + :py:class:`~pytorch_forecasting.models.xlstm.xLSTMTime`, "x", "x", "x", "", "", "", "", "x", "", 3 .. [#deepvar] Accounting for correlations using a multivariate loss function which converts the network into a DeepVAR model. diff --git a/pytorch_forecasting/layers/__init__.py b/pytorch_forecasting/layers/__init__.py index 43a8db84c..ab9f2dad6 100644 --- a/pytorch_forecasting/layers/__init__.py +++ b/pytorch_forecasting/layers/__init__.py @@ -2,7 +2,12 @@ Architectural deep learning layers from `nn.Module`. """ -from pytorch_forecasting.layers._attention import AttentionLayer, FullAttention +from pytorch_forecasting.layers._attention import ( + AttentionLayer, + FullAttention, + TriangularCausalMask, +) +from pytorch_forecasting.layers._decomposition import SeriesDecomposition from pytorch_forecasting.layers._embeddings import ( DataEmbedding_inverted, EnEmbedding, @@ -15,15 +20,32 @@ from pytorch_forecasting.layers._output._flatten_head import ( FlattenHead, ) +from pytorch_forecasting.layers._recurrent._mlstm import ( + mLSTMCell, + mLSTMLayer, + mLSTMNetwork, +) +from pytorch_forecasting.layers._recurrent._slstm import ( + sLSTMCell, + sLSTMLayer, + sLSTMNetwork, +) __all__ = [ "FullAttention", - "TriangularCausalMask", "AttentionLayer", + "TriangularCausalMask", "DataEmbedding_inverted", "EnEmbedding", "PositionalEmbedding", "Encoder", "EncoderLayer", "FlattenHead", + "mLSTMCell", + "mLSTMLayer", + "mLSTMNetwork", + "sLSTMCell", + "sLSTMLayer", + "sLSTMNetwork", + "SeriesDecomposition", ] diff --git a/pytorch_forecasting/layers/_attention/__init__.py b/pytorch_forecasting/layers/_attention/__init__.py index cdfc6c3e2..f6827bae1 100644 --- a/pytorch_forecasting/layers/_attention/__init__.py +++ b/pytorch_forecasting/layers/_attention/__init__.py @@ -3,6 +3,9 @@ """ from pytorch_forecasting.layers._attention._attention_layer import AttentionLayer -from pytorch_forecasting.layers._attention._full_attention import FullAttention +from pytorch_forecasting.layers._attention._full_attention import ( + FullAttention, + TriangularCausalMask, +) -__all__ = ["AttentionLayer", "FullAttention"] +__all__ = ["AttentionLayer", "FullAttention", "TriangularCausalMask"] diff --git a/pytorch_forecasting/layers/_recurrent/__init__.py b/pytorch_forecasting/layers/_recurrent/__init__.py new file mode 100644 index 000000000..6bd58b976 --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/__init__.py @@ -0,0 +1,21 @@ +"""Recurrent Layers for Pytorch-Forecasting""" + +from pytorch_forecasting.layers._recurrent._mlstm import ( + mLSTMCell, + mLSTMLayer, + mLSTMNetwork, +) +from pytorch_forecasting.layers._recurrent._slstm import ( + sLSTMCell, + sLSTMLayer, + sLSTMNetwork, +) + +__all__ = [ + "mLSTMCell", + "mLSTMLayer", + "mLSTMNetwork", + "sLSTMCell", + "sLSTMLayer", + "sLSTMNetwork", +] diff --git a/pytorch_forecasting/layers/_recurrent/_mlstm/__init__.py b/pytorch_forecasting/layers/_recurrent/_mlstm/__init__.py new file mode 100644 index 000000000..812bd90d6 --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_mlstm/__init__.py @@ -0,0 +1,7 @@ +"""mLSTM layer""" + +from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell +from pytorch_forecasting.layers._recurrent._mlstm.layer import mLSTMLayer +from pytorch_forecasting.layers._recurrent._mlstm.network import mLSTMNetwork + +__all__ = ["mLSTMCell", "mLSTMLayer", "mLSTMNetwork"] diff --git a/pytorch_forecasting/layers/_recurrent/_mlstm/cell.py b/pytorch_forecasting/layers/_recurrent/_mlstm/cell.py new file mode 100644 index 000000000..ae41cf129 --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_mlstm/cell.py @@ -0,0 +1,156 @@ +import math + +import torch +import torch.nn as nn + + +class mLSTMCell(nn.Module): + """Implements the Matrix Long Short-Term Memory (mLSTM) Cell. + + Implements the mLSTM algorithm as described in the paper: + (https://arxiv.org/pdf/2407.10240). + + Parameters + ---------- + input_size : int + Size of the input feature vector. + hidden_size : int + Number of hidden units in the LSTM cell. + dropout : float, optional + Dropout rate applied to inputs and hidden states, by default 0.2. + layer_norm : bool, optional + If True, apply Layer Normalization to gates and interactions, by default True. + + Attributes + ---------- + Wq : nn.Linear + Linear layer for computing the query vector. + Wk : nn.Linear + Linear layer for computing the key vector. + Wv : nn.Linear + Linear layer for computing the value vector. + Wi : nn.Linear + Linear layer for the input gate. + Wf : nn.Linear + Linear layer for the forget gate. + Wo : nn.Linear + Linear layer for the output gate. + dropout : nn.Dropout + Dropout regularization layer. + ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm + Optional layer normalization layers for respective computations. + """ + + def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.layer_norm = layer_norm + + self.Wq = nn.Linear(input_size, hidden_size) + self.Wk = nn.Linear(input_size, hidden_size) + self.Wv = nn.Linear(input_size, hidden_size) + + self.Wi = nn.Linear(input_size, hidden_size) + self.Wf = nn.Linear(input_size, hidden_size) + self.Wo = nn.Linear(input_size, hidden_size) + + self.dropout = nn.Dropout(dropout) + + if layer_norm: + self.ln_q = nn.LayerNorm(hidden_size) + self.ln_k = nn.LayerNorm(hidden_size) + self.ln_v = nn.LayerNorm(hidden_size) + self.ln_i = nn.LayerNorm(hidden_size) + self.ln_f = nn.LayerNorm(hidden_size) + self.ln_o = nn.LayerNorm(hidden_size) + + self.sigmoid = nn.Sigmoid() + self.tanh = nn.Tanh() + + def forward(self, x, h_prev, c_prev, n_prev): + """Compute the next hidden, cell, and normalized states in the mLSTM cell. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h_prev : torch.Tensor + Previous hidden state + c_prev : torch.Tensor + Previous cell state + n_prev : torch.Tensor + Previous normalized state + + Returns + ------- + tuple of torch.Tensor: + h : torch.Tensor + Current hidden state + c : torch.Tensor + Current cell state + n : torch.Tensor + Current normalized state + """ + + batch_size = x.size(0) + assert ( + x.dim() == 2 + ), f"Input should be 2D (batch_size, input_size), got {x.dim()}D" + assert h_prev.size() == ( + batch_size, + self.hidden_size, + ), f"h_prev shape mismatch: {h_prev.size()}" + assert c_prev.size() == ( + batch_size, + self.hidden_size, + ), f"c_prev shape mismatch: {c_prev.size()}" + assert n_prev.size() == ( + batch_size, + self.hidden_size, + ), f"n_prev shape mismatch: {n_prev.size()}" + + x = self.dropout(x) + h_prev = self.dropout(h_prev) + + q = self.Wq(x) + k = self.Wk(x) / math.sqrt(self.hidden_size) + v = self.Wv(x) + + if self.layer_norm: + q = self.ln_q(q) + k = self.ln_k(k) + v = self.ln_v(v) + + i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x)) + f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x)) + o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x)) + + k_expanded = k.unsqueeze(-1) + v_expanded = v.unsqueeze(-2) + + kv_interaction = k_expanded @ v_expanded + + kv_sum = kv_interaction.sum(dim=1) + + c = f * c_prev + i * kv_sum + n = f * n_prev + i * k + + epsilon = 1e-8 + normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon) + h = o * self.tanh(c * normalized_n) + + return h, c, n + + def init_hidden(self, batch_size, device=None): + """ + Initialize hidden, cell, and normalization states. + """ + if device is None: + device = next(self.parameters()).device + shape = (batch_size, self.hidden_size) + return ( + torch.zeros(shape, device=device), + torch.zeros(shape, device=device), + torch.zeros(shape, device=device), + ) diff --git a/pytorch_forecasting/layers/_recurrent/_mlstm/layer.py b/pytorch_forecasting/layers/_recurrent/_mlstm/layer.py new file mode 100644 index 000000000..af47e1fb6 --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_mlstm/layer.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell + + +class mLSTMLayer(nn.Module): + """Implements a mLSTM (Matrix LSTM) layer. + + This class stacks multiple mLSTM cells to form a deep recurrent layer. + It supports residual connections, layer normalization, and dropout. + + Parameters + ---------- + input_size : int + The number of features in the input. + hidden_size : int + The number of features in the hidden state. + num_layers : int + The number of mLSTM layers to stack. + dropout : float, optional + Dropout probability applied to the inputs and intermediate layers, + by default 0.2. + layer_norm : bool, optional + Whether to use layer normalization in each mLSTM cell, by default True. + residual_conn : bool, optional + Whether to enable residual connections between layers, by default True. + + Attributes + ---------- + cells : nn.ModuleList + A list containing all mLSTM cells in the layer. + dropout : nn.Dropout + Dropout layer applied between layers. + + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + dropout=0.2, + layer_norm=True, + residual_conn=True, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.layer_norm = layer_norm + self.residual_conn = residual_conn + self.dropout = nn.Dropout(dropout) + + self.cells = nn.ModuleList( + [ + mLSTMCell( + input_size if i == 0 else hidden_size, + hidden_size, + dropout, + layer_norm, + ) + for i in range(num_layers) + ] + ) + + def forward(self, x, h=None, c=None, n=None): + """Forward pass through the mLSTM layer. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : torch.Tensor, optional + Initial hidden states for all layers + If None, initialized to zeros, by default None. + c : torch.Tensor, optional + Initial cell states for all layers + If None, initialized to zeros, by default None. + n : torch.Tensor, optional + Initial normalized states for all layers + If None, initialized to zeros, by default None. + + Returns + ------- + tuple + output : torch.Tensor + Final output tensor from the last layer + (h, c, n) : tuple of torch.Tensor + Final hidden, cell, and normalized states for all layers: + - h : torch.Tensor + - c : torch.Tensor + - n : torch.Tensor + """ + + x = x.transpose(0, 1) + batch_size, seq_len, _ = x.size() + + if h is None or c is None or n is None: + h, c, n = self.init_hidden(batch_size) + + outputs = [] + + for t in range(seq_len): + layer_input = x[:, t, :] + next_hidden_states = [] + next_cell_states = [] + next_norm_states = [] + + for i, cell in enumerate(self.cells): + h_i, c_i, n_i = cell(layer_input, h[i], c[i], n[i]) + + if self.residual_conn and i > 0: + h_i = h_i + layer_input + + layer_input = h_i + + next_hidden_states.append(h_i) + next_cell_states.append(c_i) + next_norm_states.append(n_i) + + h = torch.stack(next_hidden_states) + c = torch.stack(next_cell_states) + n = torch.stack(next_norm_states) + + outputs.append(h[-1]) + + output = torch.stack(outputs, dim=1) + + output = output.transpose(0, 1) + + return output, (h, c, n) + + def init_hidden(self, batch_size, device=None): + """ + Initialize hidden, cell, and normalization states for all layers. + """ + if device is None: + device = next(self.parameters()).device + hidden_states, cell_states, norm_states = zip( + *[ + self.cells[i].init_hidden(batch_size, device=device) + for i in range(self.num_layers) + ] + ) + + return ( + torch.stack(hidden_states), + torch.stack(cell_states), + torch.stack(norm_states), + ) diff --git a/pytorch_forecasting/layers/_recurrent/_mlstm/network.py b/pytorch_forecasting/layers/_recurrent/_mlstm/network.py new file mode 100644 index 000000000..a39262ed3 --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_mlstm/network.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.layers._recurrent._mlstm.layer import mLSTMLayer + + +class mLSTMNetwork(nn.Module): + """Implements the mLSTM Network, a complete model based on stacked mLSTM layers. + + This network combines stacked mLSTM layers and a fully connected output layer. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each mLSTM layer. + num_layers : int + Number of mLSTM layers to stack. + output_size : int + Number of features in the output. + dropout : float, optional + Dropout probability for the mLSTM layers, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization in the mLSTM layers, by default True. + use_residual : bool, optional + Whether to use residual connections in the mLSTM layers, by default True. + + Attributes + ---------- + mlstm_layer : mLSTMLayer + Stacked mLSTM layers used for processing input sequences. + fc : nn.Linear + Fully connected layer to generate final output. + + + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + output_size, + dropout=0.0, + use_layer_norm=True, + use_residual=True, + ): + super().__init__() + + self.mlstm_layer = mLSTMLayer( + input_size, + hidden_size, + num_layers, + dropout, + use_layer_norm, + use_residual, + ) + self.fc = nn.Linear(hidden_size, output_size) + + def forward(self, x, h=None, c=None, n=None): + """Forward pass through the mLSTM Network. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : torch.Tensor, optional + Initial hidden states for all layers. + If None, initialized to zeros, by default None. + c : torch.Tensor, optional + Initial cell states for all layers. + If None, initialized to zeros, by default None. + n : torch.Tensor, optional + Initial normalized states for all layers. + If None, initialized to zeros, by default None. + + Returns + ------- + tuple + output : torch.Tensor + Final output tensor from the fully connected layer. + (h, c, n) : tuple of torch.Tensor + Final hidden, cell, and normalized states for all layers: + - h : torch.Tensor + - c : torch.Tensor + - n : torch.Tensor + """ + output, (h, c, n) = self.mlstm_layer(x, h, c, n) + + output = self.fc(output[-1]) + + return output, (h, c, n) + + def init_hidden(self, batch_size, device=None): + """Initialize hidden, cell, and normalization states.""" + if device is None: + device = next(self.parameters()).device + return self.mlstm_layer.init_hidden(batch_size, device=device) diff --git a/pytorch_forecasting/layers/_recurrent/_slstm/__init__.py b/pytorch_forecasting/layers/_recurrent/_slstm/__init__.py new file mode 100644 index 000000000..0b0243da7 --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_slstm/__init__.py @@ -0,0 +1,7 @@ +"""sLSTM layer""" + +from pytorch_forecasting.layers._recurrent._slstm.cell import sLSTMCell +from pytorch_forecasting.layers._recurrent._slstm.layer import sLSTMLayer +from pytorch_forecasting.layers._recurrent._slstm.network import sLSTMNetwork + +__all__ = ["sLSTMCell", "sLSTMLayer", "sLSTMNetwork"] diff --git a/pytorch_forecasting/layers/_recurrent/_slstm/cell.py b/pytorch_forecasting/layers/_recurrent/_slstm/cell.py new file mode 100644 index 000000000..8a09bc24b --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_slstm/cell.py @@ -0,0 +1,149 @@ +import math + +import torch +import torch.nn as nn + + +class sLSTMCell(nn.Module): + """Implements the stabilized LSTM cell + + Implements the sLSTM algorithm as described in the paper: + (https://arxiv.org/pdf/2407.10240). + + Parameters + ---------- + input_size : int + Number of input features for the cell. + hidden_size : int + Number of features in the hidden state of the cell. + dropout : float, optional + Dropout probability for the cell's input and hidden state, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization for the cell's internal computations, + by default True. + + Attributes + ---------- + input_weights : nn.Linear + Linear layer for processing input features into gate computations. + hidden_weights : nn.Linear + Linear layer for processing hidden state features into gate computations. + ln_cell : nn.LayerNorm + Layer normalization for the cell state, applied if use_layer_norm is True. + ln_hidden : nn.LayerNorm + Layer normalization for the output hidden state, + applied if use_layer_norm is True. + ln_input : nn.LayerNorm + Layer normalization for input gates, applied if use_layer_norm is True. + ln_hidden_update : nn.LayerNorm + Layer normalization for hidden state gates, applied if use_layer_norm is True. + dropout_layer : nn.Dropout + Dropout layer applied to inputs and hidden states. + grad_clip : float + Gradient clipping threshold to improve training stability. + eps : float + Small constant for numerical stability in calculations. + tanh : nn.Tanh + Tanh activation function. + sigmoid : nn.Sigmoid + Sigmoid activation function. + """ + + def __init__(self, input_size, hidden_size, dropout=0.0, use_layer_norm=True): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.use_layer_norm = use_layer_norm + self.eps = 1e-6 + + self.input_weights = nn.Linear(input_size, 4 * hidden_size) + self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size) + + if use_layer_norm: + self.ln_cell = nn.LayerNorm(hidden_size) + self.ln_hidden = nn.LayerNorm(hidden_size) + self.ln_input = nn.LayerNorm(4 * hidden_size) + self.ln_hidden_update = nn.LayerNorm(4 * hidden_size) + + self.dropout_layer = nn.Dropout(dropout) + + self.reset_parameters() + + self.grad_clip = 5.0 + + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + + def reset_parameters(self): + """Initialize parameters using Xavier/Glorot initialization""" + std = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + weight.data.uniform_(-std, std) + + def normalized_exp_gate(self, pre_gate): + """Compute normalized exponential gate activation""" + centered = pre_gate - torch.mean(pre_gate, dim=1, keepdim=True) + exp_val = torch.exp(torch.clamp(centered, min=-5.0, max=5.0)) + normalizer = torch.sum(exp_val, dim=1, keepdim=True) + self.eps + return exp_val / normalizer + + def forward(self, x, h_prev, c_prev): + """Forward pass with stabilized exponential gating. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h_prev : torch.Tensor + Previous hidden state tensor. + c_prev : torch.Tensor + Previous cell state tensor. + + Returns + ------- + h : torch.Tensor + Updated hidden state tensor. + c : torch.Tensor + Updated cell state tensor. + """ + + x = self.dropout_layer(x) + h_prev = self.dropout_layer(h_prev) + + gates_x = self.input_weights(x) + gates_h = self.hidden_weights(h_prev) + + if self.use_layer_norm: + gates_x = self.ln_input(gates_x) + gates_h = self.ln_hidden_update(gates_h) + + gates = gates_x + gates_h + i, f, g, o = gates.chunk(4, dim=1) + + i = self.normalized_exp_gate(i) + f = self.normalized_exp_gate(f) + gate_sum = i + f + i = i / (gate_sum + self.eps) + f = f / (gate_sum + self.eps) + + c_tilde = self.tanh(g) + c = f * c_prev + i * c_tilde + if self.use_layer_norm: + c = self.ln_cell(c) + + o = self.sigmoid(o) + c_out = self.tanh(c) + if self.use_layer_norm: + c_out = self.ln_hidden(c_out) + h = o * c_out + + return h, c + + def init_hidden(self, batch_size, device=None): + if device is None: + device = next(self.parameters()).device + return ( + torch.zeros(batch_size, self.hidden_size, device=device), + torch.zeros(batch_size, self.hidden_size, device=device), + ) diff --git a/pytorch_forecasting/layers/_recurrent/_slstm/layer.py b/pytorch_forecasting/layers/_recurrent/_slstm/layer.py new file mode 100644 index 000000000..86875fcac --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_slstm/layer.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.layers._recurrent._slstm.cell import sLSTMCell + + +class sLSTMLayer(nn.Module): + """Implements the sLSTM Layer, which consists of multiple stacked sLSTM cells. + + This layer is designed for sequence modeling tasks, supporting multiple layers + with optional residual connections and layer normalization. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each sLSTM cell. + num_layers : int, optional + Number of stacked sLSTM layers, by default 1. + dropout : float, optional + Dropout probability for the input of each sLSTM cell, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization for each sLSTM cell, by default True. + use_residual : bool, optional + Whether to use residual connections in each sLSTM layer, by default True. + + Attributes + ---------- + cells : nn.ModuleList + List of sLSTMCell objects, one for each layer. + input_projection : nn.Linear or None + Linear layer for projecting input to match hidden state size, + used when residual connections are enabled. + layer_norm_layers : nn.ModuleList + List of LayerNorm layers, one for each sLSTM layer (if use_layer_norm is True). + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers=1, + dropout=0.0, + use_layer_norm=True, + use_residual=True, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.dropout = dropout + self.use_layer_norm = use_layer_norm + self.use_residual = use_residual + + self.input_projection = None + if self.use_residual and input_size != hidden_size: + self.input_projection = nn.Linear(input_size, hidden_size, bias=False) + + self.cells = nn.ModuleList( + [ + sLSTMCell( + input_size if layer == 0 else hidden_size, + hidden_size, + dropout=dropout, + use_layer_norm=use_layer_norm, + ) + for layer in range(num_layers) + ] + ) + + if self.use_layer_norm: + self.layer_norm_layers = nn.ModuleList( + [nn.LayerNorm(hidden_size) for _ in range(num_layers)] + ) + + def forward(self, x, h=None, c=None): + """Forward pass through the sLSTM Layer. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : list of torch.Tensor, optional + Initial hidden states for each layer. + If None, hidden states are initialized to zeros. + c : list of torch.Tensor, optional + Initial cell states for each layer. + If None, cell states are initialized to zeros. + + Returns + ------- + output : torch.Tensor + Tensor containing hidden states for each time step. + (h, c) : tuple of lists + Final hidden and cell states for each layer. + """ + seq_len, batch_size, _ = x.size() + + if h is None or c is None: + h, c = self.init_hidden(batch_size, device=x.device) + + outputs = [] + + for t in range(seq_len): + input_t = x[t] + layer_input = input_t + + for layer in range(self.num_layers): + h[layer], c[layer] = self.cells[layer](layer_input, h[layer], c[layer]) + + if self.use_residual: + if layer == 0 and self.input_projection is not None: + residual = self.input_projection(layer_input) + else: + residual = ( + layer_input + if (layer_input.size(-1) == self.hidden_size) + else 0 + ) + h[layer] = h[layer] + residual + + if self.use_layer_norm: + h[layer] = self.layer_norm_layers[layer](h[layer]) + + layer_input = h[layer] + + outputs.append(h[-1]) + + output = torch.stack(outputs) + + h = [hi.detach() for hi in h] + c = [ci.detach() for ci in c] + + return output, (h, c) + + def init_hidden(self, batch_size, device=None): + """Initialize hidden and cell states for each layer.""" + if device is None: + device = next(self.parameters()).device + return ( + [ + torch.zeros(batch_size, self.hidden_size, device=device) + for _ in range(self.num_layers) + ], + [ + torch.zeros(batch_size, self.hidden_size, device=device) + for _ in range(self.num_layers) + ], + ) diff --git a/pytorch_forecasting/layers/_recurrent/_slstm/network.py b/pytorch_forecasting/layers/_recurrent/_slstm/network.py new file mode 100644 index 000000000..7833f055c --- /dev/null +++ b/pytorch_forecasting/layers/_recurrent/_slstm/network.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.layers._recurrent._slstm.layer import sLSTMLayer + + +class sLSTMNetwork(nn.Module): + """Implements the Stabilized LSTM Network with multiple sLSTM layers. + + This network combines sLSTM layers with a fully connected output layer for + prediction. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each sLSTM layer. + num_layers : int + Number of stacked sLSTM layers in the network. + output_size : int + Number of features in the output prediction. + dropout : float, optional + Dropout probability for the input of each sLSTM layer, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization in each sLSTM layer, by default True. + + Attributes + ---------- + slstm_layer : sLSTMLayer + Stacked sLSTM layers used for processing input sequences. + fc : nn.Linear + Fully connected layer to generate the final output predictions. + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + output_size, + dropout=0.0, + use_layer_norm=True, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.output_size = output_size + self.dropout = dropout + + self.slstm_layer = sLSTMLayer( + input_size, + hidden_size, + num_layers, + dropout, + use_layer_norm, + ) + self.fc = nn.Linear(hidden_size, output_size) + + def forward(self, x, h=None, c=None): + """ + Forward pass through the sLSTM network. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : list of torch.Tensor, optional + Initial hidden states for each layer. + If None, hidden states are initialized to zeros. + c : list of torch.Tensor, optional + Initial cell states for each layer. + If None, cell states are initialized to zeros. + + Returns + ------- + output : torch.Tensor + Tensor containing the final output predictions. + (h, c) : tuple of lists + Final hidden and cell states for each layer. + """ + output, (h, c) = self.slstm_layer(x, h, c) + output = self.fc(output[-1]) + return output, (h, c) + + def init_hidden(self, batch_size, device=None): + """Initialize hidden and cell states for the entire network.""" + if device is None: + device = next(self.parameters()).device + return self.slstm_layer.init_hidden(batch_size, device=device) diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index 0a9d600f8..07335a08f 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -20,6 +20,7 @@ ) from pytorch_forecasting.models.tide import TiDEModel from pytorch_forecasting.models.timexer import TimeXer +from pytorch_forecasting.models.xlstm import xLSTMTime __all__ = [ "NBeats", @@ -39,4 +40,5 @@ "DecoderMLP", "TiDEModel", "TimeXer", + "xLSTMTime", ] diff --git a/pytorch_forecasting/models/xlstm/__init__.py b/pytorch_forecasting/models/xlstm/__init__.py new file mode 100644 index 000000000..df658560c --- /dev/null +++ b/pytorch_forecasting/models/xlstm/__init__.py @@ -0,0 +1,6 @@ +"""xLSTMTime implementation for forecasting.""" + +from pytorch_forecasting.models.xlstm._xlstm import xLSTMTime +from pytorch_forecasting.models.xlstm._xlstm_pkg import xLSTMTime_pkg + +__all__ = ["xLSTMTime", "xLSTMTime_pkg"] diff --git a/pytorch_forecasting/models/xlstm/_xlstm.py b/pytorch_forecasting/models/xlstm/_xlstm.py new file mode 100644 index 000000000..c304c67a8 --- /dev/null +++ b/pytorch_forecasting/models/xlstm/_xlstm.py @@ -0,0 +1,181 @@ +from copy import copy +from typing import Literal, Optional, Union + +import torch +from torch import nn + +from pytorch_forecasting.layers import SeriesDecomposition, mLSTMNetwork, sLSTMNetwork +from pytorch_forecasting.metrics import SMAPE, Metric +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel + + +class xLSTMTime(AutoRegressiveBaseModel): + """ + xLSTMTime is a long‑term time series forecasting architecture built on the + extended LSTM (xLSTM) design, incorporating either the scalar-memory + stabilized LSTM (sLSTM) or the matrix-memory mLSTM variant. This model + enhances classical LSTM by adding exponential gating and richer memory + dynamics, and combines series decomposition and normalization layers to + produce robust forecasts over extended horizons. + + It is based on this paper: https://arxiv.org/pdf/2407.10240 and + https://github.com/muslehal/xLSTMTime + """ + + @classmethod + def _pkg(cls): + """Package for the model.""" + from pytorch_forecasting.models.xlstm._xlstm_pkg import xLSTMTime_pkg + + return xLSTMTime_pkg + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + xlstm_type: Literal["slstm", "mlstm"] = "slstm", + num_layers: int = 1, + decomposition_kernel: int = 25, + input_projection_size: Optional[int] = None, + dropout: float = 0.1, + loss: Metric = SMAPE(), + **kwargs, + ): + """ + Initialise the model. + + Parameters + ---------- + input_size : int + Number of input continuous features per time step. + hidden_size : int + Hidden size of the xLSTM network; also used by batch norm / LSTM internals. + output_size : int + Number of output features per time step (forecast horizon). + xlstm_type : {"slstm", "mlstm"}, default "slstm" + Specifies which xLSTM variant to use: + - "slstm": stabilized LSTM with scalar memory, + - "mlstm": matrix-memory variant for higher capacity and scalability. + num_layers : int, default 1 + Number of recurrent layers in the sLSTM or mLSTM network. + decomposition_kernel : int, default 25 + Kernel size for series decomposition into trend and seasonal components. + input_projection_size : int, optional + If specified, the encoded input (trend + seasonal) is projected to this size + before being fed to the xLSTM; otherwise equals hidden_size. + dropout : float, default 0.1 + Dropout rate applied within the recurrent layers. + loss : pytorch_forecasting.metrics.Metric, default SMAPE() + Loss (and evaluation metric) used during training. + """ + if "target" in kwargs: + del kwargs["target"] + if "target_lags" in kwargs: + del kwargs["target_lags"] + self.save_hyperparameters() + super().__init__(loss=loss, **kwargs) + + if xlstm_type not in ["slstm", "mlstm"]: + raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") + + self.xlstm_type = xlstm_type + + self.decomposition = SeriesDecomposition(decomposition_kernel) + self.batch_norm = nn.BatchNorm1d(hidden_size) + + self.input_projection_size = input_projection_size or hidden_size + + self.input_linear = nn.Linear(input_size * 2, self.input_projection_size) + + if xlstm_type == "mlstm": + self.lstm = mLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + ) + else: # slstm + self.lstm = sLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + ) + + self.output_linear = nn.Linear(hidden_size, output_size) + self.instance_norm = nn.InstanceNorm1d(output_size) + + def forward( + self, + x: dict[str, torch.Tensor], + hidden_states: Optional[ + Union[ + tuple[torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ] + ] = None, + ) -> dict[str, torch.Tensor]: + """Forward Pass for the model.""" + encoder_cont = x["encoder_cont"] + batch_size, seq_len, n_features = encoder_cont.shape + + seasonal, trend = self.decomposition(encoder_cont) + + x = torch.cat([trend, seasonal], dim=-1) + + x = self.input_linear(x) + + x = x.transpose(1, 2) + x = self.batch_norm(x) + x = x.transpose(1, 2) + + if hidden_states is None: + hidden_states = self.lstm.init_hidden(batch_size, device=x.device) + + x = x.transpose(0, 1) + output, hidden_states = self.lstm(x, *hidden_states) + + if isinstance(output, tuple): + output = output[0] + + if output.dim() == 2: + output = output.unsqueeze(0) + + output = self.output_linear(output) + + output = output.transpose(1, 2) + output = self.instance_norm(output) + output = output.transpose(1, 2) + + output = output[0, ..., : self.hparams.output_size] + return self.to_network_output(prediction=output) + + @classmethod + def from_dataset(cls, dataset, **kwargs): + """ + Create model from dataset and set parameters related to covariates. + + Parameters + ---------- + dataset: timeseries dataset + **kwargs: additional arguments such as hyperparameters for model + + Returns + ------- + xLSTMTime + """ + from pytorch_forecasting.data.encoders import NaNLabelEncoder + + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + + new_kwargs = copy(kwargs) + new_kwargs.update( + cls.deduce_default_output_parameters(dataset, kwargs, SMAPE()) + ) + + return super().from_dataset(dataset, **kwargs) diff --git a/pytorch_forecasting/models/xlstm/_xlstm_pkg.py b/pytorch_forecasting/models/xlstm/_xlstm_pkg.py new file mode 100644 index 000000000..1a10fe660 --- /dev/null +++ b/pytorch_forecasting/models/xlstm/_xlstm_pkg.py @@ -0,0 +1,77 @@ +"""xLSTMTime package container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class xLSTMTime_pkg(_BasePtForecaster): + """xLSTMTime package container.""" + + _tags = { + "info:name": "xLSTMTime", + "info:compute": 3, + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "authors": ["muslehal", "phoeenniixx"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": False, + "capability:flexible_history_length": True, + "capability:cold_start": False, + } + + @classmethod + def get_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import xLSTMTime + + return xLSTMTime + + @classmethod + def get_base_test_params(cls): + """ + Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + """ + + params = [ + {}, + {"xlstm_type": "mlstm"}, + {"num_layers": 2}, + {"xlstm_type": "slstm", "input_projection_size": 32}, + { + "xlstm_type": "mlstm", + "decomposition_kernel": 13, + "dropout": 0.2, + }, + ] + defaults = {"hidden_size": 32, "input_size": 1, "output_size": 1} + for param in params: + param.update(defaults) + return params + + @classmethod + def _get_test_dataloaders_from(cls, params): + """ + Get dataloaders from parameters. + + Parameters + ---------- + params: dict + Parameters to create dataloaders. + One of the elements in the list returned by ``get_test_train_params``. + + Returns + ------- + dataloaders: Dict[str, DataLoader] + Dict of dataloaders created from the parameters. + Train, validation, and test dataloaders created from the parameters. + """ + from pytorch_forecasting.tests._data_scenarios import ( + dataloaders_fixed_window_without_covariates, + ) + + return dataloaders_fixed_window_without_covariates() diff --git a/tests/test_models/test_x_lstm.py b/tests/test_models/test_x_lstm.py new file mode 100644 index 000000000..a527957cf --- /dev/null +++ b/tests/test_models/test_x_lstm.py @@ -0,0 +1,110 @@ +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import pytest + +from pytorch_forecasting.metrics import SMAPE +from pytorch_forecasting.models.xlstm._xlstm import xLSTMTime + + +def _integration( + dataloaders_fixed_window_without_covariates, tmp_path, xlstm_type="slstm", **kwargs +): + train_dataloader = dataloaders_fixed_window_without_covariates["train"] + val_dataloader = dataloaders_fixed_window_without_covariates["val"] + test_dataloader = dataloaders_fixed_window_without_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + ) + + model_kwargs = { + "input_size": 1, + "output_size": 1, + "hidden_size": 32, + "xlstm_type": xlstm_type, + "learning_rate": 0.01, + "loss": SMAPE(), + } + + model_kwargs.update(kwargs) + + net = xLSTMTime.from_dataset(train_dataloader.dataset, **model_kwargs) + + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + + net = xLSTMTime.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"xlstm_type": "mlstm"}, + {"num_layers": 2}, + {"xlstm_type": "slstm", "input_projection_size": 32}, + { + "xlstm_type": "mlstm", + "decomposition_kernel": 13, + "dropout": 0.2, + }, + ], +) +def test_integration(dataloaders_fixed_window_without_covariates, tmp_path, kwargs): + _integration(dataloaders_fixed_window_without_covariates, tmp_path, **kwargs) + + +@pytest.fixture(scope="session") +def model(dataloaders_fixed_window_without_covariates): + dataset = dataloaders_fixed_window_without_covariates["train"].dataset + net = xLSTMTime.from_dataset( + dataset, + input_size=1, + hidden_size=32, + output_size=1, + xlstm_type="slstm", + learning_rate=0.01, + loss=SMAPE(), + ) + return net