Skip to content

Commit f3ed312

Browse files
jamiepineclaude
andcommitted
fix(offline): guard inference paths with HF_HUB_OFFLINE (#462)
PR #443 wrapped the model *load* path with `force_offline_if_cached` so cached models don't phone home at startup. The context manager restores `HF_HUB_OFFLINE` on exit, which left inference paths (generate, transcribe, voice-prompt creation) unguarded — and `qwen_tts`, `mlx_audio`, and `transformers` perform lazy tokenizer/processor/config lookups during inference. With internet on, those lookups are near-instant and invisible; with internet off, `requests` hangs on DNS or connect until the network returns. This is exactly what users in #462 describe: model shows "Loaded", internet drops, generation "thinks" forever, internet comes back, generation completes. Chatterbox and LuxTTS don't exhibit this because their engine libs resolve everything through already-cached paths at load time. Fix: wrap each inference-sync body with `force_offline_if_cached(True, ...)`. Since inference only runs after a successful load, weights are known to be on disk, so `is_cached=True` is unconditional. Also adds the load-time guard that was missing from `qwen_custom_voice_backend.py` — CustomVoice previously had no offline protection at all. Paths patched: - PyTorchTTSBackend.create_voice_prompt (create_voice_clone_prompt) - PyTorchTTSBackend.generate (generate_voice_clone) - PyTorchSTTBackend.transcribe (Whisper generate + decoder-prompt-ids) - MLXTTSBackend.generate (mlx_audio generate, all branches) - MLXSTTBackend.transcribe (mlx_audio whisper generate) - QwenCustomVoiceBackend._load_model_sync + generate Does not address the secondary `check_model_inputs() missing 'func'` error reported in the same issue — that's a `transformers` 5.x version-skew bug on the install path, separate concern. Fixes #462. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
1 parent e3f7cd9 commit f3ed312

File tree

3 files changed

+111
-75
lines changed

3 files changed

+111
-75
lines changed

backend/backends/mlx_backend.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ async def generate(
195195

196196
logger.info("Generating audio for text: %s", text)
197197

198+
model_name = f"qwen-tts-{self._current_model_size}"
199+
198200
def _generate_sync():
199201
"""Run synchronous generation in thread pool."""
200202
# MLX generate() returns a generator yielding GenerationResult objects
@@ -220,36 +222,40 @@ def _generate_sync():
220222
logger.warning("Regenerating without voice prompt.")
221223
ref_audio = None
222224

223-
# Check if model supports voice cloning via generate method
224-
# MLX API may support ref_audio parameter directly
225-
try:
226-
# Try with voice cloning parameters if supported
227-
if ref_audio:
228-
# Check if generate accepts ref_audio parameter
229-
import inspect
230-
231-
sig = inspect.signature(self.model.generate)
232-
if "ref_audio" in sig.parameters:
233-
# Generate with voice cloning
234-
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
235-
audio_chunks.append(np.array(result.audio))
236-
sample_rate = result.sample_rate
225+
# Model is loaded → weights are on disk. Force offline so
226+
# lazy tokenizer/config lookups inside mlx_audio don't hang
227+
# when the user is disconnected (issue #462).
228+
with force_offline_if_cached(True, model_name):
229+
# Check if model supports voice cloning via generate method
230+
# MLX API may support ref_audio parameter directly
231+
try:
232+
# Try with voice cloning parameters if supported
233+
if ref_audio:
234+
# Check if generate accepts ref_audio parameter
235+
import inspect
236+
237+
sig = inspect.signature(self.model.generate)
238+
if "ref_audio" in sig.parameters:
239+
# Generate with voice cloning
240+
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
241+
audio_chunks.append(np.array(result.audio))
242+
sample_rate = result.sample_rate
243+
else:
244+
# Fallback: generate without voice cloning
245+
for result in self.model.generate(text, lang_code=lang):
246+
audio_chunks.append(np.array(result.audio))
247+
sample_rate = result.sample_rate
237248
else:
238-
# Fallback: generate without voice cloning
249+
# No voice prompt, generate normally
239250
for result in self.model.generate(text, lang_code=lang):
240251
audio_chunks.append(np.array(result.audio))
241252
sample_rate = result.sample_rate
242-
else:
243-
# No voice prompt, generate normally
253+
except Exception as e:
254+
# If voice cloning fails, try without it
255+
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
244256
for result in self.model.generate(text, lang_code=lang):
245257
audio_chunks.append(np.array(result.audio))
246258
sample_rate = result.sample_rate
247-
except Exception as e:
248-
# If voice cloning fails, try without it
249-
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
250-
for result in self.model.generate(text, lang_code=lang):
251-
audio_chunks.append(np.array(result.audio))
252-
sample_rate = result.sample_rate
253259

254260
# Concatenate all chunks
255261
if audio_chunks:
@@ -343,6 +349,8 @@ async def transcribe(
343349
"""
344350
await self.load_model_async(model_size)
345351

352+
progress_model_name = f"whisper-{self.model_size}"
353+
346354
def _transcribe_sync():
347355
"""Run synchronous transcription in thread pool."""
348356
# MLX Whisper transcription using generate method
@@ -351,7 +359,11 @@ def _transcribe_sync():
351359
if language:
352360
decode_options["language"] = language
353361

354-
result = self.model.generate(str(audio_path), **decode_options)
362+
# Model is loaded → weights are on disk. Force offline so
363+
# lazy tokenizer/config lookups don't hang when the user is
364+
# disconnected (issue #462).
365+
with force_offline_if_cached(True, progress_model_name):
366+
result = self.model.generate(str(audio_path), **decode_options)
355367

356368
# Extract text from result
357369
if isinstance(result, str):

backend/backends/pytorch_backend.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,19 @@ async def create_voice_prompt(
172172
# This shouldn't happen in practice, but handle it
173173
return {"prompt": cached_prompt}, True
174174

175+
model_name = f"qwen-tts-{self._current_model_size}"
176+
175177
def _create_prompt_sync():
176178
"""Run synchronous voice prompt creation in thread pool."""
177-
return self.model.create_voice_clone_prompt(
178-
ref_audio=str(audio_path),
179-
ref_text=reference_text,
180-
x_vector_only_mode=False,
181-
)
179+
# Model is loaded → weights are on disk. Force offline so
180+
# lazy tokenizer/config lookups inside qwen_tts don't hang
181+
# when the user is disconnected (issue #462).
182+
with force_offline_if_cached(True, model_name):
183+
return self.model.create_voice_clone_prompt(
184+
ref_audio=str(audio_path),
185+
ref_text=reference_text,
186+
x_vector_only_mode=False,
187+
)
182188

183189
# Run blocking operation in thread pool
184190
voice_prompt_items = await asyncio.to_thread(_create_prompt_sync)
@@ -221,19 +227,24 @@ async def generate(
221227
# Load model
222228
await self.load_model_async(None)
223229

230+
model_name = f"qwen-tts-{self._current_model_size}"
231+
224232
def _generate_sync():
225233
"""Run synchronous generation in thread pool."""
226234
# Set seed if provided
227235
if seed is not None:
228236
manual_seed(seed, self.device)
229237

230-
# Generate audio - this is the blocking operation
231-
wavs, sample_rate = self.model.generate_voice_clone(
232-
text=text,
233-
voice_clone_prompt=voice_prompt,
234-
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
235-
instruct=instruct,
236-
)
238+
# Model is loaded → weights are on disk. Force offline so
239+
# lazy tokenizer/config lookups inside qwen_tts don't hang
240+
# when the user is disconnected (issue #462).
241+
with force_offline_if_cached(True, model_name):
242+
wavs, sample_rate = self.model.generate_voice_clone(
243+
text=text,
244+
voice_clone_prompt=voice_prompt,
245+
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
246+
instruct=instruct,
247+
)
237248
return wavs[0], sample_rate
238249

239250
# Run blocking inference in thread pool to avoid blocking event loop
@@ -331,40 +342,46 @@ async def transcribe(
331342
"""
332343
await self.load_model_async(model_size)
333344

345+
progress_model_name = f"whisper-{self.model_size}"
346+
334347
def _transcribe_sync():
335348
"""Run synchronous transcription in thread pool."""
336349
# Load audio
337350
audio, sr = load_audio(audio_path, sample_rate=16000)
338351

339-
# Process audio
340-
inputs = self.processor(
341-
audio,
342-
sampling_rate=16000,
343-
return_tensors="pt",
344-
)
345-
inputs = inputs.to(self.device)
346-
347-
# Generate transcription
348-
# If language is provided, force it; otherwise let Whisper auto-detect
349-
generate_kwargs = {}
350-
if language:
351-
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
352-
language=language,
353-
task="transcribe",
352+
# Model is loaded → weights are on disk. Force offline so
353+
# `get_decoder_prompt_ids` and any lazy tokenizer lookups
354+
# don't hang when the user is disconnected (issue #462).
355+
with force_offline_if_cached(True, progress_model_name):
356+
# Process audio
357+
inputs = self.processor(
358+
audio,
359+
sampling_rate=16000,
360+
return_tensors="pt",
354361
)
355-
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
362+
inputs = inputs.to(self.device)
363+
364+
# Generate transcription
365+
# If language is provided, force it; otherwise let Whisper auto-detect
366+
generate_kwargs = {}
367+
if language:
368+
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
369+
language=language,
370+
task="transcribe",
371+
)
372+
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
356373

357-
with torch.no_grad():
358-
predicted_ids = self.model.generate(
359-
inputs["input_features"],
360-
**generate_kwargs,
361-
)
374+
with torch.no_grad():
375+
predicted_ids = self.model.generate(
376+
inputs["input_features"],
377+
**generate_kwargs,
378+
)
362379

363-
# Decode
364-
transcription = self.processor.batch_decode(
365-
predicted_ids,
366-
skip_special_tokens=True,
367-
)[0]
380+
# Decode
381+
transcription = self.processor.batch_decode(
382+
predicted_ids,
383+
skip_special_tokens=True,
384+
)[0]
368385

369386
return transcription.strip()
370387

backend/backends/qwen_custom_voice_backend.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
combine_voice_prompts as _combine_voice_prompts,
2929
model_load_progress,
3030
)
31+
from ..utils.hf_offline_patch import force_offline_if_cached
3132

3233
logger = logging.getLogger(__name__)
3334

@@ -104,18 +105,19 @@ def _load_model_sync(self, model_size: str) -> None:
104105
model_path = self._get_model_path(model_size)
105106
logger.info("Loading Qwen CustomVoice %s on %s...", model_size, self.device)
106107

107-
if self.device == "cpu":
108-
self.model = Qwen3TTSModel.from_pretrained(
109-
model_path,
110-
torch_dtype=torch.float32,
111-
low_cpu_mem_usage=False,
112-
)
113-
else:
114-
self.model = Qwen3TTSModel.from_pretrained(
115-
model_path,
116-
device_map=self.device,
117-
torch_dtype=torch.bfloat16,
118-
)
108+
with force_offline_if_cached(is_cached, model_name):
109+
if self.device == "cpu":
110+
self.model = Qwen3TTSModel.from_pretrained(
111+
model_path,
112+
torch_dtype=torch.float32,
113+
low_cpu_mem_usage=False,
114+
)
115+
else:
116+
self.model = Qwen3TTSModel.from_pretrained(
117+
model_path,
118+
device_map=self.device,
119+
torch_dtype=torch.bfloat16,
120+
)
119121

120122
self._current_model_size = model_size
121123
self.model_size = model_size
@@ -184,6 +186,7 @@ async def generate(
184186
await self.load_model_async(None)
185187

186188
speaker = voice_prompt.get("preset_voice_id") or QWEN_CV_DEFAULT_SPEAKER
189+
model_name = f"qwen-custom-voice-{self._current_model_size}"
187190

188191
def _generate_sync():
189192
if seed is not None:
@@ -203,7 +206,11 @@ def _generate_sync():
203206
if instruct:
204207
kwargs["instruct"] = instruct
205208

206-
wavs, sample_rate = self.model.generate_custom_voice(**kwargs)
209+
# Model is loaded → weights are on disk. Force offline so
210+
# lazy tokenizer/config lookups inside qwen_tts don't hang
211+
# when the user is disconnected (issue #462).
212+
with force_offline_if_cached(True, model_name):
213+
wavs, sample_rate = self.model.generate_custom_voice(**kwargs)
207214
return wavs[0], sample_rate
208215

209216
audio, sample_rate = await asyncio.to_thread(_generate_sync)

0 commit comments

Comments
 (0)