diff --git a/acestep/audio_utils.py b/acestep/audio_utils.py index 0cdc607ad..281a4b566 100644 --- a/acestep/audio_utils.py +++ b/acestep/audio_utils.py @@ -19,6 +19,7 @@ import torch import numpy as np import torchaudio +import soundfile as sf from loguru import logger @@ -159,12 +160,15 @@ def _save_mp3( temp_wav_path = Path(temp_wav.name) try: - torchaudio.save( - str(temp_wav_path), - tensor_to_save, - int(target_sample_rate), - channels_first=True, - backend='soundfile', + # Write WAV directly via soundfile to avoid torchaudio's torchcodec + # dispatch, which fails on environments with incompatible FFmpeg + # shared libraries (e.g. Colab with PyTorch 2.10+cu128). + audio_np = tensor_to_save.numpy() + if audio_np.ndim == 2: + audio_np = audio_np.T # (C, N) -> (N, C) for soundfile + sf.write( + str(temp_wav_path), audio_np, int(target_sample_rate), + subtype="PCM_16", ) cmd = [ 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error', diff --git a/acestep/audio_utils_test.py b/acestep/audio_utils_test.py index 627e4358e..84799e657 100644 --- a/acestep/audio_utils_test.py +++ b/acestep/audio_utils_test.py @@ -131,14 +131,14 @@ def test__save_mp3_uses_default_settings_when_not_overridden(self): output_path = Path(self.temp_dir) / "test.mp3" with ( - patch('acestep.audio_utils.torchaudio.save') as mock_torchaudio_save, + patch('acestep.audio_utils.sf.write') as mock_sf_write, patch('acestep.audio_utils.subprocess.run') as mock_subprocess_run, ): saver._save_mp3(self.sample_audio, output_path, self.sample_rate) - mock_torchaudio_save.assert_called_once() - save_args = mock_torchaudio_save.call_args[0] - self.assertEqual(save_args[2], 48000) + mock_sf_write.assert_called_once() + sf_args = mock_sf_write.call_args + self.assertEqual(sf_args[0][2], 48000) # sample_rate arg cmd = mock_subprocess_run.call_args[0][0] self.assertIn('libmp3lame', cmd) @@ -153,7 +153,7 @@ def test__save_mp3_uses_custom_bitrate_and_sample_rate(self): with ( patch('acestep.audio_utils.torchaudio.functional.resample', return_value=self.sample_audio) as mock_resample, - patch('acestep.audio_utils.torchaudio.save') as mock_torchaudio_save, + patch('acestep.audio_utils.sf.write') as mock_sf_write, patch('acestep.audio_utils.subprocess.run') as mock_subprocess_run, ): saver._save_mp3( @@ -165,14 +165,32 @@ def test__save_mp3_uses_custom_bitrate_and_sample_rate(self): ) mock_resample.assert_called_once_with(self.sample_audio, 48000, 44100) - mock_torchaudio_save.assert_called_once() - save_args = mock_torchaudio_save.call_args[0] - self.assertEqual(save_args[2], 44100) + mock_sf_write.assert_called_once() + sf_args = mock_sf_write.call_args + self.assertEqual(sf_args[0][2], 44100) # sample_rate arg cmd = mock_subprocess_run.call_args[0][0] self.assertIn('320k', cmd) self.assertIn('44100', cmd) + def test__save_mp3_does_not_call_torchaudio_save(self): + """Regression: _save_mp3 must not call torchaudio.save (torchcodec bypass). + + On PyTorch 2.10+, torchaudio.save() unconditionally routes through + torchcodec even with backend='soundfile', crashing on environments + with incompatible FFmpeg shared libraries (e.g. Colab). + """ + saver = AudioSaver() + output_path = Path(self.temp_dir) / "test_no_torchcodec.mp3" + + with ( + patch('acestep.audio_utils.sf.write'), + patch('acestep.audio_utils.subprocess.run'), + patch('acestep.audio_utils.torchaudio.save') as mock_ta_save, + ): + saver._save_mp3(self.sample_audio, output_path, self.sample_rate) + mock_ta_save.assert_not_called() + def test_save_audio_opus_uses_ffmpeg_backend(self): """Opus format should use ffmpeg backend like MP3.""" saver = AudioSaver()