Skip to content

Commit bf2cb7e

Browse files
committed
修复不同多音频格式引发的兼容性问题。
1 parent f09c6c7 commit bf2cb7e

5 files changed

Lines changed: 86 additions & 27 deletions

File tree

main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ def main():
2626

2727
if __name__ == '__main__':
2828
sys.exit(main())
29+

tasks/audio_generation_task.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pathlib import Path
55
import shutil
66
import wave
7+
from mutagen import File as MutagenFile
8+
from mutagen.mp3 import MP3
79

810
from PySide6.QtCore import QThread, Signal
911

@@ -37,35 +39,58 @@ def _safe_copy(src: Path, dst: Path):
3739

3840
def _save_one_note_wav(self, index, note_dict, generation_profile):
3941
"""保存单条讲稿并返回索引、音频时长和是否命中缓存的标志"""
40-
path = self.reviewer_page.wav_temp_path / f'{note_dict["page"]}_{index + 1}.wav'
42+
output_ext = self.tts_engine.get_output_extension()
43+
path = self.reviewer_page.wav_temp_path / f'{note_dict["page"]}_{index + 1}.{output_ext}'
4144
cache_key = self.tts_engine.build_audio_cache_key(note_dict['text'], generation_profile)
42-
cache_path = self.audio_cache_path / f'{cache_key}.wav'
45+
cache_path = self.audio_cache_path / f'{cache_key}.{output_ext}'
4346

4447
cache_hit = False
4548
if cache_path.exists() and cache_path.stat().st_size > 0:
4649
self._safe_copy(cache_path, path)
4750
cache_hit = True
4851
else:
49-
temp_path = self.audio_cache_path / f'{cache_key}.{index}.tmp.wav'
52+
temp_path = self.audio_cache_path / f'{cache_key}.{index}.tmp.{output_ext}'
5053
self.tts_engine.save_file(note_dict['text'], str(temp_path))
5154
temp_path.replace(cache_path)
5255
self._safe_copy(cache_path, path)
5356

54-
duration = self.get_wav_duration(path)
55-
return index, duration, cache_key, cache_hit
57+
duration = self.get_audio_duration(path)
58+
return index, duration, cache_key, output_ext, cache_hit
5659

5760
@staticmethod
58-
def get_wav_duration(path: Path) -> float:
59-
"""使用 wave 模块读取 wav 时长"""
60-
with wave.open(str(path), 'rb') as wav_file:
61-
return wav_file.getnframes() / float(wav_file.getframerate())
61+
def get_audio_duration(path: Path) -> float:
62+
"""读取音频时长(支持 wav/mp3)"""
63+
suffix = path.suffix.lower()
64+
if suffix == '.wav':
65+
try:
66+
with wave.open(str(path), 'rb') as wav_file:
67+
return wav_file.getnframes() / float(wav_file.getframerate())
68+
except Exception:
69+
pass
70+
71+
if suffix == '.mp3':
72+
audio = MP3(str(path))
73+
if getattr(audio, 'info', None):
74+
length = float(getattr(audio.info, 'length', 0.0))
75+
if length > 0:
76+
return length
77+
78+
# 对其他格式尝试 mutagen
79+
audio = MutagenFile(str(path))
80+
if getattr(audio, 'info', None):
81+
length = float(getattr(audio.info, 'length', 0.0))
82+
if length > 0:
83+
return length
84+
85+
raise RuntimeError(f'无法读取音频时长:{path.name}')
6286

6387
def save_wav(self):
6488
"""调用 TTS 保存文字为 wav"""
6589
notes_list = self.reviewer_page.notes_list
6690
total = len(notes_list)
6791
info_list = [0.0] * total
6892
cache_key_list = [''] * total
93+
cache_ext_list = [''] * total
6994
cache_hit_count = 0
7095
generation_profile = self.tts_engine.get_generation_profile()
7196

@@ -80,25 +105,28 @@ def save_wav(self):
80105

