From 19c463e0b6e675a4a8ffc96d9c07d54e15dc9e7c Mon Sep 17 00:00:00 2001 From: Gustavo Korzune Gurgel Date: Fri, 28 Jun 2024 02:31:19 +0200 Subject: [PATCH] Refactor conv1d call to ensure consistent state handling - Switched the assignment of `x_conv` and `conv_state` based on `return_last_state` condition - Ensured `conv_state` is assigned when `return_last_state` is True Signed-off-by: Gustavo Korzune Gurgel --- xlstm/blocks/slstm/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xlstm/blocks/slstm/layer.py b/xlstm/blocks/slstm/layer.py index 368d263..a6fcd16 100644 --- a/xlstm/blocks/slstm/layer.py +++ b/xlstm/blocks/slstm/layer.py @@ -132,9 +132,9 @@ def forward( if self.config.conv1d_kernel_size > 0: if return_last_state: - x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state) + x_conv, conv_state = self.conv1d(x, conv_state, return_last_state=return_last_state) else: - x_conv, conv_state = self.conv1d( + x_conv = self.conv1d( x, conv_state, return_last_state=return_last_state ) x_conv = self.conv_act_fn(x_conv)