diff --git a/src/torchaudio/models/wav2vec2/wavlm_attention.py b/src/torchaudio/models/wav2vec2/wavlm_attention.py index fafddfeb95..fc60d43972 100644 --- a/src/torchaudio/models/wav2vec2/wavlm_attention.py +++ b/src/torchaudio/models/wav2vec2/wavlm_attention.py @@ -90,11 +90,11 @@ def compute_bias(self, query_length: int, key_length: int) -> Tensor: Returns: Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings """ - context_position = torch.arange(query_length, dtype=torch.long)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + device = self.rel_attn_embed.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # Shape (query_length, key_length) relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) - relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]) return values