Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions xinference/model/llm/llama_cpp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import pprint
import queue
from typing import Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from packaging import version

Expand All @@ -25,11 +25,38 @@
from ...utils import check_dependency_available
from ..core import LLM, chat_context_var
from ..llm_family import LLMFamilyV2, LLMSpecV1
from ..utils import ChatModelMixin
from ..utils import ChatModelMixin, normalize_response_format

logger = logging.getLogger(__name__)


def _schema_to_grammar(schema: Dict[str, Any]) -> Optional[str]:
try:
import xllamacpp
except Exception as e: # pragma: no cover - optional dependency
logger.warning("json_schema provided but xllamacpp missing: %s", e)
return None
try:
return xllamacpp.json_schema_to_grammar(schema) # type: ignore[attr-defined]
except Exception as e: # pragma: no cover - conversion failure
logger.warning("Failed to convert json_schema to grammar for xllamacpp: %s", e)
return None


def _apply_response_format(generate_config: Dict[str, Any]) -> None:
response_format = generate_config.pop("response_format", None)
normalized = normalize_response_format(response_format)
if not normalized or normalized.get("type") != "json_schema":
return
schema_dict = normalized.get("schema_dict")
if not schema_dict:
return
generate_config.setdefault("json_schema", schema_dict)
grammar = _schema_to_grammar(schema_dict)
if grammar:
generate_config.setdefault("grammar", grammar)


class _Done:
pass

Expand All @@ -49,7 +76,7 @@ def __init__(
model_path: str,
llamacpp_model_config: Optional[dict] = None,
):
super().__init__(model_uid, model_family, model_path)
super().__init__(model_uid, model_family, model_path) # type: ignore[call-arg]
self._llamacpp_model_config = self._sanitize_model_config(llamacpp_model_config)
self._llm = None
self._executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
Expand Down Expand Up @@ -246,6 +273,7 @@ def generate(
generate_config = generate_config or {}
if not generate_config.get("max_tokens") and XINFERENCE_MAX_TOKENS:
generate_config["max_tokens"] = XINFERENCE_MAX_TOKENS
_apply_response_format(generate_config)
stream = generate_config.get("stream", False)
q: queue.Queue = queue.Queue()

Expand Down Expand Up @@ -305,6 +333,7 @@ def chat(
generate_config = generate_config or {}
if not generate_config.get("max_tokens") and XINFERENCE_MAX_TOKENS:
generate_config["max_tokens"] = XINFERENCE_MAX_TOKENS
_apply_response_format(generate_config)
stream = generate_config.get("stream", False)

chat_template_kwargs = (
Expand Down
103 changes: 103 additions & 0 deletions xinference/model/llm/llama_cpp/tests/test_structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import sys
from types import SimpleNamespace


def test_apply_response_format_sets_grammar(monkeypatch):
from xinference.model.llm.llama_cpp.core import _apply_response_format

fake_xllamacpp = SimpleNamespace(json_schema_to_grammar=lambda schema: "GRAMMAR")
monkeypatch.setitem(sys.modules, "xllamacpp", fake_xllamacpp)

cfg = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {
"type": "object",
"properties": {"a": {"type": "string"}},
"required": ["a"],
}
},
}
}

_apply_response_format(cfg)

assert "response_format" not in cfg
assert cfg["json_schema"]["required"] == ["a"]
assert cfg["grammar"] == "GRAMMAR"


def test_apply_response_format_handles_conversion_failure(monkeypatch):
from xinference.model.llm.llama_cpp.core import _apply_response_format

def _raise(_):
raise ValueError("bad schema")

fake_xllamacpp = SimpleNamespace(json_schema_to_grammar=_raise)
monkeypatch.setitem(sys.modules, "xllamacpp", fake_xllamacpp)

cfg = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {
"type": "object",
"properties": {"b": {"type": "string"}},
"required": ["b"],
}
},
}
}

_apply_response_format(cfg)

assert "response_format" not in cfg
assert cfg["json_schema"]["required"] == ["b"]
assert "grammar" not in cfg


def test_apply_response_format_ignores_non_schema(monkeypatch):
from xinference.model.llm.llama_cpp.core import _apply_response_format

cfg = {"response_format": {"type": "json_object"}}
_apply_response_format(cfg)
assert "grammar" not in cfg
assert "json_schema" not in cfg


def test_apply_response_format_uses_real_xllamacpp_if_available():
import importlib.util

import pytest

if importlib.util.find_spec("xllamacpp") is None:
pytest.skip("xllamacpp not installed")

import importlib

xllamacpp = importlib.import_module("xllamacpp")
if not hasattr(xllamacpp, "json_schema_to_grammar"):
pytest.skip("xllamacpp does not expose json_schema_to_grammar")

from xinference.model.llm.llama_cpp.core import _apply_response_format

cfg = {
"response_format": {
"type": "json_schema",
"json_schema": {
"schema": {
"type": "object",
"properties": {"c": {"type": "integer"}},
"required": ["c"],
}
},
}
}

_apply_response_format(cfg)

assert "response_format" not in cfg
# Real xllamacpp should attach grammar alongside json_schema
assert "json_schema" in cfg
assert "grammar" in cfg and cfg["grammar"]
26 changes: 26 additions & 0 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,32 @@ def get_stop_token_ids_from_config_file(model_path: str) -> Optional[List[int]]:
return None


def normalize_response_format(
response_format: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
"""
Normalize OpenAI-style response_format into a simple dict.
Returns:
None if missing/unsupported, or a dict with keys:
- type: "json_schema" | "json_object"
- schema_dict: dict (only for json_schema)
"""
if not response_format or not isinstance(response_format, dict):
return None

fmt_type = response_format.get("type")
if fmt_type not in ("json_schema", "json_object"):
return None

normalized: Dict[str, Any] = {"type": fmt_type}
if fmt_type == "json_schema":
schema_block = response_format.get("json_schema") or {}
schema_dict = schema_block.get("schema_") or schema_block.get("schema")
if schema_dict:
normalized["schema_dict"] = schema_dict
return normalized


def parse_messages(messages: List[Dict]) -> Tuple:
"""
Some older models still follow the old way of parameter passing.
Expand Down
Loading