From f0744812eeb7cb4787828cc73ef1139643113080 Mon Sep 17 00:00:00 2001 From: 1larity Date: Sat, 28 Mar 2026 09:50:18 +0000 Subject: [PATCH 1/6] feat(gradio): surface live generation progress across lm diffusion and decode --- acestep/core/generation/handler/diffusion.py | 4 + .../handler/generate_music_decode.py | 24 +++- .../handler/generate_music_decode_test.py | 13 +- .../handler/generate_music_execute.py | 87 +++++++++++- .../handler/generate_music_execute_test.py | 104 ++++++++++++++- .../handler/mlx_vae_decode_native.py | 24 +++- .../generation/handler/service_generate.py | 6 +- .../handler/service_generate_execute.py | 7 +- .../handler/service_generate_execute_test.py | 41 ++++++ .../handler/service_generate_test.py | 3 + acestep/core/generation/handler/vae_decode.py | 13 +- .../generation/handler/vae_decode_chunks.py | 64 ++++++++- .../handler/vae_decode_mixin_test.py | 3 +- .../handler/vae_decode_test_helpers.py | 6 +- acestep/llm_inference.py | 79 ++++++++++- acestep/llm_inference_cfg_fixes_test.py | 124 +++++++++++++----- .../models/base/modeling_acestep_v15_base.py | 3 + acestep/models/mlx/dit_generate.py | 7 +- .../models/sft/modeling_acestep_v15_base.py | 3 + .../turbo/modeling_acestep_v15_turbo.py | 8 +- 20 files changed, 557 insertions(+), 66 deletions(-) diff --git a/acestep/core/generation/handler/diffusion.py b/acestep/core/generation/handler/diffusion.py index 66cf067f1..cf7b14e8e 100644 --- a/acestep/core/generation/handler/diffusion.py +++ b/acestep/core/generation/handler/diffusion.py @@ -34,6 +34,7 @@ def _mlx_run_diffusion( encoder_hidden_states_non_cover=None, encoder_attention_mask_non_cover=None, context_latents_non_cover=None, + progress_callback=None, disable_tqdm: bool = False, ) -> Dict[str, Any]: """Run the MLX diffusion loop and return generated latents. @@ -56,6 +57,8 @@ def _mlx_run_diffusion( encoder_hidden_states_non_cover: Optional non-cover conditioning tensor. encoder_attention_mask_non_cover: Unused; accepted for API compatibility. context_latents_non_cover: Optional non-cover context latent tensor. + progress_callback: Optional diffusion-step callback receiving + ``(current_step, total_steps, desc)``. disable_tqdm: If True, suppress the diffusion progress bar. Returns: @@ -135,6 +138,7 @@ def _mlx_run_diffusion( encoder_hidden_states_non_cover_np=enc_nc_np, context_latents_non_cover_np=ctx_nc_np, compile_model=getattr(self, "mlx_dit_compiled", False), + progress_callback=progress_callback, disable_tqdm=disable_tqdm, ) diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py index 1ec9b0fb3..4b10b42dc 100644 --- a/acestep/core/generation/handler/generate_music_decode.py +++ b/acestep/core/generation/handler/generate_music_decode.py @@ -114,8 +114,19 @@ def _decode_generate_music_pred_latents( Returns: Tuple of decoded waveforms, CPU latents, and updated time-cost payload. """ + def _emit_decode_progress(value: float, desc: str) -> None: + if progress: + progress(value, desc=desc) + + def _decode_chunk_progress(current: int, total: int, desc: str) -> None: + if total <= 0: + return + frac = min(1.0, max(0.0, current / total)) + mapped = 0.8 + 0.18 * frac + _emit_decode_progress(mapped, desc) + if progress: - progress(0.8, desc="Decoding audio...") + _emit_decode_progress(0.8, "Decoding audio...") logger.info("[generate_music] Decoding latents with VAE...") start_time = time.time() with torch.inference_mode(): @@ -162,10 +173,16 @@ def _decode_generate_music_pred_latents( try: if use_tiled_decode: logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") - pred_wavs = self.tiled_decode(pred_latents_for_decode) + pred_wavs = self.tiled_decode( + pred_latents_for_decode, + progress_callback=_decode_chunk_progress, + ) elif using_mlx_vae: try: - pred_wavs = self._mlx_vae_decode(pred_latents_for_decode) + pred_wavs = self._mlx_vae_decode( + pred_latents_for_decode, + progress_callback=_decode_chunk_progress, + ) except Exception as exc: logger.warning( f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch" @@ -194,6 +211,7 @@ def _decode_generate_music_pred_latents( if torch.any(peak > 1.0): pred_wavs = pred_wavs / peak.clamp(min=1.0) self._empty_cache() + _emit_decode_progress(0.98, "Decoding audio chunks...") gc.collect() self._empty_cache() end_time = time.time() diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py index ece612ff1..9eef2d19d 100644 --- a/acestep/core/generation/handler/generate_music_decode_test.py +++ b/acestep/core/generation/handler/generate_music_decode_test.py @@ -115,14 +115,20 @@ def _max_memory_allocated(self): """Return deterministic max-memory value for debug logging.""" return 0.0 - def _mlx_vae_decode(self, latents): + def _mlx_vae_decode(self, latents, progress_callback=None): """Return deterministic decoded waveform for MLX decode branch.""" _ = latents + if progress_callback is not None: + progress_callback(1, 4, "Decoding audio chunks...") + progress_callback(4, 4, "Decoding audio chunks...") return torch.ones(1, 2, 8) - def tiled_decode(self, latents): + def tiled_decode(self, latents, progress_callback=None): """Return deterministic decoded waveform for tiled decode branch.""" _ = latents + if progress_callback is not None: + progress_callback(1, 4, "Decoding audio chunks...") + progress_callback(4, 4, "Decoding audio chunks...") return torch.ones(1, 2, 8) @@ -190,6 +196,8 @@ def _progress(value, desc=None): self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6) self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6) self.assertEqual(host.progress_calls[0][0], 0.8) + self.assertTrue(any(desc == "Decoding audio chunks..." for _, desc in host.progress_calls)) + self.assertAlmostEqual(host.progress_calls[-1][0], 0.98, places=6) def test_decode_pred_latents_restores_vae_device_on_decode_error(self): """It restores VAE device in the CPU-offload path even when decode raises.""" @@ -313,4 +321,3 @@ def __init__(self): if __name__ == "__main__": unittest.main() - diff --git a/acestep/core/generation/handler/generate_music_execute.py b/acestep/core/generation/handler/generate_music_execute.py index a9059699e..2b34ea651 100644 --- a/acestep/core/generation/handler/generate_music_execute.py +++ b/acestep/core/generation/handler/generate_music_execute.py @@ -1,8 +1,10 @@ """Execution helper for ``generate_music`` service invocation with progress tracking.""" import os +import queue import threading -from typing import Any, Dict, List, Optional, Sequence +import time +from typing import Any, Callable, Dict, List, Optional, Sequence from loguru import logger @@ -15,6 +17,28 @@ class GenerateMusicExecuteMixin: """Run service generation under diffusion progress estimation lifecycle.""" + @staticmethod + def _drain_runtime_progress_events( + progress_events: "queue.Queue[tuple[int, int, str]]", + emit_progress: Callable[[float, Optional[str]], float], + start: float, + end: float, + ) -> bool: + """Drain queued diffusion-step events and map them onto UI progress.""" + drained = False + while True: + try: + current, total, desc = progress_events.get_nowait() + except queue.Empty: + return drained + + if total <= 0: + continue + frac = min(1.0, max(0.0, current / total)) + mapped = start + (end - start) * frac + emit_progress(mapped, desc) + drained = True + def _run_generate_music_service_with_progress( self, progress: Any, @@ -44,9 +68,24 @@ def _run_generate_music_service_with_progress( """ infer_steps_for_progress = len(timesteps) if timesteps else inference_steps progress_desc = f"Generating music (batch size: {actual_batch_size})..." - progress(0.52, desc=progress_desc) stop_event = None progress_thread = None + progress_events: "queue.Queue[tuple[int, int, str]]" = queue.Queue() + progress_state = {"value": 0.0} + progress_lock = threading.Lock() + + def _emit_progress(value: float, desc: Optional[str] = None) -> float: + with progress_lock: + clamped = max(progress_state["value"], value) + progress_state["value"] = clamped + if progress is not None: + try: + progress(clamped, desc=desc) + except Exception as exc: + logger.debug("[generate_music] Ignoring progress callback error: {}", exc) + return clamped + + _emit_progress(0.52, progress_desc) # --- Timeout-wrapped service_generate --- # Run the actual CUDA work in a child thread so we can join() with a @@ -85,13 +124,14 @@ def _service_target(): chunk_mask_modes=service_inputs.get("chunk_mask_modes_batch"), repaint_crossfade_frames=repaint_crossfade_frames, repaint_injection_ratio=repaint_injection_ratio, + progress_callback=lambda current, total, desc: progress_events.put((current, total, desc)), ) except Exception as exc: _error["exc"] = exc try: stop_event, progress_thread = self._start_diffusion_progress_estimator( - progress=progress, + progress=_emit_progress, start=0.52, end=0.79, infer_steps=infer_steps_for_progress, @@ -106,7 +146,43 @@ def _service_target(): daemon=True, ) gen_thread.start() - gen_thread.join(timeout=_DEFAULT_GENERATION_TIMEOUT) + deadline = time.monotonic() + _DEFAULT_GENERATION_TIMEOUT + poll_interval = 0.1 + saw_runtime_progress = False + while gen_thread.is_alive(): + remaining = deadline - time.monotonic() + if remaining <= 0: + break + gen_thread.join(timeout=min(poll_interval, remaining)) + drained_runtime_progress = self._drain_runtime_progress_events( + progress_events=progress_events, + emit_progress=_emit_progress, + start=0.52, + end=0.79, + ) + if drained_runtime_progress and not saw_runtime_progress: + saw_runtime_progress = True + if stop_event is not None: + stop_event.set() + if progress_thread is not None: + progress_thread.join(timeout=1.0) + progress_thread = None + if not gen_thread.is_alive(): + break + + drained_runtime_progress = self._drain_runtime_progress_events( + progress_events=progress_events, + emit_progress=_emit_progress, + start=0.52, + end=0.79, + ) + if drained_runtime_progress and not saw_runtime_progress: + saw_runtime_progress = True + if stop_event is not None: + stop_event.set() + if progress_thread is not None: + progress_thread.join(timeout=1.0) + progress_thread = None if gen_thread.is_alive(): logger.error( @@ -122,6 +198,7 @@ def _service_target(): ) if "exc" in _error: raise _error["exc"] + _emit_progress(0.79, progress_desc) finally: if stop_event is not None: @@ -129,4 +206,4 @@ def _service_target(): if progress_thread is not None: progress_thread.join(timeout=1.0) - return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress} \ No newline at end of file + return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress} diff --git a/acestep/core/generation/handler/generate_music_execute_test.py b/acestep/core/generation/handler/generate_music_execute_test.py index aead22214..50afc2d2e 100644 --- a/acestep/core/generation/handler/generate_music_execute_test.py +++ b/acestep/core/generation/handler/generate_music_execute_test.py @@ -1,6 +1,7 @@ """Unit tests for ``generate_music`` execution helper mixin.""" import unittest +import time from acestep.core.generation.handler.generate_music_execute import GenerateMusicExecuteMixin @@ -12,12 +13,19 @@ def __init__(self): """Capture calls for assertions.""" self.started = False self.stopped = False + self.stop_calls = 0 self.service_calls = 0 + self.emit_runtime_progress = True + self.estimator_progress_values = [] + self.service_delay_sec = 0.0 def _start_diffusion_progress_estimator(self, **kwargs): """Return fake stop event/thread handles used by helper lifecycle.""" - _ = kwargs self.started = True + progress = kwargs["progress"] + desc = kwargs["desc"] + for value in self.estimator_progress_values: + progress(value, desc=desc) class _Stop: """Minimal stop-event stand-in used by the test host.""" @@ -29,6 +37,7 @@ def __init__(self, host): def set(self): """Mark progress lifecycle as stopped.""" self.host.stopped = True + self.host.stop_calls += 1 class _Thread: """Minimal thread stand-in exposing a ``join`` method.""" @@ -41,8 +50,13 @@ def join(self, timeout=None): def service_generate(self, **kwargs): """Record service invocation and return minimal output payload.""" - _ = kwargs + callback = kwargs.get("progress_callback") self.service_calls += 1 + if self.service_delay_sec > 0: + time.sleep(self.service_delay_sec) + if callback is not None and self.emit_runtime_progress: + callback(1, 4, "DiT diffusion...") + callback(4, 4, "DiT diffusion...") return {"target_latents": "ok"} @@ -52,8 +66,11 @@ class GenerateMusicExecuteMixinTests(unittest.TestCase): def test_run_service_with_progress_invokes_service_and_stops_estimator(self): """Helper should call service once and always stop progress estimator.""" host = _Host() + host.emit_runtime_progress = False + host.estimator_progress_values = [0.63, 0.79] + updates = [] out = host._run_generate_music_service_with_progress( - progress=lambda *args, **kwargs: None, + progress=lambda value, desc=None: updates.append((value, desc)), actual_batch_size=1, audio_duration=10.0, inference_steps=8, @@ -83,8 +100,89 @@ def test_run_service_with_progress_invokes_service_and_stops_estimator(self): ) self.assertTrue(host.started) self.assertTrue(host.stopped) + self.assertEqual(host.stop_calls, 1) self.assertEqual(host.service_calls, 1) self.assertEqual(out["outputs"]["target_latents"], "ok") + self.assertAlmostEqual(updates[-1][0], 0.79, places=6) + + def test_runtime_progress_events_are_forwarded_to_ui_progress(self): + """Step-level runtime progress should stop the estimator and reach phase completion.""" + host = _Host() + updates = [] + + host._run_generate_music_service_with_progress( + progress=lambda value, desc=None: updates.append((value, desc)), + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + self.assertTrue(any(desc == "DiT diffusion..." for _, desc in updates)) + self.assertEqual(host.stop_calls, 2) + self.assertAlmostEqual(updates[-1][0], 0.79, places=6) + + def test_runtime_progress_handoff_stays_monotonic_after_estimator_advances(self): + """Runtime callbacks should not drive the UI backwards after estimator progress.""" + host = _Host() + host.estimator_progress_values = [0.68] + updates = [] + + host._run_generate_music_service_with_progress( + progress=lambda value, desc=None: updates.append((value, desc)), + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + progress_values = [value for value, _ in updates] + self.assertEqual(progress_values, sorted(progress_values)) + self.assertGreaterEqual(progress_values[2], 0.68) if __name__ == "__main__": diff --git a/acestep/core/generation/handler/mlx_vae_decode_native.py b/acestep/core/generation/handler/mlx_vae_decode_native.py index 23b5c21ea..eb9cefc8f 100644 --- a/acestep/core/generation/handler/mlx_vae_decode_native.py +++ b/acestep/core/generation/handler/mlx_vae_decode_native.py @@ -28,7 +28,7 @@ def _resolve_mlx_decode_fn(self): raise RuntimeError("MLX VAE decode requested but mlx_vae is not initialized.") return self.mlx_vae.decode - def _mlx_vae_decode(self, latents_torch): + def _mlx_vae_decode(self, latents_torch, progress_callback=None): """Decode batched PyTorch latents using native MLX VAE decode. Args: @@ -52,7 +52,18 @@ def _mlx_vae_decode(self, latents_torch): decode_fn = self._resolve_mlx_decode_fn() audio_parts = [] for idx in range(batch_size): - decoded = self._mlx_decode_single(latents_mx[idx : idx + 1], decode_fn=decode_fn) + item_progress_callback = None + if progress_callback is not None: + def _item_progress(current, total, desc, offset=idx, batch_total=batch_size): + progress_callback(offset * total + current, batch_total * total, desc) + + item_progress_callback = _item_progress + + decoded = self._mlx_decode_single( + latents_mx[idx : idx + 1], + decode_fn=decode_fn, + progress_callback=item_progress_callback, + ) if decoded.dtype != mx.float32: decoded = decoded.astype(mx.float32) mx.eval(decoded) @@ -71,7 +82,7 @@ def _mlx_vae_decode(self, latents_torch): ) return torch.from_numpy(audio_ncl) - def _mlx_decode_single(self, z_nlc, decode_fn=None): + def _mlx_decode_single(self, z_nlc, decode_fn=None, progress_callback=None): """Decode a single MLX latent sample with optional tiling. Args: @@ -91,7 +102,10 @@ def _mlx_decode_single(self, z_nlc, decode_fn=None): mlx_overlap = 64 if latent_frames <= mlx_chunk: - return decode_fn(z_nlc) + out = decode_fn(z_nlc) + if progress_callback is not None: + progress_callback(1, 1, "Decoding audio chunks...") + return out stride = mlx_chunk - 2 * mlx_overlap num_steps = math.ceil(latent_frames / stride) @@ -115,5 +129,7 @@ def _mlx_decode_single(self, z_nlc, decode_fn=None): audio_len = audio_chunk.shape[1] end_idx = audio_len - trim_end if trim_end > 0 else audio_len decoded_parts.append(audio_chunk[:, trim_start:end_idx, :]) + if progress_callback is not None: + progress_callback(idx + 1, num_steps, "Decoding audio chunks...") return mx.concatenate(decoded_parts, axis=1) diff --git a/acestep/core/generation/handler/service_generate.py b/acestep/core/generation/handler/service_generate.py index 7dc3a045a..00fc188c0 100644 --- a/acestep/core/generation/handler/service_generate.py +++ b/acestep/core/generation/handler/service_generate.py @@ -5,7 +5,7 @@ and output attachment without owning model internals. """ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -47,6 +47,7 @@ def service_generate( chunk_mask_modes: Optional[List[str]] = None, repaint_crossfade_frames: int = 10, repaint_injection_ratio: float = 0.5, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Dict[str, Any]: """Generate music latents and metadata from text/audio conditioning inputs. @@ -76,6 +77,8 @@ def service_generate( timesteps: Optional explicit diffusion timestep sequence. repaint_crossfade_frames: Crossfade width (latent frames) at repaint boundaries for boundary blending. ~0.4s at 25 Hz. + progress_callback: Optional diffusion-step callback receiving + ``(current_step, total_steps, desc)``. Returns: Dict[str, Any]: Service output payload containing generated latents, @@ -134,6 +137,7 @@ def service_generate( timesteps=timesteps, repaint_crossfade_frames=repaint_crossfade_frames, repaint_injection_ratio=repaint_injection_ratio, + progress_callback=progress_callback, ) outputs, encoder_hidden_states, encoder_attention_mask, context_latents = ( self._execute_service_generate_diffusion( diff --git a/acestep/core/generation/handler/service_generate_execute.py b/acestep/core/generation/handler/service_generate_execute.py index 4ad6f6d31..836688246 100644 --- a/acestep/core/generation/handler/service_generate_execute.py +++ b/acestep/core/generation/handler/service_generate_execute.py @@ -1,7 +1,7 @@ """Execution helpers for service generation diffusion and output assembly.""" import random -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from loguru import logger @@ -77,6 +77,7 @@ def _build_service_generate_kwargs( timesteps: Optional[List[float]], repaint_crossfade_frames: int = 10, repaint_injection_ratio: float = 0.5, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Dict[str, Any]: """Build kwargs passed to model generation backends.""" repaint_mask = payload.get("repaint_mask") @@ -106,6 +107,8 @@ def _build_service_generate_kwargs( "cfg_interval_start": cfg_interval_start, "cfg_interval_end": cfg_interval_end, "shift": shift, + "use_progress_bar": not getattr(self, "disable_tqdm", False), + "progress_callback": progress_callback, "repaint_mask": repaint_mask, "clean_src_latents": clean_src_latents, "repaint_crossfade_frames": repaint_crossfade_frames, @@ -197,6 +200,8 @@ def _execute_service_generate_diffusion( encoder_hidden_states_non_cover=enc_hs_nc, encoder_attention_mask_non_cover=enc_am_nc, context_latents_non_cover=ctx_nc, + progress_callback=generate_kwargs.get("progress_callback"), + disable_tqdm=not generate_kwargs.get("use_progress_bar", True), ) _tc = outputs.get("time_costs", {}) logger.info( diff --git a/acestep/core/generation/handler/service_generate_execute_test.py b/acestep/core/generation/handler/service_generate_execute_test.py index 8af115425..77c759a08 100644 --- a/acestep/core/generation/handler/service_generate_execute_test.py +++ b/acestep/core/generation/handler/service_generate_execute_test.py @@ -15,6 +15,7 @@ class _Host(ServiceGenerateExecuteMixin, ServiceGenerateOutputsMixin): def __init__(self): """Initialize static runtime fields for helper-method tests.""" self.device = "cpu" + self.disable_tqdm = False self.silence_latent = torch.zeros(1, 4, 4, dtype=torch.float32) @@ -58,6 +59,46 @@ def test_build_generate_kwargs_adds_timesteps_tensor(self): self.assertEqual(kwargs["infer_steps"], 16) self.assertEqual(kwargs["timesteps"].dtype, torch.float32) self.assertEqual(kwargs["timesteps"].device.type, "cpu") + self.assertTrue(kwargs["use_progress_bar"]) + self.assertIsNone(kwargs["progress_callback"]) + + def test_build_generate_kwargs_forwards_runtime_progress_callback(self): + """Runtime callbacks should be forwarded into model-generation kwargs.""" + host = _Host() + payload = { + "text_hidden_states": torch.zeros(1, 2), + "text_attention_mask": torch.ones(1, 2), + "lyric_hidden_states": torch.zeros(1, 2), + "lyric_attention_mask": torch.ones(1, 2), + "refer_audio_acoustic_hidden_states_packed": torch.zeros(1, 2), + "refer_audio_order_mask": torch.zeros(1, dtype=torch.long), + "src_latents": torch.zeros(1, 4, 4), + "chunk_mask": torch.ones(1, 4, dtype=torch.bool), + "is_covers": torch.tensor([True]), + "non_cover_text_hidden_states": None, + "non_cover_text_attention_masks": None, + "precomputed_lm_hints_25Hz": None, + } + callback = lambda current, total, desc: (current, total, desc) + + kwargs = host._build_service_generate_kwargs( + payload=payload, + seed_param=123, + infer_steps=16, + guidance_scale=7.0, + audio_cover_strength=1.0, + cover_noise_strength=0.0, + infer_method="ode", + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + timesteps=None, + progress_callback=callback, + ) + + self.assertIs(kwargs["progress_callback"], callback) + self.assertTrue(kwargs["use_progress_bar"]) def test_attach_service_outputs_persists_required_fields(self): """Attached payload fields should be available to downstream handlers.""" diff --git a/acestep/core/generation/handler/service_generate_test.py b/acestep/core/generation/handler/service_generate_test.py index 6287010ad..4495c6176 100644 --- a/acestep/core/generation/handler/service_generate_test.py +++ b/acestep/core/generation/handler/service_generate_test.py @@ -128,10 +128,12 @@ def test_service_generate_forwards_runtime_controls_to_build_and_execute(self): """It forwards runtime tuning controls to downstream helper invocations.""" host = _Host() custom_timesteps = [1.0, 0.5, 0.0] + progress_callback = lambda current, total, desc: (current, total, desc) host.service_generate( captions="cap", lyrics="lyr", guidance_scale=9.5, audio_cover_strength=0.7, cover_noise_strength=0.2, use_adg=True, cfg_interval_start=0.1, cfg_interval_end=0.9, shift=1.3, infer_method="sde", timesteps=custom_timesteps, return_intermediate=False, + progress_callback=progress_callback, ) build_kwargs = host.calls["_build_service_generate_kwargs"] self.assertEqual(build_kwargs["guidance_scale"], 9.5) @@ -139,6 +141,7 @@ def test_service_generate_forwards_runtime_controls_to_build_and_execute(self): self.assertEqual(build_kwargs["cover_noise_strength"], 0.2) self.assertTrue(build_kwargs["use_adg"]) self.assertEqual(build_kwargs["timesteps"], custom_timesteps) + self.assertIs(build_kwargs["progress_callback"], progress_callback) self.assertFalse(host.calls["_attach_service_generate_outputs"]["return_intermediate"]) execute_kwargs = host.calls["_execute_service_generate_diffusion"] self.assertEqual(execute_kwargs["infer_method"], "sde") diff --git a/acestep/core/generation/handler/vae_decode.py b/acestep/core/generation/handler/vae_decode.py index c3ea76174..03a00872a 100644 --- a/acestep/core/generation/handler/vae_decode.py +++ b/acestep/core/generation/handler/vae_decode.py @@ -19,6 +19,7 @@ def tiled_decode( chunk_size: Optional[int] = None, overlap: int = 64, offload_wav_to_cpu: Optional[bool] = None, + progress_callback=None, ): """Decode latents using tiling to reduce VRAM usage. @@ -32,6 +33,8 @@ def tiled_decode( overlap: Overlap in latent frames between adjacent windows. offload_wav_to_cpu: Whether decoded waveform chunks should be offloaded to CPU immediately to reduce VRAM pressure. + progress_callback: Optional callback receiving + ``(current_step, total_steps, desc)`` for chunk decode progress. Returns: Decoded waveform tensor shaped ``[batch, audio_channels, samples]``. @@ -39,7 +42,7 @@ def tiled_decode( # ---- MLX fast path (macOS Apple Silicon) ---- if self.use_mlx_vae and self.mlx_vae is not None: try: - result = self._mlx_vae_decode(latents) + result = self._mlx_vae_decode(latents, progress_callback=progress_callback) return result except Exception as exc: logger.warning( @@ -74,7 +77,13 @@ def tiled_decode( overlap = min(overlap, _mps_overlap) try: - return self._tiled_decode_inner(latents, chunk_size, overlap, offload_wav_to_cpu) + return self._tiled_decode_inner( + latents, + chunk_size, + overlap, + offload_wav_to_cpu, + progress_callback=progress_callback, + ) except (NotImplementedError, RuntimeError) as exc: if not _is_mps: raise diff --git a/acestep/core/generation/handler/vae_decode_chunks.py b/acestep/core/generation/handler/vae_decode_chunks.py index 8564de9ad..108ddd7fc 100644 --- a/acestep/core/generation/handler/vae_decode_chunks.py +++ b/acestep/core/generation/handler/vae_decode_chunks.py @@ -10,7 +10,7 @@ class VaeDecodeChunksMixin: """Implement chunked decode strategies for GPU and CPU-offload modes.""" - def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): + def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu, progress_callback=None): """Run tiled decode with adaptive overlap and OOM fallbacks.""" bsz, _channels, latent_frames = latents.shape @@ -20,7 +20,18 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): per_sample_results = [] for b_idx in range(bsz): single = latents[b_idx : b_idx + 1] - decoded = self._tiled_decode_inner(single, chunk_size, overlap, offload_wav_to_cpu) + sample_progress = None + if progress_callback is not None: + def _sample_progress(current, total, desc, offset=b_idx): + progress_callback(offset * total + current, bsz * total, desc) + sample_progress = _sample_progress + decoded = self._tiled_decode_inner( + single, + chunk_size, + overlap, + offload_wav_to_cpu, + progress_callback=sample_progress, + ) per_sample_results.append(decoded.cpu() if decoded.device.type != "cpu" else decoded) self._empty_cache() result = torch.cat(per_sample_results, dim=0) @@ -46,6 +57,8 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): decoder_output = self.vae.decode(latents) result = decoder_output.sample del decoder_output + if progress_callback is not None: + progress_callback(1, 1, "Decoding audio chunks...") return result except torch.cuda.OutOfMemoryError: logger.warning("[tiled_decode] OOM on direct decode, falling back to CPU VAE decode") @@ -60,7 +73,15 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): if offload_wav_to_cpu: try: - return self._tiled_decode_offload_cpu(latents, bsz, latent_frames, stride, overlap, num_steps) + return self._tiled_decode_offload_cpu( + latents, + bsz, + latent_frames, + stride, + overlap, + num_steps, + progress_callback=progress_callback, + ) except torch.cuda.OutOfMemoryError: logger.warning( f"[tiled_decode] OOM during offload_cpu decode with chunk_size={chunk_size}, " @@ -70,7 +91,13 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): return self._decode_on_cpu(latents) try: - return self._tiled_decode_gpu(latents, stride, overlap, num_steps) + return self._tiled_decode_gpu( + latents, + stride, + overlap, + num_steps, + progress_callback=progress_callback, + ) except torch.cuda.OutOfMemoryError: logger.warning( f"[tiled_decode] OOM during GPU decode with chunk_size={chunk_size}, " @@ -78,13 +105,21 @@ def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): ) self._empty_cache() try: - return self._tiled_decode_offload_cpu(latents, bsz, latent_frames, stride, overlap, num_steps) + return self._tiled_decode_offload_cpu( + latents, + bsz, + latent_frames, + stride, + overlap, + num_steps, + progress_callback=progress_callback, + ) except torch.cuda.OutOfMemoryError: logger.warning("[tiled_decode] OOM even with offload path, falling back to full CPU VAE decode") self._empty_cache() return self._decode_on_cpu(latents) - def _tiled_decode_gpu(self, latents, stride, overlap, num_steps): + def _tiled_decode_gpu(self, latents, stride, overlap, num_steps, progress_callback=None): """Decode chunks and keep decoded audio tensors on GPU.""" decoded_audio_list = [] upsample_factor = None @@ -112,10 +147,21 @@ def _tiled_decode_gpu(self, latents, stride, overlap, num_steps): end_idx = audio_len - trim_end if trim_end > 0 else audio_len audio_core = audio_chunk[:, :, trim_start:end_idx] decoded_audio_list.append(audio_core) + if progress_callback is not None: + progress_callback(i + 1, num_steps, "Decoding audio chunks...") return torch.cat(decoded_audio_list, dim=-1) - def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap, num_steps): + def _tiled_decode_offload_cpu( + self, + latents, + bsz, + latent_frames, + stride, + overlap, + num_steps, + progress_callback=None, + ): """Decode chunks on GPU and copy trimmed audio cores to a CPU buffer.""" first_core_end = min(stride, latent_frames) first_win_end = min(latent_frames, first_core_end + overlap) @@ -138,6 +184,8 @@ def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap first_audio_core = first_audio_chunk[:, :, :first_end_idx] audio_write_pos = first_audio_core.shape[-1] final_audio[:, :, :audio_write_pos] = first_audio_core.cpu() + if progress_callback is not None: + progress_callback(1, num_steps, "Decoding audio chunks...") del first_audio_chunk, first_audio_core, first_latent_chunk @@ -164,6 +212,8 @@ def _tiled_decode_offload_cpu(self, latents, bsz, latent_frames, stride, overlap core_len = audio_core.shape[-1] final_audio[:, :, audio_write_pos : audio_write_pos + core_len] = audio_core.cpu() audio_write_pos += core_len + if progress_callback is not None: + progress_callback(i + 1, num_steps, "Decoding audio chunks...") del audio_chunk, audio_core, latent_chunk diff --git a/acestep/core/generation/handler/vae_decode_mixin_test.py b/acestep/core/generation/handler/vae_decode_mixin_test.py index e48030b42..13f91d7a9 100644 --- a/acestep/core/generation/handler/vae_decode_mixin_test.py +++ b/acestep/core/generation/handler/vae_decode_mixin_test.py @@ -46,8 +46,9 @@ def test_tiled_decode_falls_back_when_mlx_decode_fails(self): host.use_mlx_vae = True host.mlx_vae = object() - def _mlx_raise(_latents): + def _mlx_raise(_latents, progress_callback=None): """Raise MLX failure to exercise fallback path.""" + _ = progress_callback raise ValueError("mlx failed") host._mlx_vae_decode = _mlx_raise diff --git a/acestep/core/generation/handler/vae_decode_test_helpers.py b/acestep/core/generation/handler/vae_decode_test_helpers.py index 4eccd3167..51ed26664 100644 --- a/acestep/core/generation/handler/vae_decode_test_helpers.py +++ b/acestep/core/generation/handler/vae_decode_test_helpers.py @@ -66,12 +66,13 @@ def _should_offload_wav_to_cpu(self): """Return deterministic offload policy used by default path.""" return False - def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu): + def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu, progress_callback=None): """Record routed args and return sentinel audio tensor.""" _ = latents self.recorded["chunk_size"] = chunk_size self.recorded["overlap"] = overlap self.recorded["offload"] = offload_wav_to_cpu + self.recorded["progress_callback"] = progress_callback return torch.ones(1, 2, 8) def _tiled_decode_cpu_fallback(self, latents): @@ -79,9 +80,10 @@ def _tiled_decode_cpu_fallback(self, latents): _ = latents return torch.full((1, 2, 8), 2.0) - def _mlx_vae_decode(self, latents): + def _mlx_vae_decode(self, latents, progress_callback=None): """Return MLX sentinel tensor for MLX path assertions.""" _ = latents + self.recorded["progress_callback"] = progress_callback return torch.full((1, 2, 6), 3.0) diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index 42d7b8ad8..47d0eec32 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -9,7 +9,7 @@ import time import random import warnings -from typing import Optional, Dict, Any, Tuple, List, Union +from typing import Callable, Optional, Dict, Any, Tuple, List, Union from contextlib import contextmanager import yaml @@ -874,6 +874,7 @@ def _run_vllm( lyrics: str = "", cot_text: str = "", seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Union[str, List[str]]: """ Unified vllm generation function supporting both single and batch modes. @@ -959,6 +960,8 @@ def _run_vllm( output_texts.append(str(output)) # Return single string for single mode, list for batch mode + if progress_callback is not None: + progress_callback(1, 1, "vLLM generation") return output_texts[0] if not is_batch else output_texts def _run_pt_single( @@ -982,6 +985,7 @@ def _run_pt_single( caption: str, lyrics: str, cot_text: str, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> str: """Internal helper function for single-item PyTorch generation.""" inputs = self.llm_tokenizer( @@ -1060,6 +1064,7 @@ def _run_pt_single( pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, streamer=None, constrained_processor=constrained_processor, + progress_callback=progress_callback, ) # Extract only the conditional output (first in batch) @@ -1077,6 +1082,7 @@ def _run_pt_single( pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, streamer=None, constrained_processor=constrained_processor, + progress_callback=progress_callback, ) else: # Generate without CFG using native generate() parameters @@ -1143,6 +1149,7 @@ def _run_pt( lyrics: str = "", cot_text: str = "", seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Union[str, List[str]]: """ Unified PyTorch generation function supporting both single and batch modes. @@ -1169,6 +1176,13 @@ def _run_pt( elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): torch.mps.manual_seed(seeds[i]) + item_progress_callback = None + if progress_callback is not None: + def _item_progress(current, total, desc, index=i, batch_total=len(formatted_prompt_list)): + progress_callback(index * total + current, batch_total * total, desc) + + item_progress_callback = _item_progress + # Generate using single-item method with batch-mode defaults output_text = self._run_pt_single( formatted_prompt=formatted_prompt, @@ -1190,6 +1204,7 @@ def _run_pt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=item_progress_callback, ) output_texts.append(output_text) @@ -1219,6 +1234,7 @@ def _run_pt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool: @@ -1340,6 +1356,20 @@ def progress(*args, **kwargs): else: seeds = seeds[:actual_batch_size] + def _make_phase_progress_callback( + start: float, + end: float, + default_desc: str, + ) -> Callable[[int, int, str], None]: + def _callback(current: int, total: int, desc: str) -> None: + if total <= 0: + return + frac = min(1.0, max(0.0, current / total)) + mapped = start + (end - start) * frac + progress(mapped, desc or default_desc) + + return _callback + # ========== PHASE 1: CoT Generation ========== # Skip CoT if all metadata are user-provided OR caption is already formatted progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...") @@ -1380,6 +1410,11 @@ def progress(*args, **kwargs): use_constrained_decoding=use_constrained_decoding, constrained_decoding_debug=constrained_decoding_debug, stop_at_reasoning=True, # Always stop at in Phase 1 + progress_callback=_make_phase_progress_callback( + 0.1, + 0.3, + "LLM metadata generation", + ), ) phase1_time = time.time() - phase1_start @@ -1468,7 +1503,12 @@ def progress(*args, **kwargs): formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text) logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}") - progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...") + progress(0.31, f"Phase 2: Generating audio codes for {actual_batch_size} items...") + phase2_progress_callback = _make_phase_progress_callback( + 0.31, + 0.5, + "LLM audio code generation", + ) if is_batch: # Batch mode: generate codes for all items formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size @@ -1492,6 +1532,7 @@ def progress(*args, **kwargs): lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=phase2_progress_callback, ) elif self.llm_backend == "mlx": codes_outputs = self._run_mlx( @@ -1510,6 +1551,7 @@ def progress(*args, **kwargs): lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=phase2_progress_callback, ) else: # pt backend codes_outputs = self._run_pt( @@ -1528,6 +1570,7 @@ def progress(*args, **kwargs): lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=phase2_progress_callback, ) except Exception as e: error_msg = f"Error in batch codes generation: {str(e)}" @@ -1602,6 +1645,7 @@ def progress(*args, **kwargs): use_constrained_decoding=use_constrained_decoding, constrained_decoding_debug=constrained_decoding_debug, stop_at_reasoning=False, # Generate codes until EOS + progress_callback=phase2_progress_callback, ) if not codes_output_text: @@ -2321,6 +2365,7 @@ def generate_from_formatted_prompt( use_constrained_decoding: bool = True, constrained_decoding_debug: bool = False, stop_at_reasoning: bool = False, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Tuple[str, str]: """ Generate raw LM text output from a pre-built formatted prompt. @@ -2394,6 +2439,7 @@ def generate_from_formatted_prompt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) self._clear_accelerator_cache() return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}" @@ -2420,6 +2466,7 @@ def generate_from_formatted_prompt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) self._clear_accelerator_cache() return output_text, f"✅ Generated successfully (mlx) | length={len(output_text)}" @@ -2445,6 +2492,7 @@ def generate_from_formatted_prompt( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) self._clear_accelerator_cache() return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}" @@ -2486,6 +2534,7 @@ def _generate_with_constrained_decoding( pad_token_id: int, streamer: Optional[BaseStreamer], constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> torch.Tensor: """ Custom generation loop with constrained decoding support (non-CFG). @@ -2558,6 +2607,8 @@ def _generate_with_constrained_decoding( # Update streamer if streamer is not None: streamer.put(next_tokens_unsqueezed) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, "LLM Constrained Decoding") if should_stop: break @@ -2583,6 +2634,7 @@ def _generate_with_cfg_custom( pad_token_id: int, streamer: Optional[BaseStreamer], constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> torch.Tensor: """ Custom CFG generation loop that: @@ -2729,6 +2781,8 @@ def _generate_with_cfg_custom( # Update streamer if streamer is not None: streamer.put(next_tokens_unsqueezed) # Stream conditional tokens + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, "LLM CFG Generation") # Stop generation only when ALL sequences have finished if seq_finished.all(): @@ -3005,6 +3059,7 @@ def _run_mlx_batch_native( lyrics: str, cot_text: str, seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> List[str]: """ Optimized native MLX batch generation for codes phase. @@ -3322,6 +3377,8 @@ def _clone_cache_list(cache_list): item_last_logits[i] = logits_out[:, -1:, :] pbar.update(1) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, f"MLX {cfg_label}Batch Gen (native, n={batch_size})") # Periodic memory cleanup if step % 256 == 0 and step > 0: @@ -3370,6 +3427,7 @@ def _run_mlx_single_native( caption: str, lyrics: str, cot_text: str, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> str: """ Optimized native MLX generation using mlx-lm infrastructure. @@ -3628,6 +3686,8 @@ def _run_mlx_single_native( new_tokens.append(token_id) all_token_ids.append(token_id) pbar.update(1) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, tqdm_desc) # Update constrained processor FSM state if constrained_processor is not None: @@ -3693,6 +3753,7 @@ def _run_mlx_single( caption: str, lyrics: str, cot_text: str, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> str: """ MLX-accelerated single-item generation. @@ -3723,6 +3784,7 @@ def _run_mlx_single( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) except Exception as _native_err: logger.warning( @@ -3872,6 +3934,8 @@ def _run_mlx_single( new_tokens.append(token_id) all_token_ids.append(token_id) pbar.update(1) + if progress_callback is not None: + progress_callback(step + 1, max_new_tokens, tqdm_desc) # Update constrained processor state if constrained_processor is not None: @@ -3934,6 +3998,7 @@ def _run_mlx( lyrics: str = "", cot_text: str = "", seeds: Optional[List[int]] = None, + progress_callback: Optional[Callable[[int, int, str], None]] = None, ) -> Union[str, List[str]]: """ Unified MLX generation function supporting both single and batch modes. @@ -3983,6 +4048,7 @@ def _run_mlx( lyrics=lyrics, cot_text=cot_text, seeds=seeds, + progress_callback=progress_callback, ) except Exception as e: logger.warning( @@ -3998,6 +4064,13 @@ def _run_mlx( if seeds and i < len(seeds): mx.random.seed(seeds[i]) + item_progress_callback = None + if progress_callback is not None: + def _item_progress(current, total, desc, index=i, batch_total=batch_size): + progress_callback(index * total + current, batch_total * total, desc) + + item_progress_callback = _item_progress + output_text = self._run_mlx_single( formatted_prompt=formatted_prompt, temperature=temperature, @@ -4018,6 +4091,7 @@ def _run_mlx( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=item_progress_callback, ) output_texts.append(output_text) return output_texts @@ -4044,6 +4118,7 @@ def _run_mlx( caption=caption, lyrics=lyrics, cot_text=cot_text, + progress_callback=progress_callback, ) # ========================================================================= diff --git a/acestep/llm_inference_cfg_fixes_test.py b/acestep/llm_inference_cfg_fixes_test.py index 60893a814..4eaa92fe0 100644 --- a/acestep/llm_inference_cfg_fixes_test.py +++ b/acestep/llm_inference_cfg_fixes_test.py @@ -77,9 +77,10 @@ class TestCotCfgScaleFixed(unittest.TestCase): def test_cot_phase_uses_cfg_scale_1(self): """generate_from_formatted_prompt called during CoT must receive cfg_scale=1.0.""" - handler = LLMHandler() + handler = _make_handler() handler.llm_initialized = True handler.llm_backend = "pt" + handler.llm = MagicMock() captured_cfg = {} @@ -88,34 +89,23 @@ def fake_run_pt(formatted_prompts, temperature, cfg_scale, **kwargs): return "metadata" with patch.object(handler, "_run_pt", side_effect=fake_run_pt): - with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"): - with patch.object( - handler, - "_format_metadata_as_cot", - return_value="", - ): - with patch.object( - handler, - "build_formatted_prompt_with_cot", - return_value="PROMPT_WITH_COT", - ): - # Simulate Phase 1 CoT call via generate_from_formatted_prompt - # by calling the internal helper directly - handler.generate_from_formatted_prompt( - formatted_prompt="PROMPT", - cfg={ - "temperature": 0.6, - "cfg_scale": 1.0, # already 1.0 for CoT - "negative_prompt": "NO USER INPUT", - "top_k": None, - "top_p": None, - "repetition_penalty": 1.0, - "target_duration": None, - "generation_phase": "cot", - "caption": "test", - "lyrics": "test", - }, - ) + # Simulate Phase 1 CoT call via generate_from_formatted_prompt by + # invoking the helper directly with a current-valid handler state. + handler.generate_from_formatted_prompt( + formatted_prompt="PROMPT", + cfg={ + "temperature": 0.6, + "cfg_scale": 1.0, # already 1.0 for CoT + "negative_prompt": "NO USER INPUT", + "top_k": None, + "top_p": None, + "repetition_penalty": 1.0, + "target_duration": None, + "generation_phase": "cot", + "caption": "test", + "lyrics": "test", + }, + ) # cfg_scale must be 1.0 for CoT – captured during _run_pt call self.assertEqual(captured_cfg.get("cfg_scale"), 1.0) @@ -136,7 +126,7 @@ def capturing_gen(formatted_prompt, cfg=None, **kwargs): with patch.object(handler, "generate_from_formatted_prompt", side_effect=capturing_gen): with patch.object(handler, "build_formatted_prompt", return_value="P"): - with patch.object(handler, "_parse_metadata_from_cot", return_value={}): + with patch.object(handler, "parse_lm_output", return_value=({}, "")): with patch.object(handler, "_format_metadata_as_cot", return_value=""): with patch.object( handler, "build_formatted_prompt_with_cot", return_value="P2" @@ -471,5 +461,79 @@ def fake_call(input_ids, **kwargs): self.assertEqual(result[0, -1].item(), eos_id) +@unittest.skipIf(LLMHandler is None, f"llm_inference import unavailable: {_IMPORT_ERROR}") +class TestLmProgressCallbacks(unittest.TestCase): + """Progress callback coverage for LM token generation and phase mapping.""" + + def test_generate_with_cfg_custom_emits_token_progress(self): + """CFG token loop should emit monotonic per-token progress callbacks.""" + handler = _make_handler() + eos_id = 1 + vocab_size = 10 + handler.llm_tokenizer.eos_token_id = eos_id + + model = MagicMock() + logits = torch.full((2, 1, vocab_size), -100.0) + logits[:, :, 2] = 100.0 + outputs = SimpleNamespace(logits=logits, past_key_values=None) + model.return_value = outputs + model.generation_config = MagicMock() + model.generation_config.use_cache = False + handler.llm = model + + updates = [] + handler._generate_with_cfg_custom( + batch_input_ids=torch.zeros((2, 3), dtype=torch.long), + batch_attention_mask=None, + max_new_tokens=2, + temperature=1.0, + cfg_scale=1.0, + top_k=None, + top_p=None, + repetition_penalty=1.0, + pad_token_id=eos_id, + streamer=None, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual(updates, [(1, 2, "LLM CFG Generation"), (2, 2, "LLM CFG Generation")]) + + def test_generate_with_stop_condition_maps_phase_progress_to_outer_callback(self): + """Phase callbacks should surface token progress onto the outer progress callback.""" + handler = LLMHandler() + handler.llm_initialized = True + handler.llm_backend = "pt" + + updates = [] + + def fake_generate_from_formatted_prompt(*args, **kwargs): + progress_callback = kwargs.get("progress_callback") + if progress_callback is not None: + progress_callback(5, 10, "LLM CFG Generation") + progress_callback(10, 10, "LLM CFG Generation") + return "bpm: 120\n", "ok" + + with patch.object(handler, "generate_from_formatted_prompt", side_effect=fake_generate_from_formatted_prompt): + with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"): + with patch.object(handler, "parse_lm_output", return_value=({"bpm": 120}, "")): + handler.generate_with_stop_condition( + caption="test caption", + lyrics="test lyrics", + cfg_scale=2.0, + temperature=0.6, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + infer_type="dit", + progress=lambda value, desc=None: updates.append((value, desc)), + ) + + llm_cfg_updates = [value for value, desc in updates if desc == "LLM CFG Generation"] + self.assertTrue(llm_cfg_updates) + self.assertGreaterEqual(min(llm_cfg_updates), 0.1) + self.assertLessEqual(max(llm_cfg_updates), 0.3) + + if __name__ == "__main__": unittest.main() diff --git a/acestep/models/base/modeling_acestep_v15_base.py b/acestep/models/base/modeling_acestep_v15_base.py index 088159df2..d34dd2c30 100644 --- a/acestep/models/base/modeling_acestep_v15_base.py +++ b/acestep/models/base/modeling_acestep_v15_base.py @@ -1854,6 +1854,7 @@ def generate_audio( precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: Optional[torch.FloatTensor] = None, use_progress_bar: bool = True, + progress_callback: Optional[Callable[[int, int, str], None]] = None, use_adg: bool = False, shift: float = 1.0, cover_noise_strength: float = 0.0, @@ -2037,6 +2038,8 @@ def generate_audio( xt = _repaint_step_injection( xt, clean_src_latents, repaint_mask, t_after_step, noise, ) + if progress_callback is not None: + progress_callback(step_idx + 1, infer_steps, "DiT diffusion...") x_gen = xt if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0: diff --git a/acestep/models/mlx/dit_generate.py b/acestep/models/mlx/dit_generate.py index 454e6601a..c7b40fe5c 100644 --- a/acestep/models/mlx/dit_generate.py +++ b/acestep/models/mlx/dit_generate.py @@ -5,7 +5,7 @@ import logging import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from tqdm import tqdm @@ -134,6 +134,7 @@ def mlx_generate_diffusion( encoder_hidden_states_non_cover_np: Optional[np.ndarray] = None, context_latents_non_cover_np: Optional[np.ndarray] = None, compile_model: bool = False, + progress_callback: Optional[Callable[[int, int, str], None]] = None, disable_tqdm: bool = False, ) -> Dict[str, object]: """Run the complete MLX diffusion loop with optional CFG guidance. @@ -160,6 +161,8 @@ def mlx_generate_diffusion( encoder_hidden_states_non_cover_np: optional [B, enc_L, D] for non-cover. context_latents_non_cover_np: optional [B, T, C] for non-cover. compile_model: If True, compile the decoder step with ``mx.compile``. + progress_callback: Optional diffusion-step callback receiving + ``(current_step, total_steps, desc)``. disable_tqdm: If True, suppress the diffusion progress bar. Returns: @@ -302,6 +305,8 @@ def _raw_step(xt, t, tr, enc, ctx): xt = xt - vt * dt_arr mx.eval(xt) + if progress_callback is not None: + progress_callback(step_idx + 1, num_steps, "DiT diffusion...") diff_end = time.time() total_end = time.time() diff --git a/acestep/models/sft/modeling_acestep_v15_base.py b/acestep/models/sft/modeling_acestep_v15_base.py index 8980d4200..5148aa4d0 100644 --- a/acestep/models/sft/modeling_acestep_v15_base.py +++ b/acestep/models/sft/modeling_acestep_v15_base.py @@ -1854,6 +1854,7 @@ def generate_audio( precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: Optional[torch.FloatTensor] = None, use_progress_bar: bool = True, + progress_callback: Optional[Callable[[int, int, str], None]] = None, use_adg: bool = False, shift: float = 1.0, timesteps: Optional[torch.Tensor] = None, @@ -2045,6 +2046,8 @@ def generate_audio( xt = _repaint_step_injection( xt, clean_src_latents, repaint_mask, t_after_step, noise, ) + if progress_callback is not None: + progress_callback(step_idx + 1, infer_steps, "DiT diffusion...") x_gen = xt if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0: diff --git a/acestep/models/turbo/modeling_acestep_v15_turbo.py b/acestep/models/turbo/modeling_acestep_v15_turbo.py index 69f88a69f..e5efbfdc1 100644 --- a/acestep/models/turbo/modeling_acestep_v15_turbo.py +++ b/acestep/models/turbo/modeling_acestep_v15_turbo.py @@ -1844,6 +1844,8 @@ def generate_audio( non_cover_text_attention_mask: Optional[torch.FloatTensor] = None, precomputed_lm_hints_25Hz: Optional[torch.FloatTensor] = None, audio_codes: Optional[torch.FloatTensor] = None, + use_progress_bar: bool = True, + progress_callback: Optional[Callable[[int, int, str], None]] = None, shift: float = 3.0, timesteps: Optional[torch.Tensor] = None, cover_noise_strength: float = 0.0, @@ -2023,6 +2025,8 @@ def generate_audio( # On final step, directly compute x0 from noise if step_idx == num_steps - 1: xt = self.get_x0_from_noise(xt, vt, t_curr_tensor) + if progress_callback is not None: + progress_callback(step_idx + 1, num_steps, "DiT diffusion...") break # Update x_t based on inference method @@ -2046,6 +2050,8 @@ def generate_audio( xt = _repaint_step_injection( xt, clean_src_latents, repaint_mask, t_after_step, noise, ) + if progress_callback is not None: + progress_callback(step_idx + 1, num_steps, "DiT diffusion...") x_gen = xt if repaint_mask is not None and clean_src_latents is not None and repaint_crossfade_frames > 0: @@ -2214,4 +2220,4 @@ def test_forward(model, seed=42): # model = model.float() model = model.to("cuda") model = model.bfloat16() - test_forward(model) \ No newline at end of file + test_forward(model) From 23e90c05e90b1c63dbbedc6357f87f8db45d9905 Mon Sep 17 00:00:00 2001 From: 1larity Date: Sat, 28 Mar 2026 10:14:24 +0000 Subject: [PATCH 2/6] fix(gradio): expose decode phase in generation progress --- acestep/core/generation/handler/generate_music_decode.py | 3 ++- acestep/core/generation/handler/generate_music_decode_test.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py index 4b10b42dc..d232d826d 100644 --- a/acestep/core/generation/handler/generate_music_decode.py +++ b/acestep/core/generation/handler/generate_music_decode.py @@ -126,7 +126,7 @@ def _decode_chunk_progress(current: int, total: int, desc: str) -> None: _emit_decode_progress(mapped, desc) if progress: - _emit_decode_progress(0.8, "Decoding audio...") + _emit_decode_progress(0.8, "Preparing audio decode...") logger.info("[generate_music] Decoding latents with VAE...") start_time = time.time() with torch.inference_mode(): @@ -171,6 +171,7 @@ def _decode_chunk_progress(current: int, total: int, desc: str) -> None: pred_latents_for_decode = pred_latents_for_decode.cpu() self._empty_cache() try: + _emit_decode_progress(0.82, "Decoding audio chunks...") if use_tiled_decode: logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") pred_wavs = self.tiled_decode( diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py index 9eef2d19d..508314837 100644 --- a/acestep/core/generation/handler/generate_music_decode_test.py +++ b/acestep/core/generation/handler/generate_music_decode_test.py @@ -195,8 +195,8 @@ def _progress(value, desc=None): self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6) self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6) self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6) - self.assertEqual(host.progress_calls[0][0], 0.8) - self.assertTrue(any(desc == "Decoding audio chunks..." for _, desc in host.progress_calls)) + self.assertEqual(host.progress_calls[0], (0.8, "Preparing audio decode...")) + self.assertIn((0.82, "Decoding audio chunks..."), host.progress_calls) self.assertAlmostEqual(host.progress_calls[-1][0], 0.98, places=6) def test_decode_pred_latents_restores_vae_device_on_decode_error(self): From 966c482748291fd1ab45dd4204aa5b544b6fab07 Mon Sep 17 00:00:00 2001 From: 1larity Date: Sat, 28 Mar 2026 10:22:08 +0000 Subject: [PATCH 3/6] fix(gradio): keep decode progress monotonic --- .../core/generation/handler/generate_music_decode.py | 2 +- .../generation/handler/generate_music_decode_test.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/acestep/core/generation/handler/generate_music_decode.py b/acestep/core/generation/handler/generate_music_decode.py index d232d826d..077edd311 100644 --- a/acestep/core/generation/handler/generate_music_decode.py +++ b/acestep/core/generation/handler/generate_music_decode.py @@ -122,7 +122,7 @@ def _decode_chunk_progress(current: int, total: int, desc: str) -> None: if total <= 0: return frac = min(1.0, max(0.0, current / total)) - mapped = 0.8 + 0.18 * frac + mapped = 0.82 + 0.16 * frac _emit_decode_progress(mapped, desc) if progress: diff --git a/acestep/core/generation/handler/generate_music_decode_test.py b/acestep/core/generation/handler/generate_music_decode_test.py index 508314837..ae1c9c1d9 100644 --- a/acestep/core/generation/handler/generate_music_decode_test.py +++ b/acestep/core/generation/handler/generate_music_decode_test.py @@ -119,16 +119,16 @@ def _mlx_vae_decode(self, latents, progress_callback=None): """Return deterministic decoded waveform for MLX decode branch.""" _ = latents if progress_callback is not None: - progress_callback(1, 4, "Decoding audio chunks...") - progress_callback(4, 4, "Decoding audio chunks...") + progress_callback(1, 64, "Decoding audio chunks...") + progress_callback(64, 64, "Decoding audio chunks...") return torch.ones(1, 2, 8) def tiled_decode(self, latents, progress_callback=None): """Return deterministic decoded waveform for tiled decode branch.""" _ = latents if progress_callback is not None: - progress_callback(1, 4, "Decoding audio chunks...") - progress_callback(4, 4, "Decoding audio chunks...") + progress_callback(1, 64, "Decoding audio chunks...") + progress_callback(64, 64, "Decoding audio chunks...") return torch.ones(1, 2, 8) @@ -197,6 +197,8 @@ def _progress(value, desc=None): self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6) self.assertEqual(host.progress_calls[0], (0.8, "Preparing audio decode...")) self.assertIn((0.82, "Decoding audio chunks..."), host.progress_calls) + progress_values = [value for value, _ in host.progress_calls] + self.assertEqual(progress_values, sorted(progress_values)) self.assertAlmostEqual(host.progress_calls[-1][0], 0.98, places=6) def test_decode_pred_latents_restores_vae_device_on_decode_error(self): From 6216131eaa2a130b8596a0eb53ad2c91b07dfaab Mon Sep 17 00:00:00 2001 From: 1larity Date: Sat, 28 Mar 2026 10:45:20 +0000 Subject: [PATCH 4/6] test(gradio): lock in diffusion progress callback parity --- acestep/models/mlx/dit_generate_test.py | 90 +++++++++++++++++++ ...odeling_acestep_v15_turbo_progress_test.py | 90 +++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 acestep/models/mlx/dit_generate_test.py create mode 100644 acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py diff --git a/acestep/models/mlx/dit_generate_test.py b/acestep/models/mlx/dit_generate_test.py new file mode 100644 index 000000000..fe01789d8 --- /dev/null +++ b/acestep/models/mlx/dit_generate_test.py @@ -0,0 +1,90 @@ +import sys +import types +import unittest +from unittest.mock import patch + +import numpy as np + +from acestep.models.mlx.dit_generate import mlx_generate_diffusion + + +class _FakeRandom: + def normal(self, shape, key=None): + _ = key + return np.zeros(shape, dtype=np.float32) + + def key(self, seed): + return seed + + +class _FakeDecoder: + def __call__( + self, + hidden_states, + timestep, + timestep_r, + encoder_hidden_states, + context_latents, + cache=None, + use_cache=False, + ): + _ = timestep + _ = timestep_r + _ = encoder_hidden_states + _ = context_latents + _ = use_cache + return np.zeros_like(hidden_states), cache + + +class MlxGenerateDiffusionProgressTests(unittest.TestCase): + def test_progress_callback_fires_once_per_diffusion_step(self): + fake_mx_core = types.ModuleType("mlx.core") + fake_mx_core.array = lambda value: np.array(value, dtype=np.float32) + fake_mx_core.concatenate = lambda values, axis=0: np.concatenate(values, axis=axis) + fake_mx_core.broadcast_to = lambda value, shape: np.broadcast_to(value, shape) + fake_mx_core.full = lambda shape, fill_value: np.full(shape, fill_value, dtype=np.float32) + fake_mx_core.eval = lambda value: value + fake_mx_core.random = _FakeRandom() + + fake_mlx_pkg = types.ModuleType("mlx") + fake_mlx_pkg.core = fake_mx_core + + fake_dit_model = types.ModuleType("acestep.models.mlx.dit_model") + + class _FakeCache: + pass + + fake_dit_model.MLXCrossAttentionCache = _FakeCache + + updates = [] + with patch.dict( + sys.modules, + { + "mlx": fake_mlx_pkg, + "mlx.core": fake_mx_core, + "acestep.models.mlx.dit_model": fake_dit_model, + }, + ): + result = mlx_generate_diffusion( + mlx_decoder=_FakeDecoder(), + encoder_hidden_states_np=np.zeros((1, 2, 3), dtype=np.float32), + context_latents_np=np.zeros((1, 2, 3), dtype=np.float32), + src_latents_shape=(1, 2, 3), + timesteps=[1.0, 0.75, 0.5], + progress_callback=lambda step, total, desc: updates.append((step, total, desc)), + disable_tqdm=True, + ) + + self.assertEqual( + updates, + [ + (1, 3, "DiT diffusion..."), + (2, 3, "DiT diffusion..."), + (3, 3, "DiT diffusion..."), + ], + ) + self.assertEqual(tuple(result["target_latents"].shape), (1, 2, 3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py b/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py new file mode 100644 index 000000000..28214fbb4 --- /dev/null +++ b/acestep/models/turbo/modeling_acestep_v15_turbo_progress_test.py @@ -0,0 +1,90 @@ +import unittest + +import torch + +from acestep.models.turbo.modeling_acestep_v15_turbo import AceStepConditionGenerationModel + + +class _FakeDecoder: + def __call__( + self, + hidden_states, + timestep, + timestep_r, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + context_latents, + use_cache, + past_key_values, + ): + _ = timestep + _ = timestep_r + _ = attention_mask + _ = encoder_hidden_states + _ = encoder_attention_mask + _ = context_latents + _ = use_cache + return torch.zeros_like(hidden_states), past_key_values + + +class _FakeTurboHost: + def __init__(self): + self.decoder = _FakeDecoder() + + def prepare_condition(self, **kwargs): + src_latents = kwargs["src_latents"] + attention_mask = kwargs["attention_mask"] + return src_latents, attention_mask, src_latents + + def prepare_noise(self, context_latents, seed): + _ = seed + return torch.zeros_like(context_latents) + + def get_x0_from_noise(self, zt, vt, t): + return zt - vt * t.unsqueeze(-1).unsqueeze(-1) + + def renoise(self, x, t, noise=None): + _ = t + _ = noise + return x + + +class AceStepTurboProgressTests(unittest.TestCase): + def test_progress_callback_fires_once_per_step_including_final_step(self): + host = _FakeTurboHost() + updates = [] + base = torch.zeros((1, 2, 2), dtype=torch.float32) + mask = torch.ones((1, 2), dtype=torch.float32) + cover_mask = torch.zeros((1,), dtype=torch.float32) + + AceStepConditionGenerationModel.generate_audio( + host, + text_hidden_states=base, + text_attention_mask=mask, + lyric_hidden_states=base, + lyric_attention_mask=mask, + refer_audio_acoustic_hidden_states_packed=base, + refer_audio_order_mask=mask, + src_latents=base, + chunk_masks=base, + is_covers=cover_mask, + silence_latent=base, + attention_mask=mask, + seed=0, + infer_method="ode", + timesteps=[1.0, 0.5], + progress_callback=lambda step, total, desc: updates.append((step, total, desc)), + ) + + self.assertEqual( + updates, + [ + (1, 2, "DiT diffusion..."), + (2, 2, "DiT diffusion..."), + ], + ) + + +if __name__ == "__main__": + unittest.main() From 0bb0095d2bd67b51511c17a0f1a682f471705eab Mon Sep 17 00:00:00 2001 From: 1larity Date: Sat, 28 Mar 2026 11:20:42 +0000 Subject: [PATCH 5/6] fix(gradio): tighten progress callback monotonicity --- .../handler/generate_music_execute.py | 10 +-- .../handler/generate_music_execute_test.py | 69 +++++++++++++++++++ .../handler/service_generate_execute_test.py | 3 +- .../generation/handler/vae_decode_chunks.py | 13 +++- .../handler/vae_decode_chunks_test.py | 26 +++++++ acestep/llm_inference.py | 3 + acestep/llm_inference_cfg_fixes_test.py | 39 +++++++++++ 7 files changed, 154 insertions(+), 9 deletions(-) diff --git a/acestep/core/generation/handler/generate_music_execute.py b/acestep/core/generation/handler/generate_music_execute.py index 2b34ea651..88fdb9b88 100644 --- a/acestep/core/generation/handler/generate_music_execute.py +++ b/acestep/core/generation/handler/generate_music_execute.py @@ -78,11 +78,11 @@ def _emit_progress(value: float, desc: Optional[str] = None) -> float: with progress_lock: clamped = max(progress_state["value"], value) progress_state["value"] = clamped - if progress is not None: - try: - progress(clamped, desc=desc) - except Exception as exc: - logger.debug("[generate_music] Ignoring progress callback error: {}", exc) + if progress is not None: + try: + progress(clamped, desc=desc) + except Exception as exc: + logger.debug("[generate_music] Ignoring progress callback error: {}", exc) return clamped _emit_progress(0.52, progress_desc) diff --git a/acestep/core/generation/handler/generate_music_execute_test.py b/acestep/core/generation/handler/generate_music_execute_test.py index 50afc2d2e..a105f6c02 100644 --- a/acestep/core/generation/handler/generate_music_execute_test.py +++ b/acestep/core/generation/handler/generate_music_execute_test.py @@ -2,6 +2,7 @@ import unittest import time +import threading from acestep.core.generation.handler.generate_music_execute import GenerateMusicExecuteMixin @@ -184,6 +185,74 @@ def test_runtime_progress_handoff_stays_monotonic_after_estimator_advances(self) self.assertEqual(progress_values, sorted(progress_values)) self.assertGreaterEqual(progress_values[2], 0.68) + def test_runtime_progress_does_not_publish_behind_blocked_estimator_callback(self): + """A blocked estimator callback should not allow a stale publish after runtime progress.""" + + class _AsyncEstimatorHost(_Host): + def _start_diffusion_progress_estimator(self, **kwargs): + self.started = True + progress = kwargs["progress"] + desc = kwargs["desc"] + + class _Stop: + def __init__(self, host): + self.host = host + + def set(self): + self.host.stopped = True + self.host.stop_calls += 1 + + thread = threading.Thread( + target=lambda: progress(0.63, desc=desc), + name="test-estimator", + daemon=True, + ) + thread.start() + return _Stop(self), thread + + host = _AsyncEstimatorHost() + blocked_estimator = threading.Event() + updates = [] + + def progress(value, desc=None): + if value == 0.63 and not blocked_estimator.is_set(): + blocked_estimator.set() + time.sleep(0.05) + updates.append((value, desc)) + + host._run_generate_music_service_with_progress( + progress=progress, + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + progress_values = [value for value, _ in updates] + self.assertEqual(progress_values, sorted(progress_values)) + if __name__ == "__main__": unittest.main() diff --git a/acestep/core/generation/handler/service_generate_execute_test.py b/acestep/core/generation/handler/service_generate_execute_test.py index 77c759a08..ed2d61df5 100644 --- a/acestep/core/generation/handler/service_generate_execute_test.py +++ b/acestep/core/generation/handler/service_generate_execute_test.py @@ -79,7 +79,8 @@ def test_build_generate_kwargs_forwards_runtime_progress_callback(self): "non_cover_text_attention_masks": None, "precomputed_lm_hints_25Hz": None, } - callback = lambda current, total, desc: (current, total, desc) + def callback(current, total, desc): + return (current, total, desc) kwargs = host._build_service_generate_kwargs( payload=payload, diff --git a/acestep/core/generation/handler/vae_decode_chunks.py b/acestep/core/generation/handler/vae_decode_chunks.py index 108ddd7fc..3ee50cb44 100644 --- a/acestep/core/generation/handler/vae_decode_chunks.py +++ b/acestep/core/generation/handler/vae_decode_chunks.py @@ -13,6 +13,13 @@ class VaeDecodeChunksMixin: def _tiled_decode_inner(self, latents, chunk_size, overlap, offload_wav_to_cpu, progress_callback=None): """Run tiled decode with adaptive overlap and OOM fallbacks.""" bsz, _channels, latent_frames = latents.shape + completed_steps = 0 + + def _monotonic_progress(current, total, desc): + nonlocal completed_steps + completed_steps = max(completed_steps, current) + if progress_callback is not None: + progress_callback(completed_steps, total, desc) # Batch-sequential decode keeps peak VRAM stable across batch sizes. if bsz > 1: @@ -80,7 +87,7 @@ def _sample_progress(current, total, desc, offset=b_idx): stride, overlap, num_steps, - progress_callback=progress_callback, + progress_callback=_monotonic_progress, ) except torch.cuda.OutOfMemoryError: logger.warning( @@ -96,7 +103,7 @@ def _sample_progress(current, total, desc, offset=b_idx): stride, overlap, num_steps, - progress_callback=progress_callback, + progress_callback=_monotonic_progress, ) except torch.cuda.OutOfMemoryError: logger.warning( @@ -112,7 +119,7 @@ def _sample_progress(current, total, desc, offset=b_idx): stride, overlap, num_steps, - progress_callback=progress_callback, + progress_callback=_monotonic_progress, ) except torch.cuda.OutOfMemoryError: logger.warning("[tiled_decode] OOM even with offload path, falling back to full CPU VAE decode") diff --git a/acestep/core/generation/handler/vae_decode_chunks_test.py b/acestep/core/generation/handler/vae_decode_chunks_test.py index d44bc9da3..566a0f41c 100644 --- a/acestep/core/generation/handler/vae_decode_chunks_test.py +++ b/acestep/core/generation/handler/vae_decode_chunks_test.py @@ -74,6 +74,32 @@ def _oom(*args, **kwargs): self.assertTrue(torch.equal(out, torch.full((1, 2, 7), 9.0))) self.assertEqual(host.decode_on_cpu_calls, 1) + def test_gpu_to_offload_fallback_keeps_progress_monotonic(self): + """Offload retry progress should not rewind after GPU OOM progress has started.""" + host = _ChunksHost() + updates = [] + + def _gpu_oom(*args, **kwargs): + kwargs["progress_callback"](3, 4, "Decoding audio chunks...") + raise torch.cuda.OutOfMemoryError("gpu oom") + + def _offload_ok(*args, **kwargs): + kwargs["progress_callback"](1, 4, "Decoding audio chunks...") + kwargs["progress_callback"](4, 4, "Decoding audio chunks...") + return torch.ones(1, 2, 5) + + host._tiled_decode_gpu = _gpu_oom + host._tiled_decode_offload_cpu = _offload_ok + host._tiled_decode_inner( + torch.zeros(1, 4, 20), + chunk_size=8, + overlap=2, + offload_wav_to_cpu=False, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual([current for current, _, _ in updates], [3, 3, 4]) + if __name__ == "__main__": unittest.main() diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index 47d0eec32..fa1e2a3da 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -1434,6 +1434,7 @@ def _callback(current: int, total: int, desc: str) -> None: logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") else: logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") + progress(0.3, "LLM metadata generation complete") else: # Use user-provided metadata if is_batch: @@ -1600,6 +1601,7 @@ def _callback(current: int, total: int, desc: str) -> None: metadata_list.append(metadata.copy()) # Same metadata for all phase2_time = time.time() - phase2_start + progress(0.5, "LLM audio code generation complete") # Log results codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list] @@ -1665,6 +1667,7 @@ def _callback(current: int, total: int, desc: str) -> None: } phase2_time = time.time() - phase2_start + progress(0.5, "LLM audio code generation complete") # Parse audio codes from output (metadata should be same as Phase 1) _, audio_codes = self.parse_lm_output(codes_output_text) diff --git a/acestep/llm_inference_cfg_fixes_test.py b/acestep/llm_inference_cfg_fixes_test.py index 4eaa92fe0..598fcb4f7 100644 --- a/acestep/llm_inference_cfg_fixes_test.py +++ b/acestep/llm_inference_cfg_fixes_test.py @@ -533,6 +533,45 @@ def fake_generate_from_formatted_prompt(*args, **kwargs): self.assertTrue(llm_cfg_updates) self.assertGreaterEqual(min(llm_cfg_updates), 0.1) self.assertLessEqual(max(llm_cfg_updates), 0.3) + self.assertIn((0.3, "LLM metadata generation complete"), updates) + + def test_generate_with_stop_condition_closes_audio_code_phase_band(self): + """Phase 2 should emit an explicit terminal progress update on early success.""" + handler = LLMHandler() + handler.llm_initialized = True + handler.llm_backend = "pt" + + updates = [] + call_count = {"count": 0} + + def fake_generate_from_formatted_prompt(*args, **kwargs): + call_count["count"] += 1 + progress_callback = kwargs.get("progress_callback") + if progress_callback is not None: + progress_callback(5, 10, "LLM CFG Generation") + if call_count["count"] == 1: + return "bpm: 120\n", "ok" + return "<|audio_code_1|><|audio_code_2|>", "ok" + + with patch.object(handler, "generate_from_formatted_prompt", side_effect=fake_generate_from_formatted_prompt): + with patch.object(handler, "build_formatted_prompt", return_value="PROMPT"): + with patch.object(handler, "parse_lm_output", side_effect=[({"bpm": 120}, ""), ({}, "<|audio_code_1|><|audio_code_2|>")]): + with patch.object(handler, "_format_metadata_as_cot", return_value="cot"): + with patch.object(handler, "build_formatted_prompt_with_cot", return_value="PROMPT2"): + handler.generate_with_stop_condition( + caption="test caption", + lyrics="test lyrics", + cfg_scale=2.0, + temperature=0.6, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + infer_type="llm_dit", + progress=lambda value, desc=None: updates.append((value, desc)), + ) + + self.assertIn((0.5, "LLM audio code generation complete"), updates) if __name__ == "__main__": From 014e99eddd2eaa8761d666bfb29362180f02ef13 Mon Sep 17 00:00:00 2001 From: 1larity Date: Sat, 28 Mar 2026 11:41:44 +0000 Subject: [PATCH 6/6] fix(gradio): tighten generation progress handoff --- .../handler/generate_music_execute.py | 18 ++- .../generate_music_execute_thread_test.py | 92 +++++++++++++++ acestep/llm_inference.py | 27 ++++- acestep/llm_inference_batch_progress_test.py | 105 ++++++++++++++++++ 4 files changed, 235 insertions(+), 7 deletions(-) create mode 100644 acestep/core/generation/handler/generate_music_execute_thread_test.py create mode 100644 acestep/llm_inference_batch_progress_test.py diff --git a/acestep/core/generation/handler/generate_music_execute.py b/acestep/core/generation/handler/generate_music_execute.py index 88fdb9b88..2cfca4b06 100644 --- a/acestep/core/generation/handler/generate_music_execute.py +++ b/acestep/core/generation/handler/generate_music_execute.py @@ -129,6 +129,16 @@ def _service_target(): except Exception as exc: _error["exc"] = exc + def _stop_progress_estimator_if_finished() -> None: + """Stop the estimator thread but keep its handle until it actually exits.""" + nonlocal progress_thread + if progress_thread is None: + return + progress_thread.join(timeout=1.0) + if hasattr(progress_thread, "is_alive") and progress_thread.is_alive(): + return + progress_thread = None + try: stop_event, progress_thread = self._start_diffusion_progress_estimator( progress=_emit_progress, @@ -164,9 +174,7 @@ def _service_target(): saw_runtime_progress = True if stop_event is not None: stop_event.set() - if progress_thread is not None: - progress_thread.join(timeout=1.0) - progress_thread = None + _stop_progress_estimator_if_finished() if not gen_thread.is_alive(): break @@ -180,9 +188,7 @@ def _service_target(): saw_runtime_progress = True if stop_event is not None: stop_event.set() - if progress_thread is not None: - progress_thread.join(timeout=1.0) - progress_thread = None + _stop_progress_estimator_if_finished() if gen_thread.is_alive(): logger.error( diff --git a/acestep/core/generation/handler/generate_music_execute_thread_test.py b/acestep/core/generation/handler/generate_music_execute_thread_test.py new file mode 100644 index 000000000..056410ddb --- /dev/null +++ b/acestep/core/generation/handler/generate_music_execute_thread_test.py @@ -0,0 +1,92 @@ +"""Regression coverage for progress estimator thread lifecycle.""" + +import unittest + +from acestep.core.generation.handler.generate_music_execute import GenerateMusicExecuteMixin + + +class _Host(GenerateMusicExecuteMixin): + """Minimal host for exercising estimator shutdown behavior.""" + + def __init__(self): + self.stop_calls = 0 + self.thread = None + + def _start_diffusion_progress_estimator(self, **kwargs): + class _Stop: + def __init__(self, host): + self.host = host + + def set(self): + self.host.stop_calls += 1 + + class _Thread: + def __init__(self): + self.join_calls = 0 + self._alive = True + + def join(self, timeout=None): + _ = timeout + self.join_calls += 1 + if self.join_calls >= 2: + self._alive = False + + def is_alive(self): + return self._alive + + self.thread = _Thread() + return _Stop(self), self.thread + + def service_generate(self, **kwargs): + progress_callback = kwargs["progress_callback"] + progress_callback(1, 4, "DiT diffusion...") + return {"target_latents": "ok"} + + +class GenerateMusicExecuteThreadLifecycleTests(unittest.TestCase): + """Ensure estimator thread handles survive timed joins until actual shutdown.""" + + def test_runtime_progress_keeps_estimator_thread_handle_until_thread_exits(self): + host = _Host() + updates = [] + + out = host._run_generate_music_service_with_progress( + progress=lambda value, desc=None: updates.append((value, desc)), + actual_batch_size=1, + audio_duration=10.0, + inference_steps=8, + timesteps=None, + service_inputs={ + "captions_batch": ["c"], + "lyrics_batch": ["l"], + "metas_batch": ["m"], + "vocal_languages_batch": ["en"], + "target_wavs_tensor": None, + "repainting_start_batch": [0.0], + "repainting_end_batch": [1.0], + "instructions_batch": ["i"], + "audio_code_hints_batch": None, + "should_return_intermediate": True, + }, + refer_audios=None, + guidance_scale=7.0, + actual_seed_list=[1], + audio_cover_strength=1.0, + cover_noise_strength=0.0, + use_adg=False, + cfg_interval_start=0.0, + cfg_interval_end=1.0, + shift=1.0, + infer_method="ode", + ) + + self.assertEqual(out["outputs"]["target_latents"], "ok") + self.assertEqual(host.stop_calls, 2) + self.assertIsNotNone(host.thread) + self.assertEqual(host.thread.join_calls, 2) + self.assertFalse(host.thread.is_alive()) + self.assertTrue(any(desc == "DiT diffusion..." for _, desc in updates)) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/llm_inference.py b/acestep/llm_inference.py index fa1e2a3da..f92c5deda 100644 --- a/acestep/llm_inference.py +++ b/acestep/llm_inference.py @@ -1165,6 +1165,7 @@ def _run_pt( # loads to GPU once and offloads once, instead of per-item. if is_batch: output_texts = [] + batch_total = len(formatted_prompt_list) with self._load_model_context(): for i, formatted_prompt in enumerate(formatted_prompt_list): @@ -1177,8 +1178,10 @@ def _run_pt( torch.mps.manual_seed(seeds[i]) item_progress_callback = None + item_progress_state = {"current": 0, "total": 0, "desc": ""} if progress_callback is not None: - def _item_progress(current, total, desc, index=i, batch_total=len(formatted_prompt_list)): + def _item_progress(current, total, desc, index=i): + item_progress_state.update(current=current, total=total, desc=desc) progress_callback(index * total + current, batch_total * total, desc) item_progress_callback = _item_progress @@ -1208,6 +1211,16 @@ def _item_progress(current, total, desc, index=i, batch_total=len(formatted_prom ) output_texts.append(output_text) + if ( + progress_callback is not None + and item_progress_state["total"] > 0 + and item_progress_state["current"] < item_progress_state["total"] + ): + progress_callback( + (i + 1) * item_progress_state["total"], + batch_total * item_progress_state["total"], + item_progress_state["desc"], + ) return output_texts @@ -4068,8 +4081,10 @@ def _run_mlx( mx.random.seed(seeds[i]) item_progress_callback = None + item_progress_state = {"current": 0, "total": 0, "desc": ""} if progress_callback is not None: def _item_progress(current, total, desc, index=i, batch_total=batch_size): + item_progress_state.update(current=current, total=total, desc=desc) progress_callback(index * total + current, batch_total * total, desc) item_progress_callback = _item_progress @@ -4097,6 +4112,16 @@ def _item_progress(current, total, desc, index=i, batch_total=batch_size): progress_callback=item_progress_callback, ) output_texts.append(output_text) + if ( + progress_callback is not None + and item_progress_state["total"] > 0 + and item_progress_state["current"] < item_progress_state["total"] + ): + progress_callback( + (i + 1) * item_progress_state["total"], + batch_size * item_progress_state["total"], + item_progress_state["desc"], + ) return output_texts # Single mode diff --git a/acestep/llm_inference_batch_progress_test.py b/acestep/llm_inference_batch_progress_test.py new file mode 100644 index 000000000..680ab5bb7 --- /dev/null +++ b/acestep/llm_inference_batch_progress_test.py @@ -0,0 +1,105 @@ +"""Regression tests for sequential batch progress closure.""" + +import sys +import types +import unittest +from contextlib import nullcontext +from unittest.mock import patch + +try: + from acestep.llm_inference import LLMHandler + + _IMPORT_ERROR = None +except ImportError as exc: # pragma: no cover - dependency guard + LLMHandler = None + _IMPORT_ERROR = exc + + +def _make_handler() -> "LLMHandler": + """Return a minimal handler with model-loading context stubbed out.""" + handler = LLMHandler() + handler._load_model_context = lambda: nullcontext() + return handler + + +@unittest.skipIf(LLMHandler is None, f"llm_inference import unavailable: {_IMPORT_ERROR}") +class SequentialBatchProgressTests(unittest.TestCase): + """Sequential batch generation should close each item before the next begins.""" + + def test_run_pt_closes_early_item_progress_before_next_item(self): + handler = _make_handler() + updates = [] + + def fake_run_pt_single(*args, **kwargs): + progress_callback = kwargs["progress_callback"] + if len(updates) == 0: + progress_callback(2, 10, "LLM CFG Generation") + return "first" + progress_callback(1, 10, "LLM CFG Generation") + return "second" + + with patch.object(handler, "_run_pt_single", side_effect=fake_run_pt_single): + out = handler._run_pt( + formatted_prompts=["p1", "p2"], + temperature=0.6, + cfg_scale=1.0, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual(out, ["first", "second"]) + self.assertEqual( + updates, + [ + (2, 20, "LLM CFG Generation"), + (10, 20, "LLM CFG Generation"), + (11, 20, "LLM CFG Generation"), + (20, 20, "LLM CFG Generation"), + ], + ) + + def test_run_mlx_closes_early_item_progress_before_next_item(self): + handler = _make_handler() + updates = [] + fake_mlx = types.ModuleType("mlx") + fake_mx_core = types.ModuleType("mlx.core") + fake_mx_core.random = types.SimpleNamespace(seed=lambda *_: None) + + def fake_run_mlx_single(*args, **kwargs): + progress_callback = kwargs["progress_callback"] + if len(updates) == 0: + progress_callback(3, 12, "LLM CFG Generation") + return "first" + progress_callback(2, 12, "LLM CFG Generation") + return "second" + + with patch.dict(sys.modules, {"mlx": fake_mlx, "mlx.core": fake_mx_core}): + with patch.object(handler, "_run_mlx_single", side_effect=fake_run_mlx_single): + out = handler._run_mlx( + formatted_prompts=["p1", "p2"], + temperature=0.6, + cfg_scale=1.0, + negative_prompt="", + top_k=None, + top_p=None, + repetition_penalty=1.0, + progress_callback=lambda current, total, desc: updates.append((current, total, desc)), + ) + + self.assertEqual(out, ["first", "second"]) + self.assertEqual( + updates, + [ + (3, 24, "LLM CFG Generation"), + (12, 24, "LLM CFG Generation"), + (14, 24, "LLM CFG Generation"), + (24, 24, "LLM CFG Generation"), + ], + ) + + +if __name__ == "__main__": + unittest.main()