81106
for future in as_completed(future_map):
82107
index = future_map[future]
83-
result_index, duration, cache_key, cache_hit = future.result()
108+
result_index, duration, cache_key, cache_ext, cache_hit = future.result()
84109
info_list[result_index] = duration
85110
cache_key_list[result_index] = cache_key
111+
cache_ext_list[result_index] = cache_ext
86112
if cache_hit:
87113
cache_hit_count += 1
88114

89115
completed += 1
90116
self.signal_import_index.emit(completed)
91117
else:
92118
for index, note_dict in enumerate(notes_list):
93-
result_index, duration, cache_key, cache_hit = self._save_one_note_wav(index, note_dict, generation_profile)
119+
result_index, duration, cache_key, cache_ext, cache_hit = self._save_one_note_wav(index, note_dict, generation_profile)
94120
info_list[result_index] = duration
95121
cache_key_list[result_index] = cache_key
122+
cache_ext_list[result_index] = cache_ext
96123
if cache_hit:
97124
cache_hit_count += 1
98125
self.signal_import_index.emit(index + 1)
99126

100127
self.reviewer_page.notes_duration_list = info_list
101128
self.reviewer_page.note_cache_keys = cache_key_list
129+
self.reviewer_page.note_cache_exts = cache_ext_list
102130
self.signal_cache_hit_count.emit(cache_hit_count)
103131

104132
def save_countdown_wav(self):

tts_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ def get_generation_profile(self):
191191
'voice_index': int(self._voice_index_map.get(mode, 0)),
192192
}
193193

194+
def get_output_extension(self, mode: Optional[str] = None) -> str:
195+
"""返回指定引擎的默认输出后缀(不带点)"""
196+
target_mode = mode or self.get_mode()
197+
if target_mode == 'edge':
198+
return 'mp3'
199+
return 'wav'
200+
194201
@staticmethod
195202
def normalize_text_for_cache(text: str) -> str:
196203
"""最小化规整文本,降低空白差异导致的重复生成"""

ui/pages/reviewer_page.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import re
55
from datetime import datetime
6-
import wave
76
from pathlib import Path
87

98
import pyautogui
@@ -63,6 +62,7 @@ def __init__(self, context: AppContext, parent=None):
6362
self.notes_list = [] # 每块讲稿
6463
self.notes_duration_list = []
6564
self.note_cache_keys = []
65+
self.note_cache_exts = []
6666
self.cache_hit_count = 0 # 新增:缓存命中数
6767
self.is_play_notes = False
6868
self.is_import = False
@@ -317,6 +317,7 @@ def clean_and_reset(self):
317317
self.current_index = 0
318318
self.notes_duration_list = []
319319
self.note_cache_keys = []
320+
self.note_cache_exts = []
320321
self.cache_hit_count = 0
321322
self.is_import = False
322323
self.check_import()
@@ -327,8 +328,7 @@ def refresh_notes_duration_list(self):
327328
self.load_audio_files()
328329
duration_list = []
329330
for path in self.media_list:
330-
with wave.open(str(path), 'rb') as wav_file:
331-
duration = wav_file.getnframes() / float(wav_file.getframerate())
331+
duration = AudioGenerationTask.get_audio_duration(path)
332332
duration_list.append(duration)
333333
self.notes_duration_list = duration_list
334334

@@ -397,13 +397,14 @@ def mark_split(self):
397397

398398
@staticmethod
399399
def clean_temp_folder(path: Path):
400-
"""清理缓存 wav"""
401-
for file_path in path.glob('*.wav'):
402-
try:
403-
file_path.unlink()
404-
print(f'已清理 {file_path.name}')
405-
except Exception as e:
406-
print(f'清理文件失败: {file_path.name}, 原因: {e}')
400+
"""清理临时音频(wav/mp3)"""
401+
for pattern in ('*.wav', '*.mp3'):
402+
for file_path in path.glob(pattern):
403+
try:
404+
file_path.unlink()
405+
print(f'已清理 {file_path.name}')
406+
except Exception as e:
407+
print(f'清理文件失败: {file_path.name}, 原因: {e}')
407408
print('转换完成')
408409

409410
def thread_print_index(self, import_index):
@@ -449,6 +450,10 @@ def save_session_record(self):
449450
for item in self.notes_list
450451
]
451452

