Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions acestep/core/generation/handler/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover=None,
encoder_attention_mask_non_cover=None,
context_latents_non_cover=None,
progress_callback=None,
disable_tqdm: bool = False,
) -> Dict[str, Any]:
"""Run the MLX diffusion loop and return generated latents.
Expand All @@ -56,6 +57,8 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover: Optional non-cover conditioning tensor.
encoder_attention_mask_non_cover: Unused; accepted for API compatibility.
context_latents_non_cover: Optional non-cover context latent tensor.
progress_callback: Optional diffusion-step callback receiving
``(current_step, total_steps, desc)``.
disable_tqdm: If True, suppress the diffusion progress bar.

Returns:
Expand Down Expand Up @@ -135,6 +138,7 @@ def _mlx_run_diffusion(
encoder_hidden_states_non_cover_np=enc_nc_np,
context_latents_non_cover_np=ctx_nc_np,
compile_model=getattr(self, "mlx_dit_compiled", False),
progress_callback=progress_callback,
disable_tqdm=disable_tqdm,
)

Expand Down
25 changes: 22 additions & 3 deletions acestep/core/generation/handler/generate_music_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,19 @@ def _decode_generate_music_pred_latents(
Returns:
Tuple of decoded waveforms, CPU latents, and updated time-cost payload.
"""
def _emit_decode_progress(value: float, desc: str) -> None:
if progress:
progress(value, desc=desc)

def _decode_chunk_progress(current: int, total: int, desc: str) -> None:
if total <= 0:
return
frac = min(1.0, max(0.0, current / total))
mapped = 0.82 + 0.16 * frac
_emit_decode_progress(mapped, desc)

