diff --git a/acestep/inference.py b/acestep/inference.py index 63f2074e4..7df3b99bc 100644 --- a/acestep/inference.py +++ b/acestep/inference.py @@ -14,6 +14,7 @@ from dataclasses import dataclass, field, asdict from loguru import logger import torch +import gc from acestep.audio_utils import AudioSaver, apply_fade, generate_uuid_from_params, normalize_audio, get_lora_weights_hash @@ -331,6 +332,42 @@ def _update_metadata_from_lm( return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics +def _unload_lm_before_dit(llm_handler): + """Unload a resident LM before the DiT phase to free accelerator memory.""" + if llm_handler is None: + return + + logger.info("Unloading LM before DiT. backend={}, initialized={}", + getattr(llm_handler, "llm_backend", None), + getattr(llm_handler, "llm_initialized", None)) + + if torch.cuda.is_available(): + alloc = torch.cuda.memory_allocated() / (1024 ** 3) + reserved = torch.cuda.memory_reserved() / (1024 ** 3) + logger.info("Before LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved) + + try: + llm_handler.unload() + except Exception as exc: + logger.warning("llm_handler.unload() failed: {}", exc) + + gc.collect() + + if llm_handler is not None: + try: + llm_handler._clear_accelerator_cache("[LM handoff unload]") + except Exception as exc: + logger.warning("[LM handoff unload] accelerator cleanup failed: {}", exc) + + if torch.cuda.is_available(): + try: + alloc = torch.cuda.memory_allocated() / (1024 ** 3) + reserved = torch.cuda.memory_reserved() / (1024 ** 3) + logger.info("After LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved) + except Exception as exc: + logger.warning("[LM handoff unload] failed to query CUDA memory stats: {}", exc) + + @_get_spaces_gpu_decorator(duration=180) def generate_music( dit_handler, @@ -422,8 +459,30 @@ def generate_music( # 3. use_cot_language=True: detect vocal language via CoT # 4. use_cot_metas=True: fill missing metadata via CoT need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas - use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks + + # If this request needs the LM, but the LM was previously unloaded (for example + # after the LM->DiT handoff), try to reload it now. lm_status = [] + + request_needs_lm = (params.task_type not in skip_lm_tasks) and (params.thinking or need_lm_for_cot) + + if (request_needs_lm and llm_handler is not None and not llm_handler.llm_initialized and + getattr(llm_handler, "_last_init_config", None) is not None): + logger.info("LM required but not initialized; attempting reload from saved config") + reload_status, reload_ok = llm_handler.reload_last_configuration() + lm_status.append(reload_status) + + if not reload_ok: + logger.error("[generate_music] LM reload failed: {}", reload_status) + return GenerationResult( + audios=[], + status_message=f"❌ LM reload failed: {reload_status}", + extra_outputs={}, + success=False, + error=reload_status, + ) + + use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks if params.task_type in skip_lm_tasks: logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly") @@ -609,6 +668,17 @@ def generate_music( if params.task_type in ("cover", "repaint", "lego", "extract"): audio_duration = None + # Unload the LM now if option is enabled and the backend supports reload cleanly + unload_enabled = os.environ.get("ACESTEP_UNLOAD_LM_BEFORE_DIT", "").lower() in ("1", "true", "yes") + safe_unload_backends = {"pt", "vllm"} + current_backend = getattr(llm_handler, "llm_backend", None) if llm_handler is not None else None + + if unload_enabled and llm_handler is not None and llm_handler.llm_initialized: + if current_backend in safe_unload_backends: + _unload_lm_before_dit(llm_handler) + else: + logger.info("[generate_music] Skipping LM unload before DiT for unsupported backend={}", current_backend) + # Phase 2: DiT music generation # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation dit_generate_kwargs = { diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index 57894a3f6..11060a872 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -64,6 +64,7 @@ def __init__(self, persistent_storage_path: Optional[str] = None): self.dtype = torch.float32 self.offload_to_cpu = False self.disable_tqdm = os.environ.get("ACESTEP_DISABLE_TQDM", "").lower() in ("1", "true", "yes") or not (hasattr(sys.stderr, 'isatty') and sys.stderr.isatty()) + self._last_init_config = None # HuggingFace Space persistent storage support if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE: @@ -80,12 +81,30 @@ def __init__(self, persistent_storage_path: Optional[str] = None): self._mlx_model = None self._mlx_model_path = None - def _clear_accelerator_cache(self) -> None: + def _save_last_init_config( + self, + checkpoint_dir: str, + lm_model_path: str, + device: str, + offload_to_cpu: bool, + dtype: Optional[torch.dtype], + ) -> None: + """Persist the last successfully initialized LM configuration.""" + self._last_init_config = { + "checkpoint_dir": checkpoint_dir, + "lm_model_path": lm_model_path, + "backend": self.llm_backend, + "device": device, + "offload_to_cpu": offload_to_cpu, + "dtype": dtype, + } + + def _clear_accelerator_cache(self, context: str = "[LLM]") -> None: """Release freed accelerator memory back to the driver. Synchronises the device *before* releasing cached blocks so that every in-flight async write has landed and the freed blocks are - actually reclaimable. Supports CUDA, XPU (Intel), and MPS + actually reclaimable. Supports CUDA, XPU (Intel), and MPS (Apple Silicon) backends. """ try: @@ -103,48 +122,90 @@ def _clear_accelerator_cache(self) -> None: active_device = "mps" if active_device == "cuda" and torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() + try: + torch.cuda.synchronize() + except Exception as exc: + logger.warning("{} torch.cuda.synchronize() failed: {}", context, exc) + try: + torch.cuda.empty_cache() + except Exception as exc: + logger.warning("{} torch.cuda.empty_cache() failed: {}", context, exc) + try: + torch.cuda.ipc_collect() + except Exception as exc: + logger.warning("{} torch.cuda.ipc_collect() failed: {}", context, exc) elif active_device == "xpu" and hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.synchronize() - torch.xpu.empty_cache() + try: + torch.xpu.synchronize() + except Exception as exc: + logger.warning("{} torch.xpu.synchronize() failed: {}", context, exc) + try: + torch.xpu.empty_cache() + except Exception as exc: + logger.warning("{} torch.xpu.empty_cache() failed: {}", context, exc) elif active_device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): if hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() + try: + torch.mps.synchronize() + except Exception as exc: + logger.warning("{} torch.mps.synchronize() failed: {}", context, exc) if hasattr(torch.mps, "empty_cache"): - torch.mps.empty_cache() + try: + torch.mps.empty_cache() + except Exception as exc: + logger.warning("{} torch.mps.empty_cache() failed: {}", context, exc) def unload(self) -> None: """Release LM weights/tokenizer and clear caches to free memory.""" try: if self.llm_backend == "vllm": try: - if hasattr(self.llm, "reset"): - self.llm.reset() - except Exception: - pass - self._cleanup_torch_distributed_state() + if self.llm is not None: + if hasattr(self.llm, "exit"): + logger.info("[LLM vLLM] Calling nanovllm exit() for hard teardown") + self.llm.exit() + elif hasattr(self.llm, "reset"): + logger.info("[LLM vLLM] exit() missing, falling back to reset()") + self.llm.reset() + except Exception as exc: + logger.warning(f"[LLM vLLM] Error during vLLM teardown: {exc}") + + try: + self._cleanup_torch_distributed_state() + except Exception as exc: + logger.warning(f"[LLM vLLM] torch distributed cleanup failed: {exc}") + self.llm = None self.llm_tokenizer = None self.constrained_processor = None self.llm_initialized = False - self.llm_backend = None + self._hf_model_for_scoring = None self._mlx_model = None self._mlx_model_path = None + gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - elif hasattr(torch, "mps") and torch.backends.mps.is_available(): - if hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - if hasattr(torch.mps, "empty_cache"): - torch.mps.empty_cache() - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.empty_cache() - torch.xpu.synchronize() - except Exception: - pass + self._clear_accelerator_cache("[LLM unload]") + except Exception as exc: + logger.warning(f"[LLM] unload failed: {exc}") + + def reload_last_configuration(self) -> Tuple[str, bool]: + """Recreate the LM from the last successful initialize() configuration.""" + if not self._last_init_config: + return "❌ No previous LM initialization config available", False + + cfg = dict(self._last_init_config) + + logger.info("[LLM] Reloading last configuration: backend={} model={} device={}", + cfg.get("backend"), cfg.get("lm_model_path"), cfg.get("device")) + + return self.initialize( + checkpoint_dir=cfg["checkpoint_dir"], + lm_model_path=cfg["lm_model_path"], + backend=cfg["backend"], + device=cfg["device"], + offload_to_cpu=cfg["offload_to_cpu"], + dtype=cfg["dtype"], + ) def _cleanup_torch_distributed_state(self) -> None: """Destroy default torch distributed process group when already initialized.""" @@ -659,6 +720,13 @@ def initialize( logger.info("Attempting MLX backend for Apple Silicon acceleration...") mlx_success, mlx_status = self._load_mlx_model(full_lm_model_path) if mlx_success: + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return mlx_status, True else: logger.warning(f"MLX backend failed: {mlx_status}") @@ -669,6 +737,13 @@ def initialize( if not success: return status_msg, False status_msg = f"✅ 5Hz LM initialized (PyTorch fallback from MLX)\nModel: {full_lm_model_path}\nBackend: PyTorch" + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return status_msg, True # else: backend was "vllm" on MPS, continue to vllm attempt below elif backend == "mlx": @@ -678,6 +753,13 @@ def initialize( if not success: return status_msg, False status_msg = f"✅ 5Hz LM initialized (PyTorch fallback, MLX not available)\nModel: {full_lm_model_path}\nBackend: PyTorch" + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return status_msg, True if backend == "vllm" and device != "cuda": @@ -733,6 +815,13 @@ def initialize( logger.warning("vllm failed on MPS, trying MLX backend...") mlx_success, mlx_status = self._load_mlx_model(full_lm_model_path) if mlx_success: + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return mlx_status, True logger.warning(f"MLX also failed: {mlx_status}, falling back to PyTorch") logger.warning("Falling back to PyTorch backend") @@ -749,6 +838,13 @@ def initialize( if vllm_preflight_warning is not None: status_msg += f"\nNote: {vllm_preflight_warning}" + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return status_msg, True except Exception as e: