Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions acestep/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
import numpy as np
import torchaudio
import soundfile as sf
from loguru import logger


Expand Down Expand Up @@ -159,12 +160,15 @@ def _save_mp3(
temp_wav_path = Path(temp_wav.name)

try:
torchaudio.save(
str(temp_wav_path),
tensor_to_save,
int(target_sample_rate),
channels_first=True,
backend='soundfile',
# Write WAV directly via soundfile to avoid torchaudio's torchcodec
# dispatch, which fails on environments with incompatible FFmpeg
# shared libraries (e.g. Colab with PyTorch 2.10+cu128).
audio_np = tensor_to_save.numpy()
if audio_np.ndim == 2:
audio_np = audio_np.T # (C, N) -> (N, C) for soundfile
sf.write(
str(temp_wav_path), audio_np, int(target_sample_rate),
subtype="PCM_16",
)
cmd = [
'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error',
Expand Down
34 changes: 26 additions & 8 deletions acestep/audio_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def test__save_mp3_uses_default_settings_when_not_overridden(self):
output_path = Path(self.temp_dir) / "test.mp3"

with (
patch('acestep.audio_utils.torchaudio.save') as mock_torchaudio_save,
patch('acestep.audio_utils.sf.write') as mock_sf_write,
patch('acestep.audio_utils.subprocess.run') as mock_subprocess_run,
):
saver._save_mp3(self.sample_audio, output_path, self.sample_rate)

mock_torchaudio_save.assert_called_once()
save_args = mock_torchaudio_save.call_args[0]
self.assertEqual(save_args[2], 48000)
mock_sf_write.assert_called_once()
sf_args = mock_sf_write.call_args
self.assertEqual(sf_args[0][2], 48000) # sample_rate arg

cmd = mock_subprocess_run.call_args[0][0]
self.assertIn('libmp3lame', cmd)
Expand All @@ -153,7 +153,7 @@ def test__save_mp3_uses_custom_bitrate_and_sample_rate(self):

with (
patch('acestep.audio_utils.torchaudio.functional.resample', return_value=self.sample_audio) as mock_resample,
patch('acestep.audio_utils.torchaudio.save') as mock_torchaudio_save,
patch('acestep.audio_utils.sf.write') as mock_sf_write,
patch('acestep.audio_utils.subprocess.run') as mock_subprocess_run,
):
saver._save_mp3(
Expand All @@ -165,14 +165,32 @@ def test__save_mp3_uses_custom_bitrate_and_sample_rate(self):
)

mock_resample.assert_called_once_with(self.sample_audio, 48000, 44100)
mock_torchaudio_save.assert_called_once()
save_args = mock_torchaudio_save.call_args[0]
self.assertEqual(save_args[2], 44100)
mock_sf_write.assert_called_once()
sf_args = mock_sf_write.call_args
self.assertEqual(sf_args[0][2], 44100) # sample_rate arg

cmd = mock_subprocess_run.call_args[0][0]
self.assertIn('320k', cmd)
self.assertIn('44100', cmd)

def test__save_mp3_does_not_call_torchaudio_save(self):
"""Regression: _save_mp3 must not call torchaudio.save (torchcodec bypass).

On PyTorch 2.10+, torchaudio.save() unconditionally routes through
torchcodec even with backend='soundfile', crashing on environments
with incompatible FFmpeg shared libraries (e.g. Colab).
"""
saver = AudioSaver()
output_path = Path(self.temp_dir) / "test_no_torchcodec.mp3"

with (
patch('acestep.audio_utils.sf.write'),
patch('acestep.audio_utils.subprocess.run'),
patch('acestep.audio_utils.torchaudio.save') as mock_ta_save,
):
saver._save_mp3(self.sample_audio, output_path, self.sample_rate)
mock_ta_save.assert_not_called()

def test_save_audio_opus_uses_ffmpeg_backend(self):
"""Opus format should use ffmpeg backend like MP3."""
saver = AudioSaver()
Expand Down