Skip to content

Fix WavLM compute_bias performance on CUDA (~100x speedup)#4

Open
mhenrichsen wants to merge 1 commit intofrothywater:mainfrom
mhenrichsen:fix/wavlm-compute-bias-device
Open

Fix WavLM compute_bias performance on CUDA (~100x speedup)#4
mhenrichsen wants to merge 1 commit intofrothywater:mainfrom
mhenrichsen:fix/wavlm-compute-bias-device

Conversation

@mhenrichsen
Copy link
Copy Markdown

Summary

  • Monkey-patch WavLMSelfAttention.compute_bias to create tensors on the model device instead of CPU

Problem

torchaudio's WavLMSelfAttention.compute_bias() calls torch.arange() without a device argument, creating position tensors on CPU then transferring to GPU. This causes a severe performance cliff on CUDA: for sequences above ~180 frames (~3.6s audio), encoding time jumps from ~0.1s to ~2.5s due to CPU-GPU synchronization overhead.

This makes Kanade encoding unexpectedly slow for any audio longer than ~3.6 seconds.

Fix

Apply a module-level monkey-patch in ssl_extractor.py that replaces compute_bias with a version that creates tensors directly on the model device. The patch is applied at import time so all WavLM-based models benefit automatically.

I've also submitted the proper fix upstream to torchaudio: pytorch/audio#4176. This monkey-patch can be removed once that lands in a torchaudio release.

Benchmark

RTX 4090, kanade-25hz-clean, torch.inference_mode(), best of 3:

Audio duration Before (s) After (s) Speedup
3.4s (170 frames) 0.118 0.028 4x
3.8s (190 frames) 2.525 0.021 120x
5.0s (250 frames) 2.312 0.021 110x
8.0s (400 frames) 1.986 0.013 153x
10.0s (500 frames) 2.206 0.018 123x
15.0s (750 frames) 2.174 0.015 145x

🤖 Generated with Claude Code

Monkey-patch WavLMSelfAttention.compute_bias to create tensors
directly on the model device. The upstream torchaudio implementation
creates them on CPU, causing a ~100x slowdown for sequences longer
than ~180 frames (~3.6s audio) due to CPU-GPU synchronization.

Applied at module import time so all WavLM-based models benefit
automatically.

Upstream fix: pytorch/audio#4176

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant