Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 189 additions & 53 deletions acestep/core/scoring/lm_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
composite reward score.
"""
import contextlib
import gc
import math
import re
import threading

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -67,6 +69,66 @@ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
return 1.0 / (1.0 + math.exp(-pmi / scale))


def _empty_accelerator_cache(backend: str) -> None:
"""Release cached accelerator memory scoped to the scoring backend.

Keeps cache eviction confined to the runtime actually used by the
scoring pass so non-target runtimes (e.g. CUDA when running MLX) are
not forced through an allocator reset. ``mlx`` -> MPS, ``vllm`` ->
CUDA / XPU, ``pt`` -> delegated to ``_load_model_context`` and never
reaches this helper.
"""
if backend == "mlx":
if (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and hasattr(torch, "mps")
and hasattr(torch.mps, "empty_cache")
):
torch.mps.empty_cache()
return

if backend == "vllm":
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()


# Per-thread reentrancy depth state for _load_scoring_model_context. Keeping
# this on ``threading.local`` (rather than on the handler) prevents concurrent
# Autoscore calls on the same handler from seeing each other's depth counter
# and clearing the cached scoring model out from under a peer thread. The
# scoring forward path is strictly synchronous (``torch.no_grad`` + a plain
# ``model(...)`` call) so thread-local storage is sufficient; there are no
# async suspension points where a contextvar would be strictly required.
_scoring_tls = threading.local()

# Guards lazy creation of the per-handler RLock so two threads racing on the
# first scoring call don't each install a different lock object.
_scoring_lock_creation_lock = threading.Lock()


def _get_scoring_handler_lock(llm_handler) -> threading.RLock:
"""Return (and lazily create) a per-handler reentrant lock."""
lock = getattr(llm_handler, "_scoring_lock", None)
if lock is None:
with _scoring_lock_creation_lock:
lock = getattr(llm_handler, "_scoring_lock", None)
if lock is None:
lock = threading.RLock()
llm_handler._scoring_lock = lock
return lock


def _scoring_depth_map() -> Dict[int, int]:
depths = getattr(_scoring_tls, "depths", None)
if depths is None:
depths = {}
_scoring_tls.depths = depths
return depths


@contextlib.contextmanager
def _load_scoring_model_context(llm_handler):
"""
Expand All @@ -79,38 +141,105 @@ def _load_scoring_model_context(llm_handler):
that would otherwise stay on GPU permanently -- here we move it to GPU
only for the duration of the scoring forward pass and move it back to
CPU when done, freeing VRAM for DiT / VAE.

