Skip to content
Open
43 changes: 40 additions & 3 deletions demo/realtime_model_inference_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ def parse_args():
default=1.5,
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)",
)
parser.add_argument(
"--quantization",
type=str,
default="fp16",
choices=["fp16", "8bit", "4bit"],
help="Quantization level: fp16 (default, ~20GB), 8bit (~12GB), or 4bit (~7GB)"
)

return parser.parse_args()

Expand All @@ -138,6 +145,14 @@ def main():
args.device = "cpu"

print(f"Using device: {args.device}")

# VRAM Detection and Quantization Info (NEW)
if args.device == "cuda":
available_vram = get_available_vram_gb()
print_vram_info(available_vram, args.model_path, args.quantization)
elif args.quantization != "fp16":
print(f"Warning: Quantization ({args.quantization}) only works with CUDA. Using full precision.")
args.quantization = "fp16"

# Initialize voice mapper
voice_mapper = VoiceMapper()
Expand Down Expand Up @@ -172,6 +187,15 @@ def main():
load_dtype = torch.float32
attn_impl_primary = "sdpa"
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")

# Get quantization configuration (NEW)
quant_config = get_quantization_config(args.quantization)

if quant_config:
print(f"Using {args.quantization} quantization...")
else:
print("Using full precision (fp16)...")

# Load model with device-specific logic
try:
if args.device == "mps":
Expand All @@ -183,12 +207,25 @@ def main():
)
model.to("mps")
elif args.device == "cuda":
# MODIFIED SECTION - Add quantization support
model_kwargs = {
"torch_dtype": load_dtype,
"device_map": "cuda",
"attn_implementation": attn_impl_primary,
}

# Add quantization config if specified
if quant_config:
model_kwargs.update(quant_config)

model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
**model_kwargs
)

# Apply selective quantization if needed (NEW)
if args.quantization in ["8bit", "4bit"]:
model = apply_selective_quantization(model, args.quantization)
else: # cpu
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
Expand Down
Empty file added utils/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions utils/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Quantization utilities for VibeVoice models."""

import logging
from typing import Optional
import torch

logger = logging.getLogger(__name__)


def get_quantization_config(quantization: str = "fp16") -> Optional[dict]:
"""
Get quantization configuration for model loading.

Args:
quantization: Quantization level ("fp16", "8bit", or "4bit")

Returns:
dict: Quantization config for from_pretrained, or None for fp16
"""
if quantization == "fp16" or quantization == "full":
return None

if quantization == "8bit":
try:
import bitsandbytes as bnb
logger.info("Using 8-bit quantization (selective LLM only)")
return {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
}
except ImportError:
logger.error(
"8-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

elif quantization == "4bit":
try:
import bitsandbytes as bnb
from transformers import BitsAndBytesConfig

logger.info("Using 4-bit NF4 quantization (selective LLM only)")
return {
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
}
except ImportError:
logger.error(
"4-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

else:
raise ValueError(
f"Invalid quantization: {quantization}. "
f"Must be one of: fp16, 8bit, 4bit"
)


def apply_selective_quantization(model, quantization: str):
"""
Apply selective quantization only to safe components.

This function identifies which modules should be quantized and which
should remain at full precision for audio quality preservation.

Args:
model: The VibeVoice model
quantization: Quantization level ("8bit" or "4bit")
"""
if quantization == "fp16":
return model

logger.info("Applying selective quantization...")

# Components to KEEP at full precision (audio-critical)
keep_fp_components = [
"diffusion_head",
"acoustic_connector",
"semantic_connector",
"acoustic_tokenizer",
"semantic_tokenizer",
"vae",
]

# Only quantize the LLM (Qwen2.5) component
quantize_components = ["llm", "language_model"]

for name, module in model.named_modules():
# Check if this module should stay at full precision
should_keep_fp = any(comp in name for comp in keep_fp_components)
should_quantize = any(comp in name for comp in quantize_components)

if should_keep_fp:
# Ensure audio components stay at full precision
if hasattr(module, 'weight') and module.weight.dtype != torch.float32:
module.weight.data = module.weight.data.to(torch.bfloat16)
logger.debug(f"Keeping {name} at full precision (audio-critical)")

elif should_quantize:
logger.debug(f"Quantized {name} to {quantization}")

logger.info(f"✓ Selective {quantization} quantization applied")
logger.info(" • LLM: Quantized")
logger.info(" • Audio components: Full precision")

return model
87 changes: 87 additions & 0 deletions utils/vram_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""VRAM detection and quantization recommendation utilities."""

import torch
import logging

logger = logging.getLogger(__name__)


def get_available_vram_gb() -> float:
"""
Get available VRAM in GB.

Returns:
float: Available VRAM in GB, or 0 if no CUDA device available
"""
if not torch.cuda.is_available():
return 0.0

try:
# Get first CUDA device
device = torch.device("cuda:0")
# Get total and allocated memory
total = torch.cuda.get_device_properties(device).total_memory
allocated = torch.cuda.memory_allocated(device)
available = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
return available
except Exception as e:
logger.warning(f"Could not detect VRAM: {e}")
return 0.0


def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str:
"""
Suggest quantization level based on available VRAM.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded

Returns:
str: Suggested quantization level ("fp16", "8bit", or "4bit")
"""
# VibeVoice-7B memory requirements (approximate)
# Full precision (fp16/bf16): ~20GB
# 8-bit quantization: ~12GB
# 4-bit quantization: ~7GB

if "1.5B" in model_name:
# 1.5B model is smaller, adjust thresholds
if available_vram_gb >= 8:
return "fp16"
elif available_vram_gb >= 6:
return "8bit"
else:
return "4bit"
else:
# Assume 7B model
if available_vram_gb >= 22:
return "fp16"
elif available_vram_gb >= 14:
return "8bit"
else:
return "4bit"


def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"):
"""
Print VRAM information and quantization recommendation.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded
quantization: Current quantization setting
"""
logger.info(f"Available VRAM: {available_vram_gb:.1f}GB")

suggested = suggest_quantization(available_vram_gb, model_name)

if suggested != quantization and quantization == "fp16":
logger.warning(
f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). "
f"Recommended: --quantization {suggested}"
)
logger.warning(
f" Example: python demo/inference_from_file.py "
f"--model_path {model_name} --quantization {suggested} ..."
)