diff --git a/acestep/core/generation/handler/diffusion.py b/acestep/core/generation/handler/diffusion.py
index 228beaf70..b4a1efdd7 100644
--- a/acestep/core/generation/handler/diffusion.py
+++ b/acestep/core/generation/handler/diffusion.py
@@ -34,6 +34,7 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover=None,
encoder_attention_mask_non_cover=None,
context_latents_non_cover=None,
+ progress_callback=None,
disable_tqdm: bool = False,
sampler_mode: str = "euler",
velocity_norm_threshold: float = 0.0,
@@ -59,6 +60,8 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover: Optional non-cover conditioning tensor.
encoder_attention_mask_non_cover: Unused; accepted for API compatibility.
context_latents_non_cover: Optional non-cover context latent tensor.
+ progress_callback: Optional diffusion-step callback receiving
+ ``(current_step, total_steps, desc)``.
disable_tqdm: If True, suppress the diffusion progress bar.
sampler_mode: Sampler algorithm — ``"euler"`` or ``"heun"``.
velocity_norm_threshold: Velocity norm clamping threshold (0 = disabled).
@@ -141,6 +144,7 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover_np=enc_nc_np,
context_latents_non_cover_np=ctx_nc_np,
compile_model=getattr(self, "mlx_dit_compiled", False),
+ progress_callback=progress_callback,
disable_tqdm=disable_tqdm,
sampler_mode=sampler_mode,
velocity_norm_threshold=velocity_norm_threshold,
diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py
index b7330503c..e04704e28 100644
--- a/acestep/core/generation/handler/generate_music_decode.py
+++ b/acestep/core/generation/handler/generate_music_decode.py
@@ -122,8 +122,19 @@ def _decode_generate_music_pred_latents(
Returns:
Tuple of decoded waveforms, CPU latents, and updated time-cost payload.
"""
+ def _emit_decode_progress(value: float, desc: str) -> None:
+ if progress:
+ progress(value, desc=desc)
+
+ def _decode_chunk_progress(current: int, total: int, desc: str) -> None:
+ if total <= 0:
+ return
+ frac = min(1.0, max(0.0, current / total))
+ mapped = 0.82 + 0.16 * frac
+ _emit_decode_progress(mapped, desc)
+
if progress:
- progress(0.8, desc="Decoding audio...")
+ _emit_decode_progress(0.8, "Preparing audio decode...")
logger.info("[generate_music] Decoding latents with VAE...")
start_time = time.time()
with torch.inference_mode():
@@ -168,12 +179,19 @@ def _decode_generate_music_pred_latents(
pred_latents_for_decode = pred_latents_for_decode.cpu()
self._empty_cache()
try:
+ _emit_decode_progress(0.82, "Decoding audio chunks...")
if use_tiled_decode:
logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
- pred_wavs = self.tiled_decode(pred_latents_for_decode)
+ pred_wavs = self.tiled_decode(
+ pred_latents_for_decode,
+ progress_callback=_decode_chunk_progress,
+ )
elif using_mlx_vae:
try:
- pred_wavs = self._mlx_vae_decode(pred_latents_for_decode)
+ pred_wavs = self._mlx_vae_decode(
+ pred_latents_for_decode,
+ progress_callback=_decode_chunk_progress,
+ )
except Exception as exc:
logger.warning(
f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch"
@@ -202,6 +220,7 @@ def _decode_generate_music_pred_latents(
if torch.any(peak > 1.0):
pred_wavs = pred_wavs / peak.clamp(min=1.0)
self._empty_cache()
+ _emit_decode_progress(0.98, "Decoding audio chunks...")
gc.collect()
self._empty_cache()
end_time = time.time()
diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py
index ece612ff1..ae1c9c1d9 100644
--- a/acestep/core/generation/handler/generate_music_decode_test.py
+++ b/acestep/core/generation/handler/generate_music_decode_test.py
@@ -115,14 +115,20 @@ def _max_memory_allocated(self):
"""Return deterministic max-memory value for debug logging."""
return 0.0
- def _mlx_vae_decode(self, latents):
+ def _mlx_vae_decode(self, latents, progress_callback=None):
"""Return deterministic decoded waveform for MLX decode branch."""
_ = latents
+ if progress_callback is not None:
+ progress_callback(1, 64, "Decoding audio chunks...")
+ progress_callback(64, 64, "Decoding audio chunks...")
return torch.ones(1, 2, 8)
- def tiled_decode(self, latents):
+ def tiled_decode(self, latents, progress_callback=None):
"""Return deterministic decoded waveform for tiled decode branch."""
_ = latents
+ if progress_callback is not None:
+ progress_callback(1, 64, "Decoding audio chunks...")
+ progress_callback(64, 64, "Decoding audio chunks...")
return torch.ones(1, 2, 8)
@@ -189,7 +195,11 @@ def _progress(value, desc=None):
self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6)
self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6)
self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6)
- self.assertEqual(host.progress_calls[0][0], 0.8)
+ self.assertEqual(host.progress_calls[0], (0.8, "Preparing audio decode..."))
+ self.assertIn((0.82, "Decoding audio chunks..."), host.progress_calls)
+ progress_values = [value for value, _ in host.progress_calls]
+ self.assertEqual(progress_values, sorted(progress_values))
+ self.assertAlmostEqual(host.progress_calls[-1][0], 0.98, places=6)
def test_decode_pred_latents_restores_vae_device_on_decode_error(self):
"""It restores VAE device in the CPU-offload path even when decode raises."""
@@ -313,4 +323,3 @@ def __init__(self):
if __name__ == "__main__":
unittest.main()
-
diff --git a/acestep/core/generation/handler/generate_music_execute.py b/acestep/core/generation/handler/generate_music_execute.py
index b8deecf45..c016c8835 100644
--- a/acestep/core/generation/handler/generate_music_execute.py
+++ b/acestep/core/generation/handler/generate_music_execute.py
@@ -1,11 +1,14 @@
"""Execution helper for ``generate_music`` service invocation with progress tracking."""
import os
-import threading
-from typing import Any, Dict, List, Optional, Sequence
+import time
+from typing import Any, Callable, Dict, List, Optional, Sequence
+import threading
from loguru import logger
+from .runtime_progress_relay import RuntimeProgressRelay
+
# Maximum wall-clock seconds to wait for service_generate before declaring a hang.
# Generous default: most generations finish in 30-120s, but large batches on slow
# GPUs can take several minutes. Override via ACESTEP_GENERATION_TIMEOUT env var.
@@ -47,9 +50,10 @@ def _run_generate_music_service_with_progress(
"""
infer_steps_for_progress = len(timesteps) if timesteps else inference_steps
progress_desc = f"Generating music (batch size: {actual_batch_size})..."
- progress(0.52, desc=progress_desc)
stop_event = None
progress_thread = None
+ relay = RuntimeProgressRelay(progress=progress, start=0.52, end=0.79)
+ relay.emit_progress(0.52, progress_desc)
# --- Timeout-wrapped service_generate ---
# Run the actual CUDA work in a child thread so we can join() with a
@@ -91,13 +95,19 @@ def _service_target():
chunk_mask_modes=service_inputs.get("chunk_mask_modes_batch"),
repaint_crossfade_frames=repaint_crossfade_frames,
repaint_injection_ratio=repaint_injection_ratio,
+ progress_callback=relay.enqueue,
)
except Exception as exc:
_error["exc"] = exc
+ def _stop_progress_estimator_if_finished() -> None:
+ """Stop the estimator thread but keep its handle until it actually exits."""
+ nonlocal progress_thread
+ progress_thread = RuntimeProgressRelay.stop_estimator_if_finished(progress_thread)
+
try:
stop_event, progress_thread = self._start_diffusion_progress_estimator(
- progress=progress,
+ progress=relay.emit_progress,
start=0.52,
end=0.79,
infer_steps=infer_steps_for_progress,
@@ -112,7 +122,29 @@ def _service_target():
daemon=True,
)
gen_thread.start()
- gen_thread.join(timeout=_DEFAULT_GENERATION_TIMEOUT)
+ deadline = time.monotonic() + _DEFAULT_GENERATION_TIMEOUT
+ poll_interval = 0.1
+ saw_runtime_progress = False
+ while gen_thread.is_alive():
+ remaining = deadline - time.monotonic()
+ if remaining <= 0:
+ break
+ gen_thread.join(timeout=min(poll_interval, remaining))
+ drained_runtime_progress = relay.drain()
+ if drained_runtime_progress and not saw_runtime_progress:
+ saw_runtime_progress = True
+ if stop_event is not None:
+ stop_event.set()
+ _stop_progress_estimator_if_finished()
+ if not gen_thread.is_alive():
+ break
+
+ drained_runtime_progress = relay.drain()
+ if drained_runtime_progress and not saw_runtime_progress:
+ saw_runtime_progress = True
+ if stop_event is not None:
+ stop_event.set()
+ _stop_progress_estimator_if_finished()
if gen_thread.is_alive():
logger.error(
@@ -128,11 +160,13 @@ def _service_target():
)
if "exc" in _error:
raise _error["exc"]
+ relay.emit_progress(0.79, progress_desc)
finally:
+ relay.shutdown()
if stop_event is not None:
stop_event.set()
if progress_thread is not None:
progress_thread.join(timeout=1.0)
- return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
\ No newline at end of file
+ return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
diff --git a/acestep/core/generation/handler/generate_music_execute_test.py b/acestep/core/generation/handler/generate_music_execute_test.py
index aead22214..a105f6c02 100644
--- a/acestep/core/generation/handler/generate_music_execute_test.py
+++ b/acestep/core/generation/handler/generate_music_execute_test.py
@@ -1,6 +1,8 @@
"""Unit tests for ``generate_music`` execution helper mixin."""
import unittest
+import time
+import threading
from acestep.core.generation.handler.generate_music_execute import GenerateMusicExecuteMixin
@@ -12,12 +14,19 @@ def __init__(self):
"""Capture calls for assertions."""
self.started = False
self.stopped = False
+ self.stop_calls = 0
self.service_calls = 0
+ self.emit_runtime_progress = True
+ self.estimator_progress_values = []
+ self.service_delay_sec = 0.0
def _start_diffusion_progress_estimator(self, **kwargs):
"""Return fake stop event/thread handles used by helper lifecycle."""
- _ = kwargs
self.started = True
+ progress = kwargs["progress"]
+ desc = kwargs["desc"]
+ for value in self.estimator_progress_values:
+ progress(value, desc=desc)
class _Stop:
"""Minimal stop-event stand-in used by the test host."""
@@ -29,6 +38,7 @@ def __init__(self, host):
def set(self):
"""Mark progress lifecycle as stopped."""
self.host.stopped = True
+ self.host.stop_calls += 1
class _Thread:
"""Minimal thread stand-in exposing a ``join`` method."""
@@ -41,8 +51,13 @@ def join(self, timeout=None):
def service_generate(self, **kwargs):
"""Record service invocation and return minimal output payload."""
- _ = kwargs
+ callback = kwargs.get("progress_callback")
self.service_calls += 1
+ if self.service_delay_sec > 0:
+ time.sleep(self.service_delay_sec)
+ if callback is not None and self.emit_runtime_progress:
+ callback(1, 4, "DiT diffusion...")
+ callback(4, 4, "DiT diffusion...")
return {"target_latents": "ok"}
@@ -52,8 +67,11 @@ class GenerateMusicExecuteMixinTests(unittest.TestCase):
def test_run_service_with_progress_invokes_service_and_stops_estimator(self):
"""Helper should call service once and always stop progress estimator."""
host = _Host()
+ host.emit_runtime_progress = False
+ host.estimator_progress_values = [0.63, 0.79]
+ updates = []
out = host._run_generate_music_service_with_progress(
- progress=lambda *args, **kwargs: None,
+ progress=lambda value, desc=None: updates.append((value, desc)),
actual_batch_size=1,
audio_duration=10.0,
inference_steps=8,
@@ -83,8 +101,157 @@ def test_run_service_with_progress_invokes_service_and_stops_estimator(self):
)
self.assertTrue(host.started)
self.assertTrue(host.stopped)
+ self.assertEqual(host.stop_calls, 1)
self.assertEqual(host.service_calls, 1)
self.assertEqual(out["outputs"]["target_latents"], "ok")
+ self.assertAlmostEqual(updates[-1][0], 0.79, places=6)
+
+ def test_runtime_progress_events_are_forwarded_to_ui_progress(self):
+ """Step-level runtime progress should stop the estimator and reach phase completion."""
+ host = _Host()
+ updates = []
+
+ host._run_generate_music_service_with_progress(
+ progress=lambda value, desc=None: updates.append((value, desc)),
+ actual_batch_size=1,
+ audio_duration=10.0,
+ inference_steps=8,
+ timesteps=None,
+ service_inputs={
+ "captions_batch": ["c"],
+ "lyrics_batch": ["l"],
+ "metas_batch": ["m"],
+ "vocal_languages_batch": ["en"],
+ "target_wavs_tensor": None,
+ "repainting_start_batch": [0.0],
+ "repainting_end_batch": [1.0],
+ "instructions_batch": ["i"],
+ "audio_code_hints_batch": None,
+ "should_return_intermediate": True,
+ },
+ refer_audios=None,
+ guidance_scale=7.0,
+ actual_seed_list=[1],
+ audio_cover_strength=1.0,
+ cover_noise_strength=0.0,
+ use_adg=False,
+ cfg_interval_start=0.0,
+ cfg_interval_end=1.0,
+ shift=1.0,
+ infer_method="ode",
+ )
+
+ self.assertTrue(any(desc == "DiT diffusion..." for _, desc in updates))
+ self.assertEqual(host.stop_calls, 2)
+ self.assertAlmostEqual(updates[-1][0], 0.79, places=6)
+
+ def test_runtime_progress_handoff_stays_monotonic_after_estimator_advances(self):
+ """Runtime callbacks should not drive the UI backwards after estimator progress."""
+ host = _Host()
+ host.estimator_progress_values = [0.68]
+ updates = []
+
+ host._run_generate_music_service_with_progress(
+ progress=lambda value, desc=None: updates.append((value, desc)),
+ actual_batch_size=1,
+ audio_duration=10.0,
+ inference_steps=8,
+ timesteps=None,
+ service_inputs={
+ "captions_batch": ["c"],
+ "lyrics_batch": ["l"],
+ "metas_batch": ["m"],
+ "vocal_languages_batch": ["en"],
+ "target_wavs_tensor": None,
+ "repainting_start_batch": [0.0],
+ "repainting_end_batch": [1.0],
+ "instructions_batch": ["i"],
+ "audio_code_hints_batch": None,
+ "should_return_intermediate": True,
+ },
+ refer_audios=None,
+ guidance_scale=7.0,
+ actual_seed_list=[1],
+ audio_cover_strength=1.0,
+ cover_noise_strength=0.0,
+ use_adg=False,
+ cfg_interval_start=0.0,
+ cfg_interval_end=1.0,
+ shift=1.0,
+ infer_method="ode",
+ )
+
+ progress_values = [value for value, _ in updates]
+ self.assertEqual(progress_values, sorted(progress_values))
+ self.assertGreaterEqual(progress_values[2], 0.68)
+
+ def test_runtime_progress_does_not_publish_behind_blocked_estimator_callback(self):
+ """A blocked estimator callback should not allow a stale publish after runtime progress."""
+
+ class _AsyncEstimatorHost(_Host):
+ def _start_diffusion_progress_estimator(self, **kwargs):
+ self.started = True
+ progress = kwargs["progress"]
+ desc = kwargs["desc"]
+
+ class _Stop:
+ def __init__(self, host):
+ self.host = host
+
+ def set(self):
+ self.host.stopped = True
+ self.host.stop_calls += 1
+
+ thread = threading.Thread(
+ target=lambda: progress(0.63, desc=desc),
+ name="test-estimator",
+ daemon=True,
+ )
+ thread.start()
+ return _Stop(self), thread
+
+ host = _AsyncEstimatorHost()
+ blocked_estimator = threading.Event()
+ updates = []
+
+ def progress(value, desc=None):
+ if value == 0.63 and not blocked_estimator.is_set():
+ blocked_estimator.set()
+ time.sleep(0.05)
+ updates.append((value, desc))
+
+ host._run_generate_music_service_with_progress(
+ progress=progress,
+ actual_batch_size=1,
+ audio_duration=10.0,
+ inference_steps=8,
+ timesteps=None,
+ service_inputs={
+ "captions_batch": ["c"],
+ "lyrics_batch": ["l"],
+ "metas_batch": ["m"],
+ "vocal_languages_batch": ["en"],
+ "target_wavs_tensor": None,
+ "repainting_start_batch": [0.0],
+ "repainting_end_batch": [1.0],
+ "instructions_batch": ["i"],
+ "audio_code_hints_batch": None,
+ "should_return_intermediate": True,
+ },
+ refer_audios=None,
+ guidance_scale=7.0,
+ actual_seed_list=[1],
+ audio_cover_strength=1.0,
+ cover_noise_strength=0.0,
+ use_adg=False,
+ cfg_interval_start=0.0,
+ cfg_interval_end=1.0,
+ shift=1.0,
+ infer_method="ode",
+ )
+
+ progress_values = [value for value, _ in updates]
+ self.assertEqual(progress_values, sorted(progress_values))
if __name__ == "__main__":
diff --git a/acestep/core/generation/handler/generate_music_execute_thread_test.py b/acestep/core/generation/handler/generate_music_execute_thread_test.py
new file mode 100644
index 000000000..056410ddb
--- /dev/null
+++ b/acestep/core/generation/handler/generate_music_execute_thread_test.py
@@ -0,0 +1,92 @@
+"""Regression coverage for progress estimator thread lifecycle."""
+
+import unittest
+
+from acestep.core.generation.handler.generate_music_execute import GenerateMusicExecuteMixin
+
+
+class _Host(GenerateMusicExecuteMixin):
+ """Minimal host for exercising estimator shutdown behavior."""
+
+ def __init__(self):
+ self.stop_calls = 0
+ self.thread = None
+
+ def _start_diffusion_progress_estimator(self, **kwargs):
+ class _Stop:
+ def __init__(self, host):
+ self.host = host
+
+ def set(self):
+ self.host.stop_calls += 1
+
+ class _Thread:
+ def __init__(self):
+ self.join_calls = 0
+ self._alive = True
+
+ def join(self, timeout=None):
+ _ = timeout
+ self.join_calls += 1
+ if self.join_calls >= 2:
+ self._alive = False
+
+ def is_alive(self):
+ return self._alive
+
+ self.thread = _Thread()
+ return _Stop(self), self.thread
+
+ def service_generate(self, **kwargs):
+ progress_callback = kwargs["progress_callback"]
+ progress_callback(1, 4, "DiT diffusion...")
+ return {"target_latents": "ok"}
+
+
+class GenerateMusicExecuteThreadLifecycleTests(unittest.TestCase):
+ """Ensure estimator thread handles survive timed joins until actual shutdown."""
+
+ def test_runtime_progress_keeps_estimator_thread_handle_until_thread_exits(self):
+ host = _Host()
+ updates = []
+
+ out = host._run_generate_music_service_with_progress(
+ progress=lambda value, desc=None: updates.append((value, desc)),
+ actual_batch_size=1,
+ audio_duration=10.0,
+ inference_steps=8,
+ timesteps=None,
+ service_inputs={
+ "captions_batch": ["c"],
+ "lyrics_batch": ["l"],
+ "metas_batch": ["m"],
+ "vocal_languages_batch": ["en"],
+ "target_wavs_tensor": None,
+ "repainting_start_batch": [0.0],
+ "repainting_end_batch": [1.0],
+ "instructions_batch": ["i"],
+ "audio_code_hints_batch": None,
+ "should_return_intermediate": True,
+ },
+ refer_audios=None,
+ guidance_scale=7.0,
+ actual_seed_list=[1],
+ audio_cover_strength=1.0,
+ cover_noise_strength=0.0,
+ use_adg=False,
+ cfg_interval_start=0.0,
+ cfg_interval_end=1.0,
+ shift=1.0,
+ infer_method="ode",
+ )
+
+ self.assertEqual(out["outputs"]["target_latents"], "ok")
+ self.assertEqual(host.stop_calls, 2)
+ self.assertIsNotNone(host.thread)
+ self.assertEqual(host.thread.join_calls, 2)
+ self.assertFalse(host.thread.is_alive())
+ self.assertTrue(any(desc == "DiT diffusion..." for _, desc in updates))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/acestep/core/generation/handler/mlx_vae_decode_native.py b/acestep/core/generation/handler/mlx_vae_decode_native.py
index 23b5c21ea..eb9cefc8f 100644
--- a/acestep/core/generation/handler/mlx_vae_decode_native.py
+++ b/acestep/core/generation/handler/mlx_vae_decode_native.py
@@ -28,7 +28,7 @@ def _resolve_mlx_decode_fn(self):
raise RuntimeError("MLX VAE decode requested but mlx_vae is not initialized.")
return self.mlx_vae.decode
- def _mlx_vae_decode(self, latents_torch):
+ def _mlx_vae_decode(self, latents_torch, progress_callback=None):
"""Decode batched PyTorch latents using native MLX VAE decode.
Args:
@@ -52,7 +52,18 @@ def _mlx_vae_decode(self, latents_torch):
decode_fn = self._resolve_mlx_decode_fn()
audio_parts = []
for idx in range(batch_size):
- decoded = self._mlx_decode_single(latents_mx[idx : idx + 1], decode_fn=decode_fn)
+ item_progress_callback = None
+ if progress_callback is not None:
+ def _item_progress(current, total, desc, offset=idx, batch_total=batch_size):
+ progress_callback(offset * total + current, batch_total * total, desc)
+
+ item_progress_callback = _item_progress
+
+ decoded = self._mlx_decode_single(
+ latents_mx[idx : idx + 1],
+ decode_fn=decode_fn,
+ progress_callback=item_progress_callback,
+ )
if decoded.dtype != mx.float32:
decoded = decoded.astype(mx.float32)
mx.eval(decoded)
@@ -71,7 +82,7 @@ def _mlx_vae_decode(self, latents_torch):
)
return torch.from_numpy(audio_ncl)
- def _mlx_decode_single(self, z_nlc, decode_fn=None):
+ def _mlx_decode_single(self, z_nlc, decode_fn=None, progress_callback=None):
"""Decode a single MLX latent sample with optional tiling.
Args:
@@ -91,7 +102,10 @@ def _mlx_decode_single(self, z_nlc, decode_fn=None):
mlx_overlap = 64
if latent_frames <= mlx_chunk:
- return decode_fn(z_nlc)
+ out = decode_fn(z_nlc)
+ if progress_callback is not None:
+ progress_callback(1, 1, "Decoding audio chunks...")
+ return out
stride = mlx_chunk - 2 * mlx_overlap
num_steps = math.ceil(latent_frames / stride)
@@ -115,5 +129,7 @@ def _mlx_decode_single(self, z_nlc, decode_fn=None):
audio_len = audio_chunk.shape[1]
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
decoded_parts.append(audio_chunk[:, trim_start:end_idx, :])
+ if progress_callback is not None:
+ progress_callback(idx + 1, num_steps, "Decoding audio chunks...")
return mx.concatenate(decoded_parts, axis=1)
diff --git a/acestep/core/generation/handler/runtime_progress_relay.py b/acestep/core/generation/handler/runtime_progress_relay.py
new file mode 100644
index 000000000..04a9470a5
--- /dev/null
+++ b/acestep/core/generation/handler/runtime_progress_relay.py
@@ -0,0 +1,81 @@
+"""Runtime progress relay helpers for generation execution."""
+
+import queue
+import threading
+from typing import Callable, Optional
+
+from loguru import logger
+
+
+class RuntimeProgressRelay:
+ """Bridge runtime step events onto monotonic UI progress updates."""
+
+ def __init__(
+ self,
+ *,
+ progress: Optional[Callable[[float], None]],
+ start: float,
+ end: float,
+ ) -> None:
+ self._progress = progress
+ self._start = start
+ self._end = end
+ self._events: "queue.Queue[tuple[int, int, str] | None]" = queue.Queue()
+ self._lock = threading.Lock()
+ self._value = 0.0
+ self._active = True
+
+ def enqueue(self, current: int, total: int, desc: str) -> None:
+ """Queue a runtime progress event while the relay is active."""
+ with self._lock:
+ if not self._active:
+ return
+ self._events.put((current, total, desc))
+
+ def emit_progress(self, value: float, desc: Optional[str] = None) -> float:
+ """Emit a monotonic progress update to the UI callback."""
+ with self._lock:
+ clamped = max(self._value, value)
+ self._value = clamped
+ if self._progress is not None:
+ try:
+ self._progress(clamped, desc=desc)
+ except Exception as exc:
+ logger.debug("[generate_music] Ignoring progress callback error: {}", exc)
+ return clamped
+
+ def drain(self) -> bool:
+ """Drain queued runtime events and map them onto the configured UI range."""
+ drained = False
+ while True:
+ try:
+ item = self._events.get_nowait()
+ except queue.Empty:
+ return drained
+
+ if item is None:
+ return drained
+
+ current, total, desc = item
+ if total <= 0:
+ continue
+ frac = min(1.0, max(0.0, current / total))
+ mapped = self._start + (self._end - self._start) * frac
+ self.emit_progress(mapped, desc)
+ drained = True
+
+ def shutdown(self) -> None:
+ """Disable future runtime event forwarding."""
+ with self._lock:
+ self._active = False
+ self._events.put(None)
+
+ @staticmethod
+ def stop_estimator_if_finished(progress_thread):
+ """Return the estimator thread handle only while it is still alive."""
+ if progress_thread is None:
+ return None
+ progress_thread.join(timeout=1.0)
+ if hasattr(progress_thread, "is_alive") and progress_thread.is_alive():
+ return progress_thread
+ return None
diff --git a/acestep/core/generation/handler/runtime_progress_relay_test.py b/acestep/core/generation/handler/runtime_progress_relay_test.py
new file mode 100644
index 000000000..6d3fc7e8a
--- /dev/null
+++ b/acestep/core/generation/handler/runtime_progress_relay_test.py
@@ -0,0 +1,43 @@
+"""Unit tests for runtime progress relay helpers."""
+
+import unittest
+
+from acestep.core.generation.handler.runtime_progress_relay import RuntimeProgressRelay
+
+
+class RuntimeProgressRelayTests(unittest.TestCase):
+ """Verify runtime progress relay mapping and shutdown behavior."""
+
+ def test_drain_keeps_progress_monotonic(self):
+ updates = []
+ relay = RuntimeProgressRelay(
+ progress=lambda value, desc=None: updates.append((value, desc)),
+ start=0.52,
+ end=0.79,
+ )
+
+ relay.emit_progress(0.68, "estimator")
+ relay.enqueue(1, 4, "DiT diffusion...")
+ relay.enqueue(4, 4, "DiT diffusion...")
+
+ self.assertTrue(relay.drain())
+ self.assertEqual([value for value, _ in updates], sorted(value for value, _ in updates))
+ self.assertAlmostEqual(updates[-1][0], 0.79, places=6)
+
+ def test_shutdown_ignores_late_runtime_events(self):
+ updates = []
+ relay = RuntimeProgressRelay(
+ progress=lambda value, desc=None: updates.append((value, desc)),
+ start=0.52,
+ end=0.79,
+ )
+
+ relay.shutdown()
+ relay.enqueue(4, 4, "DiT diffusion...")
+
+ self.assertFalse(relay.drain())
+ self.assertEqual(updates, [])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/acestep/core/generation/handler/service_generate.py b/acestep/core/generation/handler/service_generate.py
index 2fcc6cab1..d56f899e4 100644
--- a/acestep/core/generation/handler/service_generate.py
+++ b/acestep/core/generation/handler/service_generate.py
@@ -5,7 +5,7 @@
and output attachment without owning model internals.
"""
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
import torch
@@ -47,6 +47,7 @@ def service_generate(
chunk_mask_modes: Optional[List[str]] = None,
repaint_crossfade_frames: int = 10,
repaint_injection_ratio: float = 0.5,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
sampler_mode: str = "euler",
velocity_norm_threshold: float = 0.0,
velocity_ema_factor: float = 0.0,
@@ -79,6 +80,8 @@ def service_generate(
timesteps: Optional explicit diffusion timestep sequence.
repaint_crossfade_frames: Crossfade width (latent frames) at repaint
boundaries for boundary blending. ~0.4s at 25 Hz.
+ progress_callback: Optional diffusion-step callback receiving
+ ``(current_step, total_steps, desc)``.
sampler_mode: Sampler algorithm — ``"euler"`` or ``"heun"``.
velocity_norm_threshold: Velocity norm clamping threshold (0 = disabled).
velocity_ema_factor: Velocity EMA smoothing factor (0 = disabled).
@@ -140,6 +143,7 @@ def service_generate(
timesteps=timesteps,
repaint_crossfade_frames=repaint_crossfade_frames,
repaint_injection_ratio=repaint_injection_ratio,
+ progress_callback=progress_callback,
sampler_mode=sampler_mode,
velocity_norm_threshold=velocity_norm_threshold,
velocity_ema_factor=velocity_ema_factor,
diff --git a/acestep/core/generation/handler/service_generate_execute.py b/acestep/core/generation/handler/service_generate_execute.py
index 5946a6132..2bf35c0f3 100644
--- a/acestep/core/generation/handler/service_generate_execute.py
+++ b/acestep/core/generation/handler/service_generate_execute.py
@@ -1,7 +1,7 @@
"""Execution helpers for service generation diffusion and output assembly."""
import random
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from loguru import logger
@@ -77,6 +77,7 @@ def _build_service_generate_kwargs(
timesteps: Optional[List[float]],
repaint_crossfade_frames: int = 10,
repaint_injection_ratio: float = 0.5,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
sampler_mode: str = "euler",
velocity_norm_threshold: float = 0.0,
velocity_ema_factor: float = 0.0,
@@ -109,6 +110,8 @@ def _build_service_generate_kwargs(
"cfg_interval_start": cfg_interval_start,
"cfg_interval_end": cfg_interval_end,
"shift": shift,
+ "use_progress_bar": not getattr(self, "disable_tqdm", False),
+ "progress_callback": progress_callback,
"repaint_mask": repaint_mask,
"clean_src_latents": clean_src_latents,
"repaint_crossfade_frames": repaint_crossfade_frames,
@@ -203,6 +206,8 @@ def _execute_service_generate_diffusion(
encoder_hidden_states_non_cover=enc_hs_nc,
encoder_attention_mask_non_cover=enc_am_nc,
context_latents_non_cover=ctx_nc,
+ progress_callback=generate_kwargs.get("progress_callback"),
+ disable_tqdm=not generate_kwargs.get("use_progress_bar", True),
sampler_mode=generate_kwargs.get("sampler_mode", "euler"),
velocity_norm_threshold=generate_kwargs.get("velocity_norm_threshold", 0.0),
velocity_ema_factor=generate_kwargs.get("velocity_ema_factor", 0.0),
diff --git a/acestep/core/generation/handler/service_generate_execute_test.py b/acestep/core/generation/handler/service_generate_execute_test.py
index 8af115425..ed2d61df5 100644
--- a/acestep/core/generation/handler/service_generate_execute_test.py
+++ b/acestep/core/generation/handler/service_generate_execute_test.py
@@ -15,6 +15,7 @@ class _Host(ServiceGenerateExecuteMixin, ServiceGenerateOutputsMixin):
def __init__(self):
"""Initialize static runtime fields for helper-method tests."""
self.device = "cpu"
+ self.disable_tqdm = False
self.silence_latent = torch.zeros(1, 4, 4, dtype=torch.float32)
@@ -58,6 +59,47 @@ def test_build_generate_kwargs_adds_timesteps_tensor(self):
self.assertEqual(kwargs["infer_steps"], 16)
self.assertEqual(kwargs["timesteps"].dtype, torch.float32)
self.assertEqual(kwargs["timesteps"].device.type, "cpu")
+ self.assertTrue(kwargs["use_progress_bar"])
+ self.assertIsNone(kwargs["progress_callback"])
+
+ def test_build_generate_kwargs_forwards_runtime_progress_callback(self):
+ """Runtime callbacks should be forwarded into model-generation kwargs."""
+ host = _Host()
+ payload = {
+ "text_hidden_states": torch.zeros(1, 2),
+ "text_attention_mask": torch.ones(1, 2),
+ "lyric_hidden_states": torch.zeros(1, 2),
+ "lyric_attention_mask": torch.ones(1, 2),
+ "refer_audio_acoustic_hidden_states_packed": torch.zeros(1, 2),
+ "refer_audio_order_mask": torch.zeros(1, dtype=torch.long),
+ "src_latents": torch.zeros(1, 4, 4),
+ "chunk_mask": torch.ones(1, 4, dtype=torch.bool),
+ "is_covers": torch.tensor([True]),
+ "non_cover_text_hidden_states": None,
+ "non_cover_text_attention_masks": None,
+ "precomputed_lm_hints_25Hz": None,
+ }
+ def callback(current, total, desc):
+ return (current, total, desc)
+
+ kwargs = host._build_service_generate_kwargs(
+ payload=payload,
+ seed_param=123,
+ infer_steps=16,
+ guidance_scale=7.0,
+ audio_cover_strength=1.0,
+ cover_noise_strength=0.0,
+ infer_method="ode",
+ use_adg=False,
+ cfg_interval_start=0.0,
+ cfg_interval_end=1.0,
+ shift=1.0,
+ timesteps=None,
+ progress_callback=callback,
+ )
+
+ self.assertIs(kwargs["progress_callback"], callback)
+ self.assertTrue(kwargs["use_progress_bar"])
def test_attach_service_outputs_persists_required_fields(self):
"""Attached payload fields should be available to downstream handlers."""
diff --git a/acestep/core/generation/handler/service_generate_test.py b/acestep/core/generation/handler/service_generate_test.py
index 6287010ad..4495c6176 100644
--- a/acestep/core/generation/handler/service_generate_test.py
+++ b/acestep/core/generation/handler/service_generate_test.py
@@ -128,10 +128,12 @@ def test_service_generate_forwards_runtime_controls_to_build_and_execute(self):
"""It forwards runtime tuning controls to downstream helper invocations."""
host = _Host()
custom_timesteps = [1.0, 0.5, 0.0]
+ progress_callback = lambda current, total, desc: (current, total, desc)
host.service_generate(
captions="cap", lyrics="lyr", guidance_scale=9.5, audio_cover_strength=0.7,
cover_noise_strength=0.2, use_adg=True, cfg_interval_start=0.1, cfg_interval_end=0.9,
shift=1.3, infer_method="sde", timesteps=custom_timesteps, return_intermediate=False,
+ progress_callback=progress_callback,
)
build_kwargs = host.calls["_build_service_generate_kwargs"]
self.assertEqual(build_kwargs["guidance_scale"], 9.5)
@@ -139,6 +141,7 @@ def test_service_generate_forwards_runtime_controls_to_build_and_execute(self):
self.assertEqual(build_kwargs["cover_noise_strength"], 0.2)
self.assertTrue(build_kwargs["use_adg"])
self.assertEqual(build_kwargs["timesteps"], custom_timesteps)
+ self.assertIs(build_kwargs["progress_callback"], progress_callback)
self.assertFalse(host.calls["_attach_service_generate_outputs"]["return_intermediate"])
execute_kwargs = host.calls["_execute_service_generate_diffusion"]
self.assertEqual(execute_kwargs["infer_method"], "sde")
diff --git a/acestep/core/generation/handler/vae_decode.py b/acestep/core/generation/handler/vae_decode.py
index c3ea76174..03a00872a 100644
--- a/acestep/core/generation/handler/vae_decode.py
+++ b/acestep/core/generation/handler/vae_decode.py
@@ -19,6 +19,7 @@ def tiled_decode(
chunk_size: Optional[int] = None,
overlap: int = 64,
offload_wav_to_cpu: Optional[bool] = None,
+ progress_callback=None,
):
"""Decode latents using tiling to reduce VRAM usage.
@@ -32,6 +33,8 @@ def tiled_decode(
overlap: Overlap in latent frames between adjacent windows.
offload_wav_to_cpu: Whether decoded waveform chunks should be
offloaded to CPU immediately to reduce VRAM pressure.
+ progress_callback: Optional callback receiving
+ ``(current_step, total_steps, desc)`` for chunk decode progress.
Returns:
Decoded waveform tensor shaped ``[batch, audio_channels, samples]``.
@@ -39,7 +42,7 @@ def tiled_decode(
# ---- MLX fast path (macOS Apple Silicon) ----
if self.use_mlx_vae and self.mlx_vae is not None:
try:
- result = self._mlx_vae_decode(latents)
+ result = self._mlx_vae_decode(latents, progress_callback=progress_callback)
return result
except Exception as exc:
logger.warning(
@@ -74,7 +77,13 @@ def tiled_decode(
overlap = min(overlap, _mps_overlap)
try:
- return self._tiled_decode_inner(latents, chunk_size, overlap, offload_wav_to_cpu)
+ return self._tiled_decode_inner(
+ latents,
+ chunk_size,
+ overlap,
+ offload_wav_to_cpu,
+ progress_callback=progress_callback,
+ )
except (NotImplementedError, RuntimeError) as exc:
if not _is_mps:
raise
diff --git a/acestep/core/generation/handler/vae_decode_chunks.py b/acestep/core/generation/handler/vae_decode_chunks.py
index 8564de9ad..3ee50cb44 100644
--- a/acestep/core/generation/handler/vae_decode_chunks.py
+++ b/acestep/core/generation/handler/vae_decode_chunks.py
@@ -10,9 +10,16 @@
class VaeDecodeChunksMixin:
"""Implement chunked decode strategies for GPU and CPU-offload modes."""
- def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
+ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu, progress_callback=None):
"""Run tiled decode with adaptive overlap and OOM fallbacks."""
bsz, _channels, latent_frames = latents.shape
+ completed_steps = 0
+
+ def _monotonic_progress(current, total, desc):
+ nonlocal completed_steps
+ completed_steps = max(completed_steps, current)
+ if progress_callback is not None:
+ progress_callback(completed_steps, total, desc)
# Batch-sequential decode keeps peak VRAM stable across batch sizes.
if bsz > 1:
@@ -20,7 +27,18 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
per_sample_results = []
for b_idx in range(bsz):
single = latents[b_idx : b_idx + 1]
- decoded = self._tiled_decode_inner(single, chunk_size, overlap, offload_wav_to_cpu)
+ sample_progress = None
+ if progress_callback is not None:
+ def _sample_progress(current, total, desc, offset=b_idx):
+ progress_callback(offset * total + current, bsz * total, desc)
+ sample_progress = _sample_progress
+ decoded = self._tiled_decode_inner(
+ single,
+ chunk_size,
+ overlap,
+ offload_wav_to_cpu,
+ progress_callback=sample_progress,
+ )
per_sample_results.append(decoded.cpu() if decoded.device.type != "cpu" else decoded)
self._empty_cache()
result = torch.cat(per_sample_results, dim=0)
@@ -46,6 +64,8 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
decoder_output = self.vae.decode(latents)
result = decoder_output.sample
del decoder_output
+ if progress_callback is not None:
+ progress_callback(1, 1, "Decoding audio chunks...")
return result
except torch.cuda.OutOfMemoryError:
logger.warning("[tiled_decode] OOM on direct decode, falling back to CPU VAE decode")
@@ -60,7 +80,15 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
if offload_wav_to_cpu:
try:
- return self._tiled_decode_offload_cpu(latents, bsz, latent_frames, stride, overlap, num_steps)
+ return self._tiled_decode_offload_cpu(
+ latents,
+ bsz,
+ latent_frames,
+ stride,
+ overlap,
+ num_steps,
+ progress_callback=_monotonic_progress,
+ )
except torch.cuda.OutOfMemoryError:
logger.warning(
f"[tiled_decode] OOM during offload_cpu decode with chunk_size={chunk_size}, "
@@ -70,7 +98,13 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
return self._decode_on_cpu(latents)
try:
- return self._tiled_decode_gpu(latents, stride, overlap, num_steps)
+ return self._tiled_decode_gpu(
+ latents,
+ stride,
+ overlap,
+ num_steps,
+ progress_callback=_monotonic_progress,
+ )
except torch.cuda.OutOfMemoryError:
logger.warning(
f"[tiled_decode] OOM during GPU decode with chunk_size={chunk_size}, "
@@ -78,13 +112,21 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
)
self._empty_cache()
try:
- return self._tiled_decode_offload_cpu(latents, bsz, latent_frames, stride, overlap, num_steps)
+ return self._tiled_decode_offload_cpu(
+ latents,
+ bsz,
+ latent_frames,
+ stride,
+ overlap,
+ num_steps,
+ progress_callback=_monotonic_progress,
+ )
except torch.cuda.OutOfMemoryError:
logger.warning("[tiled_decode] OOM even with offload path, falling back to full CPU VAE decode")
self._empty_cache()
return self._decode_on_cpu(latents)
- def _tiled_decode_gpu(self, latents, stride, overlap, num_steps):
+ def _tiled_decode_gpu(self, latents, stride, overlap, num_steps, progress_callback=None):
"""Decode chunks and keep decoded audio tensors on GPU."""
decoded_audio_list = []
upsample_factor = None
@@ -112,10 +154,21 @@ def _tiled_decode_gpu(self, latents, stride, overlap, num_steps):
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
audio_core = audio_chunk[:, :, trim_start:end_idx]
decoded_audio_list.append(audio_core)
+ if progress_callback is not None:
+ progress_callback(i + 1, num_steps, "Decoding audio chunks...")
return torch.cat(decoded_audio_list, dim=-1)
- def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap, num_steps):
+ def _tiled_decode_offload_cpu(
+ self,
+ latents,
+ bsz,
+ latent_frames,
+ stride,
+ overlap,
+ num_steps,
+ progress_callback=None,
+ ):
"""Decode chunks on GPU and copy trimmed audio cores to a CPU buffer."""
first_core_end = min(stride, latent_frames)
first_win_end = min(latent_frames, first_core_end + overlap)
@@ -138,6 +191,8 @@ def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap
first_audio_core = first_audio_chunk[:, :, :first_end_idx]
audio_write_pos = first_audio_core.shape[-1]
final_audio[:, :, :audio_write_pos] = first_audio_core.cpu()
+ if progress_callback is not None:
+ progress_callback(1, num_steps, "Decoding audio chunks...")
del first_audio_chunk, first_audio_core, first_latent_chunk
@@ -164,6 +219,8 @@ def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap
core_len = audio_core.shape[-1]
final_audio[:, :, audio_write_pos : audio_write_pos + core_len] = audio_core.cpu()
audio_write_pos += core_len
+ if progress_callback is not None:
+ progress_callback(i + 1, num_steps, "Decoding audio chunks...")
del audio_chunk, audio_core, latent_chunk
diff --git a/acestep/core/generation/handler/vae_decode_chunks_test.py b/acestep/core/generation/handler/vae_decode_chunks_test.py
index d44bc9da3..566a0f41c 100644
--- a/acestep/core/generation/handler/vae_decode_chunks_test.py
+++ b/acestep/core/generation/handler/vae_decode_chunks_test.py
@@ -74,6 +74,32 @@ def _oom(*args, **kwargs):
self.assertTrue(torch.equal(out, torch.full((1, 2, 7), 9.0)))
self.assertEqual(host.decode_on_cpu_calls, 1)
+ def test_gpu_to_offload_fallback_keeps_progress_monotonic(self):
+ """Offload retry progress should not rewind after GPU OOM progress has started."""
+ host = _ChunksHost()
+ updates = []
+
+ def _gpu_oom(*args, **kwargs):
+ kwargs["progress_callback"](3, 4, "Decoding audio chunks...")
+ raise torch.cuda.OutOfMemoryError("gpu oom")
+
+ def _offload_ok(*args, **kwargs):
+ kwargs["progress_callback"](1, 4, "Decoding audio chunks...")
+ kwargs["progress_callback"](4, 4, "Decoding audio chunks...")
+ return torch.ones(1, 2, 5)
+
+ host._tiled_decode_gpu = _gpu_oom
+ host._tiled_decode_offload_cpu = _offload_ok
+ host._tiled_decode_inner(
+ torch.zeros(1, 4, 20),
+ chunk_size=8,
+ overlap=2,
+ offload_wav_to_cpu=False,
+ progress_callback=lambda current, total, desc: updates.append((current, total, desc)),
+ )
+
+ self.assertEqual([current for current, _, _ in updates], [3, 3, 4])
+
if __name__ == "__main__":
unittest.main()
diff --git a/acestep/core/generation/handler/vae_decode_mixin_test.py b/acestep/core/generation/handler/vae_decode_mixin_test.py
index e48030b42..13f91d7a9 100644
--- a/acestep/core/generation/handler/vae_decode_mixin_test.py
+++ b/acestep/core/generation/handler/vae_decode_mixin_test.py
@@ -46,8 +46,9 @@ def test_tiled_decode_falls_back_when_mlx_decode_fails(self):
host.use_mlx_vae = True
host.mlx_vae = object()
- def _mlx_raise(_latents):
+ def _mlx_raise(_latents, progress_callback=None):
"""Raise MLX failure to exercise fallback path."""
+ _ = progress_callback
raise ValueError("mlx failed")
host._mlx_vae_decode = _mlx_raise
diff --git a/acestep/core/generation/handler/vae_decode_test_helpers.py b/acestep/core/generation/handler/vae_decode_test_helpers.py
index 4eccd3167..51ed26664 100644
--- a/acestep/core/generation/handler/vae_decode_test_helpers.py
+++ b/acestep/core/generation/handler/vae_decode_test_helpers.py
@@ -66,12 +66,13 @@ def _should_offload_wav_to_cpu(self):
"""Return deterministic offload policy used by default path."""
return False
- def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu):
+ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu, progress_callback=None):
"""Record routed args and return sentinel audio tensor."""
_ = latents
self.recorded["chunk_size"] = chunk_size
self.recorded["overlap"] = overlap
self.recorded["offload"] = offload_wav_to_cpu
+ self.recorded["progress_callback"] = progress_callback
return torch.ones(1, 2, 8)
def _tiled_decode_cpu_fallback(self, latents):
@@ -79,9 +80,10 @@ def _tiled_decode_cpu_fallback(self, latents):
_ = latents
return torch.full((1, 2, 8), 2.0)
- def _mlx_vae_decode(self, latents):
+ def _mlx_vae_decode(self, latents, progress_callback=None):
"""Return MLX sentinel tensor for MLX path assertions."""
_ = latents
+ self.recorded["progress_callback"] = progress_callback
return torch.full((1, 2, 6), 3.0)
diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py
index 42d7b8ad8..f92c5deda 100644
--- a/acestep/llm_inference.py
+++ b/acestep/llm_inference.py
@@ -9,7 +9,7 @@
import time
import random
import warnings
-from typing import Optional, Dict, Any, Tuple, List, Union
+from typing import Callable, Optional, Dict, Any, Tuple, List, Union
from contextlib import contextmanager
import yaml
@@ -874,6 +874,7 @@ def _run_vllm(
lyrics: str = "",
cot_text: str = "",
seeds: Optional[List[int]] = None,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> Union[str, List[str]]:
"""
Unified vllm generation function supporting both single and batch modes.
@@ -959,6 +960,8 @@ def _run_vllm(
output_texts.append(str(output))
# Return single string for single mode, list for batch mode
+ if progress_callback is not None:
+ progress_callback(1, 1, "vLLM generation")
return output_texts[0] if not is_batch else output_texts
def _run_pt_single(
@@ -982,6 +985,7 @@ def _run_pt_single(
caption: str,
lyrics: str,
cot_text: str,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> str:
"""Internal helper function for single-item PyTorch generation."""
inputs = self.llm_tokenizer(
@@ -1060,6 +1064,7 @@ def _run_pt_single(
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
streamer=None,
constrained_processor=constrained_processor,
+ progress_callback=progress_callback,
)
# Extract only the conditional output (first in batch)
@@ -1077,6 +1082,7 @@ def _run_pt_single(
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
streamer=None,
constrained_processor=constrained_processor,
+ progress_callback=progress_callback,
)
else:
# Generate without CFG using native generate() parameters
@@ -1143,6 +1149,7 @@ def _run_pt(
lyrics: str = "",
cot_text: str = "",
seeds: Optional[List[int]] = None,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> Union[str, List[str]]:
"""
Unified PyTorch generation function supporting both single and batch modes.
@@ -1158,6 +1165,7 @@ def _run_pt(
# loads to GPU once and offloads once, instead of per-item.
if is_batch:
output_texts = []
+ batch_total = len(formatted_prompt_list)
with self._load_model_context():
for i, formatted_prompt in enumerate(formatted_prompt_list):
@@ -1169,6 +1177,15 @@ def _run_pt(
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.manual_seed(seeds[i])
+ item_progress_callback = None
+ item_progress_state = {"current": 0, "total": 0, "desc": ""}
+ if progress_callback is not None:
+ def _item_progress(current, total, desc, index=i):
+ item_progress_state.update(current=current, total=total, desc=desc)
+ progress_callback(index * total + current, batch_total * total, desc)
+
+ item_progress_callback = _item_progress
+
# Generate using single-item method with batch-mode defaults
output_text = self._run_pt_single(
formatted_prompt=formatted_prompt,
@@ -1190,9 +1207,20 @@ def _run_pt(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=item_progress_callback,
)
output_texts.append(output_text)
+ if (
+ progress_callback is not None
+ and item_progress_state["total"] > 0
+ and item_progress_state["current"] < item_progress_state["total"]
+ ):
+ progress_callback(
+ (i + 1) * item_progress_state["total"],
+ batch_total * item_progress_state["total"],
+ item_progress_state["desc"],
+ )
return output_texts
@@ -1219,6 +1247,7 @@ def _run_pt(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=progress_callback,
)
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
@@ -1340,6 +1369,20 @@ def progress(*args, **kwargs):
else:
seeds = seeds[:actual_batch_size]
+ def _make_phase_progress_callback(
+ start: float,
+ end: float,
+ default_desc: str,
+ ) -> Callable[[int, int, str], None]:
+ def _callback(current: int, total: int, desc: str) -> None:
+ if total <= 0:
+ return
+ frac = min(1.0, max(0.0, current / total))
+ mapped = start + (end - start) * frac
+ progress(mapped, desc or default_desc)
+
+ return _callback
+
# ========== PHASE 1: CoT Generation ==========
# Skip CoT if all metadata are user-provided OR caption is already formatted
progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...")
@@ -1380,6 +1423,11 @@ def progress(*args, **kwargs):
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
stop_at_reasoning=True, # Always stop at in Phase 1
+ progress_callback=_make_phase_progress_callback(
+ 0.1,
+ 0.3,
+ "LLM metadata generation",
+ ),
)
phase1_time = time.time() - phase1_start
@@ -1399,6 +1447,7 @@ def progress(*args, **kwargs):
logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
else:
logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
+ progress(0.3, "LLM metadata generation complete")
else:
# Use user-provided metadata
if is_batch:
@@ -1468,7 +1517,12 @@ def progress(*args, **kwargs):
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
- progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...")
+ progress(0.31, f"Phase 2: Generating audio codes for {actual_batch_size} items...")
+ phase2_progress_callback = _make_phase_progress_callback(
+ 0.31,
+ 0.5,
+ "LLM audio code generation",
+ )
if is_batch:
# Batch mode: generate codes for all items
formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
@@ -1492,6 +1546,7 @@ def progress(*args, **kwargs):
lyrics=lyrics,
cot_text=cot_text,
seeds=seeds,
+ progress_callback=phase2_progress_callback,
)
elif self.llm_backend == "mlx":
codes_outputs = self._run_mlx(
@@ -1510,6 +1565,7 @@ def progress(*args, **kwargs):
lyrics=lyrics,
cot_text=cot_text,
seeds=seeds,
+ progress_callback=phase2_progress_callback,
)
else: # pt backend
codes_outputs = self._run_pt(
@@ -1528,6 +1584,7 @@ def progress(*args, **kwargs):
lyrics=lyrics,
cot_text=cot_text,
seeds=seeds,
+ progress_callback=phase2_progress_callback,
)
except Exception as e:
error_msg = f"Error in batch codes generation: {str(e)}"
@@ -1557,6 +1614,7 @@ def progress(*args, **kwargs):
metadata_list.append(metadata.copy()) # Same metadata for all
phase2_time = time.time() - phase2_start
+ progress(0.5, "LLM audio code generation complete")
# Log results
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
@@ -1602,6 +1660,7 @@ def progress(*args, **kwargs):
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
stop_at_reasoning=False, # Generate codes until EOS
+ progress_callback=phase2_progress_callback,
)
if not codes_output_text:
@@ -1621,6 +1680,7 @@ def progress(*args, **kwargs):
}
phase2_time = time.time() - phase2_start
+ progress(0.5, "LLM audio code generation complete")
# Parse audio codes from output (metadata should be same as Phase 1)
_, audio_codes = self.parse_lm_output(codes_output_text)
@@ -2321,6 +2381,7 @@ def generate_from_formatted_prompt(
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
stop_at_reasoning: bool = False,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> Tuple[str, str]:
"""
Generate raw LM text output from a pre-built formatted prompt.
@@ -2394,6 +2455,7 @@ def generate_from_formatted_prompt(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=progress_callback,
)
self._clear_accelerator_cache()
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
@@ -2420,6 +2482,7 @@ def generate_from_formatted_prompt(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=progress_callback,
)
self._clear_accelerator_cache()
return output_text, f"✅ Generated successfully (mlx) | length={len(output_text)}"
@@ -2445,6 +2508,7 @@ def generate_from_formatted_prompt(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=progress_callback,
)
self._clear_accelerator_cache()
return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
@@ -2486,6 +2550,7 @@ def _generate_with_constrained_decoding(
pad_token_id: int,
streamer: Optional[BaseStreamer],
constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> torch.Tensor:
"""
Custom generation loop with constrained decoding support (non-CFG).
@@ -2558,6 +2623,8 @@ def _generate_with_constrained_decoding(
# Update streamer
if streamer is not None:
streamer.put(next_tokens_unsqueezed)
+ if progress_callback is not None:
+ progress_callback(step + 1, max_new_tokens, "LLM Constrained Decoding")
if should_stop:
break
@@ -2583,6 +2650,7 @@ def _generate_with_cfg_custom(
pad_token_id: int,
streamer: Optional[BaseStreamer],
constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> torch.Tensor:
"""
Custom CFG generation loop that:
@@ -2729,6 +2797,8 @@ def _generate_with_cfg_custom(
# Update streamer
if streamer is not None:
streamer.put(next_tokens_unsqueezed) # Stream conditional tokens
+ if progress_callback is not None:
+ progress_callback(step + 1, max_new_tokens, "LLM CFG Generation")
# Stop generation only when ALL sequences have finished
if seq_finished.all():
@@ -3005,6 +3075,7 @@ def _run_mlx_batch_native(
lyrics: str,
cot_text: str,
seeds: Optional[List[int]] = None,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> List[str]:
"""
Optimized native MLX batch generation for codes phase.
@@ -3322,6 +3393,8 @@ def _clone_cache_list(cache_list):
item_last_logits[i] = logits_out[:, -1:, :]
pbar.update(1)
+ if progress_callback is not None:
+ progress_callback(step + 1, max_new_tokens, f"MLX {cfg_label}Batch Gen (native, n={batch_size})")
# Periodic memory cleanup
if step % 256 == 0 and step > 0:
@@ -3370,6 +3443,7 @@ def _run_mlx_single_native(
caption: str,
lyrics: str,
cot_text: str,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> str:
"""
Optimized native MLX generation using mlx-lm infrastructure.
@@ -3628,6 +3702,8 @@ def _run_mlx_single_native(
new_tokens.append(token_id)
all_token_ids.append(token_id)
pbar.update(1)
+ if progress_callback is not None:
+ progress_callback(step + 1, max_new_tokens, tqdm_desc)
# Update constrained processor FSM state
if constrained_processor is not None:
@@ -3693,6 +3769,7 @@ def _run_mlx_single(
caption: str,
lyrics: str,
cot_text: str,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> str:
"""
MLX-accelerated single-item generation.
@@ -3723,6 +3800,7 @@ def _run_mlx_single(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=progress_callback,
)
except Exception as _native_err:
logger.warning(
@@ -3872,6 +3950,8 @@ def _run_mlx_single(
new_tokens.append(token_id)
all_token_ids.append(token_id)
pbar.update(1)
+ if progress_callback is not None:
+ progress_callback(step + 1, max_new_tokens, tqdm_desc)
# Update constrained processor state
if constrained_processor is not None:
@@ -3934,6 +4014,7 @@ def _run_mlx(
lyrics: str = "",
cot_text: str = "",
seeds: Optional[List[int]] = None,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> Union[str, List[str]]:
"""
Unified MLX generation function supporting both single and batch modes.
@@ -3983,6 +4064,7 @@ def _run_mlx(
lyrics=lyrics,
cot_text=cot_text,
seeds=seeds,
+ progress_callback=progress_callback,
)
except Exception as e:
logger.warning(
@@ -3998,6 +4080,15 @@ def _run_mlx(
if seeds and i < len(seeds):
mx.random.seed(seeds[i])
+ item_progress_callback = None
+ item_progress_state = {"current": 0, "total": 0, "desc": ""}
+ if progress_callback is not None:
+ def _item_progress(current, total, desc, index=i, batch_total=batch_size):
+ item_progress_state.update(current=current, total=total, desc=desc)
+ progress_callback(index * total + current, batch_total * total, desc)
+
+ item_progress_callback = _item_progress
+
output_text = self._run_mlx_single(
formatted_prompt=formatted_prompt,
temperature=temperature,
@@ -4018,8 +4109,19 @@ def _run_mlx(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=item_progress_callback,
)
output_texts.append(output_text)
+ if (
+ progress_callback is not None
+ and item_progress_state["total"] > 0
+ and item_progress_state["current"] < item_progress_state["total"]
+ ):
+ progress_callback(
+ (i + 1) * item_progress_state["total"],
+ batch_size * item_progress_state["total"],
+ item_progress_state["desc"],
+ )
return output_texts
# Single mode
@@ -4044,6 +4146,7 @@ def _run_mlx(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
+ progress_callback=progress_callback,
)
# =========================================================================
diff --git a/acestep/llm_inference_batch_progress_test.py b/acestep/llm_inference_batch_progress_test.py
new file mode 100644
index 000000000..680ab5bb7
--- /dev/null
+++ b/acestep/llm_inference_batch_progress_test.py
@@ -0,0 +1,105 @@
+"""Regression tests for sequential batch progress closure."""
+
+import sys
+import types
+import unittest
+from contextlib import nullcontext
+from unittest.mock import patch
+
+try:
+ from acestep.llm_inference import LLMHandler
+
+ _IMPORT_ERROR = None
+except ImportError as exc: # pragma: no cover - dependency guard
+ LLMHandler = None
+ _IMPORT_ERROR = exc
+
+
+def _make_handler() -> "LLMHandler":
+ """Return a minimal handler with model-loading context stubbed out."""
+ handler = LLMHandler()
+ handler._load_model_context = lambda: nullcontext()
+ return handler
+
+
+@unittest.skipIf(LLMHandler is None, f"llm_inference import unavailable: {_IMPORT_ERROR}")
+class SequentialBatchProgressTests(unittest.TestCase):
+ """Sequential batch generation should close each item before the next begins."""
+
+ def test_run_pt_closes_early_item_progress_before_next_item(self):
+ handler = _make_handler()
+ updates = []
+
+ def fake_run_pt_single(*args, **kwargs):
+ progress_callback = kwargs["progress_callback"]
+ if len(updates) == 0:
+ progress_callback(2, 10, "LLM CFG Generation")
+ return "first"
+ progress_callback(1, 10, "LLM CFG Generation")
+ return "second"
+
+ with patch.object(handler, "_run_pt_single", side_effect=fake_run_pt_single):
+ out = handler._run_pt(
+ formatted_prompts=["p1", "p2"],
+ temperature=0.6,
+ cfg_scale=1.0,
+ negative_prompt="",
+ top_k=None,
+ top_p=None,
+ repetition_penalty=1.0,
+ progress_callback=lambda current, total, desc: updates.append((current, total, desc)),
+ )
+
+ self.assertEqual(out, ["first", "second"])
+ self.assertEqual(
+ updates,
+ [
+ (2, 20, "LLM CFG Generation"),
+ (10, 20, "LLM CFG Generation"),
+ (11, 20, "LLM CFG Generation"),
+ (20, 20, "LLM CFG Generation"),
+ ],
+ )
+
+ def test_run_mlx_closes_early_item_progress_before_next_item(self):
+ handler = _make_handler()
+ updates = []
+ fake_mlx = types.ModuleType("mlx")
+ fake_mx_core = types.ModuleType("mlx.core")
+ fake_mx_core.random = types.SimpleNamespace(seed=lambda *_: None)
+
+ def fake_run_mlx_single(*args, **kwargs):
+ progress_callback = kwargs["progress_callback"]
+ if len(updates) == 0:
+ progress_callback(3, 12, "LLM CFG Generation")
+ return "first"
+ progress_callback(2, 12, "LLM CFG Generation")
+ return "second"
+
+ with patch.dict(sys.modules, {"mlx": fake_mlx, "mlx.core": fake_mx_core}):
+ with patch.object(handler, "_run_mlx_single", side_effect=fake_run_mlx_single):
+ out = handler._run_mlx(
+ formatted_prompts=["p1", "p2"],
+ temperature=0.6,
+ cfg_scale=1.0,
+ negative_prompt="",
+ top_k=None,
+ top_p=None,
+ repetition_penalty=1.0,
+ progress_callback=lambda current, total, desc: updates.append((current, total, desc)),
+ )
+
+ self.assertEqual(out, ["first", "second"])
+ self.assertEqual(
+ updates,
+ [
+ (3, 24, "LLM CFG Generation"),
+ (12, 24, "LLM CFG Generation"),
+ (14, 24, "LLM CFG Generation"),
+ (24, 24, "LLM CFG Generation"),
+ ],
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/acestep/llm_inference_cfg_fixes_test.py b/acestep/llm_inference_cfg_fixes_test.py
index 60893a814..598fcb4f7 100644
--- a/acestep/llm_inference_cfg_fixes_test.py
+++ b/acestep/llm_inference_cfg_fixes_test.py
@@ -77,9 +77,10 @@ class TestCotCfgScaleFixed(unittest.TestCase):
def test_cot_phase_uses_cfg_scale_1(self):
"""generate_from_formatted_prompt called during CoT must receive cfg_scale=1.0."""
- handler = LLMHandler()
+ handler = _make_handler()
handler.llm_initialized = True
handler.llm_backend = "pt"
+ handler.llm = MagicMock()
captured_cfg = {}
@@ -88,34 +89,23 @@ def fake_run_pt(formatted_prompts, temperature, cfg_scale, **kwargs):
return "metadata"
with patch.object(handler, "_run_pt", side_effect=fake_run_pt):
- with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"):
- with patch.object(
- handler,
- "_format_metadata_as_cot",
- return_value="",
- ):
- with patch.object(
- handler,
- "build_formatted_prompt_with_cot",
- return_value="PROMPT_WITH_COT",
- ):
- # Simulate Phase 1 CoT call via generate_from_formatted_prompt
- # by calling the internal helper directly
- handler.generate_from_formatted_prompt(
- formatted_prompt="PROMPT",
- cfg={
- "temperature": 0.6,
- "cfg_scale": 1.0, # already 1.0 for CoT
- "negative_prompt": "NO USER INPUT",
- "top_k": None,
- "top_p": None,
- "repetition_penalty": 1.0,
- "target_duration": None,
- "generation_phase": "cot",
- "caption": "test",
- "lyrics": "test",
- },
- )
+ # Simulate Phase 1 CoT call via generate_from_formatted_prompt by
+ # invoking the helper directly with a current-valid handler state.
+ handler.generate_from_formatted_prompt(
+ formatted_prompt="PROMPT",
+ cfg={
+ "temperature": 0.6,
+ "cfg_scale": 1.0, # already 1.0 for CoT
+ "negative_prompt": "NO USER INPUT",
+ "top_k": None,
+ "top_p": None,
+ "repetition_penalty": 1.0,
+ "target_duration": None,
+ "generation_phase": "cot",
+ "caption": "test",
+ "lyrics": "test",
+ },
+ )
# cfg_scale must be 1.0 for CoT – captured during _run_pt call
self.assertEqual(captured_cfg.get("cfg_scale"), 1.0)
@@ -136,7 +126,7 @@ def capturing_gen(formatted_prompt, cfg=None, **kwargs):
with patch.object(handler, "generate_from_formatted_prompt", side_effect=capturing_gen):
with patch.object(handler, "build_formatted_prompt", return_value="P"):
- with patch.object(handler, "_parse_metadata_from_cot", return_value={}):
+ with patch.object(handler, "parse_lm_output", return_value=({}, "")):
with patch.object(handler, "_format_metadata_as_cot", return_value=""):
with patch.object(
handler, "build_formatted_prompt_with_cot", return_value="P2"
@@ -471,5 +461,118 @@ def fake_call(input_ids, **kwargs):
self.assertEqual(result[0, -1].item(), eos_id)
+@unittest.skipIf(LLMHandler is None, f"llm_inference import unavailable: {_IMPORT_ERROR}")
+class TestLmProgressCallbacks(unittest.TestCase):
+ """Progress callback coverage for LM token generation and phase mapping."""
+
+ def test_generate_with_cfg_custom_emits_token_progress(self):
+ """CFG token loop should emit monotonic per-token progress callbacks."""
+ handler = _make_handler()
+ eos_id = 1
+ vocab_size = 10
+ handler.llm_tokenizer.eos_token_id = eos_id
+
+ model = MagicMock()
+ logits = torch.full((2, 1, vocab_size), -100.0)
+ logits[:, :, 2] = 100.0
+ outputs = SimpleNamespace(logits=logits, past_key_values=None)
+ model.return_value = outputs
+ model.generation_config = MagicMock()
+ model.generation_config.use_cache = False
+ handler.llm = model
+
+ updates = []
+ handler._generate_with_cfg_custom(
+ batch_input_ids=torch.zeros((2, 3), dtype=torch.long),
+ batch_attention_mask=None,
+ max_new_tokens=2,
+ temperature=1.0,
+ cfg_scale=1.0,
+ top_k=None,
+ top_p=None,
+ repetition_penalty=1.0,
+ pad_token_id=eos_id,
+ streamer=None,
+ progress_callback=lambda current, total, desc: updates.append((current, total, desc)),
+ )
+
+ self.assertEqual(updates, [(1, 2, "LLM CFG Generation"), (2, 2, "LLM CFG Generation")])
+
+ def test_generate_with_stop_condition_maps_phase_progress_to_outer_callback(self):
+ """Phase callbacks should surface token progress onto the outer progress callback."""
+ handler = LLMHandler()
+ handler.llm_initialized = True
+ handler.llm_backend = "pt"
+
+ updates = []
+
+ def fake_generate_from_formatted_prompt(*args, **kwargs):
+ progress_callback = kwargs.get("progress_callback")
+ if progress_callback is not None:
+ progress_callback(5, 10, "LLM CFG Generation")
+ progress_callback(10, 10, "LLM CFG Generation")
+ return "bpm: 120\n", "ok"
+
+ with patch.object(handler, "generate_from_formatted_prompt", side_effect=fake_generate_from_formatted_prompt):
+ with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"):
+ with patch.object(handler, "parse_lm_output", return_value=({"bpm": 120}, "")):
+ handler.generate_with_stop_condition(
+ caption="test caption",
+ lyrics="test lyrics",
+ cfg_scale=2.0,
+ temperature=0.6,
+ negative_prompt="",
+ top_k=None,
+ top_p=None,
+ repetition_penalty=1.0,
+ infer_type="dit",
+ progress=lambda value, desc=None: updates.append((value, desc)),
+ )
+
+ llm_cfg_updates = [value for value, desc in updates if desc == "LLM CFG Generation"]
+ self.assertTrue(llm_cfg_updates)
+ self.assertGreaterEqual(min(llm_cfg_updates), 0.1)
+ self.assertLessEqual(max(llm_cfg_updates), 0.3)
+ self.assertIn((0.3, "LLM metadata generation complete"), updates)
+
+ def test_generate_with_stop_condition_closes_audio_code_phase_band(self):
+ """Phase 2 should emit an explicit terminal progress update on early success."""
+ handler = LLMHandler()
+ handler.llm_initialized = True
+ handler.llm_backend = "pt"
+
+ updates = []
+ call_count = {"count": 0}
+
+ def fake_generate_from_formatted_prompt(*args, **kwargs):
+ call_count["count"] += 1
+ progress_callback = kwargs.get("progress_callback")
+ if progress_callback is not None:
+ progress_callback(5, 10, "LLM CFG Generation")
+ if call_count["count"] == 1:
+ return "bpm: 120\n", "ok"
+ return "<|audio_code_1|><|audio_code_2|>", "ok"
+
+ with patch.object(handler, "generate_from_formatted_prompt", side_effect=fake_generate_from_formatted_prompt):
+ with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"):
+ with patch.object(handler, "parse_lm_output", side_effect=[({"bpm": 120}, ""), ({}, "<|audio_code_1|><|audio_code_2|>")]):
+ with patch.object(handler, "_format_metadata_as_cot", return_value="cot"):
+ with patch.object(handler, "build_formatted_prompt_with_cot", return_value="PROMPT2"):
+ handler.generate_with_stop_condition(
+ caption="test caption",
+ lyrics="test lyrics",
+ cfg_scale=2.0,
+ temperature=0.6,
+ negative_prompt="",
+ top_k=None,
+ top_p=None,
+ repetition_penalty=1.0,
+ infer_type="llm_dit",
+ progress=lambda value, desc=None: updates.append((value, desc)),
+ )
+
+ self.assertIn((0.5, "LLM audio code generation complete"), updates)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/acestep/models/base/modeling_acestep_v15_base.py b/acestep/models/base/modeling_acestep_v15_base.py
index 47d515493..3ef73321a 100644
--- a/acestep/models/base/modeling_acestep_v15_base.py
+++ b/acestep/models/base/modeling_acestep_v15_base.py
@@ -1854,6 +1854,7 @@ def generate_audio(
precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
audio_codes: Optional[torch.FloatTensor] = None,
use_progress_bar: bool = True,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
use_adg: bool = False,
shift: float = 1.0,
cover_noise_strength: float = 0.0,
@@ -2129,6 +2130,8 @@ def generate_audio(
xt = _repaint_step_injection(
xt, clean_src_latents, repaint_mask, t_after_step, noise,
)
+ if progress_callback is not None:
+ progress_callback(step_idx + 1, infer_steps, "DiT diffusion...")
x_gen = xt
if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0:
diff --git a/acestep/models/mlx/dit_generate.py b/acestep/models/mlx/dit_generate.py
index 46253e961..def0caf24 100644
--- a/acestep/models/mlx/dit_generate.py
+++ b/acestep/models/mlx/dit_generate.py
@@ -19,7 +19,7 @@
import logging
import time
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from tqdm import tqdm
@@ -150,6 +150,7 @@ def mlx_generate_diffusion(
encoder_hidden_states_non_cover_np: Optional[np.ndarray] = None,
context_latents_non_cover_np: Optional[np.ndarray] = None,
compile_model: bool = False,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
disable_tqdm: bool = False,
sampler_mode: str = "euler",
velocity_norm_threshold: float = 0.0,
@@ -179,6 +180,8 @@ def mlx_generate_diffusion(
encoder_hidden_states_non_cover_np: optional [B, enc_L, D] for non-cover.
context_latents_non_cover_np: optional [B, T, C] for non-cover.
compile_model: If True, compile the decoder step with ``mx.compile``.
+ progress_callback: Optional diffusion-step callback receiving
+ ``(current_step, total_steps, desc)``.
disable_tqdm: If True, suppress the diffusion progress bar.
sampler_mode: Sampler algorithm — ``"euler"`` (first-order, default) or
``"heun"`` (second-order predictor-corrector for cleaner output).
@@ -401,6 +404,8 @@ def _apply_stabilisation(vt_guided, xt_current, prev_velocity):
xt = xt - vt * dt_arr
mx.eval(xt)
+ if progress_callback is not None:
+ progress_callback(step_idx + 1, num_steps, "DiT diffusion...")
prev_vt = vt # store for EMA
diff --git a/acestep/models/mlx/dit_generate_test.py b/acestep/models/mlx/dit_generate_test.py
new file mode 100644
index 000000000..fe01789d8
--- /dev/null
+++ b/acestep/models/mlx/dit_generate_test.py
@@ -0,0 +1,90 @@
+import sys
+import types
+import unittest
+from unittest.mock import patch
+
+import numpy as np
+
+from acestep.models.mlx.dit_generate import mlx_generate_diffusion
+
+
+class _FakeRandom:
+ def normal(self, shape, key=None):
+ _ = key
+ return np.zeros(shape, dtype=np.float32)
+
+ def key(self, seed):
+ return seed
+
+
+class _FakeDecoder:
+ def __call__(
+ self,
+ hidden_states,
+ timestep,
+ timestep_r,
+ encoder_hidden_states,
+ context_latents,
+ cache=None,
+ use_cache=False,
+ ):
+ _ = timestep
+ _ = timestep_r
+ _ = encoder_hidden_states
+ _ = context_latents
+ _ = use_cache
+ return np.zeros_like(hidden_states), cache
+
+
+class MlxGenerateDiffusionProgressTests(unittest.TestCase):
+ def test_progress_callback_fires_once_per_diffusion_step(self):
+ fake_mx_core = types.ModuleType("mlx.core")
+ fake_mx_core.array = lambda value: np.array(value, dtype=np.float32)
+ fake_mx_core.concatenate = lambda values, axis=0: np.concatenate(values, axis=axis)
+ fake_mx_core.broadcast_to = lambda value, shape: np.broadcast_to(value, shape)
+ fake_mx_core.full = lambda shape, fill_value: np.full(shape, fill_value, dtype=np.float32)
+ fake_mx_core.eval = lambda value: value
+ fake_mx_core.random = _FakeRandom()
+
+ fake_mlx_pkg = types.ModuleType("mlx")
+ fake_mlx_pkg.core = fake_mx_core
+
+ fake_dit_model = types.ModuleType("acestep.models.mlx.dit_model")
+
+ class _FakeCache:
+ pass
+
+ fake_dit_model.MLXCrossAttentionCache = _FakeCache
+
+ updates = []
+ with patch.dict(
+ sys.modules,
+ {
+ "mlx": fake_mlx_pkg,
+ "mlx.core": fake_mx_core,
+ "acestep.models.mlx.dit_model": fake_dit_model,
+ },
+ ):
+ result = mlx_generate_diffusion(
+ mlx_decoder=_FakeDecoder(),
+ encoder_hidden_states_np=np.zeros((1, 2, 3), dtype=np.float32),
+ context_latents_np=np.zeros((1, 2, 3), dtype=np.float32),
+ src_latents_shape=(1, 2, 3),
+ timesteps=[1.0, 0.75, 0.5],
+ progress_callback=lambda step, total, desc: updates.append((step, total, desc)),
+ disable_tqdm=True,
+ )
+
+ self.assertEqual(
+ updates,
+ [
+ (1, 3, "DiT diffusion..."),
+ (2, 3, "DiT diffusion..."),
+ (3, 3, "DiT diffusion..."),
+ ],
+ )
+ self.assertEqual(tuple(result["target_latents"].shape), (1, 2, 3))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/acestep/models/sft/modeling_acestep_v15_base.py b/acestep/models/sft/modeling_acestep_v15_base.py
index 3310a6792..280379d0f 100644
--- a/acestep/models/sft/modeling_acestep_v15_base.py
+++ b/acestep/models/sft/modeling_acestep_v15_base.py
@@ -1854,6 +1854,7 @@ def generate_audio(
precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
audio_codes: Optional[torch.FloatTensor] = None,
use_progress_bar: bool = True,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
use_adg: bool = False,
shift: float = 1.0,
timesteps: Optional[torch.Tensor] = None,
@@ -2133,6 +2134,8 @@ def generate_audio(
xt = _repaint_step_injection(
xt, clean_src_latents, repaint_mask, t_after_step, noise,
)
+ if progress_callback is not None:
+ progress_callback(step_idx + 1, infer_steps, "DiT diffusion...")
x_gen = xt
if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0:
diff --git a/acestep/models/turbo/modeling_acestep_v15_turbo.py b/acestep/models/turbo/modeling_acestep_v15_turbo.py
index 5ec0bdbb6..510c09688 100644
--- a/acestep/models/turbo/modeling_acestep_v15_turbo.py
+++ b/acestep/models/turbo/modeling_acestep_v15_turbo.py
@@ -1844,6 +1844,8 @@ def generate_audio(
non_cover_text_attention_mask: Optional[torch.FloatTensor] = None,
precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None,
audio_codes: Optional[torch.FloatTensor] = None,
+ use_progress_bar: bool = True,
+ progress_callback: Optional[Callable[[int, int, str], None]] = None,
shift: float = 3.0,
timesteps: Optional[torch.Tensor] = None,
cover_noise_strength: float = 0.0,
@@ -2045,6 +2047,8 @@ def generate_audio(
# On final step, directly compute x0 from noise
if step_idx == num_steps - 1:
xt = self.get_x0_from_noise(xt, vt, t_curr_tensor)
+ if progress_callback is not None:
+ progress_callback(step_idx + 1, num_steps, "DiT diffusion...")
prev_vt = vt
break
@@ -2104,6 +2108,8 @@ def generate_audio(
xt = _repaint_step_injection(
xt, clean_src_latents, repaint_mask, t_after_step, noise,
)
+ if progress_callback is not None:
+ progress_callback(step_idx + 1, num_steps, "DiT diffusion...")
x_gen = xt
if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0:
@@ -2272,4 +2278,4 @@ def test_forward(model, seed=42):
# model = model.float()
model = model.to("cuda")
model = model.bfloat16()
- test_forward(model)
\ No newline at end of file
+ test_forward(model)
diff --git a/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py b/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py
new file mode 100644
index 000000000..28214fbb4
--- /dev/null
+++ b/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py
@@ -0,0 +1,90 @@
+import unittest
+
+import torch
+
+from acestep.models.turbo.modeling_acestep_v15_turbo import AceStepConditionGenerationModel
+
+
+class _FakeDecoder:
+ def __call__(
+ self,
+ hidden_states,
+ timestep,
+ timestep_r,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ context_latents,
+ use_cache,
+ past_key_values,
+ ):
+ _ = timestep
+ _ = timestep_r
+ _ = attention_mask
+ _ = encoder_hidden_states
+ _ = encoder_attention_mask
+ _ = context_latents
+ _ = use_cache
+ return torch.zeros_like(hidden_states), past_key_values
+
+
+class _FakeTurboHost:
+ def __init__(self):
+ self.decoder = _FakeDecoder()
+
+ def prepare_condition(self, **kwargs):
+ src_latents = kwargs["src_latents"]
+ attention_mask = kwargs["attention_mask"]
+ return src_latents, attention_mask, src_latents
+
+ def prepare_noise(self, context_latents, seed):
+ _ = seed
+ return torch.zeros_like(context_latents)
+
+ def get_x0_from_noise(self, zt, vt, t):
+ return zt - vt * t.unsqueeze(-1).unsqueeze(-1)
+
+ def renoise(self, x, t, noise=None):
+ _ = t
+ _ = noise
+ return x
+
+
+class AceStepTurboProgressTests(unittest.TestCase):
+ def test_progress_callback_fires_once_per_step_including_final_step(self):
+ host = _FakeTurboHost()
+ updates = []
+ base = torch.zeros((1, 2, 2), dtype=torch.float32)
+ mask = torch.ones((1, 2), dtype=torch.float32)
+ cover_mask = torch.zeros((1,), dtype=torch.float32)
+
+ AceStepConditionGenerationModel.generate_audio(
+ host,
+ text_hidden_states=base,
+ text_attention_mask=mask,
+ lyric_hidden_states=base,
+ lyric_attention_mask=mask,
+ refer_audio_acoustic_hidden_states_packed=base,
+ refer_audio_order_mask=mask,
+ src_latents=base,
+ chunk_masks=base,
+ is_covers=cover_mask,
+ silence_latent=base,
+ attention_mask=mask,
+ seed=0,
+ infer_method="ode",
+ timesteps=[1.0, 0.5],
+ progress_callback=lambda step, total, desc: updates.append((step, total, desc)),
+ )
+
+ self.assertEqual(
+ updates,
+ [
+ (1, 2, "DiT diffusion..."),
+ (2, 2, "DiT diffusion..."),
+ ],
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()