Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions xinference/model/llm/llama_cpp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import pprint
import queue
from typing import Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from packaging import version

Expand All @@ -25,11 +25,41 @@
from ...utils import check_dependency_available
from ..core import LLM, chat_context_var
from ..llm_family import LLMFamilyV2, LLMSpecV1
from ..utils import ChatModelMixin
from ..utils import ChatModelMixin, normalize_response_format

logger = logging.getLogger(__name__)


def _schema_to_grammar(schema: Dict[str, Any]) -> Optional[str]:
try:
import xllamacpp
except Exception as e: # pragma: no cover - optional dependency
logger.warning("json_schema provided but xllamacpp missing: %s", e)
return None
try:
return xllamacpp.json_schema_to_grammar(schema) # type: ignore[attr-defined]
except Exception as e: # pragma: no cover - conversion failure
logger.warning("Failed to convert json_schema to grammar for xllamacpp: %s", e)
return None


def _apply_response_format(generate_config: Dict[str, Any]) -> None:
response_format = generate_config.pop("response_format", None)
normalized = normalize_response_format(response_format)
if not normalized or normalized.get("type") != "json_schema":
return
schema_dict = normalized.get("schema_dict")
if not schema_dict:
return
grammar = _schema_to_grammar(schema_dict)
if grammar:
# xllamacpp rejects configs containing both json_schema and grammar
generate_config.pop("json_schema", None)
generate_config["grammar"] = grammar
else:
generate_config.setdefault("json_schema", schema_dict)


class _Done:
pass

Expand All @@ -49,7 +79,7 @@ def __init__(
model_path: str,
llamacpp_model_config: Optional[dict] = None,
):
super().__init__(model_uid, model_family, model_path)
super().__init__(model_uid, model_family, model_path) # type: ignore[call-arg]
self._llamacpp_model_config = self._sanitize_model_config(llamacpp_model_config)
self._llm = None
self._executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
Expand Down Expand Up @@ -246,6 +276,7 @@ def generate(
generate_config = generate_config or {}
if not generate_config.get("max_tokens") and XINFERENCE_MAX_TOKENS:
generate_config["max_tokens"] = XINFERENCE_MAX_TOKENS
_apply_response_format(generate_config)
stream = generate_config.get("stream", False)
q: queue.Queue = queue.Queue()

Expand Down Expand Up @@ -305,6 +336,7 @@ def chat(
generate_config = generate_config or {}
if not generate_config.get("max_tokens") and XINFERENCE_MAX_TOKENS:
generate_config["max_tokens"] = XINFERENCE_MAX_TOKENS
_apply_response_format(generate_config)
stream = generate_config.get("stream", False)

chat_template_kwargs = (
Expand Down
193 changes: 193 additions & 0 deletions xinference/model/llm/llama_cpp/tests/test_structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import importlib
import importlib.util
import json
import sys
from enum import Enum
from types import SimpleNamespace
from typing import Any, Dict

import openai
import pytest
from pydantic import BaseModel

from xinference.client import Client

from ..core import _apply_response_format


class CarType(str, Enum):
sedan = "sedan"
suv = "SuV"
truck = "Truck"
coupe = "Coupe"


class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType


def _load_json_from_message(message: Any) -> Dict[str, Any]:
def _strip_think(text: str) -> str:
stripped = text.lstrip()
if stripped.startswith("<think>"):
if "</think>" in stripped:
stripped = stripped.split("</think>", 1)[1]
else:
stripped = stripped.split("<think>", 1)[1]
return stripped.lstrip()

raw_content = message.content
if isinstance(raw_content, str):
return json.loads(_strip_think(raw_content))

if isinstance(raw_content, list):
text_blocks = []
for block in raw_content:
if isinstance(block, dict):
if block.get("type") == "text" and "text" in block:
text_blocks.append(_strip_think(block["text"]))
continue

block_type = getattr(block, "type", None)
block_text = getattr(block, "text", None)
if block_type == "text" and block_text:
text_blocks.append(_strip_think(block_text))

if text_blocks:
return json.loads("".join(text_blocks))

pytest.fail(f"Unexpected message content format: {raw_content!r}")
raise AssertionError("Unreachable")


def test_apply_response_format_sets_grammar(monkeypatch):
fake_xllamacpp = SimpleNamespace(json_schema_to_grammar=lambda schema: "GRAMMAR")
monkeypatch.setitem(sys.modules, "xllamacpp", fake_xllamacpp)

cfg = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {
"type": "object",
"properties": {"a": {"type": "string"}},
"required": ["a"],
}
},
}
}

