diff --git a/xinference/model/llm/llama_cpp/core.py b/xinference/model/llm/llama_cpp/core.py
index 3f06d2f899..3e619f0e42 100644
--- a/xinference/model/llm/llama_cpp/core.py
+++ b/xinference/model/llm/llama_cpp/core.py
@@ -16,7 +16,7 @@
import os
import pprint
import queue
-from typing import Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from packaging import version
@@ -25,11 +25,41 @@
from ...utils import check_dependency_available
from ..core import LLM, chat_context_var
from ..llm_family import LLMFamilyV2, LLMSpecV1
-from ..utils import ChatModelMixin
+from ..utils import ChatModelMixin, normalize_response_format
logger = logging.getLogger(__name__)
+def _schema_to_grammar(schema: Dict[str, Any]) -> Optional[str]:
+ try:
+ import xllamacpp
+ except Exception as e: # pragma: no cover - optional dependency
+ logger.warning("json_schema provided but xllamacpp missing: %s", e)
+ return None
+ try:
+ return xllamacpp.json_schema_to_grammar(schema) # type: ignore[attr-defined]
+ except Exception as e: # pragma: no cover - conversion failure
+ logger.warning("Failed to convert json_schema to grammar for xllamacpp: %s", e)
+ return None
+
+
+def _apply_response_format(generate_config: Dict[str, Any]) -> None:
+ response_format = generate_config.pop("response_format", None)
+ normalized = normalize_response_format(response_format)
+ if not normalized or normalized.get("type") != "json_schema":
+ return
+ schema_dict = normalized.get("schema_dict")
+ if not schema_dict:
+ return
+ grammar = _schema_to_grammar(schema_dict)
+ if grammar:
+ # xllamacpp rejects configs containing both json_schema and grammar
+ generate_config.pop("json_schema", None)
+ generate_config["grammar"] = grammar
+ else:
+ generate_config.setdefault("json_schema", schema_dict)
+
+
class _Done:
pass
@@ -49,7 +79,7 @@ def __init__(
model_path: str,
llamacpp_model_config: Optional[dict] = None,
):
- super().__init__(model_uid, model_family, model_path)
+ super().__init__(model_uid, model_family, model_path) # type: ignore[call-arg]
self._llamacpp_model_config = self._sanitize_model_config(llamacpp_model_config)
self._llm = None
self._executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
@@ -246,6 +276,7 @@ def generate(
generate_config = generate_config or {}
if not generate_config.get("max_tokens") and XINFERENCE_MAX_TOKENS:
generate_config["max_tokens"] = XINFERENCE_MAX_TOKENS
+ _apply_response_format(generate_config)
stream = generate_config.get("stream", False)
q: queue.Queue = queue.Queue()
@@ -305,6 +336,7 @@ def chat(
generate_config = generate_config or {}
if not generate_config.get("max_tokens") and XINFERENCE_MAX_TOKENS:
generate_config["max_tokens"] = XINFERENCE_MAX_TOKENS
+ _apply_response_format(generate_config)
stream = generate_config.get("stream", False)
chat_template_kwargs = (
diff --git a/xinference/model/llm/llama_cpp/tests/test_structured.py b/xinference/model/llm/llama_cpp/tests/test_structured.py
new file mode 100644
index 0000000000..1f2c5bc499
--- /dev/null
+++ b/xinference/model/llm/llama_cpp/tests/test_structured.py
@@ -0,0 +1,193 @@
+import importlib
+import importlib.util
+import json
+import sys
+from enum import Enum
+from types import SimpleNamespace
+from typing import Any, Dict
+
+import openai
+import pytest
+from pydantic import BaseModel
+
+from xinference.client import Client
+
+from ..core import _apply_response_format
+
+
+class CarType(str, Enum):
+ sedan = "sedan"
+ suv = "SuV"
+ truck = "Truck"
+ coupe = "Coupe"
+
+
+class CarDescription(BaseModel):
+ brand: str
+ model: str
+ car_type: CarType
+
+
+def _load_json_from_message(message: Any) -> Dict[str, Any]:
+ def _strip_think(text: str) -> str:
+ stripped = text.lstrip()
+ if stripped.startswith(""):
+ if "" in stripped:
+ stripped = stripped.split("", 1)[1]
+ else:
+ stripped = stripped.split("", 1)[1]
+ return stripped.lstrip()
+
+ raw_content = message.content
+ if isinstance(raw_content, str):
+ return json.loads(_strip_think(raw_content))
+
+ if isinstance(raw_content, list):
+ text_blocks = []
+ for block in raw_content:
+ if isinstance(block, dict):
+ if block.get("type") == "text" and "text" in block:
+ text_blocks.append(_strip_think(block["text"]))
+ continue
+
+ block_type = getattr(block, "type", None)
+ block_text = getattr(block, "text", None)
+ if block_type == "text" and block_text:
+ text_blocks.append(_strip_think(block_text))
+
+ if text_blocks:
+ return json.loads("".join(text_blocks))
+
+ pytest.fail(f"Unexpected message content format: {raw_content!r}")
+ raise AssertionError("Unreachable")
+
+
+def test_apply_response_format_sets_grammar(monkeypatch):
+ fake_xllamacpp = SimpleNamespace(json_schema_to_grammar=lambda schema: "GRAMMAR")
+ monkeypatch.setitem(sys.modules, "xllamacpp", fake_xllamacpp)
+
+ cfg = {
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {
+ "schema": {
+ "type": "object",
+ "properties": {"a": {"type": "string"}},
+ "required": ["a"],
+ }
+ },
+ }
+ }
+
+ _apply_response_format(cfg)
+
+ assert "response_format" not in cfg
+ assert "json_schema" not in cfg
+ assert cfg["grammar"] == "GRAMMAR"
+
+
+def test_apply_response_format_handles_conversion_failure(monkeypatch):
+ def _raise(_):
+ raise ValueError("bad schema")
+
+ fake_xllamacpp = SimpleNamespace(json_schema_to_grammar=_raise)
+ monkeypatch.setitem(sys.modules, "xllamacpp", fake_xllamacpp)
+
+ cfg = {
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {
+ "schema": {
+ "type": "object",
+ "properties": {"b": {"type": "string"}},
+ "required": ["b"],
+ }
+ },
+ }
+ }
+
+ _apply_response_format(cfg)
+
+ assert "response_format" not in cfg
+ assert cfg["json_schema"]["required"] == ["b"]
+ assert "grammar" not in cfg
+
+
+def test_apply_response_format_ignores_non_schema(monkeypatch):
+ cfg = {"response_format": {"type": "json_object"}}
+ _apply_response_format(cfg)
+ assert "grammar" not in cfg
+ assert "json_schema" not in cfg
+
+
+def test_apply_response_format_uses_real_xllamacpp_if_available():
+ if importlib.util.find_spec("xllamacpp") is None:
+ pytest.skip("xllamacpp not installed")
+ xllamacpp = importlib.import_module("xllamacpp")
+ if not hasattr(xllamacpp, "json_schema_to_grammar"):
+ pytest.skip("xllamacpp does not expose json_schema_to_grammar")
+
+ cfg = {
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {
+ "schema": {
+ "type": "object",
+ "properties": {"c": {"type": "integer"}},
+ "required": ["c"],
+ }
+ },
+ }
+ }
+
+ _apply_response_format(cfg)
+
+ assert "response_format" not in cfg
+ # Real xllamacpp should prefer grammar to avoid passing both
+ assert "json_schema" not in cfg
+ assert "grammar" in cfg and cfg["grammar"]
+
+
+def test_llamacpp_qwen3_json_schema(setup):
+ endpoint, _ = setup
+ client = Client(endpoint)
+ model_uid = client.launch_model(
+ model_name="qwen3",
+ model_engine="llama.cpp",
+ model_size_in_billions="0_6",
+ model_format="ggufv2",
+ quantization="Q4_K_M",
+ n_gpu=None,
+ )
+
+ try:
+ api_client = openai.Client(api_key="not empty", base_url=f"{endpoint}/v1")
+ completion = api_client.chat.completions.create(
+ model=model_uid,
+ messages=[
+ {
+ "role": "user",
+ "content": (
+ "Generate a JSON containing the brand, model, and car_type of"
+ " an iconic 90s car."
+ ),
+ }
+ ],
+ temperature=0,
+ max_tokens=128,
+ response_format={
+ "type": "json_schema",
+ "json_schema": {
+ "name": "car-description",
+ "schema": CarDescription.model_json_schema(),
+ },
+ },
+ )
+
+ parsed = _load_json_from_message(completion.choices[0].message)
+ car_description = CarDescription.model_validate(parsed)
+ assert car_description.brand
+ assert car_description.model
+ finally:
+ if model_uid is not None:
+ client.terminate_model(model_uid)
diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py
index 420598dbba..2fc54d9d3a 100644
--- a/xinference/model/llm/utils.py
+++ b/xinference/model/llm/utils.py
@@ -1235,6 +1235,32 @@ def get_stop_token_ids_from_config_file(model_path: str) -> Optional[List[int]]:
return None
+def normalize_response_format(
+ response_format: Optional[Dict[str, Any]],
+) -> Optional[Dict[str, Any]]:
+ """
+ Normalize OpenAI-style response_format into a simple dict.
+ Returns:
+ None if missing/unsupported, or a dict with keys:
+ - type: "json_schema" | "json_object"
+ - schema_dict: dict (only for json_schema)
+ """
+ if not response_format or not isinstance(response_format, dict):
+ return None
+
+ fmt_type = response_format.get("type")
+ if fmt_type not in ("json_schema", "json_object"):
+ return None
+
+ normalized: Dict[str, Any] = {"type": fmt_type}
+ if fmt_type == "json_schema":
+ schema_block = response_format.get("json_schema") or {}
+ schema_dict = schema_block.get("schema_") or schema_block.get("schema")
+ if schema_dict:
+ normalized["schema_dict"] = schema_dict
+ return normalized
+
+
def parse_messages(messages: List[Dict]) -> Tuple:
"""
Some older models still follow the old way of parameter passing.