diff --git a/acestep/core/generation/handler/diffusion.py b/acestep/core/generation/handler/diffusion.py index 228beaf70..b4a1efdd7 100644 --- a/acestep/core/generation/handler/diffusion.py +++ b/acestep/core/generation/handler/diffusion.py @@ -34,6 +34,7 @@ def _mlx_run_diffusion( encoder_hidden_states_non_cover=None, encoder_attention_mask_non_cover=None, context_latents_non_cover=None, + progress_callback=None, disable_tqdm: bool = False, sampler_mode: str = "euler", velocity_norm_threshold: float = 0.0, @@ -59,6 +60,8 @@ def _mlx_run_diffusion( encoder_hidden_states_non_cover: Optional non-cover conditioning tensor. encoder_attention_mask_non_cover: Unused; accepted for API compatibility. context_latents_non_cover: Optional non-cover context latent tensor. + progress_callback: Optional diffusion-step callback receiving + ``(current_step, total_steps, desc)``. disable_tqdm: If True, suppress the diffusion progress bar. sampler_mode: Sampler algorithm — ``"euler"`` or ``"heun"``. velocity_norm_threshold: Velocity norm clamping threshold (0 = disabled). @@ -141,6 +144,7 @@ def _mlx_run_diffusion( encoder_hidden_states_non_cover_np=enc_nc_np, context_latents_non_cover_np=ctx_nc_np, compile_model=getattr(self, "mlx_dit_compiled", False), + progress_callback=progress_callback, disable_tqdm=disable_tqdm, sampler_mode=sampler_mode, velocity_norm_threshold=velocity_norm_threshold, diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py index b7330503c..e04704e28 100644 --- a/acestep/core/generation/handler/generate_music_decode.py +++ b/acestep/core/generation/handler/generate_music_decode.py @@ -122,8 +122,19 @@ def _decode_generate_music_pred_latents( Returns: Tuple of decoded waveforms, CPU latents, and updated time-cost payload. """ + def _emit_decode_progress(value: float, desc: str) -> None: + if progress: + progress(value, desc=desc) + + def _decode_chunk_progress(current: int, total: int, desc: str) -> None: + if total <= 0: + return + frac = min(1.0, max(0.0, current / total)) + mapped = 0.82 + 0.16 * frac + _emit_decode_progress(mapped, desc) + if progress: - progress(0.8, desc="Decoding audio...") + _emit_decode_progress(0.8, "Preparing audio decode...") logger.info("[generate_music] Decoding latents with VAE...") start_time = time.time() with torch.inference_mode(): @@ -168,12 +179,19 @@ def _decode_generate_music_pred_latents( pred_latents_for_decode = pred_latents_for_decode.cpu() self._empty_cache() try: + _emit_decode_progress(0.82, "Decoding audio chunks...") if use_tiled_decode: logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") - pred_wavs = self.tiled_decode(pred_latents_for_decode) + pred_wavs = self.tiled_decode( + pred_latents_for_decode, + progress_callback=_decode_chunk_progress, + ) elif using_mlx_vae: try: - pred_wavs = self._mlx_vae_decode(pred_latents_for_decode) + pred_wavs = self._mlx_vae_decode( + pred_latents_for_decode, + progress_callback=_decode_chunk_progress, + ) except Exception as exc: logger.warning( f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch" @@ -202,6 +220,7 @@ def _decode_generate_music_pred_latents( if torch.any(peak > 1.0): pred_wavs = pred_wavs / peak.clamp(min=1.0) self._empty_cache() + _emit_decode_progress(0.98, "Decoding audio chunks...") gc.collect() self._empty_cache() end_time = time.time() diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py index ece612ff1..ae1c9c1d9 100644 --- a/acestep/core/generation/handler/generate_music_decode_test.py +++ b/acestep/core/generation/handler/generate_music_decode_test.py @@ -115,14 +115,20 @@ def _max_memory_allocated(self): """Return deterministic max-memory value for debug logging.""" return 0.0 - def _mlx_vae_decode(self, latents): + def _mlx_vae_decode(self, latents, progress_callback=None): """Return deterministic decoded waveform for MLX decode branch.""" _ = latents + if progress_callback is not None: + progress_callback(1, 64, "Decoding audio chunks...") + progress_callback(64, 64, "Decoding audio chunks...") return torch.ones(1, 2, 8) - def tiled_decode(self, latents): + def tiled_decode(self, latents, progress_callback=None): """Return deterministic decoded waveform for tiled decode branch.""" _ = latents + if progress_callback is not None: + progress_callback(1, 64, "Decoding audio chunks...") + progress_callback(64, 64, "Decoding audio chunks...") return torch.ones(1, 2, 8) @@ -189,7 +195,11 @@ def _progress(value, desc=None): self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6) self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6) self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6) - self.assertEqual(host.progress_calls[0][0], 0.8) + self.assertEqual(host.progress_calls[0], (0.8, "Preparing audio decode...")) + self.assertIn((0.82, "Decoding audio chunks..."), host.progress_calls) + progress_values = [value for value, _ in host.progress_calls] + self.assertEqual(progress_values, sorted(progress_values)) + self.assertAlmostEqual(host.progress_calls[-1][0], 0.98, places=6) def test_decode_pred_latents_restores_vae_device_on_decode_error(self): """It restores VAE device in the CPU-offload path even when decode raises.""" @@ -313,4 +323,3 @@ def __init__(self): if __name__ == "__main__": unittest.main() - diff --git a/acestep/core/generation/handler/generate_music_execute.py b/acestep/core/generation/handler/generate_music_execute.py index b8deecf45..c016c8835 100644 --- a/acestep/core/generation/handler/generate_music_execute.py +++ b/acestep/core/generation/handler/generate_music_execute.py @@ -1,11 +1,14 @@ """Execution helper for ``generate_music`` service invocation with progress tracking.""" import os -import threading -from typing import Any, Dict, List, Optional, Sequence +import time +from typing import Any, Callable, Dict, List, Optional, Sequence +import threading from loguru import logger +from .runtime_progress_relay import RuntimeProgressRelay + # Maximum wall-clock seconds to wait for service_generate before declaring a hang. # Generous default: most generations finish in 30-120s, but large batches on slow # GPUs can take several minutes. Override via ACESTEP_GENERATION_TIMEOUT env var. @@ -47,9 +50,10 @@ def _run_generate_music_service_with_progress( """ infer_steps_for_progress = len(timesteps) if timesteps else inference_steps progress_desc = f"Generating music (batch size: {actual_batch_size})..." - progress(0.52, desc=progress_desc) stop_event = None progress_thread = None + relay = RuntimeProgressRelay(progress=progress, start=0.52, end=0.79) + relay.emit_progress(0.52, progress_desc) # --- Timeout-wrapped service_generate --- # Run the actual CUDA work in a child thread so we can join() with a @@ -91,13 +95,19 @@ def _service_target(): chunk_mask_modes=service_inputs.get("chunk_mask_modes_batch"), repaint_crossfade_frames=repaint_crossfade_frames, repaint_injection_ratio=repaint_injection_ratio, + progress_callback=relay.enqueue, ) except Exception as exc: _error["exc"] = exc + def _stop_progress_estimator_if_finished() -> None: + """Stop the estimator thread but keep its handle until it actually exits.""" + nonlocal progress_thread + progress_thread = RuntimeProgressRelay.stop_estimator_if_finished(progress_thread) + try: stop_event, progress_thread = self._start_diffusion_progress_estimator( - progress=progress, + progress=relay.emit_progress, start=0.52, end=0.79, infer_steps=infer_steps_for_progress, @@ -112,7 +122,29 @@ def _service_target(): daemon=True, ) gen_thread.start() - gen_thread.join(timeout=_DEFAULT_GENERATION_TIMEOUT) + deadline = time.monotonic() + _DEFAULT_GENERATION_TIMEOUT + poll_interval = 0.1 + saw_runtime_progress = False + while gen_thread.is_alive(): + remaining = deadline - time.monotonic() + if remaining <= 0: + break + gen_thread.join(timeout=min(poll_interval, remaining)) + drained_runtime_progress = relay.drain() + if drained_runtime_progress and not saw_runtime_progress: + saw_runtime_progress = True + if stop_event is not None: + stop_event.set() + _stop_progress_estimator_if_finished() + if not gen_thread.is_alive(): + break + + drained_runtime_progress = relay.drain() + if drained_runtime_progress and not saw_runtime_progress: + saw_runtime_progress = True + if stop_event is not None: + stop_event.set() + _stop_progress_estimator_if_finished() if gen_thread.is_alive(): logger.error( @@ -128,11 +160,13 @@ def _service_target(): ) if "exc" in _error: raise _error["exc"] + relay.emit_progress(0.79, progress_desc) finally: + relay.shutdown() if stop_event is not None: stop_event.set() if progress_thread is not None: progress_thread.join(timeout=1.0) - return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress} \ No newline at end of file + return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress} diff --git a/acestep/core/generation/handler/generate_music_execute_test.py b/acestep/core/generation/handler/generate_music_execute_test.py index aead22214..a105f6c02 100644 --- a/acestep/core/generation/handler/generate_music_execute_test.py +++ b/acestep/core/generation/handler/generate_music_execute_test.py @@ -1,6 +1,8 @@ """Unit tests for ``generate_music`` execution helper mixin.""" import unittest +import time +import threading from acestep.core.generation.handler.generate_music_execute import GenerateMusicExecuteMixin @@ -12,12 +14,19 @@ def __init__(self): """Capture calls for assertions.""" self.started = False self.stopped = False + self.stop_calls = 0 self.service_calls = 0 + self.emit_runtime_progress = True + self.estimator_progress_values = [] + self.service_delay_sec = 0.0 def _start_diffusion_progress_estimator(self, **kwargs): """Return fake stop event/thread handles used by helper lifecycle.""" - _ = kwargs self.started = True + progress = kwargs["progress"] + desc = kwargs["desc"] + for value in self.estimator_progress_values: + progress(value, desc=desc) class _Stop: """Minimal stop-event stand-in used by the test host.""" @@ -29,6 +38,7 @@ def __init__(self, host): def set(self): """Mark progress lifecycle as stopped.""" self.host.stopped = True + self.host.stop_calls += 1 class _Thread: """Minimal thread stand-in exposing a ``join`` method.""" @@ -41,8 +51,13 @@ def join(self, timeout=None): def service_generate(self, **kwargs): """Record service invocation and return minimal output payload.""" - _ = kwargs + callback = kwargs.get("progress_callback") self.service_calls += 1 + if self.service_delay_sec > 0: + time.sleep(self.service_delay_sec) + if callback is not None and self.emit_runtime_progress: + callback(1, 4, "DiT diffusion...") + callback(4, 4, "DiT diffusion...") return {"target_latents": "ok"} @@ -52,8 +67,11 @@ class GenerateMusicExecuteMixinTests(unittest.TestCase): def test_run_service_with_progress_invokes_service_and_stops_estimator(self): """Helper should call service once and always stop progress estimator.""" host = _Host() + host.emit_runtime_progress = False + host.estimator_progress_values = [0.63, 0.79] + updates = [] out = host._run_generate_music_service_with_progress( - progress=lambda *args, **kwargs: None, + progress=lambda value, desc=None: updates.append((value, desc)), actual_batch_size=1, audio_duration=10.0, inference_steps=8, @@ -83,8 +101,157 @@ def test_run_service_with_progress_invokes_service_and_stops_estimator(self): ) self.assertTrue(host.started) self.assertTrue(host.stopped) + self.assertEqual(host.stop_calls, 1) self.assertEqual(host.service_calls, 1) self.assertEqual(out["outputs"]["target_latents"], "ok") + self.assertAlmostEqual(updates[-1][0], 0.79, places=6) + + def test_runtime_progress_events_are_forwarded_to_ui_progress(self): + """Step-level runtime progress should stop the estimator and reach phase completion.""" + host = _Host() + updates = [] + + host._run_generate_music_service_with_progress( + progress=lambda value, desc=None: updates.append((value, desc)), + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + self.assertTrue(any(desc == "DiT diffusion..." for _, desc in updates)) + self.assertEqual(host.stop_calls, 2) + self.assertAlmostEqual(updates[-1][0], 0.79, places=6) + + def test_runtime_progress_handoff_stays_monotonic_after_estimator_advances(self): + """Runtime callbacks should not drive the UI backwards after estimator progress.""" + host = _Host() + host.estimator_progress_values = [0.68] + updates = [] + + host._run_generate_music_service_with_progress( + progress=lambda value, desc=None: updates.append((value, desc)), + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + progress_values = [value for value, _ in updates] + self.assertEqual(progress_values, sorted(progress_values)) + self.assertGreaterEqual(progress_values[2], 0.68) + + def test_runtime_progress_does_not_publish_behind_blocked_estimator_callback(self): + """A blocked estimator callback should not allow a stale publish after runtime progress.""" + + class _AsyncEstimatorHost(_Host): + def _start_diffusion_progress_estimator(self, **kwargs): + self.started = True + progress = kwargs["progress"] + desc = kwargs["desc"] + + class _Stop: + def __init__(self, host): + self.host = host + + def set(self): + self.host.stopped = True + self.host.stop_calls += 1 + + thread = threading.Thread( + target=lambda: progress(0.63, desc=desc), + name="test-estimator", + daemon=True, + ) + thread.start() + return _Stop(self), thread + + host = _AsyncEstimatorHost() + blocked_estimator = threading.Event() + updates = [] + + def progress(value, desc=None): + if value == 0.63 and not blocked_estimator.is_set(): + blocked_estimator.set() + time.sleep(0.05) + updates.append((value, desc)) + + host._run_generate_music_service_with_progress( + progress=progress, + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + progress_values = [value for value, _ in updates] + self.assertEqual(progress_values, sorted(progress_values)) if __name__ == "__main__": diff --git a/acestep/core/generation/handler/generate_music_execute_thread_test.py b/acestep/core/generation/handler/generate_music_execute_thread_test.py new file mode 100644 index 000000000..056410ddb --- /dev/null +++ b/acestep/core/generation/handler/generate_music_execute_thread_test.py @@ -0,0 +1,92 @@ +"""Regression coverage for progress estimator thread lifecycle.""" + +import unittest + +from acestep.core.generation.handler.generate_music_execute import GenerateMusicExecuteMixin + + +class _Host(GenerateMusicExecuteMixin): + """Minimal host for exercising estimator shutdown behavior.""" + + def __init__(self): + self.stop_calls = 0 + self.thread = None + + def _start_diffusion_progress_estimator(self, **kwargs): + class _Stop: + def __init__(self, host): + self.host = host + + def set(self): + self.host.stop_calls += 1 + + class _Thread: + def __init__(self): + self.join_calls = 0 + self._alive = True + + def join(self, timeout=None): + _ = timeout + self.join_calls += 1 + if self.join_calls >= 2: + self._alive = False + + def is_alive(self): + return self._alive + + self.thread = _Thread() + return _Stop(self), self.thread + + def service_generate(self, **kwargs): + progress_callback = kwargs["progress_callback"] + progress_callback(1, 4, "DiT diffusion...") + return {"target_latents": "ok"} + + +class GenerateMusicExecuteThreadLifecycleTests(unittest.TestCase): + """Ensure estimator thread handles survive timed joins until actual shutdown.""" + + def test_runtime_progress_keeps_estimator_thread_handle_until_thread_exits(self): + host = _Host() + updates = [] + + out = host._run_generate_music_service_with_progress( + progress=lambda value, desc=None: updates.append((value, desc)), + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + self.assertEqual(out["outputs"]["target_latents"], "ok") + self.assertEqual(host.stop_calls, 2) + self.assertIsNotNone(host.thread) + self.assertEqual(host.thread.join_calls, 2) + self.assertFalse(host.thread.is_alive()) + self.assertTrue(any(desc == "DiT diffusion..." for _, desc in updates)) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/core/generation/handler/mlx_vae_decode_native.py b/acestep/core/generation/handler/mlx_vae_decode_native.py index 23b5c21ea..eb9cefc8f 100644 --- a/acestep/core/generation/handler/mlx_vae_decode_native.py +++ b/acestep/core/generation/handler/mlx_vae_decode_native.py @@ -28,7 +28,7 @@ def _resolve_mlx_decode_fn(self): raise RuntimeError("MLX VAE decode requested but mlx_vae is not initialized.") return self.mlx_vae.decode - def _mlx_vae_decode(self, latents_torch): + def _mlx_vae_decode(self, latents_torch, progress_callback=None): """Decode batched PyTorch latents using native MLX VAE decode. Args: @@ -52,7 +52,18 @@ def _mlx_vae_decode(self, latents_torch): decode_fn = self._resolve_mlx_decode_fn() audio_parts = [] for idx in range(batch_size): - decoded = self._mlx_decode_single(latents_mx[idx : idx + 1], decode_fn=decode_fn) + item_progress_callback = None + if progress_callback is not None: + def _item_progress(current, total, desc, offset=idx, batch_total=batch_size): + progress_callback(offset * total + current, batch_total * total, desc) + + item_progress_callback = _item_progress + + decoded = self._mlx_decode_single( + latents_mx[idx : idx + 1], + decode_fn=decode_fn, + progress_callback=item_progress_callback, + ) if decoded.dtype != mx.float32: decoded = decoded.astype(mx.float32) mx.eval(decoded) @@ -71,7 +82,7 @@ def _mlx_vae_decode(self, latents_torch): ) return torch.from_numpy(audio_ncl) - def _mlx_decode_single(self, z_nlc, decode_fn=None): + def _mlx_decode_single(self, z_nlc, decode_fn=None, progress_callback=None): """Decode a single MLX latent sample with optional tiling. Args: @@ -91,7 +102,10 @@ def _mlx_decode_single(self, z_nlc, decode_fn=None): mlx_overlap = 64 if latent_frames <= mlx_chunk: - return decode_fn(z_nlc) + out = decode_fn(z_nlc) + if progress_callback is not None: + progress_callback(1, 1, "Decoding audio chunks...") + return out stride = mlx_chunk - 2 * mlx_overlap num_steps = math.ceil(latent_frames / stride) @@ -115,5 +129,7 @@ def _mlx_decode_single(self, z_nlc, decode_fn=None): audio_len = audio_chunk.shape[1] end_idx = audio_len - trim_end if trim_end > 0 else audio_len decoded_parts.append(audio_chunk[:, trim_start:end_idx, :]) + if progress_callback is not None: + progress_callback(idx + 1, num_steps, "Decoding audio chunks...") return mx.concatenate(decoded_parts, axis=1) diff --git a/acestep/core/generation/handler/runtime_progress_relay.py b/acestep/core/generation/handler/runtime_progress_relay.py new file mode 100644 index 000000000..04a9470a5 --- /dev/null +++ b/acestep/core/generation/handler/runtime_progress_relay.py @@ -0,0 +1,81 @@ +"""Runtime progress relay helpers for generation execution.""" + +import queue +import threading +from typing import Callable, Optional + +from loguru import logger + + +class RuntimeProgressRelay: + """Bridge runtime step events onto monotonic UI progress updates.""" + + def __init__( + self, + *, + progress: Optional[Callable[[float], None]], + start: float, + end: float, + ) -> None: + self._progress = progress + self._start = start + self._end = end + self._events: "queue.Queue[tuple[int, int, str] | None]" = queue.Queue() + self._lock = threading.Lock() + self._value = 0.0 + self._active = True + + def enqueue(self, current: int, total: int, desc: str) -> None: + """Queue a runtime progress event while the relay is active.""" + with self._lock: + if not self._active: + return + self._events.put((current, total, desc)) + + def emit_progress(self, value: float, desc: Optional[str] = None) -> float: + """Emit a monotonic progress update to the UI callback.""" + with self._lock: + clamped = max(self._value, value) + self._value = clamped + if self._progress is not None: + try: + self._progress(clamped, desc=desc) + except Exception as exc: + logger.debug("[generate_music] Ignoring progress callback error: {}", exc) + return clamped + + def drain(self) -> bool: + """Drain queued runtime events and map them onto the configured UI range.""" + drained = False + while True: + try: + item = self._events.get_nowait() + except queue.Empty: + return drained + + if item is None: + return drained + + current, total, desc = item + if total <= 0: + continue + frac = min(1.0, max(0.0, current / total)) + mapped = self._start + (self._end - self._start) * frac + self.emit_progress(mapped, desc) + drained = True + + def shutdown(self) -> None: + """Disable future runtime event forwarding.""" + with self._lock: + self._active = False + self._events.put(None) + + @staticmethod + def stop_estimator_if_finished(progress_thread): + """Return the estimator thread handle only while it is still alive.""" + if progress_thread is None: + return None + progress_thread.join(timeout=1.0) + if hasattr(progress_thread, "is_alive") and progress_thread.is_alive(): + return progress_thread + return None diff --git a/acestep/core/generation/handler/runtime_progress_relay_test.py b/acestep/core/generation/handler/runtime_progress_relay_test.py new file mode 100644 index 000000000..6d3fc7e8a --- /dev/null +++ b/acestep/core/generation/handler/runtime_progress_relay_test.py @@ -0,0 +1,43 @@ +"""Unit tests for runtime progress relay helpers.""" + +import unittest + +from acestep.core.generation.handler.runtime_progress_relay import RuntimeProgressRelay + + +class RuntimeProgressRelayTests(unittest.TestCase): + """Verify runtime progress relay mapping and shutdown behavior.""" + + def test_drain_keeps_progress_monotonic(self): + updates = [] + relay = RuntimeProgressRelay( + progress=lambda value, desc=None: updates.append((value, desc)), + start=0.52, + end=0.79, + ) + + relay.emit_progress(0.68, "estimator") + relay.enqueue(1, 4, "DiT diffusion...") + relay.enqueue(4, 4, "DiT diffusion...") + + self.assertTrue(relay.drain()) + self.assertEqual([value for value, _ in updates], sorted(value for value, _ in updates)) + self.assertAlmostEqual(updates[-1][0], 0.79, places=6) + + def test_shutdown_ignores_late_runtime_events(self): + updates = [] + relay = RuntimeProgressRelay( + progress=lambda value, desc=None: updates.append((value, desc)), + start=0.52, + end=0.79, + ) + + relay.shutdown() + relay.enqueue(4, 4, "DiT diffusion...") + + self.assertFalse(relay.drain()) + self.assertEqual(updates, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/core/generation/handler/service_generate.py b/acestep/core/generation/handler/service_generate.py index 2fcc6cab1..d56f899e4 100644 --- a/acestep/core/generation/handler/service_generate.py +++ b/acestep/core/generation/handler/service_generate.py @@ -5,7 +5,7 @@ and output attachment without owning model internals. """ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -47,6 +47,7 @@ def service_generate( chunk_mask_modes: Optional[List[str]] = None, repaint_crossfade_frames: int = 10, repaint_injection_ratio: float = 0.5, + progress_callback: Optional[Callable[[int, int, str], None]] = None, sampler_mode: str = "euler", velocity_norm_threshold: float = 0.0, velocity_ema_factor: float = 0.0, @@ -79,6 +80,8 @@ def service_generate( timesteps: Optional explicit diffusion timestep sequence. repaint_crossfade_frames: Crossfade width (latent frames) at repaint boundaries for boundary blending. ~0.4s at 25 Hz. + progress_callback: Optional diffusion-step callback receiving + ``(current_step, total_steps, desc)``. sampler_mode: Sampler algorithm — ``"euler"`` or ``"heun"``. velocity_norm_threshold: Velocity norm clamping threshold (0 = disabled). velocity_ema_factor: Velocity EMA smoothing factor (0 = disabled). @@ -140,6 +143,7 @@ def service_generate( timesteps=timesteps, repaint_crossfade_frames=repaint_crossfade_frames, repaint_injection_ratio=repaint_injection_ratio, + progress_callback=progress_callback, sampler_mode=sampler_mode, velocity_norm_threshold=velocity_norm_threshold, velocity_ema_factor=velocity_ema_factor, diff --git a/acestep/core/generation/handler/service_generate_execute.py b/acestep/core/generation/handler/service_generate_execute.py index 5946a6132..2bf35c0f3 100644 --- a/acestep/core/generation/handler/service_generate_execute.py +++ b/acestep/core/generation/handler/service_generate_execute.py @@ -1,7 +1,7 @@ """Execution helpers for service generation diffusion and output assembly.""" import random -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from loguru import logger @@ -77,6 +77,7 @@ def _build_service_generate_kwargs( timesteps: Optional[List[float]], repaint_crossfade_frames: int = 10, repaint_injection_ratio: float = 0.5, + progress_callback: Optional[Callable[[int, int, str], None]] = None, sampler_mode: str = "euler", velocity_norm_threshold: float = 0.0, velocity_ema_factor: float = 0.0, @@ -109,6 +110,8 @@ def _build_service_generate_kwargs( "cfg_interval_start": cfg_interval_start, "cfg_interval_end": cfg_interval_end, "shift": shift, + "use_progress_bar": not getattr(self, "disable_tqdm", False), + "progress_callback": progress_callback, "repaint_mask": repaint_mask, "clean_src_latents": clean_src_latents, "repaint_crossfade_frames": repaint_crossfade_frames, @@ -203,6 +206,8 @@ def _execute_service_generate_diffusion( encoder_hidden_states_non_cover=enc_hs_nc, encoder_attention_mask_non_cover=enc_am_nc, context_latents_non_cover=ctx_nc, + progress_callback=generate_kwargs.get("progress_callback"), + disable_tqdm=not generate_kwargs.get("use_progress_bar", True), sampler_mode=generate_kwargs.get("sampler_mode", "euler"), velocity_norm_threshold=generate_kwargs.get("velocity_norm_threshold", 0.0), velocity_ema_factor=generate_kwargs.get("velocity_ema_factor", 0.0), diff --git a/acestep/core/generation/handler/service_generate_execute_test.py b/acestep/core/generation/handler/service_generate_execute_test.py index 8af115425..ed2d61df5 100644 --- a/acestep/core/generation/handler/service_generate_execute_test.py +++ b/acestep/core/generation/handler/service_generate_execute_test.py @@ -15,6 +15,7 @@ class _Host(ServiceGenerateExecuteMixin, ServiceGenerateOutputsMixin): def __init__(self): """Initialize static runtime fields for helper-method tests.""" self.device = "cpu" + self.disable_tqdm = False self.silence_latent = torch.zeros(1, 4, 4, dtype=torch.float32) @@ -58,6 +59,47 @@ def test_build_generate_kwargs_adds_timesteps_tensor(self): self.assertEqual(kwargs["infer_steps"], 16) self.assertEqual(kwargs["timesteps"].dtype, torch.float32) self.assertEqual(kwargs["timesteps"].device.type, "cpu") + self.assertTrue(kwargs["use_progress_bar"]) + self.assertIsNone(kwargs["progress_callback"]) + + def test_build_generate_kwargs_forwards_runtime_progress_callback(self): + """Runtime callbacks should be forwarded into model-generation kwargs.""" + host = _Host() + payload = { + "text_hidden_states": torch.zeros(1, 2), + "text_attention_mask": torch.ones(1, 2), + "lyric_hidden_states": torch.zeros(1, 2), + "lyric_attention_mask": torch.ones(1, 2), + "refer_audio_acoustic_hidden_states_packed": torch.zeros(1, 2), + "refer_audio_order_mask": torch.zeros(1, dtype=torch.long), + "src_latents": torch.zeros(1, 4, 4), + "chunk_mask": torch.ones(1, 4, dtype=torch.bool), + "is_covers": torch.tensor([True]), + "non_cover_text_hidden_states": None, + "non_cover_text_attention_masks": None, + "precomputed_lm_hints_25Hz": None, + } + def callback(current, total, desc): + return (current, total, desc) + + kwargs = host._build_service_generate_kwargs( + payload=payload, + seed_param=123, + infer_steps=16, + guidance_scale=7.0, + audio_cover_strength=1.0, + cover_noise_strength=0.0, + infer_method="ode", + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + timesteps=None, + progress_callback=callback, + ) + + self.assertIs(kwargs["progress_callback"], callback) + self.assertTrue(kwargs["use_progress_bar"]) def test_attach_service_outputs_persists_required_fields(self): """Attached payload fields should be available to downstream handlers.""" diff --git a/acestep/core/generation/handler/service_generate_test.py b/acestep/core/generation/handler/service_generate_test.py index 6287010ad..4495c6176 100644 --- a/acestep/core/generation/handler/service_generate_test.py +++ b/acestep/core/generation/handler/service_generate_test.py @@ -128,10 +128,12 @@ def test_service_generate_forwards_runtime_controls_to_build_and_execute(self): """It forwards runtime tuning controls to downstream helper invocations.""" host = _Host() custom_timesteps = [1.0, 0.5, 0.0] + progress_callback = lambda current, total, desc: (current, total, desc) host.service_generate( captions="cap", lyrics="lyr", guidance_scale=9.5, audio_cover_strength=0.7, cover_noise_strength=0.2, use_adg=True, cfg_interval_start=0.1, cfg_interval_end=0.9, shift=1.3, infer_method="sde", timesteps=custom_timesteps, return_intermediate=False, + progress_callback=progress_callback, ) build_kwargs = host.calls["_build_service_generate_kwargs"] self.assertEqual(build_kwargs["guidance_scale"], 9.5) @@ -139,6 +141,7 @@ def test_service_generate_forwards_runtime_controls_to_build_and_execute(self): self.assertEqual(build_kwargs["cover_noise_strength"], 0.2) self.assertTrue(build_kwargs["use_adg"]) self.assertEqual(build_kwargs["timesteps"], custom_timesteps) + self.assertIs(build_kwargs["progress_callback"], progress_callback) self.assertFalse(host.calls["_attach_service_generate_outputs"]["return_intermediate"]) execute_kwargs = host.calls["_execute_service_generate_diffusion"] self.assertEqual(execute_kwargs["infer_method"], "sde") diff --git a/acestep/core/generation/handler/vae_decode.py b/acestep/core/generation/handler/vae_decode.py index c3ea76174..03a00872a 100644 --- a/acestep/core/generation/handler/vae_decode.py +++ b/acestep/core/generation/handler/vae_decode.py @@ -19,6 +19,7 @@ def tiled_decode( chunk_size: Optional[int] = None, overlap: int = 64, offload_wav_to_cpu: Optional[bool] = None, + progress_callback=None, ): """Decode latents using tiling to reduce VRAM usage. @@ -32,6 +33,8 @@ def tiled_decode( overlap: Overlap in latent frames between adjacent windows. offload_wav_to_cpu: Whether decoded waveform chunks should be offloaded to CPU immediately to reduce VRAM pressure. + progress_callback: Optional callback receiving + ``(current_step, total_steps, desc)`` for chunk decode progress. Returns: Decoded waveform tensor shaped ``[batch, audio_channels, samples]``. @@ -39,7 +42,7 @@ def tiled_decode( # ---- MLX fast path (macOS Apple Silicon) ---- if self.use_mlx_vae and self.mlx_vae is not None: try: - result = self._mlx_vae_decode(latents) + result = self._mlx_vae_decode(latents, progress_callback=progress_callback) return result except Exception as exc: logger.warning( @@ -74,7 +77,13 @@ def tiled_decode( overlap = min(overlap, _mps_overlap) try: - return self._tiled_decode_inner(latents, chunk_size, overlap, offload_wav_to_cpu) + return self._tiled_decode_inner( + latents, + chunk_size, + overlap, + offload_wav_to_cpu, + progress_callback=progress_callback, + ) except (NotImplementedError, RuntimeError) as exc: if not _is_mps: raise diff --git a/acestep/core/generation/handler/vae_decode_chunks.py b/acestep/core/generation/handler/vae_decode_chunks.py index 8564de9ad..3ee50cb44 100644 --- a/acestep/core/generation/handler/vae_decode_chunks.py +++ b/acestep/core/generation/handler/vae_decode_chunks.py @@ -10,9 +10,16 @@ class VaeDecodeChunksMixin: """Implement chunked decode strategies for GPU and CPU-offload modes.""" - def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): + def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu, progress_callback=None): """Run tiled decode with adaptive overlap and OOM fallbacks.""" bsz, _channels, latent_frames = latents.shape + completed_steps = 0 + + def _monotonic_progress(current, total, desc): + nonlocal completed_steps + completed_steps = max(completed_steps, current) + if progress_callback is not None: + progress_callback(completed_steps, total, desc) # Batch-sequential decode keeps peak VRAM stable across batch sizes. if bsz > 1: @@ -20,7 +27,18 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): per_sample_results = [] for b_idx in range(bsz): single = latents[b_idx : b_idx + 1] - decoded = self._tiled_decode_inner(single, chunk_size, overlap, offload_wav_to_cpu) + sample_progress = None + if progress_callback is not None: + def _sample_progress(current, total, desc, offset=b_idx): + progress_callback(offset * total + current, bsz * total, desc) + sample_progress = _sample_progress + decoded = self._tiled_decode_inner( + single, + chunk_size, + overlap, + offload_wav_to_cpu, + progress_callback=sample_progress, + ) per_sample_results.append(decoded.cpu() if decoded.device.type != "cpu" else decoded) self._empty_cache() result = torch.cat(per_sample_results, dim=0) @@ -46,6 +64,8 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): decoder_output = self.vae.decode(latents) result = decoder_output.sample del decoder_output + if progress_callback is not None: + progress_callback(1, 1, "Decoding audio chunks...") return result except torch.cuda.OutOfMemoryError: logger.warning("[tiled_decode] OOM on direct decode, falling back to CPU VAE decode") @@ -60,7 +80,15 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): if offload_wav_to_cpu: try: - return self._tiled_decode_offload_cpu(latents, bsz, latent_frames, stride, overlap, num_steps) + return self._tiled_decode_offload_cpu( + latents, + bsz, + latent_frames, + stride, + overlap, + num_steps, + progress_callback=_monotonic_progress, + ) except torch.cuda.OutOfMemoryError: logger.warning( f"[tiled_decode] OOM during offload_cpu decode with chunk_size={chunk_size}, " @@ -70,7 +98,13 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): return self._decode_on_cpu(latents) try: - return self._tiled_decode_gpu(latents, stride, overlap, num_steps) + return self._tiled_decode_gpu( + latents, + stride, + overlap, + num_steps, + progress_callback=_monotonic_progress, + ) except torch.cuda.OutOfMemoryError: logger.warning( f"[tiled_decode] OOM during GPU decode with chunk_size={chunk_size}, " @@ -78,13 +112,21 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): ) self._empty_cache() try: - return self._tiled_decode_offload_cpu(latents, bsz, latent_frames, stride, overlap, num_steps) + return self._tiled_decode_offload_cpu( + latents, + bsz, + latent_frames, + stride, + overlap, + num_steps, + progress_callback=_monotonic_progress, + ) except torch.cuda.OutOfMemoryError: logger.warning("[tiled_decode] OOM even with offload path, falling back to full CPU VAE decode") self._empty_cache() return self._decode_on_cpu(latents) - def _tiled_decode_gpu(self, latents, stride, overlap, num_steps): + def _tiled_decode_gpu(self, latents, stride, overlap, num_steps, progress_callback=None): """Decode chunks and keep decoded audio tensors on GPU.""" decoded_audio_list = [] upsample_factor = None @@ -112,10 +154,21 @@ def _tiled_decode_gpu(self, latents, stride, overlap, num_steps): end_idx = audio_len - trim_end if trim_end > 0 else audio_len audio_core = audio_chunk[:, :, trim_start:end_idx] decoded_audio_list.append(audio_core) + if progress_callback is not None: + progress_callback(i + 1, num_steps, "Decoding audio chunks...") return torch.cat(decoded_audio_list, dim=-1) - def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap, num_steps): + def _tiled_decode_offload_cpu( + self, + latents, + bsz, + latent_frames, + stride, + overlap, + num_steps, + progress_callback=None, + ): """Decode chunks on GPU and copy trimmed audio cores to a CPU buffer.""" first_core_end = min(stride, latent_frames) first_win_end = min(latent_frames, first_core_end + overlap) @@ -138,6 +191,8 @@ def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap first_audio_core = first_audio_chunk[:, :, :first_end_idx] audio_write_pos = first_audio_core.shape[-1] final_audio[:, :, :audio_write_pos] = first_audio_core.cpu() + if progress_callback is not None: + progress_callback(1, num_steps, "Decoding audio chunks...") del first_audio_chunk, first_audio_core, first_latent_chunk @@ -164,6 +219,8 @@ def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap core_len = audio_core.shape[-1] final_audio[:, :, audio_write_pos : audio_write_pos + core_len] = audio_core.cpu() audio_write_pos += core_len + if progress_callback is not None: + progress_callback(i + 1, num_steps, "Decoding audio chunks...") del audio_chunk, audio_core, latent_chunk diff --git a/acestep/core/generation/handler/vae_decode_chunks_test.py b/acestep/core/generation/handler/vae_decode_chunks_test.py index d44bc9da3..566a0f41c 100644 --- a/acestep/core/generation/handler/vae_decode_chunks_test.py +++ b/acestep/core/generation/handler/vae_decode_chunks_test.py @@ -74,6 +74,32 @@ def _oom(*args, **kwargs): self.assertTrue(torch.equal(out, torch.full((1, 2, 7), 9.0))) self.assertEqual(host.decode_on_cpu_calls, 1) + def test_gpu_to_offload_fallback_keeps_progress_monotonic(self): + """Offload retry progress should not rewind after GPU OOM progress has started.""" + host = _ChunksHost() + updates = [] + + def _gpu_oom(*args, **kwargs): + kwargs["progress_callback"](3, 4, "Decoding audio chunks...") + raise torch.cuda.OutOfMemoryError("gpu oom") + + def _offload_ok(*args, **kwargs): + kwargs["progress_callback"](1, 4, "Decoding audio chunks...") + kwargs["progress_callback"](4, 4, "Decoding audio chunks...") + return torch.ones(1, 2, 5) + + host._tiled_decode_gpu = _gpu_oom + host._tiled_decode_offload_cpu = _offload_ok + host._tiled_decode_inner( + torch.zeros(1, 4, 20), + chunk_size=8, + overlap=2, + offload_wav_to_cpu=False, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual([current for current, _, _ in updates], [3, 3, 4]) + if __name__ == "__main__": unittest.main() diff --git a/acestep/core/generation/handler/vae_decode_mixin_test.py b/acestep/core/generation/handler/vae_decode_mixin_test.py index e48030b42..13f91d7a9 100644 --- a/acestep/core/generation/handler/vae_decode_mixin_test.py +++ b/acestep/core/generation/handler/vae_decode_mixin_test.py @@ -46,8 +46,9 @@ def test_tiled_decode_falls_back_when_mlx_decode_fails(self): host.use_mlx_vae = True host.mlx_vae = object() - def _mlx_raise(_latents): + def _mlx_raise(_latents, progress_callback=None): """Raise MLX failure to exercise fallback path.""" + _ = progress_callback raise ValueError("mlx failed") host._mlx_vae_decode = _mlx_raise diff --git a/acestep/core/generation/handler/vae_decode_test_helpers.py b/acestep/core/generation/handler/vae_decode_test_helpers.py index 4eccd3167..51ed26664 100644 --- a/acestep/core/generation/handler/vae_decode_test_helpers.py +++ b/acestep/core/generation/handler/vae_decode_test_helpers.py @@ -66,12 +66,13 @@ def _should_offload_wav_to_cpu(self): """Return deterministic offload policy used by default path.""" return False - def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): + def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu, progress_callback=None): """Record routed args and return sentinel audio tensor.""" _ = latents self.recorded["chunk_size"] = chunk_size self.recorded["overlap"] = overlap self.recorded["offload"] = offload_wav_to_cpu + self.recorded["progress_callback"] = progress_callback return torch.ones(1, 2, 8) def _tiled_decode_cpu_fallback(self, latents): @@ -79,9 +80,10 @@ def _tiled_decode_cpu_fallback(self, latents): _ = latents return torch.full((1, 2, 8), 2.0) - def _mlx_vae_decode(self, latents): + def _mlx_vae_decode(self, latents, progress_callback=None): """Return MLX sentinel tensor for MLX path assertions.""" _ = latents + self.recorded["progress_callback"] = progress_callback return torch.full((1, 2, 6), 3.0) diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index 42d7b8ad8..f92c5deda 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -9,7 +9,7 @@ import time import random import warnings -from typing import Optional, Dict, Any, Tuple, List, Union +from typing import Callable, Optional, Dict, Any, Tuple, List, Union from contextlib import contextmanager import yaml @@ -874,6 +874,7 @@ def _run_vllm( lyrics: str = "", cot_text: str = "", seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Union[str, List[str]]: """ Unified vllm generation function supporting both single and batch modes. @@ -959,6 +960,8 @@ def _run_vllm( output_texts.append(str(output)) # Return single string for single mode, list for batch mode + if progress_callback is not None: + progress_callback(1, 1, "vLLM generation") return output_texts[0] if not is_batch else output_texts def _run_pt_single( @@ -982,6 +985,7 @@ def _run_pt_single( caption: str, lyrics: str, cot_text: str, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> str: """Internal helper function for single-item PyTorch generation.""" inputs = self.llm_tokenizer( @@ -1060,6 +1064,7 @@ def _run_pt_single( pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, streamer=None, constrained_processor=constrained_processor, + progress_callback=progress_callback, ) # Extract only the conditional output (first in batch) @@ -1077,6 +1082,7 @@ def _run_pt_single( pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, streamer=None, constrained_processor=constrained_processor, + progress_callback=progress_callback, ) else: # Generate without CFG using native generate() parameters @@ -1143,6 +1149,7 @@ def _run_pt( lyrics: str = "", cot_text: str = "", seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Union[str, List[str]]: """ Unified PyTorch generation function supporting both single and batch modes. @@ -1158,6 +1165,7 @@ def _run_pt( # loads to GPU once and offloads once, instead of per-item. if is_batch: output_texts = [] + batch_total = len(formatted_prompt_list) with self._load_model_context(): for i, formatted_prompt in enumerate(formatted_prompt_list): @@ -1169,6 +1177,15 @@ def _run_pt( elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): torch.mps.manual_seed(seeds[i]) + item_progress_callback = None + item_progress_state = {"current": 0, "total": 0, "desc": ""} + if progress_callback is not None: + def _item_progress(current, total, desc, index=i): + item_progress_state.update(current=current, total=total, desc=desc) + progress_callback(index * total + current, batch_total * total, desc) + + item_progress_callback = _item_progress + # Generate using single-item method with batch-mode defaults output_text = self._run_pt_single( formatted_prompt=formatted_prompt, @@ -1190,9 +1207,20 @@ def _run_pt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=item_progress_callback, ) output_texts.append(output_text) + if ( + progress_callback is not None + and item_progress_state["total"] > 0 + and item_progress_state["current"] < item_progress_state["total"] + ): + progress_callback( + (i + 1) * item_progress_state["total"], + batch_total * item_progress_state["total"], + item_progress_state["desc"], + ) return output_texts @@ -1219,6 +1247,7 @@ def _run_pt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool: @@ -1340,6 +1369,20 @@ def progress(*args, **kwargs): else: seeds = seeds[:actual_batch_size] + def _make_phase_progress_callback( + start: float, + end: float, + default_desc: str, + ) -> Callable[[int, int, str], None]: + def _callback(current: int, total: int, desc: str) -> None: + if total <= 0: + return + frac = min(1.0, max(0.0, current / total)) + mapped = start + (end - start) * frac + progress(mapped, desc or default_desc) + + return _callback + # ========== PHASE 1: CoT Generation ========== # Skip CoT if all metadata are user-provided OR caption is already formatted progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...") @@ -1380,6 +1423,11 @@ def progress(*args, **kwargs): use_constrained_decoding=use_constrained_decoding, constrained_decoding_debug=constrained_decoding_debug, stop_at_reasoning=True, # Always stop at in Phase 1 + progress_callback=_make_phase_progress_callback( + 0.1, + 0.3, + "LLM metadata generation", + ), ) phase1_time = time.time() - phase1_start @@ -1399,6 +1447,7 @@ def progress(*args, **kwargs): logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") else: logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") + progress(0.3, "LLM metadata generation complete") else: # Use user-provided metadata if is_batch: @@ -1468,7 +1517,12 @@ def progress(*args, **kwargs): formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text) logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}") - progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...") + progress(0.31, f"Phase 2: Generating audio codes for {actual_batch_size} items...") + phase2_progress_callback = _make_phase_progress_callback( + 0.31, + 0.5, + "LLM audio code generation", + ) if is_batch: # Batch mode: generate codes for all items formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size @@ -1492,6 +1546,7 @@ def progress(*args, **kwargs): lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=phase2_progress_callback, ) elif self.llm_backend == "mlx": codes_outputs = self._run_mlx( @@ -1510,6 +1565,7 @@ def progress(*args, **kwargs): lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=phase2_progress_callback, ) else: # pt backend codes_outputs = self._run_pt( @@ -1528,6 +1584,7 @@ def progress(*args, **kwargs): lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=phase2_progress_callback, ) except Exception as e: error_msg = f"Error in batch codes generation: {str(e)}" @@ -1557,6 +1614,7 @@ def progress(*args, **kwargs): metadata_list.append(metadata.copy()) # Same metadata for all phase2_time = time.time() - phase2_start + progress(0.5, "LLM audio code generation complete") # Log results codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list] @@ -1602,6 +1660,7 @@ def progress(*args, **kwargs): use_constrained_decoding=use_constrained_decoding, constrained_decoding_debug=constrained_decoding_debug, stop_at_reasoning=False, # Generate codes until EOS + progress_callback=phase2_progress_callback, ) if not codes_output_text: @@ -1621,6 +1680,7 @@ def progress(*args, **kwargs): } phase2_time = time.time() - phase2_start + progress(0.5, "LLM audio code generation complete") # Parse audio codes from output (metadata should be same as Phase 1) _, audio_codes = self.parse_lm_output(codes_output_text) @@ -2321,6 +2381,7 @@ def generate_from_formatted_prompt( use_constrained_decoding: bool = True, constrained_decoding_debug: bool = False, stop_at_reasoning: bool = False, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Tuple[str, str]: """ Generate raw LM text output from a pre-built formatted prompt. @@ -2394,6 +2455,7 @@ def generate_from_formatted_prompt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) self._clear_accelerator_cache() return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}" @@ -2420,6 +2482,7 @@ def generate_from_formatted_prompt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) self._clear_accelerator_cache() return output_text, f"✅ Generated successfully (mlx) | length={len(output_text)}" @@ -2445,6 +2508,7 @@ def generate_from_formatted_prompt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) self._clear_accelerator_cache() return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}" @@ -2486,6 +2550,7 @@ def _generate_with_constrained_decoding( pad_token_id: int, streamer: Optional[BaseStreamer], constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> torch.Tensor: """ Custom generation loop with constrained decoding support (non-CFG). @@ -2558,6 +2623,8 @@ def _generate_with_constrained_decoding( # Update streamer if streamer is not None: streamer.put(next_tokens_unsqueezed) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, "LLM Constrained Decoding") if should_stop: break @@ -2583,6 +2650,7 @@ def _generate_with_cfg_custom( pad_token_id: int, streamer: Optional[BaseStreamer], constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> torch.Tensor: """ Custom CFG generation loop that: @@ -2729,6 +2797,8 @@ def _generate_with_cfg_custom( # Update streamer if streamer is not None: streamer.put(next_tokens_unsqueezed) # Stream conditional tokens + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, "LLM CFG Generation") # Stop generation only when ALL sequences have finished if seq_finished.all(): @@ -3005,6 +3075,7 @@ def _run_mlx_batch_native( lyrics: str, cot_text: str, seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> List[str]: """ Optimized native MLX batch generation for codes phase. @@ -3322,6 +3393,8 @@ def _clone_cache_list(cache_list): item_last_logits[i] = logits_out[:, -1:, :] pbar.update(1) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, f"MLX {cfg_label}Batch Gen (native, n={batch_size})") # Periodic memory cleanup if step % 256 == 0 and step > 0: @@ -3370,6 +3443,7 @@ def _run_mlx_single_native( caption: str, lyrics: str, cot_text: str, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> str: """ Optimized native MLX generation using mlx-lm infrastructure. @@ -3628,6 +3702,8 @@ def _run_mlx_single_native( new_tokens.append(token_id) all_token_ids.append(token_id) pbar.update(1) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, tqdm_desc) # Update constrained processor FSM state if constrained_processor is not None: @@ -3693,6 +3769,7 @@ def _run_mlx_single( caption: str, lyrics: str, cot_text: str, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> str: """ MLX-accelerated single-item generation. @@ -3723,6 +3800,7 @@ def _run_mlx_single( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) except Exception as _native_err: logger.warning( @@ -3872,6 +3950,8 @@ def _run_mlx_single( new_tokens.append(token_id) all_token_ids.append(token_id) pbar.update(1) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, tqdm_desc) # Update constrained processor state if constrained_processor is not None: @@ -3934,6 +4014,7 @@ def _run_mlx( lyrics: str = "", cot_text: str = "", seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Union[str, List[str]]: """ Unified MLX generation function supporting both single and batch modes. @@ -3983,6 +4064,7 @@ def _run_mlx( lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=progress_callback, ) except Exception as e: logger.warning( @@ -3998,6 +4080,15 @@ def _run_mlx( if seeds and i < len(seeds): mx.random.seed(seeds[i]) + item_progress_callback = None + item_progress_state = {"current": 0, "total": 0, "desc": ""} + if progress_callback is not None: + def _item_progress(current, total, desc, index=i, batch_total=batch_size): + item_progress_state.update(current=current, total=total, desc=desc) + progress_callback(index * total + current, batch_total * total, desc) + + item_progress_callback = _item_progress + output_text = self._run_mlx_single( formatted_prompt=formatted_prompt, temperature=temperature, @@ -4018,8 +4109,19 @@ def _run_mlx( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=item_progress_callback, ) output_texts.append(output_text) + if ( + progress_callback is not None + and item_progress_state["total"] > 0 + and item_progress_state["current"] < item_progress_state["total"] + ): + progress_callback( + (i + 1) * item_progress_state["total"], + batch_size * item_progress_state["total"], + item_progress_state["desc"], + ) return output_texts # Single mode @@ -4044,6 +4146,7 @@ def _run_mlx( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) # ========================================================================= diff --git a/acestep/llm_inference_batch_progress_test.py b/acestep/llm_inference_batch_progress_test.py new file mode 100644 index 000000000..680ab5bb7 --- /dev/null +++ b/acestep/llm_inference_batch_progress_test.py @@ -0,0 +1,105 @@ +"""Regression tests for sequential batch progress closure.""" + +import sys +import types +import unittest +from contextlib import nullcontext +from unittest.mock import patch + +try: + from acestep.llm_inference import LLMHandler + + _IMPORT_ERROR = None +except ImportError as exc: # pragma: no cover - dependency guard + LLMHandler = None + _IMPORT_ERROR = exc + + +def _make_handler() -> "LLMHandler": + """Return a minimal handler with model-loading context stubbed out.""" + handler = LLMHandler() + handler._load_model_context = lambda: nullcontext() + return handler + + +@unittest.skipIf(LLMHandler is None, f"llm_inference import unavailable: {_IMPORT_ERROR}") +class SequentialBatchProgressTests(unittest.TestCase): + """Sequential batch generation should close each item before the next begins.""" + + def test_run_pt_closes_early_item_progress_before_next_item(self): + handler = _make_handler() + updates = [] + + def fake_run_pt_single(*args, **kwargs): + progress_callback = kwargs["progress_callback"] + if len(updates) == 0: + progress_callback(2, 10, "LLM CFG Generation") + return "first" + progress_callback(1, 10, "LLM CFG Generation") + return "second" + + with patch.object(handler, "_run_pt_single", side_effect=fake_run_pt_single): + out = handler._run_pt( + formatted_prompts=["p1", "p2"], + temperature=0.6, + cfg_scale=1.0, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual(out, ["first", "second"]) + self.assertEqual( + updates, + [ + (2, 20, "LLM CFG Generation"), + (10, 20, "LLM CFG Generation"), + (11, 20, "LLM CFG Generation"), + (20, 20, "LLM CFG Generation"), + ], + ) + + def test_run_mlx_closes_early_item_progress_before_next_item(self): + handler = _make_handler() + updates = [] + fake_mlx = types.ModuleType("mlx") + fake_mx_core = types.ModuleType("mlx.core") + fake_mx_core.random = types.SimpleNamespace(seed=lambda *_: None) + + def fake_run_mlx_single(*args, **kwargs): + progress_callback = kwargs["progress_callback"] + if len(updates) == 0: + progress_callback(3, 12, "LLM CFG Generation") + return "first" + progress_callback(2, 12, "LLM CFG Generation") + return "second" + + with patch.dict(sys.modules, {"mlx": fake_mlx, "mlx.core": fake_mx_core}): + with patch.object(handler, "_run_mlx_single", side_effect=fake_run_mlx_single): + out = handler._run_mlx( + formatted_prompts=["p1", "p2"], + temperature=0.6, + cfg_scale=1.0, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual(out, ["first", "second"]) + self.assertEqual( + updates, + [ + (3, 24, "LLM CFG Generation"), + (12, 24, "LLM CFG Generation"), + (14, 24, "LLM CFG Generation"), + (24, 24, "LLM CFG Generation"), + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/llm_inference_cfg_fixes_test.py b/acestep/llm_inference_cfg_fixes_test.py index 60893a814..598fcb4f7 100644 --- a/acestep/llm_inference_cfg_fixes_test.py +++ b/acestep/llm_inference_cfg_fixes_test.py @@ -77,9 +77,10 @@ class TestCotCfgScaleFixed(unittest.TestCase): def test_cot_phase_uses_cfg_scale_1(self): """generate_from_formatted_prompt called during CoT must receive cfg_scale=1.0.""" - handler = LLMHandler() + handler = _make_handler() handler.llm_initialized = True handler.llm_backend = "pt" + handler.llm = MagicMock() captured_cfg = {} @@ -88,34 +89,23 @@ def fake_run_pt(formatted_prompts, temperature, cfg_scale, **kwargs): return "metadata" with patch.object(handler, "_run_pt", side_effect=fake_run_pt): - with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"): - with patch.object( - handler, - "_format_metadata_as_cot", - return_value="", - ): - with patch.object( - handler, - "build_formatted_prompt_with_cot", - return_value="PROMPT_WITH_COT", - ): - # Simulate Phase 1 CoT call via generate_from_formatted_prompt - # by calling the internal helper directly - handler.generate_from_formatted_prompt( - formatted_prompt="PROMPT", - cfg={ - "temperature": 0.6, - "cfg_scale": 1.0, # already 1.0 for CoT - "negative_prompt": "NO USER INPUT", - "top_k": None, - "top_p": None, - "repetition_penalty": 1.0, - "target_duration": None, - "generation_phase": "cot", - "caption": "test", - "lyrics": "test", - }, - ) + # Simulate Phase 1 CoT call via generate_from_formatted_prompt by + # invoking the helper directly with a current-valid handler state. + handler.generate_from_formatted_prompt( + formatted_prompt="PROMPT", + cfg={ + "temperature": 0.6, + "cfg_scale": 1.0, # already 1.0 for CoT + "negative_prompt": "NO USER INPUT", + "top_k": None, + "top_p": None, + "repetition_penalty": 1.0, + "target_duration": None, + "generation_phase": "cot", + "caption": "test", + "lyrics": "test", + }, + ) # cfg_scale must be 1.0 for CoT – captured during _run_pt call self.assertEqual(captured_cfg.get("cfg_scale"), 1.0) @@ -136,7 +126,7 @@ def capturing_gen(formatted_prompt, cfg=None, **kwargs): with patch.object(handler, "generate_from_formatted_prompt", side_effect=capturing_gen): with patch.object(handler, "build_formatted_prompt", return_value="P"): - with patch.object(handler, "_parse_metadata_from_cot", return_value={}): + with patch.object(handler, "parse_lm_output", return_value=({}, "")): with patch.object(handler, "_format_metadata_as_cot", return_value=""): with patch.object( handler, "build_formatted_prompt_with_cot", return_value="P2" @@ -471,5 +461,118 @@ def fake_call(input_ids, **kwargs): self.assertEqual(result[0, -1].item(), eos_id) +@unittest.skipIf(LLMHandler is None, f"llm_inference import unavailable: {_IMPORT_ERROR}") +class TestLmProgressCallbacks(unittest.TestCase): + """Progress callback coverage for LM token generation and phase mapping.""" + + def test_generate_with_cfg_custom_emits_token_progress(self): + """CFG token loop should emit monotonic per-token progress callbacks.""" + handler = _make_handler() + eos_id = 1 + vocab_size = 10 + handler.llm_tokenizer.eos_token_id = eos_id + + model = MagicMock() + logits = torch.full((2, 1, vocab_size), -100.0) + logits[:, :, 2] = 100.0 + outputs = SimpleNamespace(logits=logits, past_key_values=None) + model.return_value = outputs + model.generation_config = MagicMock() + model.generation_config.use_cache = False + handler.llm = model + + updates = [] + handler._generate_with_cfg_custom( + batch_input_ids=torch.zeros((2, 3), dtype=torch.long), + batch_attention_mask=None, + max_new_tokens=2, + temperature=1.0, + cfg_scale=1.0, + top_k=None, + top_p=None, + repetition_penalty=1.0, + pad_token_id=eos_id, + streamer=None, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual(updates, [(1, 2, "LLM CFG Generation"), (2, 2, "LLM CFG Generation")]) + + def test_generate_with_stop_condition_maps_phase_progress_to_outer_callback(self): + """Phase callbacks should surface token progress onto the outer progress callback.""" + handler = LLMHandler() + handler.llm_initialized = True + handler.llm_backend = "pt" + + updates = [] + + def fake_generate_from_formatted_prompt(*args, **kwargs): + progress_callback = kwargs.get("progress_callback") + if progress_callback is not None: + progress_callback(5, 10, "LLM CFG Generation") + progress_callback(10, 10, "LLM CFG Generation") + return "bpm: 120\n", "ok" + + with patch.object(handler, "generate_from_formatted_prompt", side_effect=fake_generate_from_formatted_prompt): + with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"): + with patch.object(handler, "parse_lm_output", return_value=({"bpm": 120}, "")): + handler.generate_with_stop_condition( + caption="test caption", + lyrics="test lyrics", + cfg_scale=2.0, + temperature=0.6, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + infer_type="dit", + progress=lambda value, desc=None: updates.append((value, desc)), + ) + + llm_cfg_updates = [value for value, desc in updates if desc == "LLM CFG Generation"] + self.assertTrue(llm_cfg_updates) + self.assertGreaterEqual(min(llm_cfg_updates), 0.1) + self.assertLessEqual(max(llm_cfg_updates), 0.3) + self.assertIn((0.3, "LLM metadata generation complete"), updates) + + def test_generate_with_stop_condition_closes_audio_code_phase_band(self): + """Phase 2 should emit an explicit terminal progress update on early success.""" + handler = LLMHandler() + handler.llm_initialized = True + handler.llm_backend = "pt" + + updates = [] + call_count = {"count": 0} + + def fake_generate_from_formatted_prompt(*args, **kwargs): + call_count["count"] += 1 + progress_callback = kwargs.get("progress_callback") + if progress_callback is not None: + progress_callback(5, 10, "LLM CFG Generation") + if call_count["count"] == 1: + return "bpm: 120\n", "ok" + return "<|audio_code_1|><|audio_code_2|>", "ok" + + with patch.object(handler, "generate_from_formatted_prompt", side_effect=fake_generate_from_formatted_prompt): + with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"): + with patch.object(handler, "parse_lm_output", side_effect=[({"bpm": 120}, ""), ({}, "<|audio_code_1|><|audio_code_2|>")]): + with patch.object(handler, "_format_metadata_as_cot", return_value="cot"): + with patch.object(handler, "build_formatted_prompt_with_cot", return_value="PROMPT2"): + handler.generate_with_stop_condition( + caption="test caption", + lyrics="test lyrics", + cfg_scale=2.0, + temperature=0.6, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + infer_type="llm_dit", + progress=lambda value, desc=None: updates.append((value, desc)), + ) + + self.assertIn((0.5, "LLM audio code generation complete"), updates) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/models/base/modeling_acestep_v15_base.py b/acestep/models/base/modeling_acestep_v15_base.py index 47d515493..3ef73321a 100644 --- a/acestep/models/base/modeling_acestep_v15_base.py +++ b/acestep/models/base/modeling_acestep_v15_base.py @@ -1854,6 +1854,7 @@ def generate_audio( precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: Optional[torch.FloatTensor] = None, use_progress_bar: bool = True, + progress_callback: Optional[Callable[[int, int, str], None]] = None, use_adg: bool = False, shift: float = 1.0, cover_noise_strength: float = 0.0, @@ -2129,6 +2130,8 @@ def generate_audio( xt = _repaint_step_injection( xt, clean_src_latents, repaint_mask, t_after_step, noise, ) + if progress_callback is not None: + progress_callback(step_idx + 1, infer_steps, "DiT diffusion...") x_gen = xt if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0: diff --git a/acestep/models/mlx/dit_generate.py b/acestep/models/mlx/dit_generate.py index 46253e961..def0caf24 100644 --- a/acestep/models/mlx/dit_generate.py +++ b/acestep/models/mlx/dit_generate.py @@ -19,7 +19,7 @@ import logging import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from tqdm import tqdm @@ -150,6 +150,7 @@ def mlx_generate_diffusion( encoder_hidden_states_non_cover_np: Optional[np.ndarray] = None, context_latents_non_cover_np: Optional[np.ndarray] = None, compile_model: bool = False, + progress_callback: Optional[Callable[[int, int, str], None]] = None, disable_tqdm: bool = False, sampler_mode: str = "euler", velocity_norm_threshold: float = 0.0, @@ -179,6 +180,8 @@ def mlx_generate_diffusion( encoder_hidden_states_non_cover_np: optional [B, enc_L, D] for non-cover. context_latents_non_cover_np: optional [B, T, C] for non-cover. compile_model: If True, compile the decoder step with ``mx.compile``. + progress_callback: Optional diffusion-step callback receiving + ``(current_step, total_steps, desc)``. disable_tqdm: If True, suppress the diffusion progress bar. sampler_mode: Sampler algorithm — ``"euler"`` (first-order, default) or ``"heun"`` (second-order predictor-corrector for cleaner output). @@ -401,6 +404,8 @@ def _apply_stabilisation(vt_guided, xt_current, prev_velocity): xt = xt - vt * dt_arr mx.eval(xt) + if progress_callback is not None: + progress_callback(step_idx + 1, num_steps, "DiT diffusion...") prev_vt = vt # store for EMA diff --git a/acestep/models/mlx/dit_generate_test.py b/acestep/models/mlx/dit_generate_test.py new file mode 100644 index 000000000..fe01789d8 --- /dev/null +++ b/acestep/models/mlx/dit_generate_test.py @@ -0,0 +1,90 @@ +import sys +import types +import unittest +from unittest.mock import patch + +import numpy as np + +from acestep.models.mlx.dit_generate import mlx_generate_diffusion + + +class _FakeRandom: + def normal(self, shape, key=None): + _ = key + return np.zeros(shape, dtype=np.float32) + + def key(self, seed): + return seed + + +class _FakeDecoder: + def __call__( + self, + hidden_states, + timestep, + timestep_r, + encoder_hidden_states, + context_latents, + cache=None, + use_cache=False, + ): + _ = timestep + _ = timestep_r + _ = encoder_hidden_states + _ = context_latents + _ = use_cache + return np.zeros_like(hidden_states), cache + + +class MlxGenerateDiffusionProgressTests(unittest.TestCase): + def test_progress_callback_fires_once_per_diffusion_step(self): + fake_mx_core = types.ModuleType("mlx.core") + fake_mx_core.array = lambda value: np.array(value, dtype=np.float32) + fake_mx_core.concatenate = lambda values, axis=0: np.concatenate(values, axis=axis) + fake_mx_core.broadcast_to = lambda value, shape: np.broadcast_to(value, shape) + fake_mx_core.full = lambda shape, fill_value: np.full(shape, fill_value, dtype=np.float32) + fake_mx_core.eval = lambda value: value + fake_mx_core.random = _FakeRandom() + + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.core = fake_mx_core + + fake_dit_model = types.ModuleType("acestep.models.mlx.dit_model") + + class _FakeCache: + pass + + fake_dit_model.MLXCrossAttentionCache = _FakeCache + + updates = [] + with patch.dict( + sys.modules, + { + "mlx": fake_mlx_pkg, + "mlx.core": fake_mx_core, + "acestep.models.mlx.dit_model": fake_dit_model, + }, + ): + result = mlx_generate_diffusion( + mlx_decoder=_FakeDecoder(), + encoder_hidden_states_np=np.zeros((1, 2, 3), dtype=np.float32), + context_latents_np=np.zeros((1, 2, 3), dtype=np.float32), + src_latents_shape=(1, 2, 3), + timesteps=[1.0, 0.75, 0.5], + progress_callback=lambda step, total, desc: updates.append((step, total, desc)), + disable_tqdm=True, + ) + + self.assertEqual( + updates, + [ + (1, 3, "DiT diffusion..."), + (2, 3, "DiT diffusion..."), + (3, 3, "DiT diffusion..."), + ], + ) + self.assertEqual(tuple(result["target_latents"].shape), (1, 2, 3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/models/sft/modeling_acestep_v15_base.py b/acestep/models/sft/modeling_acestep_v15_base.py index 3310a6792..280379d0f 100644 --- a/acestep/models/sft/modeling_acestep_v15_base.py +++ b/acestep/models/sft/modeling_acestep_v15_base.py @@ -1854,6 +1854,7 @@ def generate_audio( precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: Optional[torch.FloatTensor] = None, use_progress_bar: bool = True, + progress_callback: Optional[Callable[[int, int, str], None]] = None, use_adg: bool = False, shift: float = 1.0, timesteps: Optional[torch.Tensor] = None, @@ -2133,6 +2134,8 @@ def generate_audio( xt = _repaint_step_injection( xt, clean_src_latents, repaint_mask, t_after_step, noise, ) + if progress_callback is not None: + progress_callback(step_idx + 1, infer_steps, "DiT diffusion...") x_gen = xt if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0: diff --git a/acestep/models/turbo/modeling_acestep_v15_turbo.py b/acestep/models/turbo/modeling_acestep_v15_turbo.py index 5ec0bdbb6..510c09688 100644 --- a/acestep/models/turbo/modeling_acestep_v15_turbo.py +++ b/acestep/models/turbo/modeling_acestep_v15_turbo.py @@ -1844,6 +1844,8 @@ def generate_audio( non_cover_text_attention_mask: Optional[torch.FloatTensor] = None, precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: Optional[torch.FloatTensor] = None, + use_progress_bar: bool = True, + progress_callback: Optional[Callable[[int, int, str], None]] = None, shift: float = 3.0, timesteps: Optional[torch.Tensor] = None, cover_noise_strength: float = 0.0, @@ -2045,6 +2047,8 @@ def generate_audio( # On final step, directly compute x0 from noise if step_idx == num_steps - 1: xt = self.get_x0_from_noise(xt, vt, t_curr_tensor) + if progress_callback is not None: + progress_callback(step_idx + 1, num_steps, "DiT diffusion...") prev_vt = vt break @@ -2104,6 +2108,8 @@ def generate_audio( xt = _repaint_step_injection( xt, clean_src_latents, repaint_mask, t_after_step, noise, ) + if progress_callback is not None: + progress_callback(step_idx + 1, num_steps, "DiT diffusion...") x_gen = xt if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0: @@ -2272,4 +2278,4 @@ def test_forward(model, seed=42): # model = model.float() model = model.to("cuda") model = model.bfloat16() - test_forward(model) \ No newline at end of file + test_forward(model) diff --git a/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py b/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py new file mode 100644 index 000000000..28214fbb4 --- /dev/null +++ b/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py @@ -0,0 +1,90 @@ +import unittest + +import torch + +from acestep.models.turbo.modeling_acestep_v15_turbo import AceStepConditionGenerationModel + + +class _FakeDecoder: + def __call__( + self, + hidden_states, + timestep, + timestep_r, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + context_latents, + use_cache, + past_key_values, + ): + _ = timestep + _ = timestep_r + _ = attention_mask + _ = encoder_hidden_states + _ = encoder_attention_mask + _ = context_latents + _ = use_cache + return torch.zeros_like(hidden_states), past_key_values + + +class _FakeTurboHost: + def __init__(self): + self.decoder = _FakeDecoder() + + def prepare_condition(self, **kwargs): + src_latents = kwargs["src_latents"] + attention_mask = kwargs["attention_mask"] + return src_latents, attention_mask, src_latents + + def prepare_noise(self, context_latents, seed): + _ = seed + return torch.zeros_like(context_latents) + + def get_x0_from_noise(self, zt, vt, t): + return zt - vt * t.unsqueeze(-1).unsqueeze(-1) + + def renoise(self, x, t, noise=None): + _ = t + _ = noise + return x + + +class AceStepTurboProgressTests(unittest.TestCase): + def test_progress_callback_fires_once_per_step_including_final_step(self): + host = _FakeTurboHost() + updates = [] + base = torch.zeros((1, 2, 2), dtype=torch.float32) + mask = torch.ones((1, 2), dtype=torch.float32) + cover_mask = torch.zeros((1,), dtype=torch.float32) + + AceStepConditionGenerationModel.generate_audio( + host, + text_hidden_states=base, + text_attention_mask=mask, + lyric_hidden_states=base, + lyric_attention_mask=mask, + refer_audio_acoustic_hidden_states_packed=base, + refer_audio_order_mask=mask, + src_latents=base, + chunk_masks=base, + is_covers=cover_mask, + silence_latent=base, + attention_mask=mask, + seed=0, + infer_method="ode", + timesteps=[1.0, 0.5], + progress_callback=lambda step, total, desc: updates.append((step, total, desc)), + ) + + self.assertEqual( + updates, + [ + (1, 2, "DiT diffusion..."), + (2, 2, "DiT diffusion..."), + ], + ) + + +if __name__ == "__main__": + unittest.main()