diff --git a/src/nnssl/architectures/evaMAE_module.py b/src/nnssl/architectures/evaMAE_module.py index 9040b130..93356731 100644 --- a/src/nnssl/architectures/evaMAE_module.py +++ b/src/nnssl/architectures/evaMAE_module.py @@ -160,7 +160,7 @@ def forward(self, x): decoded = encoded # Project back to output shape - decoded = rearrange(decoded, 'b (h w d) c -> b c w h d', h=W, w=H, d=D) + decoded = rearrange(decoded, 'b (h w d) c -> b c w h d', w=W, h=H, d=D) decoded = self.up_projection(decoded) if self.use_decoder: