diff --git a/xlstm/blocks/mlstm/layer.py b/xlstm/blocks/mlstm/layer.py index 0f828ec..279d944 100644 --- a/xlstm/blocks/mlstm/layer.py +++ b/xlstm/blocks/mlstm/layer.py @@ -72,13 +72,15 @@ def __init__(self, config: mLSTMLayerConfig): ) ) - self.conv1d = CausalConv1d( - config=CausalConv1dConfig( - feature_dim=self.config._inner_embedding_dim, - kernel_size=self.config.conv1d_kernel_size, + if self.config.conv1d_kernel_size > 0: + self.conv1d = CausalConv1d( + config=CausalConv1dConfig( + feature_dim=self.config._inner_embedding_dim, + kernel_size=self.config.conv1d_kernel_size, + ) ) - ) - self.conv_act_fn = nn.SiLU() + self.conv_act_fn = nn.SiLU() + self.mlstm_cell = mLSTMCell( config=mLSTMCellConfig( context_length=self.config.context_length, @@ -106,8 +108,11 @@ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1) # mlstm branch - x_mlstm_conv = self.conv1d(x_mlstm) - x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv) + if self.config.conv1d_kernel_size > 0: + x_mlstm_conv = self.conv1d(x_mlstm) + x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv) + else: + x_mlstm_conv_act = x_mlstm q = self.q_proj(x_mlstm_conv_act) k = self.k_proj(x_mlstm_conv_act) @@ -137,8 +142,12 @@ def step( x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1) # mlstm branch - x_mlstm_conv, conv_state = self.conv1d.step(x_mlstm, conv_state=conv_state) - x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv) + if self.config.conv1d_kernel_size > 0: + x_mlstm_conv, conv_state = self.conv1d.step(x_mlstm, conv_state=conv_state) + x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv) + else: # no conv + x_mlstm_conv_act = x_mlstm + conv_state = None q = self.q_proj(x_mlstm_conv_act) k = self.k_proj(x_mlstm_conv_act) diff --git a/xlstm/components/conv.py b/xlstm/components/conv.py index 5144fd2..0e95a27 100644 --- a/xlstm/components/conv.py +++ b/xlstm/components/conv.py @@ -90,7 +90,8 @@ def __init__(self, config: CausalConv1dConfig): **self.config.conv1d_kwargs, ) # B, C, L - self.reset_parameters() + if self.config.kernel_size == 0: + self.reset_parameters() def reset_parameters(self, **kwargs): self.conv.reset_parameters()