diff --git a/fairseq/models/wav2vec/utils.py b/fairseq/models/wav2vec/utils.py index dd52d86242..6845c17ba8 100644 --- a/fairseq/models/wav2vec/utils.py +++ b/fairseq/models/wav2vec/utils.py @@ -5,6 +5,7 @@ import math import torch.nn.functional as F +import torch def pad_to_multiple(x, multiple, dim=-1, value=0): @@ -14,7 +15,9 @@ def pad_to_multiple(x, multiple, dim=-1, value=0): tsz = x.size(dim) m = tsz / multiple remainder = math.ceil(m) * multiple - tsz - if m.is_integer(): + if isinstance(m, torch.Tensor): + m = m.item() # Convert tensor to Python float + if isinstance(m, float) and m.is_integer(): return x, 0 pad_offset = (0,) * (-1 - dim) * 2