453+
if len(self.note_cache_exts) != len(self.notes_list):
454+
output_ext = self.ctx.tts_engine.get_output_extension()
455+
self.note_cache_exts = [output_ext] * len(self.notes_list)
456+
452457
durations = self.notes_duration_list[:]
453458
if len(durations) != len(self.notes_list):
454459
durations = [0.0] * len(self.notes_list)
@@ -461,6 +466,7 @@ def save_session_record(self):
461466
'text': note['text'],
462467
'duration': float(durations[index]),
463468
'cache_key': self.note_cache_keys[index],
469+
'cache_ext': self.note_cache_exts[index],
464470
})
465471

466472
now = datetime.now()
@@ -531,6 +537,7 @@ def load_session_record(self, record_path: Path):
531537
notes_list = []
532538
duration_list = []
533539
cache_keys = []
540+
cache_exts = []
534541

535542
for idx, item in enumerate(items):
536543
page = int(item.get('page', 0))
@@ -540,16 +547,29 @@ def load_session_record(self, record_path: Path):
540547
profile = record.get('generation_profile', {})
541548
cache_key = self.ctx.tts_engine.build_audio_cache_key(text, profile)
542549

543-
cache_path = self.audio_cache_path / f'{cache_key}.wav'
544-
545-
if not cache_path.exists() or cache_path.stat().st_size <= 0:
550+
cache_ext = str(item.get('cache_ext', '')).strip().lower().lstrip('.')
551+
ext_candidates = [cache_ext] if cache_ext else []
552+
for ext in ('wav', 'mp3'):
553+
if ext not in ext_candidates:
554+
ext_candidates.append(ext)
555+
556+
cache_path = None
557+
for ext in ext_candidates:
558+
candidate = self.audio_cache_path / f'{cache_key}.{ext}'
559+
if candidate.exists() and candidate.stat().st_size > 0:
560+
cache_path = candidate
561+
cache_ext = ext
562+
break
563+
564+
if cache_path is None:
546565
missing_list.append(f'第{page}页-第{idx + 1}条')
547566
continue
548567

549568
media_list.append(cache_path)
550569
notes_list.append({'page': page, 'text': text})
551570
duration_list.append(float(item.get('duration', 0.0)))
552571
cache_keys.append(cache_key)
572+
cache_exts.append(cache_ext)
553573

554574
if missing_list:
555575
missing_text = '、'.join(missing_list[:10])
@@ -572,6 +592,7 @@ def load_session_record(self, record_path: Path):
572592
self.notes_list = notes_list
573593
self.notes_duration_list = duration_list
574594
self.note_cache_keys = cache_keys
595+
self.note_cache_exts = cache_exts
575596

576597
self.media_list = media_list
577598
self.current_index = 0
@@ -608,9 +629,10 @@ def play_notes(self):
608629
self.play_audio()
609630

610631
def load_audio_files(self):
611-
"""查找所有 wav,添加到 media_list 中"""
632+
"""查找所有正文音频(wav/mp3),添加到 media_list 中"""
633+
audio_files = list(self.wav_temp_path.glob('*.wav')) + list(self.wav_temp_path.glob('*.mp3'))
612634
audio_files = sorted(
613-
self.wav_temp_path.glob('*.wav'),
635+
audio_files,
614636
key=lambda path: [int(part) if part.isdigit() else part for part in path.stem.split('_')]
615637
)
616638
self.media_list = audio_files

ui/pages/settings_page.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def preview_audio(self):
125125
text = '这是一个试听音频,用于测试当前的语音设置'
126126
temp_dir = Path('./temp').resolve()
127127
temp_dir.mkdir(parents=True, exist_ok=True)
128-
preview_path = temp_dir / f'preview_{int(time.time())}.wav'
128+
preview_ext = self.ctx.tts_engine.get_output_extension()
129+
preview_path = temp_dir / f'preview_{int(time.time())}.{preview_ext}'
129130

130131
self.previewButton.setEnabled(False)
131132
self.previewButton.setText('正在生成...')

0 commit comments

Comments
 (0)