_apply_response_format(cfg)

assert "response_format" not in cfg
assert "json_schema" not in cfg
assert cfg["grammar"] == "GRAMMAR"


def test_apply_response_format_handles_conversion_failure(monkeypatch):
def _raise(_):
raise ValueError("bad schema")

fake_xllamacpp = SimpleNamespace(json_schema_to_grammar=_raise)
monkeypatch.setitem(sys.modules, "xllamacpp", fake_xllamacpp)

cfg = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {
"type": "object",
"properties": {"b": {"type": "string"}},
"required": ["b"],
}
},
}
}

_apply_response_format(cfg)

assert "response_format" not in cfg
assert cfg["json_schema"]["required"] == ["b"]
assert "grammar" not in cfg


def test_apply_response_format_ignores_non_schema(monkeypatch):
cfg = {"response_format": {"type": "json_object"}}
_apply_response_format(cfg)
assert "grammar" not in cfg
assert "json_schema" not in cfg


def test_apply_response_format_uses_real_xllamacpp_if_available():
if importlib.util.find_spec("xllamacpp") is None:
pytest.skip("xllamacpp not installed")
xllamacpp = importlib.import_module("xllamacpp")
if not hasattr(xllamacpp, "json_schema_to_grammar"):
pytest.skip("xllamacpp does not expose json_schema_to_grammar")

cfg = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {
"type": "object",
"properties": {"c": {"type": "integer"}},
"required": ["c"],
}
},
}
}

_apply_response_format(cfg)

assert "response_format" not in cfg
# Real xllamacpp should prefer grammar to avoid passing both
assert "json_schema" not in cfg
assert "grammar" in cfg and cfg["grammar"]


def test_llamacpp_qwen3_json_schema(setup):
endpoint, _ = setup
client = Client(endpoint)
model_uid = client.launch_model(
model_name="qwen3",
model_engine="llama.cpp",
model_size_in_billions="0_6",
model_format="ggufv2",
quantization="Q4_K_M",
n_gpu=None,
)

try:
api_client = openai.Client(api_key="not empty", base_url=f"{endpoint}/v1")
completion = api_client.chat.completions.create(
model=model_uid,
messages=[
{
"role": "user",
"content": (
"Generate a JSON containing the brand, model, and car_type of"
" an iconic 90s car."
),
}
],
temperature=0,
max_tokens=128,
response_format={
"type": "json_schema",
"json_schema": {
"name": "car-description",
"schema": CarDescription.model_json_schema(),
},
},
)

parsed = _load_json_from_message(completion.choices[0].message)
car_description = CarDescription.model_validate(parsed)
assert car_description.brand
assert car_description.model
finally:
if model_uid is not None:
client.terminate_model(model_uid)
26 changes: 26 additions & 0 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,32 @@ def get_stop_token_ids_from_config_file(model_path: str) -> Optional[List[int]]:
return None


def normalize_response_format(
response_format: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""
Normalize OpenAI-style response_format into a simple dict.
Returns:
None if missing/unsupported, or a dict with keys:
- type: "json_schema" | "json_object"
- schema_dict: dict (only for json_schema)
"""
if not response_format or not isinstance(response_format, dict):
return None

fmt_type = response_format.get("type")
if fmt_type not in ("json_schema", "json_object"):
return None

normalized: Dict[str, Any] = {"type": fmt_type}
if fmt_type == "json_schema":
schema_block = response_format.get("json_schema") or {}
schema_dict = schema_block.get("schema_") or schema_block.get("schema")
if schema_dict:
normalized["schema_dict"] = schema_dict
return normalized


def parse_messages(messages: List[Dict]) -> Tuple:
"""
Some older models still follow the old way of parameter passing.
Expand Down
Loading