From a9dd9a0669c4b659e4e98cc05e32eba2cbc9d5d1 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Wed, 11 Feb 2026 20:15:57 +0100 Subject: [PATCH] Fix WavLM compute_bias to create tensors on model device `compute_bias` creates position tensors on CPU via `torch.arange()` without specifying a device, then transfers them to GPU. This causes a severe performance cliff when the model is on CUDA: for sequences above ~180 frames, encoding time jumps from ~0.1s to ~2.5s due to CPU-GPU synchronization overhead. Fix by passing `device=self.rel_attn_embed.weight.device` to `torch.arange()`, so all computation stays on the model's device. This also removes a now-redundant `.to()` call. Benchmarked on RTX 4090 with WavLM Base+ (used by Kanade tokenizer): - Before: 3.4s audio=0.12s, 3.8s audio=2.50s, 10s audio=2.21s - After: 3.4s audio=0.03s, 3.8s audio=0.02s, 10s audio=0.02s Co-Authored-By: Claude Opus 4.6 --- src/torchaudio/models/wav2vec2/wavlm_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchaudio/models/wav2vec2/wavlm_attention.py b/src/torchaudio/models/wav2vec2/wavlm_attention.py index fafddfeb95..fc60d43972 100644 --- a/src/torchaudio/models/wav2vec2/wavlm_attention.py +++ b/src/torchaudio/models/wav2vec2/wavlm_attention.py @@ -90,11 +90,11 @@ def compute_bias(self, query_length: int, key_length: int) -> Tensor: Returns: Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings """ - context_position = torch.arange(query_length, dtype=torch.long)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + device = self.rel_attn_embed.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # Shape (query_length, key_length) relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) - relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]) return values