The context is **reentrant per thread and serialised across threads**.
A single Autoscore pass performs many forward passes per sample
(per-metadata recall, caption PMI, lyrics PMI -- each with a
conditional and an unconditional prompt). Without reentrancy every
one of those forward passes would migrate a multi-GB model
CPU↔accelerator, which on Apple Silicon unified memory (MLX)
accumulates fragmented MPS allocations until the system runs out of
memory (issue #1081). By hoisting a single outer context around the
whole scoring pass we reduce that to exactly one migration per sample.

Reentrancy state is kept in thread-local storage keyed by handler id so
two concurrent Autoscore calls on the same handler do not see each
other's depth counter. The outermost entry from each thread acquires
a handler-level ``RLock`` so concurrent calls serialise the shared
load/offload transition (and, on MLX, the cached-model drop) instead
of racing on ``_hf_model_for_scoring``.

On MLX with ``offload_to_cpu`` enabled we additionally drop the cached
HF scoring model on outermost exit so the ~8 GB duplicate PyTorch copy
of the LM does not remain resident between generations. The next
scoring call re-materialises it via ``AutoModelForCausalLM.from_pretrained``
-- weights are cached by the HF hub so this is a fast rematerialisation,
not a redownload.
"""
backend = getattr(llm_handler, "llm_backend", "pt")

if backend == "pt":
# pt backend: _load_model_context already handles GPU <-> CPU
# pt backend: _load_model_context already handles GPU <-> CPU and
# has its own device-based reentrancy guard (llm_inference.py).
with llm_handler._load_model_context():
yield
return

# vllm / mlx: manage the cached HF model ourselves
model = llm_handler.get_hf_model_for_scoring()
if model is None:
yield
# vllm / mlx: manage the cached HF model ourselves. Use a per-thread
# reentrancy depth counter so nested entries from the same thread
# become no-ops. Only the outermost entry from each thread moves the
# model to the accelerator and only the outermost exit moves it back
# (and, for MLX, drops it).
depths = _scoring_depth_map()
key = id(llm_handler)
depth = depths.get(key, 0)

if depth > 0:
depths[key] = depth + 1
try:
yield
finally:
depths[key] -= 1
if depths[key] <= 0:
depths.pop(key, None)
return

offload = getattr(llm_handler, "offload_to_cpu", False)
device = llm_handler.device if hasattr(llm_handler, "device") else "cpu"
# Outermost entry for this thread: serialise with other threads on the
# same handler so the shared ``_hf_model_for_scoring`` load/offload
# transition is atomic. The RLock is safe to re-acquire from the same
# thread on the (unlikely) event that a nested path ends up here.
lock = _get_scoring_handler_lock(llm_handler)
with lock:
model = llm_handler.get_hf_model_for_scoring()
if model is None:
yield
return

if offload and hasattr(model, "to"):
logger.info(f"[scoring] Loading HF scoring model to {device}")
model.to(device)
offload = getattr(llm_handler, "offload_to_cpu", False)
device = llm_handler.device if hasattr(llm_handler, "device") else "cpu"

try:
yield
finally:
if offload and hasattr(model, "to"):
logger.info("[scoring] Offloading HF scoring model to CPU")
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
logger.info(f"[scoring] Loading HF scoring model to {device}")
model.to(device)

depths[key] = 1
try:
yield
finally:
depths.pop(key, None)
if offload and hasattr(model, "to"):
logger.info("[scoring] Offloading HF scoring model to CPU")
model.to("cpu")
_empty_accelerator_cache(backend)

# On MLX with offload enabled, the HF scoring model is a
# *separate* ~8 GB PyTorch copy of the LM (the MLX model itself
# cannot be used for torch teacher-forcing scoring). Keeping
# it resident on CPU between scoring passes doubles the LM
# footprint in unified memory and -- combined with repeated
# .to("mps") / .to("cpu") migrations -- pushes 32 GB Macs past
# their limit (issue #1081). Drop the cached copy so unified
# memory is returned to the OS; it will be re-loaded from the
# HF cache on the next scoring call.
if backend == "mlx" and offload:
llm_handler._hf_model_for_scoring = None
del model
gc.collect()
_empty_accelerator_cache("mlx")
logger.info(
"[scoring] Released cached HF scoring model on MLX "
"backend (will be reloaded on next Autoscore call)"
)


def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
Expand Down Expand Up @@ -412,40 +541,47 @@ def calculate_pmi_score_per_condition(
formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
try:
# 1. Calculate Recall for Metadata Fields
if metadata and isinstance(metadata, dict):
scores = {}
# Define which fields use which metric
metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
metadata_pmi_keys = ['caption']
for key in metadata_recall_keys:
if key in metadata and metadata[key] is not None:
recall_metadata = {key: metadata[key]}
field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
scores.update(field_scores)

# 2. Calculate PMI for Caption
for key in metadata_pmi_keys:
if key in metadata and metadata[key] is not None:
cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
target_text = f"<think>\n{cot_yaml}\n</think>\n"

log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)

pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
scores[key] = pmi_normalized

# 3. Calculate PMI for Lyrics
if lyrics:
target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"

log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)

prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)

scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
# Batch *all* scoring forward passes under a single load/offload
# cycle. The inner ``_get_logits_and_target_for_scoring`` calls
# still enter ``_load_scoring_model_context``, but the reentrancy
# guard turns those nested entries into no-ops -- so the cached HF
# scoring model is moved CPU↔accelerator exactly once per Autoscore
# pass instead of once per condition (issue #1081).
with _load_scoring_model_context(llm_handler):
# 1. Calculate Recall for Metadata Fields
if metadata and isinstance(metadata, dict):
scores = {}
# Define which fields use which metric
metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
metadata_pmi_keys = ['caption']
for key in metadata_recall_keys:
if key in metadata and metadata[key] is not None:
recall_metadata = {key: metadata[key]}
field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
scores.update(field_scores)

# 2. Calculate PMI for Caption
for key in metadata_pmi_keys:
if key in metadata and metadata[key] is not None:
cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
target_text = f"<think>\n{cot_yaml}\n</think>\n"

log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)

pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
scores[key] = pmi_normalized

# 3. Calculate PMI for Lyrics
if lyrics:
target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"

log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)

prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)

scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)

Comment on lines +550 to 591
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a focused regression test for the new lifecycle contract.

This now depends on three invariants that are easy to regress silently: one outer load/offload per scoring pass, nested _load_scoring_model_context() entries staying no-op, and MLX outermost exit clearing _hf_model_for_scoring. I’d add a fake-handler test around this block before merging.

Based on learnings: "AI-Agent Workflow: Add/update focused tests."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/core/scoring/lm_score.py` around lines 478 - 519, Add a focused
regression test that constructs a fake llm_handler and exercises the Autoscore
scoring path that wraps the block using _load_scoring_model_context, asserting
three invariants: (1) the model load/offload cycle is invoked exactly once for
the outer scoring pass (mock the load/offload side-effects on the fake handler),
(2) nested re-entries into _load_scoring_model_context (triggered by calls like
_get_logits_and_target_for_scoring or _calculate_log_prob) are no-ops (verify
they do not increment load/unload counters), and (3) after the outer context
exits the module-level _hf_model_for_scoring is cleared/None. Use the same call
sites as in the diff (invoke the scoring code path via llm_handler and calls to
_calculate_log_prob/_calculate_metadata_recall) so the test exercises the exact
lifecycle contract.

if not scores:
return {}, 0.0, "❌ No conditions to evaluate"
Expand Down
Loading
Loading