From fc44709544ba7f1e68f885d7c7acf117d221590d Mon Sep 17 00:00:00 2001 From: 5kft <5kft@users.noreply.github.com> Date: Sat, 11 Apr 2026 08:14:14 -0700 Subject: [PATCH 1/5] provide the ability to unload the LM prior to the DiT execution phase --- acestep/inference.py | 57 +++++++++++++++++++++++++++++++- acestep/llm_inference.py | 71 ++++++++++++++++++++++++++++++++++------ 2 files changed, 117 insertions(+), 11 deletions(-) diff --git a/acestep/inference.py b/acestep/inference.py index 63f2074e4..b9ea581f9 100644 --- a/acestep/inference.py +++ b/acestep/inference.py @@ -14,6 +14,7 @@ from dataclasses import dataclass, field, asdict from loguru import logger import torch +import gc from acestep.audio_utils import AudioSaver, apply_fade, generate_uuid_from_params, normalize_audio, get_lora_weights_hash @@ -331,6 +332,42 @@ def _update_metadata_from_lm( return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics +def _unload_lm_before_dit(llm_handler): + if llm_handler is None: + return + + logger.info("Unloading LM before DiT. backend={}, initialized={}", + getattr(llm_handler, "llm_backend", None), + getattr(llm_handler, "llm_initialized", None)) + + if torch.cuda.is_available(): + alloc = torch.cuda.memory_allocated() / (1024 ** 3) + reserved = torch.cuda.memory_reserved() / (1024 ** 3) + logger.info("Before LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved) + + try: + llm_handler.unload() + except Exception as exc: + logger.warning("llm_handler.unload() failed: {}", exc) + + gc.collect() + + try: + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + try: + torch.cuda.ipc_collect() + except Exception: + pass + + alloc = torch.cuda.memory_allocated() / (1024 ** 3) + reserved = torch.cuda.memory_reserved() / (1024 ** 3) + logger.info("After LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved) + except Exception: + pass + + @_get_spaces_gpu_decorator(duration=180) def generate_music( dit_handler, @@ -422,8 +459,22 @@ def generate_music( # 3. use_cot_language=True: detect vocal language via CoT # 4. use_cot_metas=True: fill missing metadata via CoT need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas - use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks + + # If this request needs the LM, but the LM was previously unloaded (for example + # after the LM->DiT handoff), try to reload it now. lm_status = [] + + request_needs_lm = (params.task_type not in skip_lm_tasks) and (params.thinking or need_lm_for_cot) + + if request_needs_lm and llm_handler is not None and not llm_handler.llm_initialized: + logger.info("LM required but not initialized; attempting reload from saved config") + reload_status, reload_ok = llm_handler.reload_last_configuration() + lm_status.append(reload_status) + + if not reload_ok: + logger.error(f"[generate_music] LM reload failed: {reload_status}") + + use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks if params.task_type in skip_lm_tasks: logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly") @@ -609,6 +660,10 @@ def generate_music( if params.task_type in ("cover", "repaint", "lego", "extract"): audio_duration = None + # Unload the LM now if option is enabled + if use_lm and os.environ.get("ACESTEP_UNLOAD_LM_BEFORE_DIT", "").lower() in ("1", "true", "yes"): + _unload_lm_before_dit(llm_handler) + # Phase 2: DiT music generation # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation dit_generate_kwargs = { diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index 57894a3f6..17ef1da19 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -64,6 +64,7 @@ def __init__(self, persistent_storage_path: Optional[str] = None): self.dtype = torch.float32 self.offload_to_cpu = False self.disable_tqdm = os.environ.get("ACESTEP_DISABLE_TQDM", "").lower() in ("1", "true", "yes") or not (hasattr(sys.stderr, 'isatty') and sys.stderr.isatty()) + self._last_init_config = None # HuggingFace Space persistent storage support if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE: @@ -119,22 +120,44 @@ def unload(self) -> None: try: if self.llm_backend == "vllm": try: - if hasattr(self.llm, "reset"): - self.llm.reset() - except Exception: - pass - self._cleanup_torch_distributed_state() + if self.llm is not None: + if hasattr(self.llm, "exit"): + logger.info("[LLM vLLM] Calling nanovllm exit() for hard teardown") + self.llm.exit() + elif hasattr(self.llm, "reset"): + logger.info("[LLM vLLM] exit() missing, falling back to reset()") + self.llm.reset() + except Exception as exc: + logger.warning(f"[LLM vLLM] Error during vLLM teardown: {exc}") + + try: + self._cleanup_torch_distributed_state() + except Exception as exc: + logger.warning(f"[LLM vLLM] torch distributed cleanup failed: {exc}") + self.llm = None self.llm_tokenizer = None self.constrained_processor = None self.llm_initialized = False - self.llm_backend = None + self._hf_model_for_scoring = None self._mlx_model = None self._mlx_model_path = None + gc.collect() + if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() + try: + torch.cuda.synchronize() + except Exception: + pass + try: + torch.cuda.empty_cache() + except Exception: + pass + try: + torch.cuda.ipc_collect() + except Exception: + pass elif hasattr(torch, "mps") and torch.backends.mps.is_available(): if hasattr(torch.mps, "synchronize"): torch.mps.synchronize() @@ -143,8 +166,27 @@ def unload(self) -> None: elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() torch.xpu.synchronize() - except Exception: - pass + except Exception as exc: + logger.warning(f"[LLM] unload failed: {exc}") + + def reload_last_configuration(self) -> Tuple[str, bool]: + """Recreate the LM from the last successful initialize() configuration.""" + if not self._last_init_config: + return "❌ No previous LM initialization config available", False + + cfg = dict(self._last_init_config) + + logger.info("[LLM] Reloading last configuration: backend={} model={} device={}", + cfg.get("backend"), cfg.get("lm_model_path"), cfg.get("device")) + + return self.initialize( + checkpoint_dir=cfg["checkpoint_dir"], + lm_model_path=cfg["lm_model_path"], + backend=cfg["backend"], + device=cfg["device"], + offload_to_cpu=cfg["offload_to_cpu"], + dtype=cfg["dtype"], + ) def _cleanup_torch_distributed_state(self) -> None: """Destroy default torch distributed process group when already initialized.""" @@ -695,6 +737,15 @@ def initialize( vllm_fallback_note = None + self._last_init_config = { + "checkpoint_dir": checkpoint_dir, + "lm_model_path": lm_model_path, + "backend": backend, + "device": device, + "offload_to_cpu": offload_to_cpu, + "dtype": dtype, + } + # Initialize based on user-selected backend if backend == "vllm": _warn_if_prerelease_python() From fd286f56dfc7518a9992eaab08427a3002986fd9 Mon Sep 17 00:00:00 2001 From: 5kft <5kft@users.noreply.github.com> Date: Sat, 11 Apr 2026 10:02:35 -0700 Subject: [PATCH 2/5] gate unload to vllm and pt; properly persist reload state after initialization --- acestep/inference.py | 13 +++++++-- acestep/llm_inference.py | 62 ++++++++++++++++++++++++++++++++++------ 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/acestep/inference.py b/acestep/inference.py index b9ea581f9..de106d944 100644 --- a/acestep/inference.py +++ b/acestep/inference.py @@ -660,9 +660,16 @@ def generate_music( if params.task_type in ("cover", "repaint", "lego", "extract"): audio_duration = None - # Unload the LM now if option is enabled - if use_lm and os.environ.get("ACESTEP_UNLOAD_LM_BEFORE_DIT", "").lower() in ("1", "true", "yes"): - _unload_lm_before_dit(llm_handler) + # Unload the LM now if option is enabled and the backend supports reload cleanly + unload_enabled = os.environ.get("ACESTEP_UNLOAD_LM_BEFORE_DIT", "").lower() in ("1", "true", "yes") + safe_unload_backends = {"pt", "vllm"} + current_backend = getattr(llm_handler, "llm_backend", None) if llm_handler is not None else None + + if use_lm and unload_enabled: + if current_backend in safe_unload_backends: + _unload_lm_before_dit(llm_handler) + else: + logger.info("[generate_music] Skipping LM unload before DiT for unsupported backend={}", current_backend) # Phase 2: DiT music generation # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index 17ef1da19..6a0aa06c7 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -81,6 +81,24 @@ def __init__(self, persistent_storage_path: Optional[str] = None): self._mlx_model = None self._mlx_model_path = None + def _save_last_init_config( + self, + checkpoint_dir: str, + lm_model_path: str, + device: str, + offload_to_cpu: bool, + dtype: Optional[torch.dtype], + ) -> None: + """Persist the last successfully initialized LM configuration.""" + self._last_init_config = { + "checkpoint_dir": checkpoint_dir, + "lm_model_path": lm_model_path, + "backend": self.llm_backend, + "device": device, + "offload_to_cpu": offload_to_cpu, + "dtype": dtype, + } + def _clear_accelerator_cache(self) -> None: """Release freed accelerator memory back to the driver. @@ -701,6 +719,13 @@ def initialize( logger.info("Attempting MLX backend for Apple Silicon acceleration...") mlx_success, mlx_status = self._load_mlx_model(full_lm_model_path) if mlx_success: + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return mlx_status, True else: logger.warning(f"MLX backend failed: {mlx_status}") @@ -711,6 +736,13 @@ def initialize( if not success: return status_msg, False status_msg = f"✅ 5Hz LM initialized (PyTorch fallback from MLX)\nModel: {full_lm_model_path}\nBackend: PyTorch" + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return status_msg, True # else: backend was "vllm" on MPS, continue to vllm attempt below elif backend == "mlx": @@ -720,6 +752,13 @@ def initialize( if not success: return status_msg, False status_msg = f"✅ 5Hz LM initialized (PyTorch fallback, MLX not available)\nModel: {full_lm_model_path}\nBackend: PyTorch" + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return status_msg, True if backend == "vllm" and device != "cuda": @@ -737,15 +776,6 @@ def initialize( vllm_fallback_note = None - self._last_init_config = { - "checkpoint_dir": checkpoint_dir, - "lm_model_path": lm_model_path, - "backend": backend, - "device": device, - "offload_to_cpu": offload_to_cpu, - "dtype": dtype, - } - # Initialize based on user-selected backend if backend == "vllm": _warn_if_prerelease_python() @@ -784,6 +814,13 @@ def initialize( logger.warning("vllm failed on MPS, trying MLX backend...") mlx_success, mlx_status = self._load_mlx_model(full_lm_model_path) if mlx_success: + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return mlx_status, True logger.warning(f"MLX also failed: {mlx_status}, falling back to PyTorch") logger.warning("Falling back to PyTorch backend") @@ -800,6 +837,13 @@ def initialize( if vllm_preflight_warning is not None: status_msg += f"\nNote: {vllm_preflight_warning}" + self._save_last_init_config( + checkpoint_dir=checkpoint_dir, + lm_model_path=lm_model_path, + device=device, + offload_to_cpu=offload_to_cpu, + dtype=dtype, + ) return status_msg, True except Exception as e: From c96cdf3e2d24e731811be4271949fbf20a36a357 Mon Sep 17 00:00:00 2001 From: 5kft <5kft@users.noreply.github.com> Date: Sat, 11 Apr 2026 10:27:11 -0700 Subject: [PATCH 3/5] consolidated accelerator cleanup logic and added logging; changed to check unload based upon resident LM state --- acestep/inference.py | 21 +++++++------- acestep/llm_inference.py | 61 ++++++++++++++++++++-------------------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/acestep/inference.py b/acestep/inference.py index de106d944..350ef809d 100644 --- a/acestep/inference.py +++ b/acestep/inference.py @@ -352,20 +352,19 @@ def _unload_lm_before_dit(llm_handler): gc.collect() - try: - if torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() - try: - torch.cuda.ipc_collect() - except Exception: - pass + if llm_handler is not None: + try: + llm_handler._clear_accelerator_cache("[LM handoff unload]") + except Exception as exc: + logger.warning("[LM handoff unload] accelerator cleanup failed: {}", exc) + if torch.cuda.is_available(): + try: alloc = torch.cuda.memory_allocated() / (1024 ** 3) reserved = torch.cuda.memory_reserved() / (1024 ** 3) logger.info("After LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved) - except Exception: - pass + except Exception as exc: + logger.warning("[LM handoff unload] failed to query CUDA memory stats: {}", exc) @_get_spaces_gpu_decorator(duration=180) @@ -665,7 +664,7 @@ def generate_music( safe_unload_backends = {"pt", "vllm"} current_backend = getattr(llm_handler, "llm_backend", None) if llm_handler is not None else None - if use_lm and unload_enabled: + if unload_enabled and llm_handler is not None and llm_handler.llm_initialized: if current_backend in safe_unload_backends: _unload_lm_before_dit(llm_handler) else: diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index 6a0aa06c7..11060a872 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -99,12 +99,12 @@ def _save_last_init_config( "dtype": dtype, } - def _clear_accelerator_cache(self) -> None: + def _clear_accelerator_cache(self, context: str = "[LLM]") -> None: """Release freed accelerator memory back to the driver. Synchronises the device *before* releasing cached blocks so that every in-flight async write has landed and the freed blocks are - actually reclaimable. Supports CUDA, XPU (Intel), and MPS + actually reclaimable. Supports CUDA, XPU (Intel), and MPS (Apple Silicon) backends. """ try: @@ -122,16 +122,38 @@ def _clear_accelerator_cache(self) -> None: active_device = "mps" if active_device == "cuda" and torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() + try: + torch.cuda.synchronize() + except Exception as exc: + logger.warning("{} torch.cuda.synchronize() failed: {}", context, exc) + try: + torch.cuda.empty_cache() + except Exception as exc: + logger.warning("{} torch.cuda.empty_cache() failed: {}", context, exc) + try: + torch.cuda.ipc_collect() + except Exception as exc: + logger.warning("{} torch.cuda.ipc_collect() failed: {}", context, exc) elif active_device == "xpu" and hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.synchronize() - torch.xpu.empty_cache() + try: + torch.xpu.synchronize() + except Exception as exc: + logger.warning("{} torch.xpu.synchronize() failed: {}", context, exc) + try: + torch.xpu.empty_cache() + except Exception as exc: + logger.warning("{} torch.xpu.empty_cache() failed: {}", context, exc) elif active_device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): if hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() + try: + torch.mps.synchronize() + except Exception as exc: + logger.warning("{} torch.mps.synchronize() failed: {}", context, exc) if hasattr(torch.mps, "empty_cache"): - torch.mps.empty_cache() + try: + torch.mps.empty_cache() + except Exception as exc: + logger.warning("{} torch.mps.empty_cache() failed: {}", context, exc) def unload(self) -> None: """Release LM weights/tokenizer and clear caches to free memory.""" @@ -162,28 +184,7 @@ def unload(self) -> None: self._mlx_model_path = None gc.collect() - - if torch.cuda.is_available(): - try: - torch.cuda.synchronize() - except Exception: - pass - try: - torch.cuda.empty_cache() - except Exception: - pass - try: - torch.cuda.ipc_collect() - except Exception: - pass - elif hasattr(torch, "mps") and torch.backends.mps.is_available(): - if hasattr(torch.mps, "synchronize"): - torch.mps.synchronize() - if hasattr(torch.mps, "empty_cache"): - torch.mps.empty_cache() - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.empty_cache() - torch.xpu.synchronize() + self._clear_accelerator_cache("[LLM unload]") except Exception as exc: logger.warning(f"[LLM] unload failed: {exc}") From 2058f7381dc0e2123d17af0508983c53d2014da3 Mon Sep 17 00:00:00 2001 From: 5kft <5kft@users.noreply.github.com> Date: Sat, 11 Apr 2026 10:45:27 -0700 Subject: [PATCH 4/5] fail fast when the LM reload fails; added docstring --- acestep/inference.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/acestep/inference.py b/acestep/inference.py index 350ef809d..4ed705ed0 100644 --- a/acestep/inference.py +++ b/acestep/inference.py @@ -333,6 +333,7 @@ def _update_metadata_from_lm( def _unload_lm_before_dit(llm_handler): + """Unload a resident LM before the DiT phase to free accelerator memory.""" if llm_handler is None: return @@ -471,7 +472,14 @@ def generate_music( lm_status.append(reload_status) if not reload_ok: - logger.error(f"[generate_music] LM reload failed: {reload_status}") + logger.error("[generate_music] LM reload failed: {}", reload_status) + return GenerationResult( + audios=[], + status_message=f"❌ LM reload failed: {reload_status}", + extra_outputs={}, + success=False, + error=reload_status, + ) use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks From aa1c94802457bc466fcdff4e81ca56715cd64447 Mon Sep 17 00:00:00 2001 From: 5kft <5kft@users.noreply.github.com> Date: Sat, 11 Apr 2026 11:04:54 -0700 Subject: [PATCH 5/5] gate reload/fail-fast to actual handoff unloads --- acestep/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/acestep/inference.py b/acestep/inference.py index 4ed705ed0..7df3b99bc 100644 --- a/acestep/inference.py +++ b/acestep/inference.py @@ -466,7 +466,8 @@ def generate_music( request_needs_lm = (params.task_type not in skip_lm_tasks) and (params.thinking or need_lm_for_cot) - if request_needs_lm and llm_handler is not None and not llm_handler.llm_initialized: + if (request_needs_lm and llm_handler is not None and not llm_handler.llm_initialized and + getattr(llm_handler, "_last_init_config", None) is not None): logger.info("LM required but not initialized; attempting reload from saved config") reload_status, reload_ok = llm_handler.reload_last_configuration() lm_status.append(reload_status)