Skip to content

Commit 56c6b43

Browse files
committed
fix transpose on attn mask
1 parent 358de26 commit 56c6b43

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

i6_models/parts/conformer/mhsa_rel_pos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
224224
q_with_bias_u.transpose(-3, -2), # [B, #heads, T, F']
225225
k.transpose(-3, -2), # [B, #heads, T', F']
226226
v.transpose(-3, -2), # [B, #heads, T, F']
227-
attn_mask=attn_bd_mask_scaled.transpose(-3, -2), # [B, #heads, T, T']
227+
attn_mask=attn_bd_mask_scaled, # [B, #heads, T, T']
228228
dropout_p=self.att_weights_dropout.p if self.training else 0.0,
229229
scale=scale,
230230
) # [B, #heads, T, F']

0 commit comments

Comments
 (0)