diff --git a/aeon/networks/__init__.py b/aeon/networks/__init__.py index d774abe102..d80a21c6da 100644 --- a/aeon/networks/__init__.py +++ b/aeon/networks/__init__.py @@ -19,6 +19,7 @@ "AEBiGRUNetwork", "DisjointCNNNetwork", "RecurrentNetwork", + "InformerNetwork", ] from aeon.networks._ae_abgru import AEAttentionBiGRUNetwork from aeon.networks._ae_bgru import AEBiGRUNetwork @@ -32,6 +33,7 @@ from aeon.networks._encoder import EncoderNetwork from aeon.networks._fcn import FCNNetwork from aeon.networks._inception import InceptionNetwork +from aeon.networks._informer import InformerNetwork from aeon.networks._lite import LITENetwork from aeon.networks._mlp import MLPNetwork from aeon.networks._resnet import ResNetNetwork diff --git a/aeon/networks/_informer.py b/aeon/networks/_informer.py new file mode 100644 index 0000000000..f02c058afa --- /dev/null +++ b/aeon/networks/_informer.py @@ -0,0 +1,869 @@ +"""Informer Network for time series forecasting.""" + +__maintainer__ = [""] + + +from aeon.networks.base import BaseDeepLearningNetwork +from aeon.utils.validation._dependencies import _check_soft_dependencies + +if _check_soft_dependencies(["tensorflow"], severity="none"): + + from aeon.utils.networks.attention import ( + AttentionLayer, + KerasProbAttention, + ) + + +class InformerNetwork(BaseDeepLearningNetwork): + """ + TensorFlow implementation of the Informer network for time series forecasting. + + The Informer network is a Transformer-based architecture designed for + long sequence time-series forecasting. It uses ProbSparse self-attention + mechanism and distilling operation to reduce computational complexity. + + Parameters + ---------- + encoder_input_len : int, default=96 + Encoder input sequence length. + decoder_input_len : int, default=48 + Start token length for decoder. + prediction_horizon : int, default=24 + Prediction sequence length. + factor : int, default=5 + ProbSparse attention factor. + model_dimension : int, default=512 + Model dimension. + num_attention_heads : int, default=8 + Number of attention heads. + encoder_layers : int, default=3 + Number of encoder layers. + decoder_layers : int, default=2 + Number of decoder layers. + feedforward_dim : int, default=512 + Feed forward network dimension. + dropout : float, default=0.0 + Dropout rate. + attention_type : str, default='prob' + Attention mechanism type ('prob' or 'full'). + activation : str, default='gelu' + Activation function. + distil : bool, default=True + Whether to use distilling operation. + mix : bool, default=True + Whether to use mix attention in decoder. + + References + ---------- + .. [1] Zhou, H., Zhang, S., Peng, J., Zhang, S., Li, J., Xiong, H., & Zhang, W. + (2021). Informer: Beyond efficient transformer for long sequence + time-series forecasting. In Proceedings of the AAAI conference on + artificial intelligence (Vol. 35, No. 12, pp. 11106-11115). + """ + + _config = { + "python_dependencies": ["tensorflow"], + "python_version": "<3.13", + "structure": "transformer", + } + + def __init__( + self, + encoder_input_len: int = 96, + decoder_input_len: int = 48, + prediction_horizon: int = 24, + factor: int = 5, + model_dimension: int = 512, + num_attention_heads: int = 8, + encoder_layers: int = 3, + decoder_layers: int = 2, + feedforward_dim: int = 512, + dropout: float = 0.0, + attention_type: str = "prob", + activation: str = "gelu", + distil: bool = True, + mix: bool = True, + ): + self.encoder_input_len = encoder_input_len + self.decoder_input_len = decoder_input_len + self.prediction_horizon = prediction_horizon + self.factor = factor + self.model_dimension = model_dimension + self.num_attention_heads = num_attention_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.feedforward_dim = feedforward_dim + self.dropout = dropout + self.attention_type = attention_type + self.activation = activation + self.distil = distil + self.mix = mix + + super().__init__() + + def _token_embedding(self, input_tensor, c_in, model_dimension): + """ + Token embedding layer using 1D convolution with causal padding. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor to be processed. + c_in : int + Number of input channels. + model_dimension : int + Dimension of the model (number of output filters). + + Returns + ------- + tf.Tensor + Output tensor after token embedding transformation. + """ + import tensorflow as tf + + x = tf.keras.layers.Conv1D( + filters=model_dimension, + kernel_size=3, + padding="causal", + activation="linear", + )(input_tensor) + x = tf.keras.layers.LeakyReLU()(x) + return x + + def _positional_embedding(self, input_tensor, model_dimension, max_len=5000): + """ + Positional embedding layer that computes positional encodings. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor to get positional embeddings for. + model_dimension : int + Dimension of the model. + max_len : int, optional + Maximum length of the sequence, by default 5000 + + Returns + ------- + tf.Tensor + Positional encoding tensor matching input tensor's sequence length. + """ + import math + + import numpy as np + import tensorflow as tf + + # Compute the positional encodings + pe = np.zeros((max_len, model_dimension), dtype=np.float32) + position = np.expand_dims(np.arange(0, max_len, dtype=np.float32), 1) + div_term = np.exp( + np.arange(0, model_dimension, 2, dtype=np.float32) + * -(math.log(10000.0) / model_dimension) + ) + + pe[:, 0::2] = np.sin(position * div_term) + pe[:, 1::2] = np.cos(position * div_term) + + # Convert to tensor and add batch dimension + pe_tensor = tf.expand_dims(tf.convert_to_tensor(pe), 0) + + # Return positional embeddings for the input tensor's sequence length + return pe_tensor[:, : input_tensor.shape[1]] + + def _data_embedding( + self, + input_tensor, + c_in, + model_dimension, + dropout=0.1, + max_len=5000, + ): + """ + Combine token and positional embeddings for the input tensor. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor to be processed. + c_in : int + Number of input channels. + model_dimension : int + Dimension of the model (number of output filters). + dropout : float, optional + Dropout rate, by default 0.1 + max_len : int, optional + Maximum length of the sequence for positional embedding + + Returns + ------- + tf.Tensor + Output tensor after data embedding transformation. + """ + import tensorflow as tf + + # Get token embeddings + token_emb = self._token_embedding(input_tensor, c_in, model_dimension) + + # Get positional embeddings + pos_emb = self._positional_embedding(input_tensor, model_dimension, max_len) + + # Combine embeddings + x = token_emb + pos_emb + + # Apply dropout + x = tf.keras.layers.Dropout(dropout)(x) + + return x + + def _conv_layer(self, input_tensor, c_in): + """ + Convolutional layer with batch normalization, ELU, and max pooling. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor to be processed. + c_in : int + Number of input channels (filters for the convolution). + + Returns + ------- + tf.Tensor + Output tensor after convolution and pooling operations. + """ + import tensorflow as tf + + # Apply 1D convolution with causal padding + x = tf.keras.layers.Conv1D(filters=c_in, kernel_size=3, padding="causal")( + input_tensor + ) + + # Apply batch normalization + x = tf.keras.layers.BatchNormalization()(x) + + # Apply ELU activation + x = tf.keras.layers.ELU()(x) + + # Apply max pooling for downsampling + x = tf.keras.layers.MaxPool1D(pool_size=3, strides=2)(x) + + return x + + def _attention_out( + self, + input_tensor, + attention_type, + mask_flag, + model_dimension, + num_attention_heads, + factor=5, + dropout=0.1, + attn_mask=None, + ): + """ + Attention output layer applying either ProbAttention or FullAttention. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor for attention computation. + attention_type : str + Type of attention mechanism ('prob' or 'full'). + mask_flag : bool + Whether to use attention masking. + model_dimension : int + Model dimension. + num_attention_heads : int + Number of attention heads. + factor : int, optional + Attention factor for ProbSparse attention, by default 5 + dropout : float, optional + Dropout rate, by default 0.1 + attn_mask : tf.Tensor, optional + Attention mask tensor, by default None + + Returns + ------- + tf.Tensor + Output tensor after attention computation. + """ + import tensorflow as tf + + if attention_type == "prob": + prob_attention = KerasProbAttention( + mask_flag=mask_flag, + factor=factor, + attention_dropout=dropout, + ) + + output = AttentionLayer( + attention=prob_attention, + d_model=model_dimension, + n_heads=num_attention_heads, + d_keys=model_dimension // num_attention_heads, # 512 // 8 = 64 + d_values=model_dimension // num_attention_heads, # 512 // 8 = 64 + )(input_tensor, attn_mask=attn_mask) + + else: + queries, keys, values = input_tensor + output = tf.keras.layers.MultiHeadAttention( + num_heads=num_attention_heads, # 8 + key_dim=model_dimension // num_attention_heads, # 512 // 8 = 64 + value_dim=model_dimension // num_attention_heads, # 512 // 8 = 64 + dropout=dropout, + use_bias=True, + )( + query=queries, # (32, 20, 512) + key=keys, # (32, 20, 512) + value=values, # (32, 20, 512) + attention_mask=attn_mask, + use_causal_mask=mask_flag, + ) + + return output + + def _encoder_layer( + self, + input_tensor, + model_dimension, + feedforward_dim=None, + dropout=0.1, + activation="relu", + attn_mask=None, + attention_type="prob", + mask_flag=True, + num_attention_heads=8, + factor=5, + ): + """ + Apply encoder layer with multi-head attention and feed-forward network. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape [B, L, D] where B is batch size, + L is sequence length, D is model dimension. + model_dimension : int + Model dimension (must match input tensor's last dimension). + feedforward_dim : int, optional + Feed-forward network dimension + dropout : float, optional + Dropout rate, by default 0.1 + activation : str, optional + Activation function ('relu' or 'gelu'), by default "relu" + attn_mask : tf.Tensor, optional + Attention mask tensor, by default None + + Returns + ------- + tf.Tensor + Output tensor after encoder layer processing. + """ + import tensorflow as tf + + # Set default feedforward_dim if not provided + if feedforward_dim is None: + feedforward_dim = 4 * model_dimension + + # Self-attention using the _attention_out function with parameters + attn_output = self._attention_out( + input_tensor=[input_tensor, input_tensor, input_tensor], + attention_type=attention_type, + mask_flag=mask_flag, + model_dimension=model_dimension, + num_attention_heads=num_attention_heads, + factor=factor, + dropout=dropout, + attn_mask=attn_mask, + ) + + # Apply dropout and residual connection + x = input_tensor + tf.keras.layers.Dropout(dropout)(attn_output) + + # First layer normalization + x = tf.keras.layers.LayerNormalization()(x) + + # Store for second residual connection + residual = x + + # Feed-forward network + # First 1D convolution (expansion) + y = tf.keras.layers.Conv1D(filters=feedforward_dim, kernel_size=1)(x) + + # Apply activation function + if activation == "relu": + y = tf.keras.layers.ReLU()(y) + else: # gelu + y = tf.keras.layers.Activation("gelu")(y) + + # Apply dropout + y = tf.keras.layers.Dropout(dropout)(y) + + # Second 1D convolution (compression back to d_model) + y = tf.keras.layers.Conv1D(filters=model_dimension, kernel_size=1)(y) + + # Apply dropout + y = tf.keras.layers.Dropout(dropout)(y) + + # Second residual connection and layer normalization + output = tf.keras.layers.LayerNormalization()(residual + y) + + return output + + def _encoder( + self, + input_tensor, + encoder_layers, + model_dimension, + feedforward_dim=None, + dropout=0.1, + activation="relu", + attn_mask=None, + attention_type="prob", + mask_flag=True, + num_attention_heads=8, + factor=5, + use_conv_layers=False, + c_in=None, + use_norm=True, + ): + """ + Apply encoder stack with multiple encoder layers and optional conv layers. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape [B, L, D] + encoder_layers : int + Number of encoder layers to stack. + model_dimension : int + Model dimension (must match input tensor's last dimension). + feedforward_dim : int, optional + Feed-forward network dimension + dropout : float, optional + Dropout rate, by default 0.1 + activation : str, optional + Activation function ('relu' or 'gelu'), by default "relu" + attn_mask : tf.Tensor, optional + Attention mask tensor, by default None + attention_type : str, optional + Type of attention mechanism ('prob' or 'full') + mask_flag : bool, optional + Whether to use attention masking, by default True + num_attention_heads : int, optional + Number of attention heads, by default 8 + factor : int, optional + Attention factor for ProbSparse attention, by default 5 + use_conv_layers : bool, optional + Whether to use convolutional layers between encoder layers + c_in : int, optional + Number of input channels for convolutional layers + use_norm : bool, optional + Whether to apply final layer normalization, by default True + + Returns + ------- + tf.Tensor + Output tensor after encoder stack processing. + """ + import tensorflow as tf + + # Set default values + if c_in is None: + c_in = model_dimension + + x = input_tensor + + # Apply encoder layers with optional convolutional layers + if use_conv_layers: + # Apply paired encoder and conv layers + for _ in range(encoder_layers - 1): + # Apply encoder layer + x = self._encoder_layer( + input_tensor=x, + model_dimension=model_dimension, + feedforward_dim=feedforward_dim, + dropout=dropout, + activation=activation, + attn_mask=attn_mask, + attention_type=attention_type, + mask_flag=mask_flag, + num_attention_heads=num_attention_heads, + factor=factor, + ) + + # Apply convolutional layer for downsampling + x = self._conv_layer( + input_tensor=x, + c_in=c_in, + ) + + # Apply final encoder layer (without conv layer) + x = self._encoder_layer( + input_tensor=x, + model_dimension=model_dimension, + feedforward_dim=feedforward_dim, + dropout=dropout, + activation=activation, + attn_mask=attn_mask, + attention_type=attention_type, + mask_flag=mask_flag, + num_attention_heads=num_attention_heads, + factor=factor, + ) + + else: + # Apply only encoder layers without convolutional layers + for _ in range(encoder_layers): + x = self._encoder_layer( + input_tensor=x, + model_dimension=model_dimension, + feedforward_dim=feedforward_dim, + dropout=dropout, + activation=activation, + attn_mask=attn_mask, + attention_type=attention_type, + mask_flag=mask_flag, + num_attention_heads=num_attention_heads, + factor=factor, + ) + + # Apply optional final layer normalization + if use_norm: + x = tf.keras.layers.LayerNormalization()(x) + + return x + + def _decoder_layer( + self, + input_tensor, + cross_tensor, + model_dimension, + feedforward_dim=None, + dropout=0.1, + activation="relu", + x_mask=None, + cross_mask=None, + self_attention_type="prob", + cross_attention_type="prob", + mask_flag=True, + num_attention_heads=8, + factor=5, + ): + """ + Apply decoder layer with self-attention, cross-attention, and FFN. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape [B, L, D] + cross_tensor : tf.Tensor + Cross-attention input tensor (encoder output) of shape [B, L_enc, D] + model_dimension : int + Model dimension (must match input tensor's last dimension). + feedforward_dim : int, optional + Feed-forward network dimension + dropout : float, optional + Dropout rate, by default 0.1 + activation : str, optional + Activation function ('relu' or 'gelu'), by default "relu" + x_mask : tf.Tensor, optional + Self-attention mask tensor, by default None + cross_mask : tf.Tensor, optional + Cross-attention mask tensor, by default None + self_attention_type : str, optional + Type of self-attention mechanism ('prob' or 'full') + cross_attention_type : str, optional + Type of cross-attention mechanism ('prob' or 'full') + mask_flag : bool, optional + Whether to use attention masking, by default True + num_attention_heads : int, optional + Number of attention heads, by default 8 + factor : int, optional + Attention factor for ProbSparse attention, by default 5 + + Returns + ------- + tf.Tensor + Output tensor after decoder layer processing with same shape. + """ + import tensorflow as tf + + # Set default feedforward_dim if not provided + if feedforward_dim is None: + feedforward_dim = 4 * model_dimension + + # Self-attention block + self_attn_output = self._attention_out( + input_tensor=[input_tensor, input_tensor, input_tensor], + attention_type=self_attention_type, + mask_flag=mask_flag, + model_dimension=model_dimension, + num_attention_heads=num_attention_heads, + factor=factor, + dropout=dropout, + attn_mask=x_mask, + ) + + # Apply dropout and first residual connection + x = input_tensor + tf.keras.layers.Dropout(dropout)(self_attn_output) + + # First layer normalization + x = tf.keras.layers.LayerNormalization()(x) + + # Cross-attention block + cross_attn_output = self._attention_out( + input_tensor=[x, cross_tensor, cross_tensor], + attention_type=cross_attention_type, + mask_flag=mask_flag, + model_dimension=model_dimension, + num_attention_heads=num_attention_heads, + factor=factor, + dropout=dropout, + attn_mask=cross_mask, + ) + + # Apply dropout and second residual connection + x = x + tf.keras.layers.Dropout(dropout)(cross_attn_output) + + # Second layer normalization + x = tf.keras.layers.LayerNormalization()(x) + + # Store for third residual connection + residual = x + + # Feed-forward network + # First 1D convolution (expansion) + y = tf.keras.layers.Conv1D(filters=feedforward_dim, kernel_size=1)(x) + + # Apply activation function + if activation == "relu": + y = tf.keras.layers.ReLU()(y) + else: # gelu + y = tf.keras.layers.Activation("gelu")(y) + + # Apply dropout + y = tf.keras.layers.Dropout(dropout)(y) + + # Second 1D convolution (compression back to d_model) + y = tf.keras.layers.Conv1D(filters=model_dimension, kernel_size=1)(y) + + # Apply dropout + y = tf.keras.layers.Dropout(dropout)(y) + + # Third residual connection and final layer normalization + output = tf.keras.layers.LayerNormalization()(residual + y) + + return output + + def _decoder( + self, + input_tensor, + cross_tensor, + decoder_layers, + model_dimension, + feedforward_dim=None, + dropout=0.1, + activation="relu", + x_mask=None, + cross_mask=None, + self_attention_type="prob", + cross_attention_type="prob", + mask_flag=True, + num_attention_heads=8, + factor=5, + use_norm=True, + ): + """ + Apply decoder stack with multiple decoder layers and optional normalization. + + Parameters + ---------- + input_tensor : tf.Tensor + Decoder input tensor of shape [B, L_dec, D] + cross_tensor : tf.Tensor + Cross-attention input tensor (encoder output) of shape [B, L_enc, D] + decoder_layers : int + Number of decoder layers to stack. + model_dimension : int + Model dimension (must match input tensor's last dimension). + feedforward_dim : int, optional + Feed-forward network dimension + dropout : float, optional + Dropout rate, by default 0.1 + activation : str, optional + Activation function ('relu' or 'gelu'), by default "relu" + x_mask : tf.Tensor, optional + Self-attention mask tensor for decoder, by default None + cross_mask : tf.Tensor, optional + Cross-attention mask tensor, by default None + self_attention_type : str, optional + Type of self-attention mechanism ('prob' or 'full') + cross_attention_type : str, optional + Type of cross-attention mechanism ('prob' or 'full') + mask_flag : bool, optional + Whether to use attention masking, by default True + num_attention_heads : int, optional + Number of attention heads, by default 8 + factor : int, optional + Attention factor for ProbSparse attention, by default 5 + use_norm : bool, optional + Whether to apply final layer normalization, by default True + + Returns + ------- + tf.Tensor + Output tensor after decoder stack processing. + """ + import tensorflow as tf + + x = input_tensor + + # Apply multiple decoder layers + for _ in range(decoder_layers): + x = self._decoder_layer( + input_tensor=x, + cross_tensor=cross_tensor, + model_dimension=model_dimension, + feedforward_dim=feedforward_dim, + dropout=dropout, + activation=activation, + x_mask=x_mask, + cross_mask=cross_mask, + self_attention_type=self_attention_type, + cross_attention_type=cross_attention_type, + mask_flag=mask_flag, + num_attention_heads=num_attention_heads, + factor=factor, + ) + + # Apply optional final layer normalization + if use_norm: + x = tf.keras.layers.LayerNormalization()(x) + + return x + + def _preprocess_time_series( + self, data, encoder_input_len, decoder_input_len, prediction_horizon + ): + """ + Preprocess time series data of shape (None, n_timepoints, n_channels). + + Parameters + ---------- + data : tf.Tensor + Input tensor of shape (None, n_timepoints, n_channels) + encoder_input_len : int + Encoder input sequence length + decoder_input_len : int + Known decoder input length + prediction_horizon : int + Prediction length + + Returns + ------- + tuple + (x_enc, x_dec) where: + - x_enc: Encoder input tensor of shape (None, encoder_input_len, n_channels) + - x_dec: Decoder input tensor of shape (None, + decoder_input_len + prediction_horizon, n_channels) + """ + import tensorflow as tf + + # Get tensor dimensions - handle None batch dimension + batch_size, n_timepoints, n_channels = data.shape + + # Encoder input: first seq_len timepoints + x_enc = data[:, :encoder_input_len, :] # (None, encoder_input_len, n_channels) + + # Decoder input construction + x_dec_known = data[ + :, encoder_input_len - decoder_input_len : encoder_input_len, : + ] # (None, decoder_input_len, n_channels) + + # Unknown part: zeros for prediction horizon + x_dec_pred = data[:, :prediction_horizon, :] + + # Concatenate known and prediction parts + x_dec = tf.keras.layers.Concatenate(axis=1)([x_dec_known, x_dec_pred]) + + return x_enc, x_dec + + def build_network(self, input_shape, **kwargs): + """Build the complete Informer architecture for time series forecasting.""" + import tensorflow as tf + + # Get input dimensions + n_timepoints, n_channels = input_shape + + input_data = tf.keras.layers.Input( + shape=input_shape, + name="time_series_input", + ) + + encoder_input, decoder_input = self._preprocess_time_series( + data=input_data, + encoder_input_len=self.encoder_input_len, + decoder_input_len=self.decoder_input_len, + prediction_horizon=self.prediction_horizon, + ) + + # Encoder embedding + enc_embedded = self._data_embedding( + input_tensor=encoder_input, + c_in=n_channels, + model_dimension=self.model_dimension, + dropout=self.dropout, + max_len=self.encoder_input_len, + ) + + # Encoder processing + enc_output = self._encoder( + input_tensor=enc_embedded, + encoder_layers=self.encoder_layers, + model_dimension=self.model_dimension, + feedforward_dim=self.feedforward_dim, + dropout=self.dropout, + activation=self.activation, + attention_type=self.attention_type, + mask_flag=False, + num_attention_heads=self.num_attention_heads, + factor=self.factor, + use_conv_layers=self.distil, + c_in=self.model_dimension, + use_norm=True, + ) + + # Decoder embedding + dec_embedded = self._data_embedding( + input_tensor=decoder_input, + c_in=n_channels, + model_dimension=self.model_dimension, + dropout=self.dropout, + max_len=self.decoder_input_len + self.prediction_horizon, + ) + + # Decoder processing + dec_output = self._decoder( + input_tensor=dec_embedded, + cross_tensor=enc_output, + decoder_layers=self.decoder_layers, + model_dimension=self.model_dimension, + feedforward_dim=self.feedforward_dim, + dropout=self.dropout, + activation=self.activation, + self_attention_type=self.attention_type, + cross_attention_type="full", + mask_flag=self.mix, + num_attention_heads=self.num_attention_heads, + factor=self.factor, + use_norm=True, + ) + + # Final projection to output dimension + output = tf.keras.layers.Dense(n_channels, name="output_projection")(dec_output) + + # Extract only the prediction part (last out_len timesteps) + output = output[:, -self.prediction_horizon :, :] + + return input_data, output diff --git a/aeon/networks/tests/test_all_networks.py b/aeon/networks/tests/test_all_networks.py index 9ca85474fb..924e1f2623 100644 --- a/aeon/networks/tests/test_all_networks.py +++ b/aeon/networks/tests/test_all_networks.py @@ -75,6 +75,11 @@ def test_all_networks_params(network): f"{network.__name__} not to be tested (AE networks have their own tests)." ) + if network._config["structure"] == "transformer": + pytest.skip( + f"{network.__name__} not to be tested (transformers have their own tests)." + ) + if not ( _check_soft_dependencies( network._config["python_dependencies"], severity="none" diff --git a/aeon/networks/tests/test_informer.py b/aeon/networks/tests/test_informer.py new file mode 100644 index 0000000000..8c59fa30d5 --- /dev/null +++ b/aeon/networks/tests/test_informer.py @@ -0,0 +1,178 @@ +"""Tests for the Informer Network Model.""" + +import random + +import pytest + +from aeon.networks import InformerNetwork +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize( + "encoder_input_len,decoder_input_len," + "prediction_horizon,model_dimension,num_attention_heads," + "encoder_layers,decoder_layers", + [ + (96, 48, 24, 512, 8, 3, 2), + (48, 24, 12, 256, 4, 2, 1), + (120, 60, 30, 128, 2, 1, 1), + (72, 36, 18, 64, 1, 2, 2), + ], +) +def test_informer_network_init( + encoder_input_len, + decoder_input_len, + prediction_horizon, + model_dimension, + num_attention_heads, + encoder_layers, + decoder_layers, +): + """Test whether InformerNetwork initializes correctly for various parameters.""" + informer = InformerNetwork( + encoder_input_len=encoder_input_len, + decoder_input_len=decoder_input_len, + prediction_horizon=prediction_horizon, + model_dimension=model_dimension, + num_attention_heads=num_attention_heads, + encoder_layers=encoder_layers, + decoder_layers=decoder_layers, + factor=random.choice([3, 5, 7]), + dropout=random.choice([0.0, 0.1, 0.2]), + attention_type=random.choice(["prob", "full"]), + activation=random.choice(["relu", "gelu"]), + ) + + inputs, outputs = informer.build_network((encoder_input_len + decoder_input_len, 5)) + assert inputs is not None + assert outputs is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize( + "attention_type,activation", + [("prob", "relu"), ("full", "gelu"), ("prob", "gelu"), ("full", "relu")], +) +def test_informer_network_attention_activation(attention_type, activation): + """Test InformerNetwork with different attention and activation.""" + informer = InformerNetwork( + encoder_input_len=96, + decoder_input_len=48, + prediction_horizon=24, + model_dimension=128, + num_attention_heads=4, + encoder_layers=2, + decoder_layers=1, + attention_type=attention_type, + activation=activation, + ) + + inputs, outputs = informer.build_network((144, 3)) + assert inputs is not None + assert outputs is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize( + "distil,mix,factor", + [(True, True, 5), (False, False, 3), (True, False, 7), (False, True, 2)], +) +def test_informer_network_distil_mix_factor(distil, mix, factor): + """Test whether InformerNetwork works with different configurations.""" + informer = InformerNetwork( + encoder_input_len=48, + decoder_input_len=24, + prediction_horizon=12, + model_dimension=64, + num_attention_heads=2, + encoder_layers=1, + decoder_layers=1, + distil=distil, + mix=mix, + factor=factor, + ) + + inputs, outputs = informer.build_network((72, 2)) + assert inputs is not None + assert outputs is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_informer_network_default_parameters(): + """Test whether InformerNetwork works with default parameters.""" + informer = InformerNetwork() + + inputs, outputs = informer.build_network((120, 1)) + assert inputs is not None + assert outputs is not None + + # Check default values + assert informer.encoder_input_len == 96 + assert informer.decoder_input_len == 48 + assert informer.prediction_horizon == 24 + assert informer.model_dimension == 512 + assert informer.num_attention_heads == 8 + assert informer.encoder_layers == 3 + assert informer.decoder_layers == 2 + assert informer.attention_type == "prob" + assert informer.activation == "gelu" + assert informer.distil + assert informer.mix + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_informer_network_parameter_validation(): + """Test whether InformerNetwork handles edge case parameters correctly.""" + informer = InformerNetwork( + encoder_input_len=12, + decoder_input_len=6, + prediction_horizon=3, + model_dimension=32, + num_attention_heads=1, + encoder_layers=1, + decoder_layers=1, + factor=1, + dropout=0.0, + ) + + inputs, outputs = informer.build_network((18, 1)) + assert inputs is not None + assert outputs is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_informer_network_different_channels(): + """Test whether InformerNetwork works with different numbers of input channels.""" + for n_channels in [1, 3, 5, 10]: + informer = InformerNetwork( + encoder_input_len=48, + decoder_input_len=24, + prediction_horizon=12, + model_dimension=64, + num_attention_heads=2, + encoder_layers=1, + decoder_layers=1, + ) + + inputs, outputs = informer.build_network((72, n_channels)) + assert inputs is not None + assert outputs is not None diff --git a/aeon/utils/networks/attention.py b/aeon/utils/networks/attention.py new file mode 100644 index 0000000000..5201bc5134 --- /dev/null +++ b/aeon/utils/networks/attention.py @@ -0,0 +1,351 @@ +"""Full Attention, ProbSparseAttention and Attention Layer.""" + +from aeon.utils.validation._dependencies import _check_soft_dependencies + +if _check_soft_dependencies(["tensorflow"], severity="none"): + import numpy as np + import tensorflow as tf + from tensorflow.keras.layers import Dropout, Layer + + @tf.keras.utils.register_keras_serializable(package="aeon") + class KerasProbAttention(Layer): + """Keras implementation of ProbSparse Attention mechanism for Informer.""" + + def __init__( + self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, **kwargs + ): + """Initialize KerasProbAttention layer.""" + super().__init__(**kwargs) + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.attention_dropout = attention_dropout + self.dropout = Dropout(attention_dropout) + + def build(self, input_shape): + """Build the layer.""" + super().build(input_shape) + + def compute_output_shape(self, input_shape): + """Compute output shape for the layer.""" + # Return the same shape as queries input + return input_shape[0] # queries shape + + def compute_output_spec(self, input_spec): + """Compute output spec for the layer.""" + return input_spec[0] # Return queries spec + + def _prob_QK(self, Q, K, sample_k, n_top): + """Compute probabilistic QK with fixed dimension handling.""" + B, H, L, _ = tf.shape(Q)[0], tf.shape(Q)[1], tf.shape(Q)[2], tf.shape(Q)[3] + S = tf.shape(K)[2] + + # Ensure sample_k doesn't exceed available dimensions + sample_k = tf.minimum(sample_k, L) + n_top = tf.minimum(n_top, S) # Ensure n_top doesn't exceed sequence length + + # Expand K for sampling + K_expand = tf.expand_dims(K, axis=2) # [B, H, 1, L, E] + K_expand = tf.tile(K_expand, [1, 1, S, 1, 1]) # [B, H, S, L, E] + + # Generate random indices - ensure they're within bounds + indx_q_seq = tf.random.uniform([S], maxval=L, dtype=tf.int32) + indx_k_seq = tf.random.uniform([sample_k], maxval=L, dtype=tf.int32) + + # Gather operations for sampling + indices_s = tf.range(S) + K_sample = tf.gather(K_expand, indices_s, axis=2) + K_sample = tf.gather(K_sample, indx_q_seq, axis=2) + K_sample = tf.gather(K_sample, indx_k_seq, axis=3) + + # Matrix multiplication for Q_K_sample + Q_expanded = tf.expand_dims(Q, axis=-2) # [B, H, S, 1, E] + K_sample_transposed = tf.transpose(K_sample, perm=[0, 1, 2, 4, 3]) + Q_K_sample = tf.squeeze(tf.matmul(Q_expanded, K_sample_transposed), axis=-2) + + # Sparsity measurement calculation + M_max = tf.reduce_max(Q_K_sample, axis=-1) + M_mean = tf.reduce_sum(Q_K_sample, axis=-1) / tf.cast(sample_k, tf.float32) + M = M_max - M_mean + + # Top-k selection with dynamic k + actual_k = tf.minimum(n_top, tf.shape(M)[-1]) + _, M_top = tf.nn.top_k(M, k=actual_k, sorted=False) + + # Create indices for gather_nd + batch_range = tf.range(B) + head_range = tf.range(H) + batch_indices = tf.tile( + tf.expand_dims(tf.expand_dims(batch_range, 1), 2), [1, H, actual_k] + ) + + head_indices = tf.tile( + tf.expand_dims(tf.expand_dims(head_range, 0), 2), [B, 1, actual_k] + ) + + # Stack indices for gather_nd + idx = tf.stack([batch_indices, head_indices, M_top], axis=-1) + + # Reduce Q and calculate final Q_K + Q_reduce = tf.gather_nd(Q, idx) + K_transposed = tf.transpose(K, perm=[0, 1, 3, 2]) + Q_K = tf.matmul(Q_reduce, K_transposed) + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + """Get initial context using Keras-compatible operations.""" + if not self.mask_flag: + # Sum reduction and broadcasting + V_sum = tf.reduce_sum(V, axis=-2) # [B, H, D] + V_sum_expanded = tf.expand_dims(V_sum, axis=-2) # [B, H, 1, D] + context = tf.tile(V_sum_expanded, [1, 1, L_Q, 1]) # [B, H, L_Q, D] + else: + # Cumulative sum for masked attention + context = tf.cumsum(V, axis=-2) + + return context + + def _create_prob_mask(self, B, H, L, index, scores): + """Create probability mask for tf.where compatibility.""" + # Create base mask with ones + _mask = tf.ones((L, tf.shape(scores)[-1]), dtype=tf.float32) + + # Create upper triangular matrix (including diagonal) + mask_a = tf.linalg.band_part( + _mask, 0, -1 + ) # Upper triangular matrix of 0s and 1s + + # Create diagonal matrix + mask_b = tf.linalg.band_part(_mask, 0, 0) # Diagonal matrix of 0s and 1s + + # Subtract diagonal from upper triangular to get strict upper triangular + _mask = tf.cast(mask_a - mask_b, dtype=tf.float32) + + # Broadcast to [B, H, L, scores.shape[-1]] + _mask_ex = tf.broadcast_to(_mask, [B, H, L, tf.shape(scores)[-1]]) + + # Create indexing tensors + batch_indices = tf.range(B)[:, None, None] + head_indices = tf.range(H)[None, :, None] + + # Extract indicator using advanced indexing + indicator = tf.gather_nd( + _mask_ex, + tf.stack( + [ + tf.broadcast_to(batch_indices, tf.shape(index)), + tf.broadcast_to(head_indices, tf.shape(index)), + index, + ], + axis=-1, + ), + ) + + # Reshape to match scores shape + prob_mask_float = tf.reshape(indicator, tf.shape(scores)) + + # **KEY FIX**: Convert to boolean tensor + prob_mask_bool = tf.cast(prob_mask_float, tf.bool) + + return prob_mask_bool + + def _update_context(self, context_in, V, scores, index, L_Q): + """Update context using Keras-compatible operations.""" + if self.mask_flag: + # Apply simple masking + attn_mask = self._create_prob_mask( + tf.shape(V)[0], tf.shape(V)[1], L_Q, index, scores + ) + + # Apply mask with large negative value + large_neg = -1e9 + mask_value = tf.where(attn_mask, 0.0, large_neg) + scores = scores + mask_value + + # Softmax activation + attn = tf.nn.softmax(scores, axis=-1) + attn = self.dropout(attn) + + # Create indices for scatter update + B, H = tf.shape(V)[0], tf.shape(V)[1] + index_shape = tf.shape(index)[-1] + + batch_indices = tf.tile( + tf.expand_dims(tf.expand_dims(tf.range(B), 1), 2), [1, H, index_shape] + ) + + head_indices = tf.tile( + tf.expand_dims(tf.expand_dims(tf.range(H), 0), 2), [B, 1, index_shape] + ) + + idx = tf.stack([batch_indices, head_indices, index], axis=-1) + + # Matrix multiplication and scatter update + attn_V = tf.matmul(attn, V) + context_updated = tf.tensor_scatter_nd_update(context_in, idx, attn_V) + + return context_updated + + def call(self, inputs, attention_mask=None, training=None): + """Run forward pass with fixed tensor operations.""" + queries, keys, values = inputs + + # Get shapes + # B = tf.shape(queries)[0] + L = tf.shape(queries)[1] # sequence length + # H = tf.shape(queries)[2] # number of heads + D = tf.shape(queries)[3] # dimension per head + S = tf.shape(keys)[1] # source sequence length + + # Reshape tensors - transpose to [B, H, L, D] + queries = tf.transpose(queries, perm=[0, 2, 1, 3]) # [B, H, L, D] + keys = tf.transpose(keys, perm=[0, 2, 1, 3]) # [B, H, S, D] + values = tf.transpose(values, perm=[0, 2, 1, 3]) # [B, H, S, D] + + # Calculate sampling parameters with bounds checking + # Use tf.py_function to handle numpy operations safely + def safe_log_calc(seq_len, factor): + if hasattr(seq_len, "numpy"): + return int(factor * np.ceil(np.log(max(seq_len.numpy(), 2)))) + else: + return int(factor * np.ceil(np.log(20))) # fallback + + U = tf.py_function( + func=lambda: safe_log_calc(S, self.factor), inp=[], Tout=tf.int32 + ) + + u = tf.py_function( + func=lambda: safe_log_calc(L, self.factor), inp=[], Tout=tf.int32 + ) + + # Ensure U and u are within reasonable bounds + U = tf.minimum(U, S) # Can't select more than available + u = tf.minimum(u, L) + + # Probabilistic QK computation + scores_top, index = self._prob_QK(queries, keys, u, U) + + # Apply scale factor + scale = self.scale or (1.0 / tf.sqrt(tf.cast(D, tf.float32))) + scores_top = scores_top * scale + + # Get initial context + context = self._get_initial_context(values, L) + + # Update context with selected queries + context = self._update_context(context, values, scores_top, index, L) + + # Transpose back to original format [B, L, H, D] + context = tf.transpose(context, perm=[0, 2, 1, 3]) + + return context + + def get_config(self): + """Return the config of the layer.""" + config = super().get_config() + config.update( + { + "mask_flag": self.mask_flag, + "factor": self.factor, + "scale": self.scale, + "attention_dropout": self.attention_dropout, + } + ) + return config + + @classmethod + def from_config(cls, config): + """Create layer from config.""" + return cls(**config) + + @tf.keras.utils.register_keras_serializable(package="aeon") + class AttentionLayer(Layer): + """Keras multi-head attention layer using a custom attention mechanism.""" + + def __init__( + self, attention, d_model, n_heads, d_keys=None, d_values=None, **kwargs + ): + super().__init__(**kwargs) + self.d_keys = d_keys or (d_model // n_heads) + self.d_values = d_values or (d_model // n_heads) + self.d_model = d_model + self.n_heads = n_heads + + # Store the attention mechanism + self.inner_attention = attention + + # Projection layers + self.query_projection = tf.keras.layers.Dense( + self.d_keys * n_heads, name="query_proj" + ) + + self.key_projection = tf.keras.layers.Dense( + self.d_keys * n_heads, name="key_proj" + ) + + self.value_projection = tf.keras.layers.Dense( + self.d_values * n_heads, name="value_proj" + ) + + self.out_projection = tf.keras.layers.Dense(d_model, name="output_proj") + + def build(self, input_shape): + """Build the layer.""" + # Build the projection layers + super().build(input_shape) + + def compute_output_shape(self, input_shape): + """Compute output shape for the layer.""" + batch_size, seq_length, _ = input_shape[0] + return (batch_size, seq_length, self.d_model) + + def call(self, inputs, attn_mask=None, training=None): + """Run forward pass for the attention layer.""" + queries, keys, values = inputs + + # Get batch size and sequence lengths dynamically + B = tf.shape(queries)[0] + L = tf.shape(queries)[1] # target sequence length + S = tf.shape(keys)[1] # source sequence length + H = self.n_heads + + # Apply projections + queries_proj = self.query_projection(queries) # [B, L, d_keys * n_heads] + keys_proj = self.key_projection(keys) # [B, S, d_keys * n_heads] + values_proj = self.value_projection(values) # [B, S, d_values * n_heads] + + # Reshape to multi-head format: [B, L/S, H, d_keys/d_values] + queries_reshaped = tf.reshape(queries_proj, (B, L, H, self.d_keys)) + keys_reshaped = tf.reshape(keys_proj, (B, S, H, self.d_keys)) + values_reshaped = tf.reshape(values_proj, (B, S, H, self.d_values)) + + # Apply inner attention mechanism + attention_output = self.inner_attention( + [queries_reshaped, keys_reshaped, values_reshaped], + attention_mask=attn_mask, + training=training, + ) + + # Reshape attention output back to [B, L, H * d_values] + attention_flattened = tf.reshape( + attention_output, (B, L, H * self.d_values) + ) + + # Final output projection + output = self.out_projection(attention_flattened) + + return output + + def get_config(self): + """Return the config of the layer.""" + config = super().get_config() + config.update( + { + "d_model": self.d_model, + "n_heads": self.n_heads, + "d_keys": self.d_keys, + "d_values": self.d_values, + } + ) + return config