Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions long_script_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from pathlib import Path
from vibevoice import load_model # Replace with actual VibeVoice import

# -----------------------
# Configuration
# -----------------------
MODEL_NAME = "1.5B" # or your preferred model
SCRIPT_PATH = "long_script.txt" # path to your input script
OUTPUT_PATH = "output_audio.wav"
CHUNK_SIZE = 500 # number of tokens or characters per chunk (adjust for GPU)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------
# Load model
# -----------------------
print(f"Loading VibeVoice model ({MODEL_NAME})...")
model = load_model(MODEL_NAME).to(DEVICE) # Replace with actual load function
model.eval()

# -----------------------
# Load script
# -----------------------
with open(SCRIPT_PATH, "r", encoding="utf-8") as f:
script_text = f.read()

# -----------------------
# Generate audio in chunks
# -----------------------
def generate_long_script(model, text, chunk_size=CHUNK_SIZE):
"""Split long script and generate audio chunk by chunk."""
# Simple character-based chunking
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
audio_outputs = []

for i, chunk in enumerate(chunks):
print(f"Generating chunk {i+1}/{len(chunks)}...")
# Replace with your actual model.generate() or inference function
audio_chunk = model.generate(chunk) # Example placeholder
audio_outputs.append(audio_chunk.cpu())

# Concatenate audio tensors
final_audio = torch.cat(audio_outputs)
return final_audio

# -----------------------
# Generate and save
# -----------------------
final_audio = generate_long_script(model, script_text)

# Convert tensor to WAV file
import soundfile as sf
sf.write(OUTPUT_PATH, final_audio.numpy(), 22050) # adjust sample rate if needed

print(f"Audio generation complete! Saved to {OUTPUT_PATH}")