diff --git a/harness/baselines/check-qwen3.5-4b.json b/harness/baselines/check-qwen3.5-4b.json index 441068d8..208f048b 100644 --- a/harness/baselines/check-qwen3.5-4b.json +++ b/harness/baselines/check-qwen3.5-4b.json @@ -1,20 +1,20 @@ { - "captured_at": "2026-04-15T21:36:32", + "captured_at": "2026-04-16T16:34:02", "rapid_mlx_version": "0.5.0", "model": "qwen3.5-4b", "metrics": { - "cold_ttft_ms": 313.62908403389156, - "cold_tps": 49.2736011322839, - "cached_ttft_ms": 393.99170805700123, - "decode_tps": 49.67459517661095, + "cold_ttft_ms": 170.7701669074595, + "cold_tps": 168.1036081769402, + "cached_ttft_ms": 223.16349996253848, + "decode_tps": 167.821708310493, "decode_tps_stdev": 0, - "mt_ttft_ms": 402.93216705322266, - "mt_tps": 48.32873664326039, - "tc_latency_ms": 2819.6406660135835, + "mt_ttft_ms": 233.40062494389713, + "mt_tps": 165.54264366438625, + "tc_latency_ms": 1068.896499928087, "tc_success_rate": 1.0, - "long_ttft_ms": 1274.2738330271095, - "long_tps": 48.325022009001415, - "long_cached_ttft_ms": 1233.1054580863565, - "composite_score": 109.9 + "long_ttft_ms": 966.4679998531938, + "long_tps": 163.68971893952352, + "long_cached_ttft_ms": 943.7295419629663, + "composite_score": 257.2 } } \ No newline at end of file diff --git a/harness/scorecard/latest.md b/harness/scorecard/latest.md index ef51f998..b27e3153 100644 --- a/harness/scorecard/latest.md +++ b/harness/scorecard/latest.md @@ -1,45 +1,7 @@ # Rapid-MLX Benchmark Scorecard -_Generated: 2026-04-16T07:10:38_ +_Generated: 2026-04-16T16:50:20_ | Model | Decode TPS | Cold TTFT | Cached TTFT | Tool % | Score | Status | | --- | ---: | ---: | ---: | ---: | ---: | --- | -| deepseek-r1-32b | 8.6 | 1111ms | 418ms | 0% | 51.8 | OK | -| llama3-3b | 34.9 | 258ms | 189ms | 0% | 130.1 | OK | -| qwen3-vl-8b | 12.2 | 456ms | 505ms | 100% | 59.3 | OK | -| qwen3.5-27b | — | — | — | — | — | FAIL — server boot failed: server exited with code 1 before becoming healthy | -| qwen3.5-35b | 10.9 | 1091ms | 1063ms | 0% | 26.5 | OK | -| qwen3.5-4b | 25.1 | 448ms | 460ms | 100% | 78.5 | OK | -| qwen3.5-9b | 20.7 | 539ms | 563ms | 100% | 68.1 | OK | -| qwopus-27b | 8.8 | 1165ms | 1145ms | 100% | 41.9 | OK | -| qwopus-27b-8bit | — | — | — | — | — | FAIL — server boot failed: server exited with code 1 before becoming healthy | - -## Skipped - -- **deepseek-r1-8b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **devstral-24b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **devstral-v2-24b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **gemma-3n-e4b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **gemma-4-26b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **gemma-4-31b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **gemma3-12b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **gemma3-1b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **gemma3-27b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **glm4.5-air** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **glm4.7-9b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **gpt-oss-20b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **hermes3-8b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **hermes4-70b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **kimi-48b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **kimi-k2.5** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **minimax-m2.5** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **ministral-3b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **mistral-24b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **phi4-14b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **qwen3-coder** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **qwen3-coder-30b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **qwen3-vl-30b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **qwen3-vl-4b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **qwen3.5-122b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **qwen3.5-122b-8bit** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio -- **qwopus-9b** — not found in HF_HUB_CACHE / ~/.cache/huggingface / ~/.lmstudio +| qwopus-27b-8bit | 22.4 | 508ms | 513ms | 0% | 50.5 | OK | diff --git a/pyproject.toml b/pyproject.toml index 01c368a1..58aa34dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ vllm = [ # Guided decoding with outlines for structured JSON output guided = [ "outlines[mlxlm]>=1.0.0", + "outlines-core>=0.2.0", # FSM engine for tool call constrained decoding ] # Audio dependencies for TTS/STT (mlx-audio) audio = [ diff --git a/tests/test_fsm_tool_call.py b/tests/test_fsm_tool_call.py new file mode 100644 index 00000000..ed8b4e60 --- /dev/null +++ b/tests/test_fsm_tool_call.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for FSM-based tool call constrained decoding.""" + +from __future__ import annotations + +import json +import time +from unittest.mock import MagicMock + +import pytest + +# Skip all tests if outlines-core not installed +pytest.importorskip("outlines_core") + + +class TestFSMToolCallCache: + """Tests for FSM compilation cache.""" + + def test_precompile_success(self): + from vllm_mlx.api.fsm_tool_call import FSMToolCallCache + + cache = FSMToolCallCache() + # Build vocabulary from real tokenizer + from outlines_core import Vocabulary + + cache._vocabulary = Vocabulary.from_pretrained( + "mlx-community/Qwen3.5-4B-MLX-4bit" + ) + + tools = [ + { + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}, + } + } + ] + assert cache.precompile(tools) is True + + def test_cache_hit(self): + from vllm_mlx.api.fsm_tool_call import FSMToolCallCache + + cache = FSMToolCallCache() + from outlines_core import Vocabulary + + cache._vocabulary = Vocabulary.from_pretrained( + "mlx-community/Qwen3.5-4B-MLX-4bit" + ) + + tools = [{"function": {"name": "search", "parameters": {"type": "object"}}}] + + # First call compiles + t0 = time.perf_counter() + cache.precompile(tools) + first_time = time.perf_counter() - t0 + + # Second call hits cache + t0 = time.perf_counter() + result = cache.precompile(tools) + second_time = time.perf_counter() - t0 + + assert result is True + assert second_time < first_time / 10, "Cache hit should be >10x faster" + + def test_get_guide_returns_fresh_guide(self): + from vllm_mlx.api.fsm_tool_call import FSMToolCallCache + + cache = FSMToolCallCache() + from outlines_core import Vocabulary + + cache._vocabulary = Vocabulary.from_pretrained( + "mlx-community/Qwen3.5-4B-MLX-4bit" + ) + + tools = [{"function": {"name": "test", "parameters": {"type": "object"}}}] + + g1 = cache.get_guide(tools) + g2 = cache.get_guide(tools) + assert g1 is not None + assert g2 is not None + # Each guide is a fresh instance (independent state) + assert g1 is not g2 + + def test_schema_builds_correct_enum(self): + from vllm_mlx.api.fsm_tool_call import _build_tool_call_schema + + tools = [ + {"function": {"name": "get_weather"}}, + {"function": {"name": "search"}}, + {"function": {"name": "calculate"}}, + ] + schema = json.loads(_build_tool_call_schema(tools)) + assert schema["properties"]["name"]["enum"] == [ + "get_weather", + "search", + "calculate", + ] + assert schema["required"] == ["name", "arguments"] + + +class TestFSMToolCallProcessor: + """Tests for the two-mode logits processor.""" + + @pytest.fixture + def tokenizer(self): + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained("mlx-community/Qwen3.5-4B-MLX-4bit") + + @pytest.fixture + def tools(self): + return [ + { + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + } + } + ] + + @pytest.fixture + def processor(self, tokenizer, tools): + from outlines_core import Vocabulary + + from vllm_mlx.api.fsm_tool_call import ( + FSMToolCallCache, + FSMToolCallProcessor, + ) + + cache = FSMToolCallCache() + cache._vocabulary = Vocabulary.from_pretrained( + "mlx-community/Qwen3.5-4B-MLX-4bit" + ) + cache.precompile(tools) + + return FSMToolCallProcessor( + tokenizer=tokenizer, + tools=tools, + parser_name="hermes", + cache=cache, + ) + + def test_free_mode_passes_logits_through(self, processor, tokenizer): + """In free mode, logits should pass through unchanged.""" + import mlx.core as mx + + logits = mx.random.normal((1, 248077)) + token_ids = mx.array(tokenizer.encode("Hello world")) + + result = processor(token_ids, logits) + # Should be identical (no masking) + assert mx.array_equal(result, logits) + + def test_trigger_plus_json_activates_constrained_mode(self, processor, tokenizer): + """After seeing \\n + '{', processor should constrain.""" + import mlx.core as mx + + # Feed trigger + JSON start + text = '\n{"' + token_ids = tokenizer.encode(text, add_special_tokens=False) + logits = mx.zeros((1, 248077)) + + for i, tid in enumerate(token_ids): + all_ids = mx.array(token_ids[: i + 1]) + processor(all_ids, logits) + + assert processor._constrained, "Should be in constrained mode after trigger + '{'" + + def test_trigger_plus_xml_skips_fsm(self, processor, tokenizer): + """After \\n + '<', FSM should NOT activate (XML format).""" + import mlx.core as mx + + text = "\n\n", add_special_tokens=False) + logits = mx.zeros((1, 248077)) + + # Feed the last trigger token to activate FSM + result = processor(mx.array(trigger_ids), logits) + + if processor._constrained: + # Most tokens should be -inf (masked) + result_np = result.tolist()[0] + n_valid = sum(1 for x in result_np if x > -1e9) + n_masked = sum(1 for x in result_np if x < -1e9) + print(f"\n Constrained: {n_valid} valid, {n_masked} masked") + assert n_valid < 100, f"Expected < 100 valid tokens, got {n_valid}" + assert n_masked > 200000, "Expected most tokens masked" + + def test_reset_clears_state(self, processor): + processor._constrained = True + processor._recent_text = "some text" + processor._guide = MagicMock() + + processor.reset() + + assert not processor._constrained + assert processor._recent_text == "" + assert processor._guide is None + + +class TestFSMFactory: + """Tests for the factory function.""" + + def test_create_returns_processor_when_available(self): + from outlines_core import Vocabulary + from transformers import AutoTokenizer + + from vllm_mlx.api.fsm_tool_call import create_fsm_processor, get_fsm_cache + + tok = AutoTokenizer.from_pretrained("mlx-community/Qwen3.5-4B-MLX-4bit") + cache = get_fsm_cache() + cache._vocabulary = Vocabulary.from_pretrained( + "mlx-community/Qwen3.5-4B-MLX-4bit" + ) + + tools = [{"function": {"name": "test", "parameters": {"type": "object"}}}] + proc = create_fsm_processor("hermes", tok, tools) + assert proc is not None + + def test_create_returns_processor_with_generic_schema(self): + """Even without specific tools, factory returns a processor + with generic schema (any name + any arguments).""" + from vllm_mlx.api.fsm_tool_call import create_fsm_processor + + tok = MagicMock() + # No tools → generic schema processor (not None) + proc = create_fsm_processor("hermes", tok, None) + assert proc is not None, "Should return generic FSM processor" + + def test_all_parsers_have_triggers(self): + """Every parser should have a trigger pattern registered.""" + from vllm_mlx.api.fsm_tool_call import TOOL_CALL_TRIGGERS + + expected_parsers = [ + "hermes", "llama", "minimax", "qwen", "deepseek", + "glm47", "granite", "nemotron", "kimi", "gemma4", + "functionary", "seed_oss", "mistral", "xlam", + ] + for p in expected_parsers: + assert p in TOOL_CALL_TRIGGERS, f"Missing trigger for parser {p!r}" + + +class TestFSMPerformance: + """Verify FSM overhead is negligible.""" + + def test_per_token_overhead_under_10us(self): + """FSM lookup must be < 10µs per token to not affect decode speed.""" + from outlines_core import Guide, Index, Vocabulary, json_schema + + vocabulary = Vocabulary.from_pretrained("mlx-community/Qwen3.5-4B-MLX-4bit") + schema = json.dumps({ + "type": "object", + "properties": { + "name": {"type": "string", "enum": ["get_weather"]}, + "arguments": {"type": "object"}, + }, + "required": ["name", "arguments"], + }) + regex = json_schema.build_regex_from_schema(schema) + index = Index(regex, vocabulary) + + from transformers import AutoTokenizer + tok = AutoTokenizer.from_pretrained("mlx-community/Qwen3.5-4B-MLX-4bit") + target = '{"name": "get_weather", "arguments": {}}' + target_ids = tok.encode(target, add_special_tokens=False) + + # Benchmark + t0 = time.perf_counter() + for _ in range(1000): + guide = Guide(index) + for tid in target_ids: + guide.get_tokens() + guide.advance(tid) + dt = time.perf_counter() - t0 + per_token_us = dt / (1000 * len(target_ids)) * 1e6 + + print(f"\n Per-token FSM overhead: {per_token_us:.1f} µs") + assert per_token_us < 10, f"FSM overhead too high: {per_token_us:.1f} µs" diff --git a/vllm_mlx/api/fsm_tool_call.py b/vllm_mlx/api/fsm_tool_call.py new file mode 100644 index 00000000..96f50d77 --- /dev/null +++ b/vllm_mlx/api/fsm_tool_call.py @@ -0,0 +1,455 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +FSM-based tool call constrained decoding. + +Uses outlines-core's finite state machine to guarantee valid JSON +tool calls during generation. Replaces the 18 regex-based parsers +with a single FSM that: + +1. Lets the model generate freely until it outputs a tool call trigger +2. Switches to constrained mode — only FSM-valid tokens are allowed +3. Guarantees the output is valid JSON matching the tool schema +4. Switches back to free mode after the JSON body is complete + +Performance: +- FSM compilation: ~2-8s (once per tool schema, cached by hash) +- Per-token overhead: 0.8 µs (0.004% of a 20ms decode step) +- Precompiled at server startup → zero latency for users +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +try: + from outlines_core import Guide, Index, Vocabulary, json_schema + + HAS_OUTLINES_CORE = True +except ImportError: + HAS_OUTLINES_CORE = False + Guide = None + Index = None + Vocabulary = None + json_schema = None + + +# ===================================================================== +# Tool call trigger patterns (per parser format) +# +# Each entry maps parser_name → (trigger_suffix, closing_tag). +# - trigger_suffix: text that signals the start of a JSON tool call +# - closing_tag: expected text after the JSON body (for clean extraction) +# ===================================================================== + +TOOL_CALL_TRIGGERS: dict[str, tuple[str, str]] = { + # Hermes/Qwen/NousResearch — JSON format + "hermes": ("\n", "\n"), + "nous": ("\n", "\n"), + "qwen": ("\n", "\n"), + "qwen3": ("\n", "\n"), + "qwen3_coder": ("\n", "\n"), + "glm47": ("\n", "\n"), + "glm4": ("\n", "\n"), + "granite": ("\n", "\n"), + "granite3": ("\n", "\n"), + # Llama — function format + "llama": (""), + "llama3": (""), + "llama4": (""), + # Functionary + "functionary": (""), + "meetkai": (""), + # Nemotron + "nemotron": (""), + "nemotron3": (""), + # Qwen3 Coder XML + "qwen3_coder_xml": ("\n", "\n"), + "qwen3_xml": ("\n", "\n"), + # Seed OSS + "seed_oss": ("", ""), + "seed": ("", ""), + "gpt_oss": ("", ""), + "gpt-oss": ("", ""), + "harmony": ("", ""), + # MiniMax — XML invoke format (trigger is different) + "minimax": ("", ""), + "minimax_m2": ("", ""), + # Mistral + "mistral": ("[TOOL_CALLS]", ""), + # DeepSeek + "deepseek": ("<|tool▁sep|>", "<|tool▁call▁end|>"), + "deepseek_v3": ("<|tool▁sep|>", "<|tool▁call▁end|>"), + "deepseek_r1": ("<|tool▁sep|>", "<|tool▁call▁end|>"), + "deepseek_v31": ("<|tool▁sep|>", "<|tool▁call▁end|>"), + "deepseek_r1_0528": ("<|tool▁sep|>", "<|tool▁call▁end|>"), + # Kimi + "kimi": ("<|tool_call_argument_begin|>", "<|tool_call_end|>"), + "kimi_k2": ("<|tool_call_argument_begin|>", "<|tool_call_end|>"), + "moonshot": ("<|tool_call_argument_begin|>", "<|tool_call_end|>"), + # Gemma 4 + "gemma4": ("<|tool_call>", ""), + "gemma_4": ("<|tool_call>", ""), + # xLAM + "xlam": ("[TOOL_CALLS]", ""), + # Auto/Generic + "auto": ("\n", "\n"), + "generic": ("\n", "\n"), +} + + +# ===================================================================== +# FSM Compilation Cache +# ===================================================================== + + +def _schema_hash(tools: list[dict]) -> str: + """Stable hash of tool definitions for cache keying.""" + canonical = json.dumps(tools, sort_keys=True, ensure_ascii=True) + return hashlib.sha256(canonical.encode()).hexdigest()[:16] + + +def _build_tool_call_schema(tools: list[dict]) -> str: + """Build a JSON schema that matches any valid tool call. + + Produces: {"name": "", "arguments": {}} + + This is the "fast path" — constrains JSON structure without + validating per-tool argument schemas (which would slow compilation). + """ + tool_names = [] + for tool in tools: + func = tool.get("function", tool) + name = func.get("name", "") + if name and name != "__generic__": + tool_names.append(name) + + if not tool_names: + # Generic: any string name (no enum constraint) + name_schema: dict[str, Any] = {"type": "string"} + elif len(tool_names) == 1: + name_schema = {"type": "string", "const": tool_names[0]} + else: + name_schema = {"type": "string", "enum": tool_names} + + schema = { + "type": "object", + "properties": { + "name": name_schema, + "arguments": {"type": "object"}, + }, + "required": ["name", "arguments"], + } + return json.dumps(schema) + + +class FSMToolCallCache: + """Cache of compiled FSM indices for tool call schemas. + + Compilation is expensive (~2-8s) but the result is reused for all + requests with the same tool set. Call ``precompile()`` at server + startup to hide the cost from users. + """ + + def __init__(self, vocabulary: Any | None = None): + self._vocabulary = vocabulary + self._cache: dict[str, tuple[Any, Any]] = {} # hash → (Index, Guide template) + + def set_vocabulary(self, tokenizer: Any) -> None: + """Build vocabulary from a HuggingFace tokenizer. + + ``Vocabulary.from_pretrained`` requires a HuggingFace model ID + (not a local path) because it downloads ``tokenizer.json`` + internally. If ``name_or_path`` looks like a local path, we + resolve the original model ID from the HF cache metadata. + """ + if not HAS_OUTLINES_CORE: + return + try: + import os + + model_id = getattr(tokenizer, "name_or_path", "") or "" + + # If name_or_path is a local path, try to resolve the HF model ID + if os.sep in model_id or model_id.startswith("/"): + resolved_id = self._resolve_hf_model_id(model_id) + if resolved_id: + model_id = resolved_id + else: + logger.warning( + f"[FSM] Cannot resolve HF model ID from local path: " + f"{model_id}. FSM constrained decoding unavailable." + ) + return + + self._vocabulary = Vocabulary.from_pretrained(model_id) + logger.info( + f"[FSM] Vocabulary built: {len(self._vocabulary)} tokens " + f"(model={model_id})" + ) + except Exception as e: + logger.warning(f"[FSM] Failed to build vocabulary: {e}") + self._vocabulary = None + + @staticmethod + def _resolve_hf_model_id(local_path: str) -> str | None: + """Try to extract HF model ID from a local snapshot path. + + HF cache layout: .../models--{org}--{repo}/snapshots/{hash}/ + """ + import re + + match = re.search(r"models--([^/]+)--([^/]+)", local_path) + if match: + org, repo = match.group(1), match.group(2) + return f"{org}/{repo}" + return None + + def precompile(self, tools: list[dict]) -> bool: + """Precompile FSM for a tool set. Returns True on success.""" + if not HAS_OUTLINES_CORE or self._vocabulary is None: + return False + + key = _schema_hash(tools) + if key in self._cache: + return True + + schema_str = _build_tool_call_schema(tools) + try: + import time + + t0 = time.perf_counter() + regex = json_schema.build_regex_from_schema(schema_str) + index = Index(regex, self._vocabulary) + dt = time.perf_counter() - t0 + + self._cache[key] = index + n_tools = len( + [t for t in tools if t.get("function", t).get("name")] + ) + logger.info( + f"[FSM] Precompiled tool call grammar: {n_tools} tools, " + f"{len(regex)} char regex, {dt:.1f}s compile time " + f"(cached as {key})" + ) + return True + except Exception as e: + logger.warning(f"[FSM] Failed to compile tool call grammar: {e}") + return False + + def get_guide(self, tools: list[dict]) -> Any | None: + """Get a fresh Guide for the given tools. Compiles on miss.""" + if not HAS_OUTLINES_CORE or self._vocabulary is None: + return None + + key = _schema_hash(tools) + if key not in self._cache: + if not self.precompile(tools): + return None + + index = self._cache[key] + return Guide(index) + + +# Global cache instance +_fsm_cache = FSMToolCallCache() + + +def get_fsm_cache() -> FSMToolCallCache: + """Get the global FSM cache instance.""" + return _fsm_cache + + +# ===================================================================== +# FSM Logits Processor +# ===================================================================== + + +class FSMToolCallProcessor: + """Logits processor that constrains tool call JSON via FSM. + + Two modes: + - **Free mode** (default): all tokens allowed, no constraint. + Model generates text, reasoning, etc. freely. + - **Constrained mode**: activated when model outputs a tool call + trigger (e.g., ``\\n``). Only FSM-valid tokens are + allowed, guaranteeing valid JSON output. + + The processor tracks recent output text to detect triggers. + When a trigger is found, it creates a Guide from the cached FSM + Index and masks invalid tokens until the JSON body is complete. + """ + + def __init__( + self, + tokenizer: Any, + tools: list[dict], + parser_name: str = "hermes", + cache: FSMToolCallCache | None = None, + ): + self.tokenizer = tokenizer + self.tools = tools + self.parser_name = parser_name + self._cache = cache or _fsm_cache + + # Resolve trigger pattern + trigger_info = TOOL_CALL_TRIGGERS.get(parser_name) + self._trigger = trigger_info[0] if trigger_info else None + self._closing = trigger_info[1] if trigger_info else None + + # State + self._recent_text = "" + self._guide: Any | None = None # Active Guide when in constrained mode + self._constrained = False + self._json_depth = 0 # Track brace depth for JSON completion + + def reset(self) -> None: + """Reset for a new generation.""" + self._recent_text = "" + self._guide = None + self._constrained = False + self._json_depth = 0 + + def __call__(self, token_ids: Any, logits: Any) -> Any: + """Apply FSM constraint to logits. + + In free mode, returns logits unchanged. + In constrained mode, masks all tokens not allowed by the FSM. + """ + import mlx.core as mx + + # Decode last token + if hasattr(token_ids, "tolist"): + id_list = token_ids.tolist() + else: + id_list = list(token_ids) + + if not id_list: + return logits + + last_tok = id_list[-1] + last_text = self.tokenizer.decode([last_tok], skip_special_tokens=False) + self._recent_text += last_text + if len(self._recent_text) > 500: + self._recent_text = self._recent_text[-500:] + + # --- Constrained mode: mask invalid tokens --- + if self._constrained and self._guide is not None: + # Advance FSM state with the token we just generated + try: + self._guide.advance(last_tok) + except Exception: + # Token not in FSM vocabulary — deactivate + logger.debug("[FSM] Token not in vocabulary, deactivating") + self._constrained = False + self._guide = None + return logits + + if self._guide.is_finished(): + # JSON body complete — back to free mode + self._constrained = False + self._guide = None + logger.info("[FSM] JSON body complete, back to free mode") + return logits + + # Get allowed tokens and mask logits + allowed = self._guide.get_tokens() + if allowed: + mask = mx.full(logits.shape, -float("inf")) + allowed_arr = mx.array(allowed) + if logits.ndim == 2: + mask[0, allowed_arr] = 0.0 + else: + mask[allowed_arr] = 0.0 + return logits + mask + + return logits + + # --- Free mode: check for pending activation from previous step --- + if getattr(self, "_pending_activation", False): + self._pending_activation = False + # Check if the model is starting JSON output + last_char = last_text.strip() + if last_char.startswith("{"): + guide = self._cache.get_guide(self.tools) + if guide is not None: + self._guide = guide + self._constrained = True + # Advance past the '{' we just saw + try: + self._guide.advance(last_tok) + except Exception: + self._constrained = False + self._guide = None + else: + logger.info( + f"[FSM] Trigger {self._trigger!r} + JSON start → " + "constrained mode" + ) + # Mask for the NEXT token + if not self._guide.is_finished(): + allowed = self._guide.get_tokens() + if allowed: + mask = mx.full(logits.shape, -float("inf")) + allowed_arr = mx.array(allowed) + if logits.ndim == 2: + mask[0, allowed_arr] = 0.0 + else: + mask[allowed_arr] = 0.0 + return logits + mask + else: + logger.info( + f"[FSM] Trigger detected but non-JSON start " + f"({last_char[:10]!r}), skipping FSM" + ) + + # --- Check for trigger (sets pending for NEXT step) --- + if self._trigger and self._recent_text.endswith(self._trigger): + self._pending_activation = True + + return logits + + +# ===================================================================== +# Factory +# ===================================================================== + + +def create_fsm_processor( + parser_name: str, + tokenizer: Any, + tools: list[dict] | None = None, +) -> FSMToolCallProcessor | None: + """Create an FSM tool call processor. + + Returns None if outlines-core is not installed or tools are empty. + """ + if not HAS_OUTLINES_CORE: + logger.debug("[FSM] outlines-core not installed, skipping") + return None + + if parser_name not in TOOL_CALL_TRIGGERS: + logger.debug(f"[FSM] No trigger pattern for parser {parser_name!r}") + return None + + # When no tools are provided (Scheduler doesn't have per-request tools), + # use a generic schema: {"name": , "arguments": } + effective_tools = tools or [ + {"function": {"name": "__generic__", "parameters": {"type": "object"}}} + ] + + return FSMToolCallProcessor( + tokenizer=tokenizer, + tools=effective_tools, + parser_name=parser_name, + cache=_fsm_cache, + ) + + +def is_fsm_available() -> bool: + """Check if FSM constrained decoding is available.""" + return HAS_OUTLINES_CORE diff --git a/vllm_mlx/doctor/cli.py b/vllm_mlx/doctor/cli.py index fe8cd965..27ec0c9a 100644 --- a/vllm_mlx/doctor/cli.py +++ b/vllm_mlx/doctor/cli.py @@ -251,6 +251,11 @@ def _run_per_model_block( model=model, log_path=server_log, boot_timeout_s=boot_timeout_s, + extra_args=[ + "--enable-auto-tool-choice", + "--tool-call-parser", "hermes", + "--enable-tool-logits-bias", + ], ) as info: port = info["port"] print(f" [server] {model} up on port {port}, log → {server_log.name}") @@ -406,6 +411,11 @@ def run_benchmark_tier(models: list[str] | None = None): model_path=local_path, log_path=server_log, boot_timeout_s=600, + extra_args=[ + "--enable-auto-tool-choice", + "--tool-call-parser", "hermes", + "--enable-tool-logits-bias", + ], ) as info: port = info["port"] print(f" [server] {model} up on port {port}") diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 47c7c525..6e2eea15 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -101,6 +101,12 @@ def __init__( self._cached_token_ids: list[int] = [] self._cache_lock = False # Simple guard against concurrent use + # Logits processor factory (e.g. FSM constrained decoding for tool calls). + # Set by server after model load. Called per-request, returns a processor + # or None. The processor's __call__(token_ids, logits) → logits interface + # is passed directly to mlx_lm.generate_step via logits_processors. + self._logits_processor_factory: Any | None = None + # Token-level output router (set in load() if model supports it) self._output_router = None @@ -627,6 +633,16 @@ def stream_generate( "prefill_step_size": self.prefill_step_size, } + # FSM constrained decoding (tool calls) — create a fresh processor + # per request so FSM state doesn't leak between requests. + if self._logits_processor_factory: + try: + proc = self._logits_processor_factory() + if proc is not None: + gen_kwargs["logits_processors"] = [proc] + except Exception as lp_err: + logger.warning("Failed to create logits processor: %s", lp_err) + # Native MTP speculative decoding if self._mtp: gen_kwargs["mtp"] = True diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a9646838..66e4cb95 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -358,6 +358,9 @@ async def lifespan(app: FastAPI): if _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded: await _engine.start() + # Complete FSM setup now that tokenizer is available (BatchedEngine path) + _deferred_fsm_setup() + # Warmup: generate one token to trigger Metal shader compilation. # Runs here (not in CLI) so all engine types are fully started first. if _engine is not None: @@ -970,39 +973,161 @@ def load_model( if _engine.preserve_native_tool_format: logger.info(f"Native tool format enabled for parser: {_tool_call_parser}") - # Set up tool logits bias processor factory (jump-forward decoding) + # Set up FSM-based tool call constrained decoding. + # The FSM guarantees valid JSON tool calls by masking invalid tokens + # during generation. Compilation happens at server startup (deferred + # to lifespan for BatchedEngine where tokenizer is available later). if _enable_tool_logits_bias and _enable_auto_tool_choice and _tool_call_parser: + _setup_fsm_tool_calls() + + logger.info(f"Default max tokens: {_default_max_tokens}") + + _register_model() + + +def _setup_fsm_tool_calls() -> None: + """Set up FSM-based constrained decoding for tool calls. + + Tries to initialize immediately (works for SimpleEngine where + tokenizer is available). For BatchedEngine, the tokenizer is only + available after start() — the lifespan handler calls + ``_deferred_fsm_setup()`` to complete initialization. + """ + global _engine + + tokenizer = None + if hasattr(_engine, "tokenizer"): try: - from .api.tool_logits import create_tool_logits_processor + tokenizer = _engine.tokenizer + except Exception: + pass + if tokenizer is None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer - tokenizer = None - if hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - elif hasattr(_engine, "tokenizer"): - tokenizer = _engine.tokenizer - if tokenizer is not None: - # Create factory that produces fresh processors per request - # Accepts optional tools for parameter value schema constraint - def _make_factory(parser_name, tok): - def factory(tools=None): - return create_tool_logits_processor( - parser_name, tok, tools=tools - ) + if tokenizer is None: + logger.info( + "[FSM] Tokenizer not yet available — will complete setup " + "after engine.start() (BatchedEngine)" + ) + return - return factory + _init_fsm_factory(tokenizer) - factory = _make_factory(_tool_call_parser, tokenizer) - # Set on BatchedEngine for use during scheduler init - if hasattr(_engine, "_tool_logits_processor_factory"): - _engine._tool_logits_processor_factory = factory - logger.info(f"Tool logits bias enabled for parser: {_tool_call_parser}") - else: - logger.warning("Tool logits bias requested but tokenizer not available") - except Exception as e: - logger.warning(f"Failed to set up tool logits bias: {e}") - logger.info(f"Default max tokens: {_default_max_tokens}") +def _deferred_fsm_setup() -> None: + """Complete FSM setup after engine.start() (for BatchedEngine).""" + global _engine + + if not _enable_tool_logits_bias or not _enable_auto_tool_choice: + return + # Skip if already set up + if ( + hasattr(_engine, "_tool_logits_processor_factory") + and _engine._tool_logits_processor_factory is not None + ): + return + + tokenizer = None + if hasattr(_engine, "tokenizer"): + try: + tokenizer = _engine.tokenizer + except Exception: + pass + if tokenizer is None: + return + + _init_fsm_factory(tokenizer) + + # Propagate to Scheduler (which captured None at init time) + _async_core = getattr(_engine, "_engine", None) + _core = getattr(_async_core, "engine", None) if _async_core else None + _sched = getattr(_core, "scheduler", None) if _core else None + if _sched is not None and hasattr(_sched, "_tool_logits_processor_factory"): + _sched._tool_logits_processor_factory = _engine._tool_logits_processor_factory + logger.info("[FSM] Propagated FSM factory to Scheduler") + +def _init_fsm_factory(tokenizer) -> None: + """Build the FSM cache and processor factory, set on the engine.""" + global _engine + + try: + from .api.fsm_tool_call import ( + create_fsm_processor, + get_fsm_cache, + is_fsm_available, + ) + + if not is_fsm_available(): + logger.info( + "[FSM] outlines-core not installed. " + "Install with: pip install outlines-core" + ) + # Fall back to old bias-based processor + _init_legacy_tool_logits(tokenizer) + return + + # Initialize the FSM vocabulary (one-time) + cache = get_fsm_cache() + cache.set_vocabulary(tokenizer) + + # Create factory: returns a fresh FSM processor per request + def _make_fsm_factory(parser_name, tok): + def factory(tools=None): + return create_fsm_processor(parser_name, tok, tools) + + return factory + + factory = _make_fsm_factory(_tool_call_parser, tokenizer) + + # BatchedEngine: set on engine for Scheduler to pick up + if hasattr(_engine, "_tool_logits_processor_factory"): + _engine._tool_logits_processor_factory = factory + + # SimpleEngine: set on the MLXLanguageModel for stream_generate + _model = getattr(_engine, "_model", None) + if _model is not None and hasattr(_model, "_logits_processor_factory"): + _model._logits_processor_factory = factory + logger.info("[FSM] Set logits processor factory on SimpleEngine model") + + logger.info( + f"[FSM] Constrained decoding enabled for parser: {_tool_call_parser}" + ) + + except Exception as e: + logger.warning(f"[FSM] Setup failed, falling back to legacy: {e}") + _init_legacy_tool_logits(tokenizer) + + +def _init_legacy_tool_logits(tokenizer) -> None: + """Fall back to the old bias-based tool logits processor.""" + global _engine + + try: + from .api.tool_logits import create_tool_logits_processor + + def _make_factory(parser_name, tok): + def factory(tools=None): + return create_tool_logits_processor(parser_name, tok, tools=tools) + + return factory + + factory = _make_factory(_tool_call_parser, tokenizer) + if hasattr(_engine, "_tool_logits_processor_factory"): + _engine._tool_logits_processor_factory = factory + logger.info( + f"Tool logits bias (legacy) enabled for parser: {_tool_call_parser}" + ) + except Exception as e: + logger.warning(f"Failed to set up legacy tool logits: {e}") + + +def _register_model() -> None: + """Register the loaded model in the multi-model registry. + + Must be called at the end of ``load_model()`` after all engine + and tool setup is complete. + """ # Register in multi-model registry aliases = set() if _model_alias and _model_alias != _model_name: @@ -1010,7 +1135,7 @@ def factory(tools=None): entry = ModelEntry( engine=_engine, model_name=_model_name, - model_path=_model_path or model_name, + model_path=_model_path or _model_name, aliases=aliases, tool_call_parser=_tool_call_parser, reasoning_parser=_reasoning_parser_name, @@ -1971,6 +2096,20 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re last_user_preview = content[:300] has_tools = bool(request.tools) n_tools = len(request.tools) if request.tools else 0 + + # Trigger FSM precompile on first request with tools (lazy, cached) + if has_tools and _enable_tool_logits_bias: + try: + from .api.fsm_tool_call import get_fsm_cache + + tools_dicts = [ + t.model_dump(exclude_none=True) if hasattr(t, "model_dump") else t + for t in request.tools + ] + get_fsm_cache().precompile(tools_dicts) + except Exception: + pass # FSM is optional enhancement, never block requests + logger.info( f"[REQUEST] POST /v1/chat/completions stream={request.stream} " f"model={request.model!r} max_tokens={request.max_tokens} "