diff --git a/acestep/core/scoring/lm_score.py b/acestep/core/scoring/lm_score.py index d67e43872..eca85e0de 100644 --- a/acestep/core/scoring/lm_score.py +++ b/acestep/core/scoring/lm_score.py @@ -6,8 +6,10 @@ composite reward score. """ import contextlib +import gc import math import re +import threading import torch import torch.nn.functional as F @@ -67,6 +69,66 @@ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float: return 1.0 / (1.0 + math.exp(-pmi / scale)) +def _empty_accelerator_cache(backend: str) -> None: + """Release cached accelerator memory scoped to the scoring backend. + + Keeps cache eviction confined to the runtime actually used by the + scoring pass so non-target runtimes (e.g. CUDA when running MLX) are + not forced through an allocator reset. ``mlx`` -> MPS, ``vllm`` -> + CUDA / XPU, ``pt`` -> delegated to ``_load_model_context`` and never + reaches this helper. + """ + if backend == "mlx": + if ( + hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + and hasattr(torch, "mps") + and hasattr(torch.mps, "empty_cache") + ): + torch.mps.empty_cache() + return + + if backend == "vllm": + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() + + +# Per-thread reentrancy depth state for _load_scoring_model_context. Keeping +# this on ``threading.local`` (rather than on the handler) prevents concurrent +# Autoscore calls on the same handler from seeing each other's depth counter +# and clearing the cached scoring model out from under a peer thread. The +# scoring forward path is strictly synchronous (``torch.no_grad`` + a plain +# ``model(...)`` call) so thread-local storage is sufficient; there are no +# async suspension points where a contextvar would be strictly required. +_scoring_tls = threading.local() + +# Guards lazy creation of the per-handler RLock so two threads racing on the +# first scoring call don't each install a different lock object. +_scoring_lock_creation_lock = threading.Lock() + + +def _get_scoring_handler_lock(llm_handler) -> threading.RLock: + """Return (and lazily create) a per-handler reentrant lock.""" + lock = getattr(llm_handler, "_scoring_lock", None) + if lock is None: + with _scoring_lock_creation_lock: + lock = getattr(llm_handler, "_scoring_lock", None) + if lock is None: + lock = threading.RLock() + llm_handler._scoring_lock = lock + return lock + + +def _scoring_depth_map() -> Dict[int, int]: + depths = getattr(_scoring_tls, "depths", None) + if depths is None: + depths = {} + _scoring_tls.depths = depths + return depths + + @contextlib.contextmanager def _load_scoring_model_context(llm_handler): """ @@ -79,38 +141,111 @@ def _load_scoring_model_context(llm_handler): that would otherwise stay on GPU permanently -- here we move it to GPU only for the duration of the scoring forward pass and move it back to CPU when done, freeing VRAM for DiT / VAE. + + The context is **reentrant per thread and serialised across threads**. + A single Autoscore pass performs many forward passes per sample + (per-metadata recall, caption PMI, lyrics PMI -- each with a + conditional and an unconditional prompt). Without reentrancy every + one of those forward passes would migrate a multi-GB model + CPU↔accelerator, which on Apple Silicon unified memory (MLX) + accumulates fragmented MPS allocations until the system runs out of + memory (issue #1081). By hoisting a single outer context around the + whole scoring pass we reduce that to exactly one migration per sample. + + Reentrancy state is kept in thread-local storage keyed by handler id so + two concurrent Autoscore calls on the same handler do not see each + other's depth counter. The outermost entry from each thread acquires + a handler-level ``RLock`` so concurrent calls serialise the shared + load/offload transition (and, on MLX, the cached-model drop) instead + of racing on ``_hf_model_for_scoring``. + + On MLX with ``offload_to_cpu`` enabled we additionally drop the cached + HF scoring model on outermost exit so the ~8 GB duplicate PyTorch copy + of the LM does not remain resident between generations. The next + scoring call re-materialises it via ``AutoModelForCausalLM.from_pretrained`` + -- weights are cached by the HF hub so this is a fast rematerialisation, + not a redownload. """ backend = getattr(llm_handler, "llm_backend", "pt") if backend == "pt": - # pt backend: _load_model_context already handles GPU <-> CPU + # pt backend: _load_model_context already handles GPU <-> CPU and + # has its own device-based reentrancy guard (llm_inference.py). with llm_handler._load_model_context(): yield return - # vllm / mlx: manage the cached HF model ourselves - model = llm_handler.get_hf_model_for_scoring() - if model is None: - yield + # vllm / mlx: manage the cached HF model ourselves. Use a per-thread + # reentrancy depth counter so nested entries from the same thread + # become no-ops. Only the outermost entry from each thread moves the + # model to the accelerator and only the outermost exit moves it back + # (and, for MLX, drops it). + depths = _scoring_depth_map() + key = id(llm_handler) + depth = depths.get(key, 0) + + if depth > 0: + depths[key] = depth + 1 + try: + yield + finally: + depths[key] -= 1 + if depths[key] <= 0: + depths.pop(key, None) return - offload = getattr(llm_handler, "offload_to_cpu", False) - device = llm_handler.device if hasattr(llm_handler, "device") else "cpu" + # Outermost entry for this thread: serialise with other threads on the + # same handler so the shared ``_hf_model_for_scoring`` load/offload + # transition is atomic. The RLock is safe to re-acquire from the same + # thread on the (unlikely) event that a nested path ends up here. + lock = _get_scoring_handler_lock(llm_handler) + with lock: + model = llm_handler.get_hf_model_for_scoring() + if model is None: + yield + return - if offload and hasattr(model, "to"): - logger.info(f"[scoring] Loading HF scoring model to {device}") - model.to(device) + offload = getattr(llm_handler, "offload_to_cpu", False) + device = llm_handler.device if hasattr(llm_handler, "device") else "cpu" - try: - yield - finally: if offload and hasattr(model, "to"): - logger.info("[scoring] Offloading HF scoring model to CPU") - model.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): - torch.mps.empty_cache() + logger.info(f"[scoring] Loading HF scoring model to {device}") + model.to(device) + + depths[key] = 1 + try: + yield + finally: + depths.pop(key, None) + # Nested try/finally so that a failing ``model.to("cpu")`` does + # not skip the MLX cached-model release below. If the offload + # raises and we leave ``_hf_model_for_scoring`` cached, the + # unified-memory leak this PR is meant to fix simply moves from + # the happy path to the failure path. + try: + if offload and hasattr(model, "to"): + logger.info("[scoring] Offloading HF scoring model to CPU") + model.to("cpu") + _empty_accelerator_cache(backend) + finally: + # On MLX with offload enabled, the HF scoring model is a + # *separate* ~8 GB PyTorch copy of the LM (the MLX model itself + # cannot be used for torch teacher-forcing scoring). Keeping + # it resident on CPU between scoring passes doubles the LM + # footprint in unified memory and -- combined with repeated + # .to("mps") / .to("cpu") migrations -- pushes 32 GB Macs past + # their limit (issue #1081). Drop the cached copy so unified + # memory is returned to the OS; it will be re-loaded from the + # HF cache on the next scoring call. + if backend == "mlx" and offload: + llm_handler._hf_model_for_scoring = None + del model + gc.collect() + _empty_accelerator_cache("mlx") + logger.info( + "[scoring] Released cached HF scoring model on MLX " + "backend (will be reloaded on next Autoscore call)" + ) def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str, @@ -412,40 +547,47 @@ def calculate_pmi_score_per_condition( formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False) prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) try: - # 1. Calculate Recall for Metadata Fields - if metadata and isinstance(metadata, dict): - scores = {} - # Define which fields use which metric - metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature'] - metadata_pmi_keys = ['caption'] - for key in metadata_recall_keys: - if key in metadata and metadata[key] is not None: - recall_metadata = {key: metadata[key]} - field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk) - scores.update(field_scores) - - # 2. Calculate PMI for Caption - for key in metadata_pmi_keys: - if key in metadata and metadata[key] is not None: - cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip() - target_text = f"\n{cot_yaml}\n\n" - - log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) - log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) - - pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) - scores[key] = pmi_normalized - - # 3. Calculate PMI for Lyrics - if lyrics: - target_text = f"\n\n# Lyric\n{lyrics}\n" - - log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) - - prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) - log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) - - scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) + # Batch *all* scoring forward passes under a single load/offload + # cycle. The inner ``_get_logits_and_target_for_scoring`` calls + # still enter ``_load_scoring_model_context``, but the reentrancy + # guard turns those nested entries into no-ops -- so the cached HF + # scoring model is moved CPU↔accelerator exactly once per Autoscore + # pass instead of once per condition (issue #1081). + with _load_scoring_model_context(llm_handler): + # 1. Calculate Recall for Metadata Fields + if metadata and isinstance(metadata, dict): + scores = {} + # Define which fields use which metric + metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature'] + metadata_pmi_keys = ['caption'] + for key in metadata_recall_keys: + if key in metadata and metadata[key] is not None: + recall_metadata = {key: metadata[key]} + field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk) + scores.update(field_scores) + + # 2. Calculate PMI for Caption + for key in metadata_pmi_keys: + if key in metadata and metadata[key] is not None: + cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip() + target_text = f"\n{cot_yaml}\n\n" + + log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) + log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) + + pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) + scores[key] = pmi_normalized + + # 3. Calculate PMI for Lyrics + if lyrics: + target_text = f"\n\n# Lyric\n{lyrics}\n" + + log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text) + + prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False) + log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text) + + scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale) if not scores: return {}, 0.0, "❌ No conditions to evaluate" diff --git a/acestep/core/scoring/lm_score_context_test.py b/acestep/core/scoring/lm_score_context_test.py new file mode 100644 index 000000000..91dc4e20f --- /dev/null +++ b/acestep/core/scoring/lm_score_context_test.py @@ -0,0 +1,218 @@ +"""Regression tests for the ``_load_scoring_model_context`` lifecycle. + +These protect the Autoscore unified-memory fix from issue #1081 by +asserting three invariants that are easy to regress silently: + +1. ``_load_scoring_model_context`` performs exactly one load/offload + cycle per outermost entry, regardless of how deeply nested the + inner re-entries are. +2. Nested re-entries from the same thread are no-ops and do not + trigger additional CPU↔accelerator migrations. +3. On the MLX backend with ``offload_to_cpu=True`` the cached HF + scoring model is released on outermost exit so the ~8 GB duplicate + PyTorch copy does not remain resident between generations. +""" + +import threading +import unittest + +from acestep.core.scoring.lm_score import _load_scoring_model_context + + +class _FakeModel: + """Recorder stand-in for an ``AutoModelForCausalLM``. + + Captures every ``.to(device)`` invocation so tests can assert the + exact migration sequence produced by a scoring context. + """ + + def __init__(self): + """Initialise with an empty migration history.""" + self.device_calls = [] + + def to(self, device): + """Record the target device and return ``self`` (chain-friendly).""" + self.device_calls.append(str(device)) + return self + + +class _FakeHandler: + """Handler stub exposing just enough surface for the scoring context. + + Stands in for ``LLMInferenceHandler`` so the context-manager tests + avoid loading any real model weights. The handler tracks how many + times ``get_hf_model_for_scoring`` is queried, which lets nested- + entry tests prove that inner re-entries do not re-fetch the model. + """ + + def __init__(self, backend: str, offload: bool = True): + """Configure backend / offload flags and install a fake model. + + Args: + backend: One of ``"mlx"``, ``"vllm"``, ``"pt"``. Selects the + branch of ``_load_scoring_model_context`` under test. + offload: Whether ``offload_to_cpu`` should be reported as + enabled. Controls whether the context manager performs + any ``.to()`` migrations and whether the MLX release + path fires on outermost exit. + """ + self.llm_backend = backend + self.llm_initialized = True + self.offload_to_cpu = offload + # Synthetic target device -- ``_FakeModel.to`` is a recorder so + # the string never has to correspond to a real torch device. + self.device = "cuda" + self._hf_model_for_scoring = _FakeModel() + self.get_calls = 0 + + def get_hf_model_for_scoring(self): + """Return the cached fake model and bump the query counter.""" + self.get_calls += 1 + return self._hf_model_for_scoring + + +class LoadScoringModelContextTests(unittest.TestCase): + """Lifecycle-contract tests for ``_load_scoring_model_context``.""" + + def test_single_load_offload_per_outer_context_mlx(self): + """One outer entry should trigger exactly one load and one offload.""" + handler = _FakeHandler("mlx") + model = handler._hf_model_for_scoring + with _load_scoring_model_context(handler): + pass + # Exactly one load (to accelerator) and one offload (to cpu). + self.assertEqual( + model.device_calls, ["cuda", "cpu"], + "expected exactly one load+offload cycle", + ) + + def test_nested_entries_are_noops_mlx(self): + """Nested re-entries must not move the model again.""" + handler = _FakeHandler("mlx") + model = handler._hf_model_for_scoring + with _load_scoring_model_context(handler): + calls_after_outer_load = list(model.device_calls) + # Deep nesting (simulating _get_logits_and_target_for_scoring + # called many times per Autoscore pass). + with _load_scoring_model_context(handler): + with _load_scoring_model_context(handler): + # Inner contexts must not migrate the model again. + self.assertEqual(model.device_calls, calls_after_outer_load) + # After the outermost exit the handler has dropped the cached + # model, so use ``model`` (the captured reference) to verify the + # offload. + self.assertEqual( + model.device_calls, ["cuda", "cpu"], + "nested entries should not add extra migrations", + ) + + def test_mlx_outermost_exit_drops_cached_model(self): + """On MLX+offload, outer exit must clear ``_hf_model_for_scoring``.""" + handler = _FakeHandler("mlx", offload=True) + # Use ``getattr(..., None)`` so the test is resilient to either + # cleanup style: setting the attribute to ``None`` or deleting it. + self.assertIsNotNone(getattr(handler, "_hf_model_for_scoring", None)) + with _load_scoring_model_context(handler): + # Still cached while we're inside the context. + self.assertIsNotNone(getattr(handler, "_hf_model_for_scoring", None)) + # Released after outermost exit so unified memory is returned to OS. + self.assertIsNone(getattr(handler, "_hf_model_for_scoring", None)) + + def test_vllm_outermost_exit_keeps_cached_model(self): + """vllm backend must NOT drop the cached HF model (CUDA is fine).""" + handler = _FakeHandler("vllm", offload=True) + cached = handler._hf_model_for_scoring + with _load_scoring_model_context(handler): + pass + # vllm keeps the cached HF scoring model between passes; the + # MLX-specific release path must not fire here. + self.assertIs(handler._hf_model_for_scoring, cached) + + def test_mlx_no_offload_keeps_cached_model(self): + """Without offload_to_cpu the MLX release path must not fire.""" + handler = _FakeHandler("mlx", offload=False) + cached = handler._hf_model_for_scoring + with _load_scoring_model_context(handler): + pass + # Not offloading means no load/offload transitions and no drop. + self.assertIs(handler._hf_model_for_scoring, cached) + self.assertEqual(cached.device_calls, []) + + def test_get_hf_called_only_by_outermost_entry(self): + """Nested entries must not re-query ``get_hf_model_for_scoring``.""" + handler = _FakeHandler("mlx") + with _load_scoring_model_context(handler): + with _load_scoring_model_context(handler): + with _load_scoring_model_context(handler): + pass + # Outer entry queries once; nested entries are pure no-ops. + self.assertEqual(handler.get_calls, 1) + + def test_outermost_detection_is_thread_local(self): + """Each thread must track its own outermost-entry depth. + + The reentrancy depth is kept in ``threading.local`` keyed by + handler id so that two Autoscore calls on the same handler do + not see each other's counter. This test runs two worker + threads that each enter nested contexts on a shared handler; if + depth tracking regressed to a process-global counter one thread + could mis-identify itself as "nested" and skip the outer model + fetch. Uses ``offload=False`` so the MLX release path does not + drop the cached model between threads and confuse the count. + """ + handler = _FakeHandler("mlx", offload=False) + start = threading.Barrier(2) + + def worker(): + # Synchronise thread start so both worker threads race into + # the handler-level lock simultaneously. + start.wait() + with _load_scoring_model_context(handler): + with _load_scoring_model_context(handler): + pass + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t1.start() + t2.start() + t1.join() + t2.join() + + # Both threads must have counted as outermost, so + # ``get_hf_model_for_scoring`` must have fired twice. If depth + # tracking leaked across threads, the second thread would see + # ``depth > 0`` at outer entry, fall into the nested no-op + # branch, and never call ``get_hf_model_for_scoring``. + self.assertEqual(handler.get_calls, 2) + + def test_mlx_cleanup_runs_when_offload_raises(self): + """MLX cached-model release must run even if model.to("cpu") fails. + + If ``model.to("cpu")`` raises inside the outermost exit, the + ``_hf_model_for_scoring = None`` cleanup must still execute so that + the ~8 GB HF scoring model is not leaked on the exception path. + """ + + class _FailingOffloadModel(_FakeModel): + def to(self, device): + super().to(device) + if device == "cpu": + raise RuntimeError("simulated offload failure") + return self + + handler = _FakeHandler("mlx", offload=True) + handler._hf_model_for_scoring = _FailingOffloadModel() + + with self.assertRaises(RuntimeError): + with _load_scoring_model_context(handler): + pass + + self.assertIsNone( + getattr(handler, "_hf_model_for_scoring", None), + "MLX cleanup must still release the cached model after a " + "failed .to('cpu') call so unified memory is returned to the OS", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/core/scoring/scoring_test.py b/acestep/core/scoring/scoring_test.py index 6abab9b5f..6345f5a94 100644 --- a/acestep/core/scoring/scoring_test.py +++ b/acestep/core/scoring/scoring_test.py @@ -138,5 +138,10 @@ def test_metadata_aggregation(self): self.assertAlmostEqual(total, 0.75, places=2) +# Regression tests for ``_load_scoring_model_context`` (issue #1081) live in +# ``acestep/core/scoring/lm_score_context_test.py`` to keep this file under +# the project's 200 LOC hard cap. + + if __name__ == "__main__": unittest.main()