-
Notifications
You must be signed in to change notification settings - Fork 1.2k
provide the ability to unload the LM prior to the DiT execution phase #1090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
fc44709
fd286f5
c96cdf3
2058f73
aa1c948
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||||||||||
| from dataclasses import dataclass, field, asdict | ||||||||||||||
| from loguru import logger | ||||||||||||||
| import torch | ||||||||||||||
| import gc | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| from acestep.audio_utils import AudioSaver, apply_fade, generate_uuid_from_params, normalize_audio, get_lora_weights_hash | ||||||||||||||
|
|
@@ -331,6 +332,42 @@ def _update_metadata_from_lm( | |||||||||||||
| return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def _unload_lm_before_dit(llm_handler): | ||||||||||||||
| if llm_handler is None: | ||||||||||||||
| return | ||||||||||||||
|
|
||||||||||||||
| logger.info("Unloading LM before DiT. backend={}, initialized={}", | ||||||||||||||
| getattr(llm_handler, "llm_backend", None), | ||||||||||||||
| getattr(llm_handler, "llm_initialized", None)) | ||||||||||||||
|
|
||||||||||||||
| if torch.cuda.is_available(): | ||||||||||||||
| alloc = torch.cuda.memory_allocated() / (1024 ** 3) | ||||||||||||||
| reserved = torch.cuda.memory_reserved() / (1024 ** 3) | ||||||||||||||
| logger.info("Before LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved) | ||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| llm_handler.unload() | ||||||||||||||
| except Exception as exc: | ||||||||||||||
| logger.warning("llm_handler.unload() failed: {}", exc) | ||||||||||||||
|
|
||||||||||||||
| gc.collect() | ||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| if torch.cuda.is_available(): | ||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||
| try: | ||||||||||||||
| torch.cuda.ipc_collect() | ||||||||||||||
| except Exception: | ||||||||||||||
| pass | ||||||||||||||
|
|
||||||||||||||
| alloc = torch.cuda.memory_allocated() / (1024 ** 3) | ||||||||||||||
| reserved = torch.cuda.memory_reserved() / (1024 ** 3) | ||||||||||||||
| logger.info("After LM unload: allocated={:.2f} GB reserved={:.2f} GB", alloc, reserved) | ||||||||||||||
| except Exception: | ||||||||||||||
| pass | ||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @_get_spaces_gpu_decorator(duration=180) | ||||||||||||||
| def generate_music( | ||||||||||||||
| dit_handler, | ||||||||||||||
|
|
@@ -422,8 +459,22 @@ def generate_music( | |||||||||||||
| # 3. use_cot_language=True: detect vocal language via CoT | ||||||||||||||
| # 4. use_cot_metas=True: fill missing metadata via CoT | ||||||||||||||
| need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas | ||||||||||||||
| use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks | ||||||||||||||
|
|
||||||||||||||
| # If this request needs the LM, but the LM was previously unloaded (for example | ||||||||||||||
| # after the LM->DiT handoff), try to reload it now. | ||||||||||||||
| lm_status = [] | ||||||||||||||
|
|
||||||||||||||
| request_needs_lm = (params.task_type not in skip_lm_tasks) and (params.thinking or need_lm_for_cot) | ||||||||||||||
|
|
||||||||||||||
| if request_needs_lm and llm_handler is not None and not llm_handler.llm_initialized: | ||||||||||||||
| logger.info("LM required but not initialized; attempting reload from saved config") | ||||||||||||||
| reload_status, reload_ok = llm_handler.reload_last_configuration() | ||||||||||||||
| lm_status.append(reload_status) | ||||||||||||||
|
|
||||||||||||||
| if not reload_ok: | ||||||||||||||
| logger.error(f"[generate_music] LM reload failed: {reload_status}") | ||||||||||||||
|
|
||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||||||||||||||
| use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks | ||||||||||||||
|
|
||||||||||||||
| if params.task_type in skip_lm_tasks: | ||||||||||||||
| logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly") | ||||||||||||||
|
|
@@ -609,6 +660,17 @@ def generate_music( | |||||||||||||
| if params.task_type in ("cover", "repaint", "lego", "extract"): | ||||||||||||||
| audio_duration = None | ||||||||||||||
|
|
||||||||||||||
| # Unload the LM now if option is enabled and the backend supports reload cleanly | ||||||||||||||
| unload_enabled = os.environ.get("ACESTEP_UNLOAD_LM_BEFORE_DIT", "").lower() in ("1", "true", "yes") | ||||||||||||||
| safe_unload_backends = {"pt", "vllm"} | ||||||||||||||
| current_backend = getattr(llm_handler, "llm_backend", None) if llm_handler is not None else None | ||||||||||||||
|
|
||||||||||||||
| if use_lm and unload_enabled: | ||||||||||||||
| if current_backend in safe_unload_backends: | ||||||||||||||
| _unload_lm_before_dit(llm_handler) | ||||||||||||||
| else: | ||||||||||||||
| logger.info("[generate_music] Skipping LM unload before DiT for unsupported backend={}", current_backend) | ||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
Comment on lines
+676
to
+680
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This unload path strands the other LM inference APIs. After Line 677 unloads the shared handler, Based on learnings, 🤖 Prompt for AI Agents
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically, yes, this is a lifecycle inconsistency. However it is not clear that these other APIs would even be hit from the normal UI flow here. I don't know enough at this point to determine if this is a real problem or not...
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainLet me check how those functions are wired into the Gradio UI to give you a concrete answer. 🏁 Script executed: Length of output: 66675
Here's what was found:
All three are separate clickable buttons in the same Gradio session sharing the same
Notably, Severity is bounded because:
The fix would be adding the same reload guard (check 🧠 Learnings used |
||||||||||||||
|
|
||||||||||||||
| # Phase 2: DiT music generation | ||||||||||||||
| # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation | ||||||||||||||
| dit_generate_kwargs = { | ||||||||||||||
|
|
||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.