Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion acestep/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataclasses import dataclass, field, asdict
from loguru import logger
import torch
import gc


from acestep.audio_utils import AudioSaver, apply_fade, generate_uuid_from_params, normalize_audio, get_lora_weights_hash
Expand Down Expand Up @@ -331,6 +332,42 @@ def _update_metadata_from_lm(
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics


def _unload_lm_before_dit(llm_handler):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if llm_handler is None:
return

logger.info("Unloading LM before DiT. backend={}, initialized={}",
getattr(llm_handler, "llm_backend", None),
getattr(llm_handler, "llm_initialized", None))

if torch.cuda.is_available():
alloc = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
logger.info("Before LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved)

try:
llm_handler.unload()
except Exception as exc:
logger.warning("llm_handler.unload() failed: {}", exc)

gc.collect()

try:
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
try:
torch.cuda.ipc_collect()
except Exception:
pass

alloc = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
logger.info("After LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved)
except Exception:
pass
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


@_get_spaces_gpu_decorator(duration=180)
def generate_music(
dit_handler,
Expand Down Expand Up @@ -422,8 +459,22 @@ def generate_music(
# 3. use_cot_language=True: detect vocal language via CoT
# 4. use_cot_metas=True: fill missing metadata via CoT
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks

# If this request needs the LM, but the LM was previously unloaded (for example
# after the LM->DiT handoff), try to reload it now.
lm_status = []

request_needs_lm = (params.task_type not in skip_lm_tasks) and (params.thinking or need_lm_for_cot)

if request_needs_lm and llm_handler is not None and not llm_handler.llm_initialized:
logger.info("LM required but not initialized; attempting reload from saved config")
reload_status, reload_ok = llm_handler.reload_last_configuration()
lm_status.append(reload_status)

if not reload_ok:
logger.error(f"[generate_music] LM reload failed: {reload_status}")

Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks

if params.task_type in skip_lm_tasks:
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
Expand Down Expand Up @@ -609,6 +660,17 @@ def generate_music(
if params.task_type in ("cover", "repaint", "lego", "extract"):
audio_duration = None

# Unload the LM now if option is enabled and the backend supports reload cleanly
unload_enabled = os.environ.get("ACESTEP_UNLOAD_LM_BEFORE_DIT", "").lower() in ("1", "true", "yes")
safe_unload_backends = {"pt", "vllm"}
current_backend = getattr(llm_handler, "llm_backend", None) if llm_handler is not None else None

if use_lm and unload_enabled:
if current_backend in safe_unload_backends:
_unload_lm_before_dit(llm_handler)
else:
logger.info("[generate_music] Skipping LM unload before DiT for unsupported backend={}", current_backend)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment on lines +676 to +680
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This unload path strands the other LM inference APIs.

After Line 677 unloads the shared handler, understand_music(), create_sample(), and format_sample() in this same module still short-circuit on not llm_handler.llm_initialized and never call reload_last_configuration(). With the flag enabled, a generate_music() call can therefore make the next LM-only request fail until something else reinitializes the model. Please reuse the same “ensure LM ready” path there, or centralize it in a shared helper.

Based on learnings, acestep/inference.py should lazy load models when needed to minimize memory allocations.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/inference.py` around lines 675 - 679, The unload path in
generate_music leaves the shared llm_handler unloaded while understand_music,
create_sample, and format_sample still short-circuit on not
llm_handler.llm_initialized; change those functions to reuse the same "ensure LM
ready" logic (or extract a shared helper, e.g., ensure_llm_ready()) that checks
llm_handler.llm_initialized and calls reload_last_configuration() when needed
instead of returning early, and update generate_music to call that same helper
after _unload_lm_before_dit so all LM-only APIs lazily reload the model when
required (reference: generate_music, understand_music, create_sample,
format_sample, _unload_lm_before_dit, reload_last_configuration,
llm_handler.llm_initialized).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, yes, this is a lifecycle inconsistency. However it is not clear that these other APIs would even be hit from the normal UI flow here. I don't know enough at this point to determine if this is a real problem or not...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

Let me check how those functions are wired into the Gradio UI to give you a concrete answer.


🏁 Script executed:

#!/bin/bash
echo "=== understand_music usages ==="
rg -n "understand_music" --type py -C2

echo ""
echo "=== create_sample usages ==="
rg -n "create_sample" --type py -C2

echo ""
echo "=== format_sample usages ==="
rg -n "format_sample" --type py -C2

Length of output: 66675


@5kft, the search results give a concrete answer: yes, this is a real (though narrow) issue in the normal Gradio UI flow.

Here's what was found:

Function Called from UI How
understand_music ✅ Yes llm_analysis_actions.py ("Analyze Audio" button), metadata_loading.py (load metadata event)
create_sample ✅ Yes llm_sample_actions.py → wired to "Create Sample" button click in generation_mode_wiring.py
format_sample ✅ Yes llm_format_actions.py → "Format Caption/Lyrics" buttons

All three are separate clickable buttons in the same Gradio session sharing the same llm_handler. So the failure scenario is:

  1. User clicks Generate (with ACESTEP_UNLOAD_LM_BEFORE_DIT=1) → LM unloaded after DiT handoff ✅
  2. User then clicks Create Sample, Format, or Analyze in the same session → llm_handler.llm_initialized is False → silent failure or error ❌

Notably, metadata_loading.py line 257 just silently skips with if llm_handler.llm_initialized: rather than surfacing any error.

Severity is bounded because:

  • The flag ACESTEP_UNLOAD_LM_BEFORE_DIT=1 must be explicitly opted into
  • The generate_music reload path already works correctly
  • These auxiliary actions are optional helpers

The fix would be adding the same reload guard (check _last_init_config, call reload_last_configuration()) in the three functions before the not llm_handler.llm_initialized early-return, or extracting a shared helper. Would you like me to plan those changes?


🧠 Learnings used
Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: .github/copilot-instructions.md:0-0
Timestamp: 2026-03-20T01:31:33.304Z
Learning: Applies to acestep/handler.py,acestep/inference.py : Use batch operations supported up to 8 songs for audio generation

Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: .github/copilot-instructions.md:0-0
Timestamp: 2026-03-24T04:42:41.285Z
Learning: Applies to acestep/{handler,inference,llm_inference}.py : Lazy load models when needed to minimize memory allocations

Learnt from: chigkim
Repo: ace-step/ACE-Step-1.5 PR: 362
File: cli.py:982-982
Timestamp: 2026-02-08T22:50:13.134Z
Learning: In cli.py, the --backend argument defaults to "vllm" on non-MPS systems. However, the actual fallback from "vllm" to "pt" on non-CUDA devices is handled in LLMHandler.initialize() (acestep/llm_inference.py lines 524-528), not in the CLI. This separation allows the CLI to express user preference while LLMHandler performs capability detection.

Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: .github/copilot-instructions.md:0-0
Timestamp: 2026-03-20T01:31:25.142Z
Learning: Applies to acestep/{handler,inference}.py : Batch operations are supported up to 8 songs

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 811
File: docs/en/GRADIO_GUIDE.md:0-0
Timestamp: 2026-03-11T12:03:00.308Z
Learning: In ace-step/ACE-Step-1.5 GRADIO_GUIDE.md, the "How It Integrates With 5Hz LM" section explicitly notes that when external mode is active, some LM-assisted workflows (notably Create Sample) still check `llm_handler.llm_initialized` in both the Gradio handler path and `create_sample()`, and therefore may still require local 5Hz LM initialization until the remaining routing is completed. The doc describes the external-mode paths as "intended behavior" rather than currently fully functional. Do not flag this remaining initialization requirement as a new bug in documentation reviews.

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 774
File: acestep/ui/gradio/interfaces/training_dataset_tab_label_preview.py:70-95
Timestamp: 2026-03-05T23:12:28.095Z
Learning: In ace-step/ACE-Step-1.5, the hardcoded English placeholder strings in `acestep/ui/gradio/interfaces/training_dataset_tab_label_preview.py` (e.g., `"Music description..."` for `edit_caption`, `"pop, electronic, dance..."` for `edit_genre`, `Dropdown` choices `["Use Global Ratio", "Caption", "Genre"]` and its default value, and the lyrics placeholder `"[Verse 1]\nLyrics here...\n\n[Chorus]\n..."`) are pre-existing localization gaps carried forward from before the decomposition refactor. They are deferred to a follow-up PR and should not be flagged as new issues introduced by decomposition PRs.

Learnt from: tonyjohnvan
Repo: ace-step/ACE-Step-1.5 PR: 858
File: acestep/customized_vllm/__init__.py:125-128
Timestamp: 2026-03-17T21:18:29.789Z
Learning: In `acestep/customized_vllm/__init__.py` and `acestep/llm_inference.py`, the ACE-Step vLLM path always runs with `batch_size=1` (one prompt per `LLM.generate()` call). Because of this, `[sampling_params] * len(prompts)` only ever produces a single-element list, and sharing one `MetadataConstrainedLogitsProcessor` / `logits_processor_update_state` instance across all slots is safe by design. Do not flag the shared logits-processor reference as a mutable-state / batch-corruption bug.

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 769
File: acestep/api/train_api_dataset_auto_label_async_route.py:158-169
Timestamp: 2026-03-04T22:40:44.930Z
Learning: In ace-step/ACE-Step-1.5, the `builder.label_all_samples(...)` call in both `train_api_dataset_auto_label_async_route.py` and `train_api_dataset_auto_label_sync_route.py` passes `chunk_size`, `batch_size`, and `sample_labeled_callback` kwargs that are not declared in the current `label_all_samples` signature in `acestep/training/dataset_builder_modules/label_all.py`. This is a pre-existing issue carried forward unchanged from the pre-decomposition `train_api_dataset_service.py` and is deferred to a follow-up PR.

Learnt from: tonyjohnvan
Repo: ace-step/ACE-Step-1.5 PR: 858
File: acestep/customized_vllm/__init__.py:1-406
Timestamp: 2026-03-17T21:14:13.573Z
Learning: In `acestep/customized_vllm/__init__.py`, the author (tonyjohnvan) intentionally keeps all inference engine internals (ForwardState, thread-local context, GenerationSlot, CachePool, _EngineConfig, LLM, generation loop) in a single ~400 LOC file with clear section comments. They consider splitting into sub-modules to add import complexity for no functional benefit. Do not raise module-size refactor suggestions for this file.

Learnt from: tonyjohnvan
Repo: ace-step/ACE-Step-1.5 PR: 858
File: acestep/customized_vllm/pipeline.py:1-423
Timestamp: 2026-03-17T21:17:04.611Z
Learning: In `acestep/customized_vllm/pipeline.py`, the author (tonyjohnvan) intentionally keeps all inference pipeline internals (sampling helpers _filter_by_top_k/_filter_by_nucleus/sample_tokens, model init, KV cache provisioning _provision_kv_storage, transfer buffers _init_transfer_buffers, prefill/decode execution _execute_prefill/_execute_autoregressive/_forward_pass, and CUDA graph capture _compile_execution_graphs) in a single ~420 LOC file with clear section comments. They consider splitting into sub-modules to add import complexity for no functional benefit. Do not raise module-size refactor suggestions for this file.

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 837
File: acestep/core/generation/handler/init_service_loader_components.py:44-47
Timestamp: 2026-03-14T10:35:13.950Z
Learning: In `acestep/core/generation/handler/init_service_loader.py` (origin/main, pre-extraction), `_load_text_encoder_and_tokenizer()` already used `self.text_encoder.to("cpu").to(self.dtype)` in the CPU-offload branch — using `self.dtype` (which can be bfloat16) instead of `torch.float32`. This is a pre-existing behavior, not introduced by PR `#837`. PR `#837` moves it unchanged into `init_service_loader_components.py` and intends to fix it there. Do not flag the original `init_service_loader.py` pattern as a new regression in future reviews.

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 808
File: acestep/inference.py:25-26
Timestamp: 2026-03-11T10:56:30.827Z
Learning: In ace-step/ACE-Step-1.5, acestep/inference.py’s lyric-density guard now uses a Unicode-aware tokenizer (_extract_lyric_word_tokens with str.isalnum()) and the counting helpers (_count_lyric_words/_count_lyric_syllables) ignore tag lines. A non‑Latin regression test exists, so do not re‑flag ASCII-only tokenization for this path in future reviews.

Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: AGENTS.md:0-0
Timestamp: 2026-02-28T14:14:55.765Z
Learning: AI-Agent Workflow: Add/update focused tests. Self-review only changed hunks for regressions and scope creep. Summarize risk, validation, and non-target impact in PR notes.

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 0
File: :0-0
Timestamp: 2026-03-20T01:04:19.217Z
Learning: In `acestep/text_tasks/external_lm_captioning.py`, `apply_user_metadata_overrides()` intentionally writes both the canonical field name and its runtime alias in a single pass (e.g., `plan.keyscale` + `plan.key_scale`, `plan.timesignature` + `plan.time_signature`, `plan.language` + `plan.vocal_language`). Invalid `bpm`/`duration` conversion errors are swallowed and logged via `loguru.logger.debug()`. Do not flag either of these patterns as bugs in future reviews.

Learnt from: CR
Repo: ace-step/ACE-Step-1.5 PR: 0
File: .github/copilot-instructions.md:0-0
Timestamp: 2026-03-24T14:09:24.530Z
Learning: Applies to acestep/{handler,inference,audio_utils}.py : Target 4GB VRAM as the memory constraint and minimize memory allocations accordingly

Learnt from: 1larity
Repo: ace-step/ACE-Step-1.5 PR: 508
File: acestep/core/generation/handler/memory_utils.py:1-18
Timestamp: 2026-02-13T11:03:36.432Z
Learning: Enforce a module size guideline for Python files: only raise module-size concerns when a file exceeds 200 lines of code (LOC). Do not raise for files in the 150–200 LOC range, and accept files below 200 LOC without justification to split. Apply this across all Python files (pattern **/*.py) to maintain consistency in reviews; when a file exceeds 200 LOC, suggest splitting into smaller modules to improve readability and maintainability.


# Phase 2: DiT music generation
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
dit_generate_kwargs = {
Expand Down
115 changes: 105 additions & 10 deletions acestep/llm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, persistent_storage_path: Optional[str] = None):
self.dtype = torch.float32
self.offload_to_cpu = False
self.disable_tqdm = os.environ.get("ACESTEP_DISABLE_TQDM", "").lower() in ("1", "true", "yes") or not (hasattr(sys.stderr, 'isatty') and sys.stderr.isatty())
self._last_init_config = None

# HuggingFace Space persistent storage support
if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE:
Expand All @@ -80,6 +81,24 @@ def __init__(self, persistent_storage_path: Optional[str] = None):
self._mlx_model = None
self._mlx_model_path = None

def _save_last_init_config(
self,
checkpoint_dir: str,
lm_model_path: str,
device: str,
offload_to_cpu: bool,
dtype: Optional[torch.dtype],
) -> None:
"""Persist the last successfully initialized LM configuration."""
self._last_init_config = {
"checkpoint_dir": checkpoint_dir,
"lm_model_path": lm_model_path,
"backend": self.llm_backend,
"device": device,
"offload_to_cpu": offload_to_cpu,
"dtype": dtype,
}

def _clear_accelerator_cache(self) -> None:
"""Release freed accelerator memory back to the driver.

Expand Down Expand Up @@ -119,22 +138,44 @@ def unload(self) -> None:
try:
if self.llm_backend == "vllm":
try:
if hasattr(self.llm, "reset"):
self.llm.reset()
except Exception:
pass
self._cleanup_torch_distributed_state()
if self.llm is not None:
if hasattr(self.llm, "exit"):
logger.info("[LLM vLLM] Calling nanovllm exit() for hard teardown")
self.llm.exit()
elif hasattr(self.llm, "reset"):
logger.info("[LLM vLLM] exit() missing, falling back to reset()")
self.llm.reset()
except Exception as exc:
logger.warning(f"[LLM vLLM] Error during vLLM teardown: {exc}")

try:
self._cleanup_torch_distributed_state()
except Exception as exc:
logger.warning(f"[LLM vLLM] torch distributed cleanup failed: {exc}")

self.llm = None
self.llm_tokenizer = None
self.constrained_processor = None
self.llm_initialized = False
self.llm_backend = None
self._hf_model_for_scoring = None
self._mlx_model = None
self._mlx_model_path = None

gc.collect()

if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
torch.cuda.synchronize()
except Exception:
pass
try:
torch.cuda.empty_cache()
except Exception:
pass
try:
torch.cuda.ipc_collect()
except Exception:
pass
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
if hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
Expand All @@ -143,8 +184,27 @@ def unload(self) -> None:
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
torch.xpu.synchronize()
except Exception:
pass
except Exception as exc:
logger.warning(f"[LLM] unload failed: {exc}")

def reload_last_configuration(self) -> Tuple[str, bool]:
"""Recreate the LM from the last successful initialize() configuration."""
if not self._last_init_config:
return "❌ No previous LM initialization config available", False

cfg = dict(self._last_init_config)

logger.info("[LLM] Reloading last configuration: backend={} model={} device={}",
cfg.get("backend"), cfg.get("lm_model_path"), cfg.get("device"))

return self.initialize(
checkpoint_dir=cfg["checkpoint_dir"],
lm_model_path=cfg["lm_model_path"],
backend=cfg["backend"],
device=cfg["device"],
offload_to_cpu=cfg["offload_to_cpu"],
dtype=cfg["dtype"],
)

def _cleanup_torch_distributed_state(self) -> None:
"""Destroy default torch distributed process group when already initialized."""
Expand Down Expand Up @@ -659,6 +719,13 @@ def initialize(
logger.info("Attempting MLX backend for Apple Silicon acceleration...")
mlx_success, mlx_status = self._load_mlx_model(full_lm_model_path)
if mlx_success:
self._save_last_init_config(
checkpoint_dir=checkpoint_dir,
lm_model_path=lm_model_path,
device=device,
offload_to_cpu=offload_to_cpu,
dtype=dtype,
)
return mlx_status, True
else:
logger.warning(f"MLX backend failed: {mlx_status}")
Expand All @@ -669,6 +736,13 @@ def initialize(
if not success:
return status_msg, False
status_msg = f"✅ 5Hz LM initialized (PyTorch fallback from MLX)\nModel: {full_lm_model_path}\nBackend: PyTorch"
self._save_last_init_config(
checkpoint_dir=checkpoint_dir,
lm_model_path=lm_model_path,
device=device,
offload_to_cpu=offload_to_cpu,
dtype=dtype,
)
return status_msg, True
# else: backend was "vllm" on MPS, continue to vllm attempt below
elif backend == "mlx":
Expand All @@ -678,6 +752,13 @@ def initialize(
if not success:
return status_msg, False
status_msg = f"✅ 5Hz LM initialized (PyTorch fallback, MLX not available)\nModel: {full_lm_model_path}\nBackend: PyTorch"
self._save_last_init_config(
checkpoint_dir=checkpoint_dir,
lm_model_path=lm_model_path,
device=device,
offload_to_cpu=offload_to_cpu,
dtype=dtype,
)
return status_msg, True

if backend == "vllm" and device != "cuda":
Expand Down Expand Up @@ -733,6 +814,13 @@ def initialize(
logger.warning("vllm failed on MPS, trying MLX backend...")
mlx_success, mlx_status = self._load_mlx_model(full_lm_model_path)
if mlx_success:
self._save_last_init_config(
checkpoint_dir=checkpoint_dir,
lm_model_path=lm_model_path,
device=device,
offload_to_cpu=offload_to_cpu,
dtype=dtype,
)
return mlx_status, True
logger.warning(f"MLX also failed: {mlx_status}, falling back to PyTorch")
logger.warning("Falling back to PyTorch backend")
Expand All @@ -749,6 +837,13 @@ def initialize(
if vllm_preflight_warning is not None:
status_msg += f"\nNote: {vllm_preflight_warning}"

self._save_last_init_config(
checkpoint_dir=checkpoint_dir,
lm_model_path=lm_model_path,
device=device,
offload_to_cpu=offload_to_cpu,
dtype=dtype,
)
return status_msg, True

except Exception as e:
Expand Down