Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 106 additions & 22 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,24 @@ def __init__(self, default_runtime: NanoTTSService) -> None:
self._lock = threading.Lock()
self._cpu_execution_lock = threading.Lock()
self._cpu_runtime: NanoTTSService | None = None
self._cuda_runtimes: dict[str, NanoTTSService] = {}

@staticmethod
def normalize_requested_execution_device(requested: str | None) -> str:
normalized = str(requested or "default").strip().lower()
if normalized not in {"default", "cpu"}:
return "default"
return normalized
if normalized in {"default", "cpu"}:
return normalized
# 允许 "cuda" 或 "cuda:N" 格式
if normalized == "cuda" or normalized.startswith("cuda:"):
parts = normalized.split(":", 1)
if len(parts) == 2:
try:
int(parts[1])
return normalized
except ValueError:
return "default"
return "cuda"
return "default"

def is_dedicated_cpu_request(self, requested: str | None) -> bool:
normalized = self.normalize_requested_execution_device(requested)
Expand All @@ -282,14 +293,37 @@ def _build_cpu_runtime_locked(self) -> NanoTTSService:
)
return self._cpu_runtime

def _build_cuda_runtime_locked(self, device: str) -> NanoTTSService:
if device in self._cuda_runtimes:
return self._cuda_runtimes[device]
self._cuda_runtimes[device] = NanoTTSService(
checkpoint_path=self.default_runtime.checkpoint_path,
audio_tokenizer_path=self.default_runtime.audio_tokenizer_path,
device=device,
dtype=self.default_runtime.dtype or "auto",
attn_implementation=self.default_runtime.attn_implementation or "auto",
output_dir=self.default_runtime.output_dir,
voice_presets=self.default_runtime.voice_presets,
)
return self._cuda_runtimes[device]

def resolve_runtime(self, requested: str | None) -> tuple[NanoTTSService, str]:
normalized = self.normalize_requested_execution_device(requested)
if normalized != "cpu":
if normalized == "default":
return self.default_runtime, str(self.default_runtime.device.type)
if self.default_runtime.device.type == "cpu":
return self.default_runtime, "cpu"
with self._lock:
return self._build_cpu_runtime_locked(), "cpu"
if normalized == "cpu":
if self.default_runtime.device.type == "cpu":
return self.default_runtime, "cpu"
with self._lock:
return self._build_cpu_runtime_locked(), "cpu"
# 请求 CUDA 设备
if normalized.startswith("cuda"):
target_device = normalized
if self.default_runtime.device.type == "cuda":
return self.default_runtime, str(self.default_runtime.device)
with self._lock:
return self._build_cuda_runtime_locked(target_device), target_device
return self.default_runtime, str(self.default_runtime.device.type)

