Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm_mlx/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class ServerConfig:
# --- Multi-model ---
model_registry: Any = None

# --- On-demand model loading ---
enable_on_demand_loading: bool = False


# Singleton instance
_config = ServerConfig()
Expand Down
2 changes: 2 additions & 0 deletions vllm_mlx/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
_validate_model_name,
_validate_tool_call_params,
_wait_with_disconnect,
ensure_model_loaded,
get_engine,
get_usage,
)
Expand Down Expand Up @@ -146,6 +147,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
}
```
"""
await ensure_model_loaded(request.model)
_validate_model_name(request.model)
engine = get_engine(request.model)

Expand Down
2 changes: 2 additions & 0 deletions vllm_mlx/routes/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_resolve_top_p,
_validate_model_name,
_wait_with_disconnect,
ensure_model_loaded,
get_engine,
get_usage,
)
Expand All @@ -41,6 +42,7 @@
)
async def create_completion(request: CompletionRequest, raw_request: Request):
"""Create a text completion."""
await ensure_model_loaded(request.model)
_validate_model_name(request.model)
engine = get_engine(request.model)

Expand Down
58 changes: 51 additions & 7 deletions vllm_mlx/routes/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
"""Model listing endpoints."""

import os
from pathlib import Path

from fastapi import APIRouter, Depends, HTTPException

from ..api.models import ModelInfo, ModelsResponse
Expand All @@ -10,22 +13,63 @@
router = APIRouter()


def _cached_mlx_models() -> list[str]:
"""Scan the local HuggingFace cache and return MLX model IDs."""
hf_cache = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")) / "hub"
if not hf_cache.exists():
return []

found = []
for entry in hf_cache.iterdir():
if not entry.name.startswith("models--"):
continue
# models--mlx-community--Qwen3-0.6B-8bit → mlx-community/Qwen3-0.6B-8bit
model_id = entry.name[len("models--"):].replace("--", "/", 1)
# Only include models that have MLX weight files
snapshot_dir = entry / "snapshots"
if not snapshot_dir.exists():
continue
has_mlx = any(
f.suffix in (".safetensors", ".npz")
for snap in snapshot_dir.iterdir()
if snap.is_dir()
for f in snap.iterdir()
)
# Skip known non-chat model types
_id_lower = model_id.lower()
if any(x in _id_lower for x in ("tts", "whisper", "asr", "sentence-transformer", "embed")):
continue
if has_mlx:
found.append(model_id)
return sorted(found)


@router.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def list_models() -> ModelsResponse:
"""List available models (supports multi-model)."""
cfg = get_config()

seen: set[str] = set()
models = []

def _add(model_id: str) -> None:
if model_id and model_id not in seen:
seen.add(model_id)
models.append(ModelInfo(id=model_id))

if cfg.model_registry:
for entry in cfg.model_registry.list_entries():
models.append(ModelInfo(id=entry.model_name))
_add(entry.model_name)
for alias in sorted(entry.aliases):
if alias != entry.model_name:
models.append(ModelInfo(id=alias))
_add(alias)
elif cfg.model_name:
models.append(ModelInfo(id=cfg.model_name))
if cfg.model_alias and cfg.model_alias != cfg.model_name:
models.append(ModelInfo(id=cfg.model_alias))
_add(cfg.model_name)
if cfg.model_alias:
_add(cfg.model_alias)

if cfg.enable_on_demand_loading:
for model_id in _cached_mlx_models():
_add(model_id)

return ModelsResponse(data=models)


Expand Down
104 changes: 103 additions & 1 deletion vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"""

import argparse
import asyncio
import gc
import logging
import os
Expand Down Expand Up @@ -210,6 +211,11 @@ def configure_logging(log_level: str) -> str:
_pin_system_prompt: bool = False # Auto-pin system prompt prefix cache blocks
_pinned_system_prompt_hash: str | None = None # Hash of pinned system prompt

# On-demand model loading (--enable-on-demand-loading)
_enable_on_demand_loading: bool = False
_loading_model: str | None = None # name of model currently being swapped in
_swap_lock: asyncio.Lock | None = None # lazy-initialized; must not be created before event loop


from .runtime.cache import ( # noqa: E402
get_cache_dir as _get_cache_dir, # noqa: F401
Expand Down Expand Up @@ -606,6 +612,78 @@ def factory(tools=None):
_sync_config()


def _get_swap_lock() -> asyncio.Lock:
"""Return the module-level swap lock, creating it lazily on first call.

asyncio.Lock() must be created inside a running event loop; module-level
construction raises DeprecationWarning on 3.10 and a hard error on 3.12+.
swap_to_model is always called from async context, so the loop exists.
"""
global _swap_lock
if _swap_lock is None:
_swap_lock = asyncio.Lock()
return _swap_lock


def get_loading_model() -> str | None:
"""Return the model name currently being swapped in, or None."""
return _loading_model


async def swap_to_model(model_name: str) -> None:
"""Hot-swap to model_name on demand, handling single-model and registry modes.

Single-model mode (registry has ≤1 entry): stops the current engine first to
free GPU memory — running two large models simultaneously is not viable on
Apple Silicon.

Registry mode (≥2 entries): adds the new engine alongside existing ones without
stopping any; the operator explicitly started multiple models.

The asyncio.Lock serializes concurrent swap requests: the second acquirer
re-checks _is_model_loaded() inside the lock and returns immediately if the
first acquirer already finished loading.
"""
global _engine, _model_name, _model_path, _model_alias, _loading_model

async with _get_swap_lock():
from .service.helpers import _is_model_loaded

if _is_model_loaded(model_name):
return

registry_mode = len(_model_registry) > 1
logger.info(
f"[swap_to_model] Loading '{model_name}' "
f"({'registry' if registry_mode else 'single-model'} mode)"
)
_loading_model = model_name
try:
if not registry_mode and _engine is not None:
old_engine, old_name = _engine, _model_name
_model_registry.remove(old_name)
logger.info(f"[swap_to_model] Stopping '{old_name}'")
await old_engine.stop()

# load_model() is synchronous: constructs BatchedEngine, updates
# globals (_engine, _model_name, …), and calls _sync_config().
# It does NOT start the engine — that is always async.
load_model(model_name)

await _engine.start()

# Best-effort warmup to compile Metal shaders; non-fatal on failure.
try:
if not getattr(_engine, "_is_mllm", False):
_engine.generate_warmup()
except Exception as e:
logger.warning(f"[swap_to_model] Warmup failed (non-fatal): {e}")

logger.info(f"[swap_to_model] '{model_name}' ready")
finally:
_loading_model = None # clear even on failure so swap is retryable


def _sync_config() -> None:
"""Copy server globals into the ServerConfig singleton.

Expand Down Expand Up @@ -647,6 +725,12 @@ def _sync_config() -> None:
cfg.pinned_system_prompt_hash = _pinned_system_prompt_hash
cfg.mcp_executor = _mcp_executor
cfg.model_registry = _model_registry
# NOTE: enable_on_demand_loading is NOT synced here. Running as `-m vllm_mlx.server`
# registers the module as __main__, but `from ..server import swap_to_model` in
# helpers.py re-imports vllm_mlx.server as a separate module instance with
# _enable_on_demand_loading = False (the default). If we sync this field here,
# the second instance's _sync_config() calls would stomp the value set by main().
# Instead, main() sets cfg.enable_on_demand_loading directly on the singleton.


# Re-export for backward compatibility (test_streaming_pipeline_integration)
Expand Down Expand Up @@ -916,15 +1000,29 @@ def main():
default=None,
help="API key for cloud model (overrides environment variable).",
)
parser.add_argument(
"--enable-on-demand-loading",
action="store_true",
default=False,
help=(
"Auto-load unrecognised models from HuggingFace on first request. "
"Off by default — enabling allows any caller to trigger downloads. "
"Use with --api-key in production."
),
)

args = parser.parse_args()
uvicorn_log_level = configure_logging(args.log_level)

# Set global configuration
global _api_key, _default_timeout, _rate_limiter
global _api_key, _default_timeout, _rate_limiter, _enable_on_demand_loading
global _default_temperature, _default_top_p
_api_key = args.api_key
_default_timeout = args.timeout
_enable_on_demand_loading = args.enable_on_demand_loading
# Write directly to the singleton so swap_to_model (which runs in a re-imported
# vllm_mlx.server module instance) cannot overwrite this via _sync_config().
get_config().enable_on_demand_loading = args.enable_on_demand_loading
if args.default_temperature is not None:
_default_temperature = args.default_temperature
if args.default_top_p is not None:
Expand All @@ -950,6 +1048,10 @@ def main():
else:
logger.warning(" Rate limiting: DISABLED - Use --rate-limit to enable")
logger.info(f" Request timeout: {args.timeout}s")
if _enable_on_demand_loading:
logger.warning(" On-demand loading: ENABLED (--enable-on-demand-loading)")
else:
logger.info(" On-demand loading: DISABLED (default)")
logger.info("=" * 60)

# Set MCP config for lifespan
Expand Down
45 changes: 45 additions & 0 deletions vllm_mlx/service/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,51 @@ def _validate_model_name(request_model: str) -> None:
)


def _is_model_loaded(model_name: str) -> bool:
"""Return True if model_name refers to the currently loaded model(s)."""
cfg = get_config()
if cfg.model_registry:
return model_name in cfg.model_registry or model_name == "default"
if not cfg.model_name:
return False
accepted = {cfg.model_name}
if cfg.model_alias:
accepted.add(cfg.model_alias)
if cfg.model_path:
accepted.add(cfg.model_path)
return model_name in accepted


async def ensure_model_loaded(model_name: str | None) -> None:
"""Auto-load a model on demand, like ollama.

If the requested model isn't currently loaded, swaps to it (downloading
from HuggingFace if needed). No-ops when model_name is empty/default or
the model is already loaded.

Requires --enable-on-demand-loading; off by default to prevent unauthenticated
HuggingFace downloads.
"""
if not get_config().enable_on_demand_loading:
return
if not model_name or model_name == "default":
return
if _is_model_loaded(model_name):
return

from ..server import get_loading_model, swap_to_model

in_flight = get_loading_model()
if in_flight and in_flight != model_name:
raise HTTPException(
status_code=503,
detail=f"Model swap in progress: '{in_flight}'. Retry after swap completes.",
headers={"Retry-After": "30"},
)

await swap_to_model(model_name)


# ── Tool call parsing ──────────────────────────────────────────────


Expand Down