diff --git a/vllm_mlx/config/server_config.py b/vllm_mlx/config/server_config.py index da1d3a08..32a0c478 100644 --- a/vllm_mlx/config/server_config.py +++ b/vllm_mlx/config/server_config.py @@ -77,6 +77,9 @@ class ServerConfig: # --- Multi-model --- model_registry: Any = None + # --- On-demand model loading --- + enable_on_demand_loading: bool = False + # Singleton instance _config = ServerConfig() diff --git a/vllm_mlx/routes/chat.py b/vllm_mlx/routes/chat.py index 7eaf919e..ddf0079e 100644 --- a/vllm_mlx/routes/chat.py +++ b/vllm_mlx/routes/chat.py @@ -55,6 +55,7 @@ _validate_model_name, _validate_tool_call_params, _wait_with_disconnect, + ensure_model_loaded, get_engine, get_usage, ) @@ -146,6 +147,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re } ``` """ + await ensure_model_loaded(request.model) _validate_model_name(request.model) engine = get_engine(request.model) diff --git a/vllm_mlx/routes/completions.py b/vllm_mlx/routes/completions.py index f540a0f4..e9db4e3c 100644 --- a/vllm_mlx/routes/completions.py +++ b/vllm_mlx/routes/completions.py @@ -26,6 +26,7 @@ _resolve_top_p, _validate_model_name, _wait_with_disconnect, + ensure_model_loaded, get_engine, get_usage, ) @@ -41,6 +42,7 @@ ) async def create_completion(request: CompletionRequest, raw_request: Request): """Create a text completion.""" + await ensure_model_loaded(request.model) _validate_model_name(request.model) engine = get_engine(request.model) diff --git a/vllm_mlx/routes/models.py b/vllm_mlx/routes/models.py index ebac8c23..7c6a281c 100644 --- a/vllm_mlx/routes/models.py +++ b/vllm_mlx/routes/models.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """Model listing endpoints.""" +import os +from pathlib import Path + from fastapi import APIRouter, Depends, HTTPException from ..api.models import ModelInfo, ModelsResponse @@ -10,22 +13,63 @@ router = APIRouter() +def _cached_mlx_models() -> list[str]: + """Scan the local HuggingFace cache and return MLX model IDs.""" + hf_cache = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")) / "hub" + if not hf_cache.exists(): + return [] + + found = [] + for entry in hf_cache.iterdir(): + if not entry.name.startswith("models--"): + continue + # models--mlx-community--Qwen3-0.6B-8bit → mlx-community/Qwen3-0.6B-8bit + model_id = entry.name[len("models--"):].replace("--", "/", 1) + # Only include models that have MLX weight files + snapshot_dir = entry / "snapshots" + if not snapshot_dir.exists(): + continue + has_mlx = any( + f.suffix in (".safetensors", ".npz") + for snap in snapshot_dir.iterdir() + if snap.is_dir() + for f in snap.iterdir() + ) + # Skip known non-chat model types + _id_lower = model_id.lower() + if any(x in _id_lower for x in ("tts", "whisper", "asr", "sentence-transformer", "embed")): + continue + if has_mlx: + found.append(model_id) + return sorted(found) + + @router.get("/v1/models", dependencies=[Depends(verify_api_key)]) async def list_models() -> ModelsResponse: """List available models (supports multi-model).""" cfg = get_config() - + seen: set[str] = set() models = [] + + def _add(model_id: str) -> None: + if model_id and model_id not in seen: + seen.add(model_id) + models.append(ModelInfo(id=model_id)) + if cfg.model_registry: for entry in cfg.model_registry.list_entries(): - models.append(ModelInfo(id=entry.model_name)) + _add(entry.model_name) for alias in sorted(entry.aliases): - if alias != entry.model_name: - models.append(ModelInfo(id=alias)) + _add(alias) elif cfg.model_name: - models.append(ModelInfo(id=cfg.model_name)) - if cfg.model_alias and cfg.model_alias != cfg.model_name: - models.append(ModelInfo(id=cfg.model_alias)) + _add(cfg.model_name) + if cfg.model_alias: + _add(cfg.model_alias) + + if cfg.enable_on_demand_loading: + for model_id in _cached_mlx_models(): + _add(model_id) + return ModelsResponse(data=models) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f81ad171..d4e50e92 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -38,6 +38,7 @@ """ import argparse +import asyncio import gc import logging import os @@ -210,6 +211,11 @@ def configure_logging(log_level: str) -> str: _pin_system_prompt: bool = False # Auto-pin system prompt prefix cache blocks _pinned_system_prompt_hash: str | None = None # Hash of pinned system prompt +# On-demand model loading (--enable-on-demand-loading) +_enable_on_demand_loading: bool = False +_loading_model: str | None = None # name of model currently being swapped in +_swap_lock: asyncio.Lock | None = None # lazy-initialized; must not be created before event loop + from .runtime.cache import ( # noqa: E402 get_cache_dir as _get_cache_dir, # noqa: F401 @@ -606,6 +612,78 @@ def factory(tools=None): _sync_config() +def _get_swap_lock() -> asyncio.Lock: + """Return the module-level swap lock, creating it lazily on first call. + + asyncio.Lock() must be created inside a running event loop; module-level + construction raises DeprecationWarning on 3.10 and a hard error on 3.12+. + swap_to_model is always called from async context, so the loop exists. + """ + global _swap_lock + if _swap_lock is None: + _swap_lock = asyncio.Lock() + return _swap_lock + + +def get_loading_model() -> str | None: + """Return the model name currently being swapped in, or None.""" + return _loading_model + + +async def swap_to_model(model_name: str) -> None: + """Hot-swap to model_name on demand, handling single-model and registry modes. + + Single-model mode (registry has ≤1 entry): stops the current engine first to + free GPU memory — running two large models simultaneously is not viable on + Apple Silicon. + + Registry mode (≥2 entries): adds the new engine alongside existing ones without + stopping any; the operator explicitly started multiple models. + + The asyncio.Lock serializes concurrent swap requests: the second acquirer + re-checks _is_model_loaded() inside the lock and returns immediately if the + first acquirer already finished loading. + """ + global _engine, _model_name, _model_path, _model_alias, _loading_model + + async with _get_swap_lock(): + from .service.helpers import _is_model_loaded + + if _is_model_loaded(model_name): + return + + registry_mode = len(_model_registry) > 1 + logger.info( + f"[swap_to_model] Loading '{model_name}' " + f"({'registry' if registry_mode else 'single-model'} mode)" + ) + _loading_model = model_name + try: + if not registry_mode and _engine is not None: + old_engine, old_name = _engine, _model_name + _model_registry.remove(old_name) + logger.info(f"[swap_to_model] Stopping '{old_name}'") + await old_engine.stop() + + # load_model() is synchronous: constructs BatchedEngine, updates + # globals (_engine, _model_name, …), and calls _sync_config(). + # It does NOT start the engine — that is always async. + load_model(model_name) + + await _engine.start() + + # Best-effort warmup to compile Metal shaders; non-fatal on failure. + try: + if not getattr(_engine, "_is_mllm", False): + _engine.generate_warmup() + except Exception as e: + logger.warning(f"[swap_to_model] Warmup failed (non-fatal): {e}") + + logger.info(f"[swap_to_model] '{model_name}' ready") + finally: + _loading_model = None # clear even on failure so swap is retryable + + def _sync_config() -> None: """Copy server globals into the ServerConfig singleton. @@ -647,6 +725,12 @@ def _sync_config() -> None: cfg.pinned_system_prompt_hash = _pinned_system_prompt_hash cfg.mcp_executor = _mcp_executor cfg.model_registry = _model_registry + # NOTE: enable_on_demand_loading is NOT synced here. Running as `-m vllm_mlx.server` + # registers the module as __main__, but `from ..server import swap_to_model` in + # helpers.py re-imports vllm_mlx.server as a separate module instance with + # _enable_on_demand_loading = False (the default). If we sync this field here, + # the second instance's _sync_config() calls would stomp the value set by main(). + # Instead, main() sets cfg.enable_on_demand_loading directly on the singleton. # Re-export for backward compatibility (test_streaming_pipeline_integration) @@ -916,15 +1000,29 @@ def main(): default=None, help="API key for cloud model (overrides environment variable).", ) + parser.add_argument( + "--enable-on-demand-loading", + action="store_true", + default=False, + help=( + "Auto-load unrecognised models from HuggingFace on first request. " + "Off by default — enabling allows any caller to trigger downloads. " + "Use with --api-key in production." + ), + ) args = parser.parse_args() uvicorn_log_level = configure_logging(args.log_level) # Set global configuration - global _api_key, _default_timeout, _rate_limiter + global _api_key, _default_timeout, _rate_limiter, _enable_on_demand_loading global _default_temperature, _default_top_p _api_key = args.api_key _default_timeout = args.timeout + _enable_on_demand_loading = args.enable_on_demand_loading + # Write directly to the singleton so swap_to_model (which runs in a re-imported + # vllm_mlx.server module instance) cannot overwrite this via _sync_config(). + get_config().enable_on_demand_loading = args.enable_on_demand_loading if args.default_temperature is not None: _default_temperature = args.default_temperature if args.default_top_p is not None: @@ -950,6 +1048,10 @@ def main(): else: logger.warning(" Rate limiting: DISABLED - Use --rate-limit to enable") logger.info(f" Request timeout: {args.timeout}s") + if _enable_on_demand_loading: + logger.warning(" On-demand loading: ENABLED (--enable-on-demand-loading)") + else: + logger.info(" On-demand loading: DISABLED (default)") logger.info("=" * 60) # Set MCP config for lifespan diff --git a/vllm_mlx/service/helpers.py b/vllm_mlx/service/helpers.py index 08ffaabc..56c8faa1 100644 --- a/vllm_mlx/service/helpers.py +++ b/vllm_mlx/service/helpers.py @@ -242,6 +242,51 @@ def _validate_model_name(request_model: str) -> None: ) +def _is_model_loaded(model_name: str) -> bool: + """Return True if model_name refers to the currently loaded model(s).""" + cfg = get_config() + if cfg.model_registry: + return model_name in cfg.model_registry or model_name == "default" + if not cfg.model_name: + return False + accepted = {cfg.model_name} + if cfg.model_alias: + accepted.add(cfg.model_alias) + if cfg.model_path: + accepted.add(cfg.model_path) + return model_name in accepted + + +async def ensure_model_loaded(model_name: str | None) -> None: + """Auto-load a model on demand, like ollama. + + If the requested model isn't currently loaded, swaps to it (downloading + from HuggingFace if needed). No-ops when model_name is empty/default or + the model is already loaded. + + Requires --enable-on-demand-loading; off by default to prevent unauthenticated + HuggingFace downloads. + """ + if not get_config().enable_on_demand_loading: + return + if not model_name or model_name == "default": + return + if _is_model_loaded(model_name): + return + + from ..server import get_loading_model, swap_to_model + + in_flight = get_loading_model() + if in_flight and in_flight != model_name: + raise HTTPException( + status_code=503, + detail=f"Model swap in progress: '{in_flight}'. Retry after swap completes.", + headers={"Retry-After": "30"}, + ) + + await swap_to_model(model_name) + + # ── Tool call parsing ──────────────────────────────────────────────