Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/kanade_tokenizer/module/ssl_extractor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
import math

import torch
import torch.nn as nn
import torchaudio
import torchaudio.pipelines as pipelines
from torch import Tensor
from torchaudio.models.wav2vec2 import Wav2Vec2Model
from torchaudio.models.wav2vec2.components import ConvLayerBlock
from torchaudio.models.wav2vec2.wavlm_attention import WavLMSelfAttention

from ..util import get_logger

logger = get_logger()


def _patch_wavlm_compute_bias():
"""Monkey-patch WavLMSelfAttention.compute_bias to create tensors on the
model device instead of CPU. The upstream implementation calls
``torch.arange()`` without a ``device`` argument, which triggers an
expensive CPU-to-CUDA synchronisation for sequences longer than ~180
frames (~3.6 s of audio), causing a ~100x slowdown.

Upstream fix: https://github.com/pytorch/audio/pull/4176
"""

def _compute_bias(self, query_length: int, key_length: int) -> Tensor:
device = self.rel_attn_embed.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
values = self.rel_attn_embed(relative_position_bucket)
values = values.permute([2, 0, 1])
return values

WavLMSelfAttention.compute_bias = _compute_bias
logger.debug("Patched WavLMSelfAttention.compute_bias for device-aware tensor creation.")


_patch_wavlm_compute_bias()


# Map of friendly names to torchaudio pipeline bundles
MODEL_REGISTRY = {
"wav2vec2_base": pipelines.WAV2VEC2_BASE,
Expand Down