def _resolve_cpu_threads(self, cpu_threads: int | None) -> int:
if cpu_threads is None:
Expand Down Expand Up @@ -657,13 +691,28 @@ async def _persist_uploaded_prompt_audio(upload: UploadFile | None) -> tuple[str
return temp_path, _format_uploaded_prompt_display_name(original_filename)


def _build_cuda_options_html(cuda_available: bool, runtime_device: str) -> str:
if not cuda_available:
return ""
import torch
parts = []
for i in range(torch.cuda.device_count()):
device_name = f"cuda:{i}"
label = f"CUDA:{i}"
if runtime_device == device_name:
label += " (runtime)"
parts.append(f' <option value="{device_name}">{label}</option>')
return "\n".join(parts)


def _render_index_html(
*,
request: Request,
runtime: NanoTTSService,
demo_entries: list[DemoEntry],
warmup_status: str,
text_normalization_status: str,
cuda_available: bool = False,
) -> str:
base_path = request.scope.get("root_path", "").rstrip("/")
template = """
Expand Down Expand Up @@ -1112,12 +1161,22 @@ def _render_index_html(
Buffered generation keeps chunk order and decodes codec sub-batches no larger than the current TTS batch.
Realtime Streaming Decode keeps output order and uses the smallest active chunk-group width among auto batching, Max TTS Batch Size, and Max Codec Batch Size.
</div>
<div class="field">
<label for="cpu-thread-count">CPU Threads</label>
<input id="cpu-thread-count" type="number" min="1" step="1" value="4">
<div class="row">
<div class="field">
<label for="execution-device">Device</label>
<select id="execution-device">
<option value="default" selected>Default (__RUNTIME_DEVICE__)</option>
<option value="cpu">CPU</option>
__CUDA_OPTIONS__
</select>
</div>
<div class="field">
<label for="cpu-thread-count">CPU Threads</label>
<input id="cpu-thread-count" type="number" min="1" step="1" value="4">
</div>
</div>
<div class="meta">
This app is CPU-only. CPU Threads maps to torch.set_num_threads for that request.
Select inference device. Default uses the runtime device (__RUNTIME_DEVICE__). CPU Threads maps to torch.set_num_threads for CPU requests.
</div>
<div class="row">
<div class="field">
Expand Down Expand Up @@ -1709,6 +1768,7 @@ def _render_index_html(
formData.append("enable_text_normalization", document.getElementById("enable-text-normalization").checked ? "1" : "0");
formData.append("enable_normalize_tts_text", document.getElementById("enable-robust-text-normalization").checked ? "1" : "0");
formData.append("cpu_threads", document.getElementById("cpu-thread-count").value || String(DEFAULT_CPU_THREADS));
formData.append("execution_device", document.getElementById("execution-device").value);
return formData;
}

Expand Down Expand Up @@ -2164,6 +2224,9 @@ def _render_index_html(
"__TEXT_NORMALIZATION_STATUS__": text_normalization_status,
"__CHECKPOINT__": str(runtime.checkpoint_path),
"__AUDIO_TOKENIZER__": str(runtime.audio_tokenizer_path),
"__RUNTIME_DEVICE__": str(runtime.device),
"__CUDA_OPTIONS__": _build_cuda_options_html(cuda_available, str(runtime.device)),
"__CUDA_AVAILABLE__": json.dumps(cuda_available),
}
for placeholder, value in replacements.items():
template = template.replace(placeholder, value)
Expand All @@ -2175,6 +2238,7 @@ def _build_app(
warmup_manager: WarmupManager,
text_normalizer_manager: WeTextProcessingManager | None,
root_path: str | None,
cuda_available: bool = False,
) -> FastAPI:
app = FastAPI(title="MOSS-TTS-Nano Demo", root_path=root_path or "")
stream_jobs = StreamingJobManager()
Expand All @@ -2187,14 +2251,15 @@ def _resolve_voice_clone_text_chunks(
text: str,
voice_clone_max_text_tokens: int,
cpu_threads: int,
execution_device: str = "default",
) -> list[str]:
normalized_text = str(text or "").strip()
if not normalized_text:
return []

try:
chunks, _, _ = runtime_manager.call_with_runtime(
requested_execution_device="cpu",
requested_execution_device=execution_device,
cpu_threads=cpu_threads,
callback=lambda selected_runtime: selected_runtime.split_voice_clone_text(
text=normalized_text,
Expand Down Expand Up @@ -2293,6 +2358,7 @@ def _run_streaming_job(
tts_max_batch_size: int,
codec_max_batch_size: int,
cpu_threads: int,
execution_device: str = "default",
attn_implementation: str,
do_sample: bool,
text_temperature: float,
Expand All @@ -2305,7 +2371,8 @@ def _run_streaming_job(
seed: int | None,
) -> None:
try:
initial_execution_label = "cpu"
_normalized_device = RequestRuntimeManager.normalize_requested_execution_device(execution_device)
initial_execution_label = str(runtime.device) if _normalized_device == "default" else _normalized_device
with job.lock:
job.started_at = time.monotonic()
job.state = "running"
Expand Down Expand Up @@ -2334,7 +2401,7 @@ def _stream_factory(selected_runtime: NanoTTSService):
)

for event, resolved_execution_device, resolved_cpu_threads in runtime_manager.iter_with_runtime(
requested_execution_device="cpu",
requested_execution_device=execution_device,
cpu_threads=cpu_threads,
factory=_stream_factory,
):
Expand Down Expand Up @@ -2425,6 +2492,7 @@ async def index(request: Request):
text_normalization_status=_text_normalization_status_text(
text_normalizer_manager.snapshot() if text_normalizer_manager is not None else None
),
cuda_available=cuda_available,
)
)

Expand All @@ -2433,6 +2501,7 @@ async def health():
return {
"status": "ok",
"device": str(runtime.device),
"cuda_available": cuda_available,
"dtype": str(runtime.dtype),
"cpu_runtime_loaded": runtime_manager.is_cpu_runtime_loaded(),
"default_cpu_threads": runtime_manager.default_cpu_threads,
Expand Down Expand Up @@ -2509,6 +2578,7 @@ async def generate_stream_start(
codec_max_batch_size: int = Form(0),
enable_text_normalization: str = Form("1"),
enable_normalize_tts_text: str = Form("1"),
execution_device: str = Form("default"),
cpu_threads: int = Form(0),
attn_implementation: str = Form("model_default"),
do_sample: str = Form("1"),
Expand Down Expand Up @@ -2559,6 +2629,7 @@ async def generate_stream_start(
text=str(prepared_texts["text"]),
voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
cpu_threads=int(cpu_threads),
execution_device=execution_device,
)
job = stream_jobs.create()
with job.lock:
Expand All @@ -2577,6 +2648,7 @@ async def generate_stream_start(
"tts_max_batch_size": int(tts_max_batch_size),
"codec_max_batch_size": int(codec_max_batch_size),
"cpu_threads": int(cpu_threads),
"execution_device": execution_device,
"attn_implementation": attn_implementation,
"do_sample": _coerce_bool(do_sample, True),
"text_temperature": float(text_temperature),
Expand All @@ -2594,7 +2666,8 @@ async def generate_stream_start(
thread.start()
prompt_audio_cleanup_path = None

initial_execution_label = "cpu"
_normalized_device = RequestRuntimeManager.normalize_requested_execution_device(execution_device)
initial_execution_label = str(runtime.device) if _normalized_device == "default" else _normalized_device

return {
"stream_id": job.stream_id,
Expand Down Expand Up @@ -2720,6 +2793,7 @@ async def generate(
codec_max_batch_size: int = Form(0),
enable_text_normalization: str = Form("1"),
enable_normalize_tts_text: str = Form("1"),
execution_device: str = Form("default"),
cpu_threads: int = Form(0),
attn_implementation: str = Form("model_default"),
do_sample: str = Form("1"),
Expand Down Expand Up @@ -2791,7 +2865,7 @@ def _synthesize(selected_runtime: NanoTTSService):
)

result, resolved_execution_device, resolved_cpu_threads = runtime_manager.call_with_runtime(
requested_execution_device="cpu",
requested_execution_device=execution_device,
cpu_threads=cpu_threads,
callback=_synthesize,
)
Expand All @@ -2809,6 +2883,7 @@ def _synthesize(selected_runtime: NanoTTSService):
text=str(prepared_texts["text"]),
voice_clone_max_text_tokens=int(voice_clone_max_text_tokens),
cpu_threads=int(cpu_threads),
execution_device=execution_device,
)
generated_audio_path = str(result["audio_path"])
wav_bytes = _audio_to_wav_bytes(result["waveform_numpy"], int(result["sample_rate"]))
Expand Down Expand Up @@ -2847,7 +2922,7 @@ def main(argv: Optional[Sequence[str]] = None) -> None:
default=str(DEFAULT_AUDIO_TOKENIZER_PATH),
)
parser.add_argument("--output-dir", "--output_dir", dest="output_dir", type=str, default=str(DEFAULT_OUTPUT_DIR))
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "auto"])
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "auto", "cuda"])
parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"])
parser.add_argument(
"--attn-implementation",
Expand All @@ -2867,9 +2942,18 @@ def main(argv: Optional[Sequence[str]] = None) -> None:
level=logging.INFO,
)

resolved_runtime_device = "cpu"
if args.device != "cpu":
logging.info("CPU-only app mode: ignoring --device=%s and forcing cpu.", args.device)
import torch as _torch
resolved_runtime_device = args.device
if resolved_runtime_device == "auto":
resolved_runtime_device = "cuda" if _torch.cuda.is_available() else "cpu"
logging.info("auto device resolved to: %s", resolved_runtime_device)
elif resolved_runtime_device == "cuda" and not _torch.cuda.is_available():
logging.warning("--device=cuda specified but CUDA is not available, falling back to cpu.")
resolved_runtime_device = "cpu"

cuda_available = _torch.cuda.is_available()
if cuda_available:
logging.info("CUDA available: %s (device count: %d)", _torch.cuda.get_device_name(0), _torch.cuda.device_count())

runtime = NanoTTSService(
checkpoint_path=args.checkpoint_path,
Expand All @@ -2890,7 +2974,7 @@ def main(argv: Optional[Sequence[str]] = None) -> None:
if args.share:
logging.warning("--share is ignored by the FastAPI-based Nano-TTS app.")

app = _build_app(runtime, warmup_manager, text_normalizer_manager, root_path)
app = _build_app(runtime, warmup_manager, text_normalizer_manager, root_path, cuda_available=cuda_available)
uvicorn.run(
app,
host=args.host,
Expand Down