Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions acestep/core/generation/handler/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover=None,
encoder_attention_mask_non_cover=None,
context_latents_non_cover=None,
progress_callback=None,
disable_tqdm: bool = False,
sampler_mode: str = "euler",
velocity_norm_threshold: float = 0.0,
Expand All @@ -59,6 +60,8 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover: Optional non-cover conditioning tensor.
encoder_attention_mask_non_cover: Unused; accepted for API compatibility.
context_latents_non_cover: Optional non-cover context latent tensor.
progress_callback: Optional diffusion-step callback receiving
``(current_step, total_steps, desc)``.
disable_tqdm: If True, suppress the diffusion progress bar.
sampler_mode: Sampler algorithm — ``"euler"`` or ``"heun"``.
velocity_norm_threshold: Velocity norm clamping threshold (0 = disabled).
Expand Down Expand Up @@ -141,6 +144,7 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover_np=enc_nc_np,
context_latents_non_cover_np=ctx_nc_np,
compile_model=getattr(self, "mlx_dit_compiled", False),
progress_callback=progress_callback,
disable_tqdm=disable_tqdm,
sampler_mode=sampler_mode,
velocity_norm_threshold=velocity_norm_threshold,
Expand Down
25 changes: 22 additions & 3 deletions acestep/core/generation/handler/generate_music_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,19 @@ def _decode_generate_music_pred_latents(
Returns:
Tuple of decoded waveforms, CPU latents, and updated time-cost payload.
"""
def _emit_decode_progress(value: float, desc: str) -> None:
if progress:
progress(value, desc=desc)

def _decode_chunk_progress(current: int, total: int, desc: str) -> None:
if total <= 0:
return
frac = min(1.0, max(0.0, current / total))
mapped = 0.82 + 0.16 * frac
_emit_decode_progress(mapped, desc)

if progress:
progress(0.8, desc="Decoding audio...")
_emit_decode_progress(0.8, "Preparing audio decode...")
logger.info("[generate_music] Decoding latents with VAE...")
start_time = time.time()
with torch.inference_mode():
Expand Down Expand Up @@ -168,12 +179,19 @@ def _decode_generate_music_pred_latents(
pred_latents_for_decode = pred_latents_for_decode.cpu()
self._empty_cache()
try:
_emit_decode_progress(0.82, "Decoding audio chunks...")
if use_tiled_decode:
logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
pred_wavs = self.tiled_decode(pred_latents_for_decode)
pred_wavs = self.tiled_decode(
pred_latents_for_decode,
progress_callback=_decode_chunk_progress,
)
elif using_mlx_vae:
try:
pred_wavs = self._mlx_vae_decode(pred_latents_for_decode)
pred_wavs = self._mlx_vae_decode(
pred_latents_for_decode,
progress_callback=_decode_chunk_progress,
)
except Exception as exc:
logger.warning(
f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch"
Expand Down Expand Up @@ -202,6 +220,7 @@ def _decode_generate_music_pred_latents(
if torch.any(peak > 1.0):
pred_wavs = pred_wavs / peak.clamp(min=1.0)
self._empty_cache()
_emit_decode_progress(0.98, "Decoding audio chunks...")
gc.collect()
self._empty_cache()
end_time = time.time()
Expand Down
17 changes: 13 additions & 4 deletions acestep/core/generation/handler/generate_music_decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,20 @@ def _max_memory_allocated(self):
"""Return deterministic max-memory value for debug logging."""
return 0.0

def _mlx_vae_decode(self, latents):
def _mlx_vae_decode(self, latents, progress_callback=None):
"""Return deterministic decoded waveform for MLX decode branch."""
_ = latents
if progress_callback is not None:
progress_callback(1, 64, "Decoding audio chunks...")
progress_callback(64, 64, "Decoding audio chunks...")
return torch.ones(1, 2, 8)

def tiled_decode(self, latents):
def tiled_decode(self, latents, progress_callback=None):
"""Return deterministic decoded waveform for tiled decode branch."""
_ = latents
if progress_callback is not None:
progress_callback(1, 64, "Decoding audio chunks...")
progress_callback(64, 64, "Decoding audio chunks...")
return torch.ones(1, 2, 8)


Expand Down Expand Up @@ -189,7 +195,11 @@ def _progress(value, desc=None):
self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6)
self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6)
self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6)
self.assertEqual(host.progress_calls[0][0], 0.8)
self.assertEqual(host.progress_calls[0], (0.8, "Preparing audio decode..."))
self.assertIn((0.82, "Decoding audio chunks..."), host.progress_calls)
progress_values = [value for value, _ in host.progress_calls]
self.assertEqual(progress_values, sorted(progress_values))
self.assertAlmostEqual(host.progress_calls[-1][0], 0.98, places=6)

def test_decode_pred_latents_restores_vae_device_on_decode_error(self):
"""It restores VAE device in the CPU-offload path even when decode raises."""
Expand Down Expand Up @@ -313,4 +323,3 @@ def __init__(self):

if __name__ == "__main__":
unittest.main()

46 changes: 40 additions & 6 deletions acestep/core/generation/handler/generate_music_execute.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Execution helper for ``generate_music`` service invocation with progress tracking."""

import os
import threading
from typing import Any, Dict, List, Optional, Sequence
import time
from typing import Any, Callable, Dict, List, Optional, Sequence

import threading
from loguru import logger

from .runtime_progress_relay import RuntimeProgressRelay

# Maximum wall-clock seconds to wait for service_generate before declaring a hang.
# Generous default: most generations finish in 30-120s, but large batches on slow
# GPUs can take several minutes. Override via ACESTEP_GENERATION_TIMEOUT env var.
Expand Down Expand Up @@ -47,9 +50,10 @@ def _run_generate_music_service_with_progress(
"""
infer_steps_for_progress = len(timesteps) if timesteps else inference_steps
progress_desc = f"Generating music (batch size: {actual_batch_size})..."
progress(0.52, desc=progress_desc)
stop_event = None
progress_thread = None
relay = RuntimeProgressRelay(progress=progress, start=0.52, end=0.79)
relay.emit_progress(0.52, progress_desc)

# --- Timeout-wrapped service_generate ---
# Run the actual CUDA work in a child thread so we can join() with a
Expand Down Expand Up @@ -91,13 +95,19 @@ def _service_target():
chunk_mask_modes=service_inputs.get("chunk_mask_modes_batch"),
repaint_crossfade_frames=repaint_crossfade_frames,
repaint_injection_ratio=repaint_injection_ratio,
progress_callback=relay.enqueue,
)
except Exception as exc:
_error["exc"] = exc

def _stop_progress_estimator_if_finished() -> None:
"""Stop the estimator thread but keep its handle until it actually exits."""
nonlocal progress_thread
progress_thread = RuntimeProgressRelay.stop_estimator_if_finished(progress_thread)

try:
stop_event, progress_thread = self._start_diffusion_progress_estimator(
progress=progress,
progress=relay.emit_progress,
start=0.52,
end=0.79,
infer_steps=infer_steps_for_progress,
Expand All @@ -112,7 +122,29 @@ def _service_target():
daemon=True,
)
gen_thread.start()
gen_thread.join(timeout=_DEFAULT_GENERATION_TIMEOUT)
deadline = time.monotonic() + _DEFAULT_GENERATION_TIMEOUT
poll_interval = 0.1
saw_runtime_progress = False
while gen_thread.is_alive():
remaining = deadline - time.monotonic()
if remaining <= 0:
break
gen_thread.join(timeout=min(poll_interval, remaining))
drained_runtime_progress = relay.drain()
if drained_runtime_progress and not saw_runtime_progress:
saw_runtime_progress = True
if stop_event is not None:
stop_event.set()
_stop_progress_estimator_if_finished()
if not gen_thread.is_alive():
break

drained_runtime_progress = relay.drain()
if drained_runtime_progress and not saw_runtime_progress:
saw_runtime_progress = True
if stop_event is not None:
stop_event.set()
_stop_progress_estimator_if_finished()

if gen_thread.is_alive():
logger.error(
Expand All @@ -128,11 +160,13 @@ def _service_target():
)
if "exc" in _error:
raise _error["exc"]
relay.emit_progress(0.79, progress_desc)

finally:
relay.shutdown()
if stop_event is not None:
stop_event.set()
if progress_thread is not None:
progress_thread.join(timeout=1.0)

return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
Loading
Loading