if progress:
progress(0.8, desc="Decoding audio...")
_emit_decode_progress(0.8, "Preparing audio decode...")
logger.info("[generate_music] Decoding latents with VAE...")
start_time = time.time()
with torch.inference_mode():
Expand Down Expand Up @@ -160,12 +171,19 @@ def _decode_generate_music_pred_latents(
pred_latents_for_decode = pred_latents_for_decode.cpu()
self._empty_cache()
try:
_emit_decode_progress(0.82, "Decoding audio chunks...")
if use_tiled_decode:
logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
pred_wavs = self.tiled_decode(pred_latents_for_decode)
pred_wavs = self.tiled_decode(
pred_latents_for_decode,
progress_callback=_decode_chunk_progress,
)
elif using_mlx_vae:
try:
pred_wavs = self._mlx_vae_decode(pred_latents_for_decode)
pred_wavs = self._mlx_vae_decode(
pred_latents_for_decode,
progress_callback=_decode_chunk_progress,
)
except Exception as exc:
logger.warning(
f"[generate_music] MLX direct decode failed ({exc}), falling back to PyTorch"
Expand Down Expand Up @@ -194,6 +212,7 @@ def _decode_generate_music_pred_latents(
if torch.any(peak > 1.0):
pred_wavs = pred_wavs / peak.clamp(min=1.0)
self._empty_cache()
_emit_decode_progress(0.98, "Decoding audio chunks...")
gc.collect()
self._empty_cache()
end_time = time.time()
Expand Down
17 changes: 13 additions & 4 deletions acestep/core/generation/handler/generate_music_decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,20 @@ def _max_memory_allocated(self):
"""Return deterministic max-memory value for debug logging."""
return 0.0

def _mlx_vae_decode(self, latents):
def _mlx_vae_decode(self, latents, progress_callback=None):
"""Return deterministic decoded waveform for MLX decode branch."""
_ = latents
if progress_callback is not None:
progress_callback(1, 64, "Decoding audio chunks...")
progress_callback(64, 64, "Decoding audio chunks...")
return torch.ones(1, 2, 8)

def tiled_decode(self, latents):
def tiled_decode(self, latents, progress_callback=None):
"""Return deterministic decoded waveform for tiled decode branch."""
_ = latents
if progress_callback is not None:
progress_callback(1, 64, "Decoding audio chunks...")
progress_callback(64, 64, "Decoding audio chunks...")
return torch.ones(1, 2, 8)


Expand Down Expand Up @@ -189,7 +195,11 @@ def _progress(value, desc=None):
self.assertAlmostEqual(updated_costs["vae_decode_time_cost"], 1.5, places=6)
self.assertAlmostEqual(updated_costs["total_time_cost"], 2.5, places=6)
self.assertAlmostEqual(updated_costs["offload_time_cost"], 0.25, places=6)
self.assertEqual(host.progress_calls[0][0], 0.8)
self.assertEqual(host.progress_calls[0], (0.8, "Preparing audio decode..."))
self.assertIn((0.82, "Decoding audio chunks..."), host.progress_calls)
progress_values = [value for value, _ in host.progress_calls]
self.assertEqual(progress_values, sorted(progress_values))
self.assertAlmostEqual(host.progress_calls[-1][0], 0.98, places=6)

def test_decode_pred_latents_restores_vae_device_on_decode_error(self):
"""It restores VAE device in the CPU-offload path even when decode raises."""
Expand Down Expand Up @@ -313,4 +323,3 @@ def __init__(self):

if __name__ == "__main__":
unittest.main()

87 changes: 82 additions & 5 deletions acestep/core/generation/handler/generate_music_execute.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Execution helper for ``generate_music`` service invocation with progress tracking."""

import os
import queue
import threading
from typing import Any, Dict, List, Optional, Sequence
import time
from typing import Any, Callable, Dict, List, Optional, Sequence

from loguru import logger

Expand All @@ -15,6 +17,28 @@
class GenerateMusicExecuteMixin:
"""Run service generation under diffusion progress estimation lifecycle."""

@staticmethod
def _drain_runtime_progress_events(
progress_events: "queue.Queue[tuple[int, int, str]]",
emit_progress: Callable[[float, Optional[str]], float],
start: float,
end: float,
) -> bool:
"""Drain queued diffusion-step events and map them onto UI progress."""
drained = False
while True:
try:
current, total, desc = progress_events.get_nowait()
except queue.Empty:
return drained

if total <= 0:
continue
frac = min(1.0, max(0.0, current / total))
mapped = start + (end - start) * frac
emit_progress(mapped, desc)
drained = True

def _run_generate_music_service_with_progress(
self,
progress: Any,
Expand Down Expand Up @@ -44,9 +68,24 @@ def _run_generate_music_service_with_progress(
"""
infer_steps_for_progress = len(timesteps) if timesteps else inference_steps
progress_desc = f"Generating music (batch size: {actual_batch_size})..."
progress(0.52, desc=progress_desc)
stop_event = None
progress_thread = None
progress_events: "queue.Queue[tuple[int, int, str]]" = queue.Queue()
progress_state = {"value": 0.0}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
progress_lock = threading.Lock()

def _emit_progress(value: float, desc: Optional[str] = None) -> float:
with progress_lock:
clamped = max(progress_state["value"], value)
progress_state["value"] = clamped
if progress is not None:
try:
progress(clamped, desc=desc)
except Exception as exc:
logger.debug("[generate_music] Ignoring progress callback error: {}", exc)
return clamped

_emit_progress(0.52, progress_desc)

# --- Timeout-wrapped service_generate ---
# Run the actual CUDA work in a child thread so we can join() with a
Expand Down Expand Up @@ -85,13 +124,14 @@ def _service_target():
chunk_mask_modes=service_inputs.get("chunk_mask_modes_batch"),
repaint_crossfade_frames=repaint_crossfade_frames,
repaint_injection_ratio=repaint_injection_ratio,
progress_callback=lambda current, total, desc: progress_events.put((current, total, desc)),
)
except Exception as exc:
_error["exc"] = exc

try:
stop_event, progress_thread = self._start_diffusion_progress_estimator(
progress=progress,
progress=_emit_progress,
start=0.52,
end=0.79,
infer_steps=infer_steps_for_progress,
Expand All @@ -106,7 +146,43 @@ def _service_target():
daemon=True,
)
gen_thread.start()
gen_thread.join(timeout=_DEFAULT_GENERATION_TIMEOUT)
deadline = time.monotonic() + _DEFAULT_GENERATION_TIMEOUT
poll_interval = 0.1
saw_runtime_progress = False
while gen_thread.is_alive():
remaining = deadline - time.monotonic()
if remaining <= 0:
break
gen_thread.join(timeout=min(poll_interval, remaining))
drained_runtime_progress = self._drain_runtime_progress_events(
progress_events=progress_events,
emit_progress=_emit_progress,
start=0.52,
end=0.79,
)
if drained_runtime_progress and not saw_runtime_progress:
saw_runtime_progress = True
if stop_event is not None:
stop_event.set()
if progress_thread is not None:
progress_thread.join(timeout=1.0)
progress_thread = None
Comment thread
1larity marked this conversation as resolved.
Outdated
if not gen_thread.is_alive():
break

drained_runtime_progress = self._drain_runtime_progress_events(
progress_events=progress_events,
emit_progress=_emit_progress,
start=0.52,
end=0.79,
)
if drained_runtime_progress and not saw_runtime_progress:
saw_runtime_progress = True
if stop_event is not None:
stop_event.set()
if progress_thread is not None:
progress_thread.join(timeout=1.0)
progress_thread = None

if gen_thread.is_alive():
logger.error(
Expand All @@ -122,11 +198,12 @@ def _service_target():
)
if "exc" in _error:
raise _error["exc"]
_emit_progress(0.79, progress_desc)

finally:
if stop_event is not None:
stop_event.set()
if progress_thread is not None:
progress_thread.join(timeout=1.0)

return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
return {"outputs": _result["outputs"], "infer_steps_for_progress": infer_steps_for_progress}
Loading
Loading