diff --git a/acestep/core/generation/handler/generate_music.py b/acestep/core/generation/handler/generate_music.py index 502f2a5ec..a6e89afde 100644 --- a/acestep/core/generation/handler/generate_music.py +++ b/acestep/core/generation/handler/generate_music.py @@ -5,6 +5,7 @@ """ import gc +import os import traceback from typing import Any, Dict, List, Optional, Union @@ -328,13 +329,26 @@ def generate_music( repainting_end=repainting_end, chunk_mask_mode=chunk_mask_mode, ) - vram_error = self._vram_preflight_check( - actual_batch_size=actual_batch_size, - audio_duration=audio_duration, - guidance_scale=guidance_scale, - ) - if vram_error is not None: - return vram_error + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + skip_preflight = os.environ.get( + "ACESTEP_SKIP_VRAM_PREFLIGHT", "", + ).lower() in ("1", "true", "yes") + if skip_preflight: + logger.warning( + "[generate_music] VRAM pre-flight check skipped via " + "ACESTEP_SKIP_VRAM_PREFLIGHT=1. If generation OOMs, " + "unset this variable to re-enable the safety check." + ) + else: + vram_error = self._vram_preflight_check( + actual_batch_size=actual_batch_size, + audio_duration=audio_duration, + guidance_scale=guidance_scale, + ) + if vram_error is not None: + return vram_error injection_ratio, resolved_cf_frames, resolved_wav_cf = ( _resolve_repaint_config(repaint_mode, repaint_strength)