diff --git a/utils.py b/utils.py index 06c2206..6bb8f8f 100644 --- a/utils.py +++ b/utils.py @@ -12,7 +12,7 @@ def calculate_matmul_n_times(n_components, mat_a, mat_b): for i in range(n_components): mat_a_i = mat_a[:, i, :, :].squeeze(-2) - mat_b_i = mat_b[0, i, :, :].squeeze() + mat_b_i = mat_b[0, i, :, :] res[:, i, :, :] = mat_a_i.mm(mat_b_i).unsqueeze(1) return res