diff --git a/acestep/core/scoring/lm_score.py b/acestep/core/scoring/lm_score.py
index d67e43872..eca85e0de 100644
--- a/acestep/core/scoring/lm_score.py
+++ b/acestep/core/scoring/lm_score.py
@@ -6,8 +6,10 @@
composite reward score.
"""
import contextlib
+import gc
import math
import re
+import threading
import torch
import torch.nn.functional as F
@@ -67,6 +69,66 @@ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
return 1.0 / (1.0 + math.exp(-pmi / scale))
+def _empty_accelerator_cache(backend: str) -> None:
+ """Release cached accelerator memory scoped to the scoring backend.
+
+ Keeps cache eviction confined to the runtime actually used by the
+ scoring pass so non-target runtimes (e.g. CUDA when running MLX) are
+ not forced through an allocator reset. ``mlx`` -> MPS, ``vllm`` ->
+ CUDA / XPU, ``pt`` -> delegated to ``_load_model_context`` and never
+ reaches this helper.
+ """
+ if backend == "mlx":
+ if (
+ hasattr(torch.backends, "mps")
+ and torch.backends.mps.is_available()
+ and hasattr(torch, "mps")
+ and hasattr(torch.mps, "empty_cache")
+ ):
+ torch.mps.empty_cache()
+ return
+
+ if backend == "vllm":
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
+ torch.xpu.empty_cache()
+
+
+# Per-thread reentrancy depth state for _load_scoring_model_context. Keeping
+# this on ``threading.local`` (rather than on the handler) prevents concurrent
+# Autoscore calls on the same handler from seeing each other's depth counter
+# and clearing the cached scoring model out from under a peer thread. The
+# scoring forward path is strictly synchronous (``torch.no_grad`` + a plain
+# ``model(...)`` call) so thread-local storage is sufficient; there are no
+# async suspension points where a contextvar would be strictly required.
+_scoring_tls = threading.local()
+
+# Guards lazy creation of the per-handler RLock so two threads racing on the
+# first scoring call don't each install a different lock object.
+_scoring_lock_creation_lock = threading.Lock()
+
+
+def _get_scoring_handler_lock(llm_handler) -> threading.RLock:
+ """Return (and lazily create) a per-handler reentrant lock."""
+ lock = getattr(llm_handler, "_scoring_lock", None)
+ if lock is None:
+ with _scoring_lock_creation_lock:
+ lock = getattr(llm_handler, "_scoring_lock", None)
+ if lock is None:
+ lock = threading.RLock()
+ llm_handler._scoring_lock = lock
+ return lock
+
+
+def _scoring_depth_map() -> Dict[int, int]:
+ depths = getattr(_scoring_tls, "depths", None)
+ if depths is None:
+ depths = {}
+ _scoring_tls.depths = depths
+ return depths
+
+
@contextlib.contextmanager
def _load_scoring_model_context(llm_handler):
"""
@@ -79,38 +141,111 @@ def _load_scoring_model_context(llm_handler):
that would otherwise stay on GPU permanently -- here we move it to GPU
only for the duration of the scoring forward pass and move it back to
CPU when done, freeing VRAM for DiT / VAE.
+
+ The context is **reentrant per thread and serialised across threads**.
+ A single Autoscore pass performs many forward passes per sample
+ (per-metadata recall, caption PMI, lyrics PMI -- each with a
+ conditional and an unconditional prompt). Without reentrancy every
+ one of those forward passes would migrate a multi-GB model
+ CPU↔accelerator, which on Apple Silicon unified memory (MLX)
+ accumulates fragmented MPS allocations until the system runs out of
+ memory (issue #1081). By hoisting a single outer context around the
+ whole scoring pass we reduce that to exactly one migration per sample.
+
+ Reentrancy state is kept in thread-local storage keyed by handler id so
+ two concurrent Autoscore calls on the same handler do not see each
+ other's depth counter. The outermost entry from each thread acquires
+ a handler-level ``RLock`` so concurrent calls serialise the shared
+ load/offload transition (and, on MLX, the cached-model drop) instead
+ of racing on ``_hf_model_for_scoring``.
+
+ On MLX with ``offload_to_cpu`` enabled we additionally drop the cached
+ HF scoring model on outermost exit so the ~8 GB duplicate PyTorch copy
+ of the LM does not remain resident between generations. The next
+ scoring call re-materialises it via ``AutoModelForCausalLM.from_pretrained``
+ -- weights are cached by the HF hub so this is a fast rematerialisation,
+ not a redownload.
"""
backend = getattr(llm_handler, "llm_backend", "pt")
if backend == "pt":
- # pt backend: _load_model_context already handles GPU <-> CPU
+ # pt backend: _load_model_context already handles GPU <-> CPU and
+ # has its own device-based reentrancy guard (llm_inference.py).
with llm_handler._load_model_context():
yield
return
- # vllm / mlx: manage the cached HF model ourselves
- model = llm_handler.get_hf_model_for_scoring()
- if model is None:
- yield
+ # vllm / mlx: manage the cached HF model ourselves. Use a per-thread
+ # reentrancy depth counter so nested entries from the same thread
+ # become no-ops. Only the outermost entry from each thread moves the
+ # model to the accelerator and only the outermost exit moves it back
+ # (and, for MLX, drops it).
+ depths = _scoring_depth_map()
+ key = id(llm_handler)
+ depth = depths.get(key, 0)
+
+ if depth > 0:
+ depths[key] = depth + 1
+ try:
+ yield
+ finally:
+ depths[key] -= 1
+ if depths[key] <= 0:
+ depths.pop(key, None)
return
- offload = getattr(llm_handler, "offload_to_cpu", False)
- device = llm_handler.device if hasattr(llm_handler, "device") else "cpu"
+ # Outermost entry for this thread: serialise with other threads on the
+ # same handler so the shared ``_hf_model_for_scoring`` load/offload
+ # transition is atomic. The RLock is safe to re-acquire from the same
+ # thread on the (unlikely) event that a nested path ends up here.
+ lock = _get_scoring_handler_lock(llm_handler)
+ with lock:
+ model = llm_handler.get_hf_model_for_scoring()
+ if model is None:
+ yield
+ return
- if offload and hasattr(model, "to"):
- logger.info(f"[scoring] Loading HF scoring model to {device}")
- model.to(device)
+ offload = getattr(llm_handler, "offload_to_cpu", False)
+ device = llm_handler.device if hasattr(llm_handler, "device") else "cpu"
- try:
- yield
- finally:
if offload and hasattr(model, "to"):
- logger.info("[scoring] Offloading HF scoring model to CPU")
- model.to("cpu")
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
- torch.mps.empty_cache()
+ logger.info(f"[scoring] Loading HF scoring model to {device}")
+ model.to(device)
+
+ depths[key] = 1
+ try:
+ yield
+ finally:
+ depths.pop(key, None)
+ # Nested try/finally so that a failing ``model.to("cpu")`` does
+ # not skip the MLX cached-model release below. If the offload
+ # raises and we leave ``_hf_model_for_scoring`` cached, the
+ # unified-memory leak this PR is meant to fix simply moves from
+ # the happy path to the failure path.
+ try:
+ if offload and hasattr(model, "to"):
+ logger.info("[scoring] Offloading HF scoring model to CPU")
+ model.to("cpu")
+ _empty_accelerator_cache(backend)
+ finally:
+ # On MLX with offload enabled, the HF scoring model is a
+ # *separate* ~8 GB PyTorch copy of the LM (the MLX model itself
+ # cannot be used for torch teacher-forcing scoring). Keeping
+ # it resident on CPU between scoring passes doubles the LM
+ # footprint in unified memory and -- combined with repeated
+ # .to("mps") / .to("cpu") migrations -- pushes 32 GB Macs past
+ # their limit (issue #1081). Drop the cached copy so unified
+ # memory is returned to the OS; it will be re-loaded from the
+ # HF cache on the next scoring call.
+ if backend == "mlx" and offload:
+ llm_handler._hf_model_for_scoring = None
+ del model
+ gc.collect()
+ _empty_accelerator_cache("mlx")
+ logger.info(
+ "[scoring] Released cached HF scoring model on MLX "
+ "backend (will be reloaded on next Autoscore call)"
+ )
def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
@@ -412,40 +547,47 @@ def calculate_pmi_score_per_condition(
formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
try:
- # 1. Calculate Recall for Metadata Fields
- if metadata and isinstance(metadata, dict):
- scores = {}
- # Define which fields use which metric
- metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
- metadata_pmi_keys = ['caption']
- for key in metadata_recall_keys:
- if key in metadata and metadata[key] is not None:
- recall_metadata = {key: metadata[key]}
- field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
- scores.update(field_scores)
-
- # 2. Calculate PMI for Caption
- for key in metadata_pmi_keys:
- if key in metadata and metadata[key] is not None:
- cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
- target_text = f"\n{cot_yaml}\n\n"
-
- log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
- log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
-
- pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
- scores[key] = pmi_normalized
-
- # 3. Calculate PMI for Lyrics
- if lyrics:
- target_text = f"\n\n# Lyric\n{lyrics}\n"
-
- log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
-
- prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
- log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
-
- scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
+ # Batch *all* scoring forward passes under a single load/offload
+ # cycle. The inner ``_get_logits_and_target_for_scoring`` calls
+ # still enter ``_load_scoring_model_context``, but the reentrancy
+ # guard turns those nested entries into no-ops -- so the cached HF
+ # scoring model is moved CPU↔accelerator exactly once per Autoscore
+ # pass instead of once per condition (issue #1081).
+ with _load_scoring_model_context(llm_handler):
+ # 1. Calculate Recall for Metadata Fields
+ if metadata and isinstance(metadata, dict):
+ scores = {}
+ # Define which fields use which metric
+ metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
+ metadata_pmi_keys = ['caption']
+ for key in metadata_recall_keys:
+ if key in metadata and metadata[key] is not None:
+ recall_metadata = {key: metadata[key]}
+ field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
+ scores.update(field_scores)
+
+ # 2. Calculate PMI for Caption
+ for key in metadata_pmi_keys:
+ if key in metadata and metadata[key] is not None:
+ cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
+ target_text = f"\n{cot_yaml}\n\n"
+
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
+
+ pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
+ scores[key] = pmi_normalized
+
+ # 3. Calculate PMI for Lyrics
+ if lyrics:
+ target_text = f"\n\n# Lyric\n{lyrics}\n"
+
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
+
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
+
+ scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
if not scores:
return {}, 0.0, "❌ No conditions to evaluate"
diff --git a/acestep/core/scoring/lm_score_context_test.py b/acestep/core/scoring/lm_score_context_test.py
new file mode 100644
index 000000000..91dc4e20f
--- /dev/null
+++ b/acestep/core/scoring/lm_score_context_test.py
@@ -0,0 +1,218 @@
+"""Regression tests for the ``_load_scoring_model_context`` lifecycle.
+
+These protect the Autoscore unified-memory fix from issue #1081 by
+asserting three invariants that are easy to regress silently:
+
+1. ``_load_scoring_model_context`` performs exactly one load/offload
+ cycle per outermost entry, regardless of how deeply nested the
+ inner re-entries are.
+2. Nested re-entries from the same thread are no-ops and do not
+ trigger additional CPU↔accelerator migrations.
+3. On the MLX backend with ``offload_to_cpu=True`` the cached HF
+ scoring model is released on outermost exit so the ~8 GB duplicate
+ PyTorch copy does not remain resident between generations.
+"""
+
+import threading
+import unittest
+
+from acestep.core.scoring.lm_score import _load_scoring_model_context
+
+
+class _FakeModel:
+ """Recorder stand-in for an ``AutoModelForCausalLM``.
+
+ Captures every ``.to(device)`` invocation so tests can assert the
+ exact migration sequence produced by a scoring context.
+ """
+
+ def __init__(self):
+ """Initialise with an empty migration history."""
+ self.device_calls = []
+
+ def to(self, device):
+ """Record the target device and return ``self`` (chain-friendly)."""
+ self.device_calls.append(str(device))
+ return self
+
+
+class _FakeHandler:
+ """Handler stub exposing just enough surface for the scoring context.
+
+ Stands in for ``LLMInferenceHandler`` so the context-manager tests
+ avoid loading any real model weights. The handler tracks how many
+ times ``get_hf_model_for_scoring`` is queried, which lets nested-
+ entry tests prove that inner re-entries do not re-fetch the model.
+ """
+
+ def __init__(self, backend: str, offload: bool = True):
+ """Configure backend / offload flags and install a fake model.
+
+ Args:
+ backend: One of ``"mlx"``, ``"vllm"``, ``"pt"``. Selects the
+ branch of ``_load_scoring_model_context`` under test.
+ offload: Whether ``offload_to_cpu`` should be reported as
+ enabled. Controls whether the context manager performs
+ any ``.to()`` migrations and whether the MLX release
+ path fires on outermost exit.
+ """
+ self.llm_backend = backend
+ self.llm_initialized = True
+ self.offload_to_cpu = offload
+ # Synthetic target device -- ``_FakeModel.to`` is a recorder so
+ # the string never has to correspond to a real torch device.
+ self.device = "cuda"
+ self._hf_model_for_scoring = _FakeModel()
+ self.get_calls = 0
+
+ def get_hf_model_for_scoring(self):
+ """Return the cached fake model and bump the query counter."""
+ self.get_calls += 1
+ return self._hf_model_for_scoring
+
+
+class LoadScoringModelContextTests(unittest.TestCase):
+ """Lifecycle-contract tests for ``_load_scoring_model_context``."""
+
+ def test_single_load_offload_per_outer_context_mlx(self):
+ """One outer entry should trigger exactly one load and one offload."""
+ handler = _FakeHandler("mlx")
+ model = handler._hf_model_for_scoring
+ with _load_scoring_model_context(handler):
+ pass
+ # Exactly one load (to accelerator) and one offload (to cpu).
+ self.assertEqual(
+ model.device_calls, ["cuda", "cpu"],
+ "expected exactly one load+offload cycle",
+ )
+
+ def test_nested_entries_are_noops_mlx(self):
+ """Nested re-entries must not move the model again."""
+ handler = _FakeHandler("mlx")
+ model = handler._hf_model_for_scoring
+ with _load_scoring_model_context(handler):
+ calls_after_outer_load = list(model.device_calls)
+ # Deep nesting (simulating _get_logits_and_target_for_scoring
+ # called many times per Autoscore pass).
+ with _load_scoring_model_context(handler):
+ with _load_scoring_model_context(handler):
+ # Inner contexts must not migrate the model again.
+ self.assertEqual(model.device_calls, calls_after_outer_load)
+ # After the outermost exit the handler has dropped the cached
+ # model, so use ``model`` (the captured reference) to verify the
+ # offload.
+ self.assertEqual(
+ model.device_calls, ["cuda", "cpu"],
+ "nested entries should not add extra migrations",
+ )
+
+ def test_mlx_outermost_exit_drops_cached_model(self):
+ """On MLX+offload, outer exit must clear ``_hf_model_for_scoring``."""
+ handler = _FakeHandler("mlx", offload=True)
+ # Use ``getattr(..., None)`` so the test is resilient to either
+ # cleanup style: setting the attribute to ``None`` or deleting it.
+ self.assertIsNotNone(getattr(handler, "_hf_model_for_scoring", None))
+ with _load_scoring_model_context(handler):
+ # Still cached while we're inside the context.
+ self.assertIsNotNone(getattr(handler, "_hf_model_for_scoring", None))
+ # Released after outermost exit so unified memory is returned to OS.
+ self.assertIsNone(getattr(handler, "_hf_model_for_scoring", None))
+
+ def test_vllm_outermost_exit_keeps_cached_model(self):
+ """vllm backend must NOT drop the cached HF model (CUDA is fine)."""
+ handler = _FakeHandler("vllm", offload=True)
+ cached = handler._hf_model_for_scoring
+ with _load_scoring_model_context(handler):
+ pass
+ # vllm keeps the cached HF scoring model between passes; the
+ # MLX-specific release path must not fire here.
+ self.assertIs(handler._hf_model_for_scoring, cached)
+
+ def test_mlx_no_offload_keeps_cached_model(self):
+ """Without offload_to_cpu the MLX release path must not fire."""
+ handler = _FakeHandler("mlx", offload=False)
+ cached = handler._hf_model_for_scoring
+ with _load_scoring_model_context(handler):
+ pass
+ # Not offloading means no load/offload transitions and no drop.
+ self.assertIs(handler._hf_model_for_scoring, cached)
+ self.assertEqual(cached.device_calls, [])
+
+ def test_get_hf_called_only_by_outermost_entry(self):
+ """Nested entries must not re-query ``get_hf_model_for_scoring``."""
+ handler = _FakeHandler("mlx")
+ with _load_scoring_model_context(handler):
+ with _load_scoring_model_context(handler):
+ with _load_scoring_model_context(handler):
+ pass
+ # Outer entry queries once; nested entries are pure no-ops.
+ self.assertEqual(handler.get_calls, 1)
+
+ def test_outermost_detection_is_thread_local(self):
+ """Each thread must track its own outermost-entry depth.
+
+ The reentrancy depth is kept in ``threading.local`` keyed by
+ handler id so that two Autoscore calls on the same handler do
+ not see each other's counter. This test runs two worker
+ threads that each enter nested contexts on a shared handler; if
+ depth tracking regressed to a process-global counter one thread
+ could mis-identify itself as "nested" and skip the outer model
+ fetch. Uses ``offload=False`` so the MLX release path does not
+ drop the cached model between threads and confuse the count.
+ """
+ handler = _FakeHandler("mlx", offload=False)
+ start = threading.Barrier(2)
+
+ def worker():
+ # Synchronise thread start so both worker threads race into
+ # the handler-level lock simultaneously.
+ start.wait()
+ with _load_scoring_model_context(handler):
+ with _load_scoring_model_context(handler):
+ pass
+
+ t1 = threading.Thread(target=worker)
+ t2 = threading.Thread(target=worker)
+ t1.start()
+ t2.start()
+ t1.join()
+ t2.join()
+
+ # Both threads must have counted as outermost, so
+ # ``get_hf_model_for_scoring`` must have fired twice. If depth
+ # tracking leaked across threads, the second thread would see
+ # ``depth > 0`` at outer entry, fall into the nested no-op
+ # branch, and never call ``get_hf_model_for_scoring``.
+ self.assertEqual(handler.get_calls, 2)
+
+ def test_mlx_cleanup_runs_when_offload_raises(self):
+ """MLX cached-model release must run even if model.to("cpu") fails.
+
+ If ``model.to("cpu")`` raises inside the outermost exit, the
+ ``_hf_model_for_scoring = None`` cleanup must still execute so that
+ the ~8 GB HF scoring model is not leaked on the exception path.
+ """
+
+ class _FailingOffloadModel(_FakeModel):
+ def to(self, device):
+ super().to(device)
+ if device == "cpu":
+ raise RuntimeError("simulated offload failure")
+ return self
+
+ handler = _FakeHandler("mlx", offload=True)
+ handler._hf_model_for_scoring = _FailingOffloadModel()
+
+ with self.assertRaises(RuntimeError):
+ with _load_scoring_model_context(handler):
+ pass
+
+ self.assertIsNone(
+ getattr(handler, "_hf_model_for_scoring", None),
+ "MLX cleanup must still release the cached model after a "
+ "failed .to('cpu') call so unified memory is returned to the OS",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/acestep/core/scoring/scoring_test.py b/acestep/core/scoring/scoring_test.py
index 6abab9b5f..6345f5a94 100644
--- a/acestep/core/scoring/scoring_test.py
+++ b/acestep/core/scoring/scoring_test.py
@@ -138,5 +138,10 @@ def test_metadata_aggregation(self):
self.assertAlmostEqual(total, 0.75, places=2)
+# Regression tests for ``_load_scoring_model_context`` (issue #1081) live in
+# ``acestep/core/scoring/lm_score_context_test.py`` to keep this file under
+# the project's 200 LOC hard cap.
+
+
if __name__ == "__main__":
unittest.main()