diff --git a/xlstm/blocks/slstm/layer.py b/xlstm/blocks/slstm/layer.py index 368d263..a6fcd16 100644 --- a/xlstm/blocks/slstm/layer.py +++ b/xlstm/blocks/slstm/layer.py @@ -132,9 +132,9 @@ def forward( if self.config.conv1d_kernel_size > 0: if return_last_state: - x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state) + x_conv, conv_state = self.conv1d(x, conv_state, return_last_state=return_last_state) else: - x_conv, conv_state = self.conv1d( + x_conv = self.conv1d( x, conv_state, return_last_state=return_last_state ) x_conv = self.conv_act_fn(x_conv)