diff --git a/tests/entrypoints/openai/test_tool_parser_kwargs.py b/tests/entrypoints/openai/test_tool_parser_kwargs.py new file mode 100644 index 000000000000..b3fb934a15da --- /dev/null +++ b/tests/entrypoints/openai/test_tool_parser_kwargs.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E402 + +import sys +import types +from unittest.mock import MagicMock + +import pytest + +llguidance = types.ModuleType("llguidance") +llguidance.LLMatcher = object +llguidance.LLTokenizer = object +sys.modules.setdefault("llguidance", llguidance) + +llguidance_hf = types.ModuleType("llguidance.hf") +llguidance_hf.from_tokenizer = lambda *args, **kwargs: None +sys.modules.setdefault("llguidance.hf", llguidance_hf) + +llguidance_torch = types.ModuleType("llguidance.torch") +llguidance_torch.allocate_token_bitmask = lambda *args, **kwargs: None +llguidance_torch.fill_next_token_bitmask = lambda *args, **kwargs: None +sys.modules.setdefault("llguidance.torch", llguidance_torch) + +diskcache = types.ModuleType("diskcache") +diskcache.Cache = object +sys.modules.setdefault("diskcache", diskcache) + +xgrammar = types.ModuleType("xgrammar") +xgrammar.TokenizerInfo = object +xgrammar.VocabType = types.SimpleNamespace(RAW="RAW", BYTE_FALLBACK="BYTE_FALLBACK") +xgrammar.GrammarCompiler = object +xgrammar.StructuralTagItem = object +xgrammar.GrammarMatcher = object +xgrammar.__file__ = "/tmp/xgrammar_stub.py" + + +def _xgrammar_getattr(name: str): + if name.startswith("__"): + raise AttributeError(name) + return object + + +xgrammar.__getattr__ = _xgrammar_getattr +sys.modules.setdefault("xgrammar", xgrammar) + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.tool_parsers import ToolParser + +pytestmark = pytest.mark.cpu_test + + +class KwargAwareToolParser(ToolParser): + def __init__(self, tokenizer, chat_template_kwargs=None): + super().__init__(tokenizer) + self.tool_format = (chat_template_kwargs or {}).get("tool_format") + + def extract_tool_calls(self, model_output, request): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + ToolCall( + function=FunctionCall( + name=self.tool_format or "missing", + arguments="{}", + ) + ) + ], + content=None, + ) + + def extract_tool_calls_streaming( + self, + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request, + ): + return None + + +class LegacyToolParser(ToolParser): + def __init__(self, tokenizer): + super().__init__(tokenizer) + + def extract_tool_calls(self, model_output, request): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + ToolCall( + function=FunctionCall( + name="legacy", + arguments="{}", + ) + ) + ], + content=None, + ) + + def extract_tool_calls_streaming( + self, + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request, + ): + return None + + +def make_request() -> ChatCompletionRequest: + return ChatCompletionRequest( + model="test-model", + messages=[], + tool_choice="auto", + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + ], + ) + + +def test_parse_tool_calls_from_content_passes_chat_template_kwargs(): + request = make_request() + + function_calls, content = OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=MagicMock(), + enable_auto_tools=True, + tool_parser_cls=KwargAwareToolParser, + content="noop()", + chat_template_kwargs={"tool_format": "python"}, + ) + + assert content is None + assert function_calls is not None + assert len(function_calls) == 1 + assert function_calls[0].name == "python" + assert function_calls[0].arguments == "{}" + + +def test_parse_tool_calls_from_content_keeps_legacy_parsers_compatible(): + request = make_request() + + function_calls, content = OpenAIServing._parse_tool_calls_from_content( + request=request, + tokenizer=MagicMock(), + enable_auto_tools=True, + tool_parser_cls=LegacyToolParser, + content="noop()", + chat_template_kwargs={"tool_format": "python"}, + ) + + assert content is None + assert function_calls is not None + assert len(function_calls) == 1 + assert function_calls[0].name == "legacy" + assert function_calls[0].arguments == "{}" diff --git a/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py new file mode 100644 index 000000000000..0a403a29b0d2 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py @@ -0,0 +1,386 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import run_tool_extraction_nonstreaming +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + +pytestmark = pytest.mark.cpu_test + + +class FakeTokenizer: + def __init__(self): + self._vocab: dict[str, int] = {} + self._next_token_id = 1 + + def get_vocab(self): + return self._vocab + + def encode(self, text: str, add_special_tokens: bool = False): + if text not in self._vocab: + self._vocab[text] = self._next_token_id + self._next_token_id += 1 + return [self._vocab[text]] + + def decode(self, token_ids): + reverse_vocab = {token_id: token for token, token_id in self._vocab.items()} + return "".join(reverse_vocab[token_id] for token_id in token_ids) + + +def make_parser(tool_format: str) -> ToolParser: + return ToolParserManager.get_tool_parser("multi_format")( + FakeTokenizer(), + chat_template_kwargs={"tool_format": tool_format}, + ) + + +def make_request() -> ChatCompletionRequest: + return ChatCompletionRequest( + model="test-model", + messages=[], + ) + + +def test_default_format_delegates_to_hermes(): + parser = make_parser("default") + + extracted = run_tool_extraction_nonstreaming( + parser, + '\n{"name":"get_weather","arguments":{"city":"Tokyo"}}\n', + make_request(), + ) + + assert extracted.tools_called + assert extracted.content is None + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == {"city": "Tokyo"} + + +def test_qwen3_format_delegates_to_qwen3xml(): + parser = make_parser("qwen3") + + extracted = run_tool_extraction_nonstreaming( + parser, + "\n\n" + "Tokyo\n" + "\n", + make_request(), + ) + + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == {"city": "Tokyo"} + + +def test_glm_format_matches_template_output(): + parser = make_parser("glm") + + extracted = run_tool_extraction_nonstreaming( + parser, + "get_weathercity" + "Beijing", + make_request(), + ) + + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "city": "Beijing" + } + + +def test_minimax_format_extracts_inline_invokes(): + parser = make_parser("minimax") + + extracted = run_tool_extraction_nonstreaming( + parser, + "Checking." + '' + 'Tokyo' + '5' + 'true' + '{"units":"metric"}' + "", + make_request(), + ) + + assert extracted.tools_called + assert extracted.content == "Checking." + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "city": "Tokyo", + "days": 5, + "enabled": True, + "filters": {"units": "metric"}, + } + + +def test_dsv32_format_honors_string_attribute(): + parser = make_parser("dsv32") + + extracted = run_tool_extraction_nonstreaming( + parser, + "Prefix" + "" + '' + 'Tokyo' + '5' + 'false' + "" + "", + make_request(), + ) + + assert extracted.tools_called + assert extracted.content == "Prefix" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "city": "Tokyo", + "days": 5, + "enabled": False, + } + + +def test_gptoss_format_extracts_multiple_calls(): + parser = make_parser("gptoss") + + extracted = run_tool_extraction_nonstreaming( + parser, + "Planning...\n" + "to=functions.get_weather json\n" + '{"location":"SF"}\n' + "" + "to=functions.get_time json\n" + '{"timezone":"UTC"}\n' + "", + make_request(), + ) + + assert extracted.tools_called + assert extracted.content == "Planning...\n" + assert [tool_call.function.name for tool_call in extracted.tool_calls] == [ + "get_weather", + "get_time", + ] + assert json.loads(extracted.tool_calls[0].function.arguments) == {"location": "SF"} + assert json.loads(extracted.tool_calls[1].function.arguments) == { + "timezone": "UTC" + } + + +def test_gptoss_format_with_assistant_prefix(): + parser = make_parser("gptoss") + + extracted = run_tool_extraction_nonstreaming( + parser, + 'assistant to=functions.get_weather json\n' + '{"location": "San Francisco, CA", "unit": "celsius"}\n' + "", + make_request(), + ) + + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "location": "San Francisco, CA", + "unit": "celsius", + } + + +def test_python_format_extracts_single_call(): + parser = make_parser("python") + + extracted = run_tool_extraction_nonstreaming( + parser, + '\nget_weather(city="SF")\n', + make_request(), + ) + + assert extracted.tools_called + assert extracted.content is None + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == {"city": "SF"} + + +def test_python_format_extracts_multiple_calls(): + parser = make_parser("python") + + extracted = run_tool_extraction_nonstreaming( + parser, + '\nget_weather(city="SF")\n' + "\n" + '\nget_time(timezone="UTC")\n', + make_request(), + ) + + assert extracted.tools_called + assert [tool_call.function.name for tool_call in extracted.tool_calls] == [ + "get_weather", + "get_time", + ] + assert json.loads(extracted.tool_calls[0].function.arguments) == {"city": "SF"} + assert json.loads(extracted.tool_calls[1].function.arguments) == { + "timezone": "UTC" + } + + +def test_python_format_accepts_nested_json_style_literals(): + parser = make_parser("python") + + extracted = run_tool_extraction_nonstreaming( + parser, + '\n' + 'get_weather(city="SF", meta={"enabled": true, "missing": null})\n' + '', + make_request(), + ) + + assert extracted.tools_called + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "city": "SF", + "meta": {"enabled": True, "missing": None}, + } + + +def test_custom_formats_do_not_stream_yet(): + parser = make_parser("python") + + delta = parser.extract_tool_calls_streaming( + previous_text="", + current_text="", + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=make_request(), + ) + + assert delta is None + + +def test_readme_default_example(): + parser = make_parser("default") + extracted = run_tool_extraction_nonstreaming( + parser, + '\n' + '{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}}\n' + "", + make_request(), + ) + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "location": "San Francisco, CA" + } + + +def test_readme_qwen3_example(): + parser = make_parser("qwen3") + extracted = run_tool_extraction_nonstreaming( + parser, + '\n' + '\n' + '\n' + 'San Francisco, CA\n' + '\n' + '\n' + '', + make_request(), + ) + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "location": "San Francisco, CA" + } + + +def test_readme_minimax_example(): + parser = make_parser("minimax") + extracted = run_tool_extraction_nonstreaming( + parser, + '\n' + '\n' + 'San Francisco, CA\n' + 'celsius\n' + '\n' + '', + make_request(), + ) + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "location": "San Francisco, CA", + "unit": "celsius", + } + + +def test_readme_glm_example(): + parser = make_parser("glm") + extracted = run_tool_extraction_nonstreaming( + parser, + "get_weather" + "locationSan Francisco, CA" + "", + make_request(), + ) + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "location": "San Francisco, CA" + } + + +def test_readme_dsv32_example(): + parser = make_parser("dsv32") + extracted = run_tool_extraction_nonstreaming( + parser, + '\n' + '\n' + 'San Francisco, CA\n' + 'celsius\n' + '\n' + '', + make_request(), + ) + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + args = json.loads(extracted.tool_calls[0].function.arguments) + assert args == {"location": "San Francisco, CA", "unit": "celsius"} + assert isinstance(args["location"], str) + assert isinstance(args["unit"], str) + + +def test_readme_gptoss_example(): + parser = make_parser("gptoss") + extracted = run_tool_extraction_nonstreaming( + parser, + 'assistant to=functions.get_weather json\n' + '{"location": "San Francisco, CA", "unit": "celsius"}\n' + "", + make_request(), + ) + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "location": "San Francisco, CA", + "unit": "celsius", + } + + +def test_readme_python_example(): + parser = make_parser("python") + extracted = run_tool_extraction_nonstreaming( + parser, + '\n' + 'get_weather(location="San Francisco, CA", unit="celsius")\n' + '', + make_request(), + ) + assert extracted.tools_called + assert extracted.tool_calls[0].function.name == "get_weather" + assert json.loads(extracted.tool_calls[0].function.arguments) == { + "location": "San Francisco, CA", + "unit": "celsius", + } diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index cecd1da1e554..8ba308934e2f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -599,7 +599,11 @@ async def chat_completion_stream_generator( try: if tool_choice_auto and self.tool_parser: tool_parsers: list[ToolParser | None] = [ - self.tool_parser(tokenizer) + self._create_tool_parser( + self.tool_parser, + tokenizer, + chat_template_kwargs=request.chat_template_kwargs, + ) ] * num_choices else: tool_parsers = [None] * num_choices @@ -1344,7 +1348,11 @@ async def chat_completion_full_generator( reasoning = None if self.tool_parser is not None: - tool_parser = self.tool_parser(tokenizer) + tool_parser = self._create_tool_parser( + self.tool_parser, + tokenizer, + chat_template_kwargs=request.chat_template_kwargs, + ) # NOTE: We use token_ids for openai tool parser tool_call_info = tool_parser.extract_tool_calls( "", @@ -1413,6 +1421,7 @@ async def chat_completion_full_generator( content=content, enable_auto_tools=self.enable_auto_tools, tool_parser_cls=self.tool_parser, + chat_template_kwargs=request.chat_template_kwargs, ) tool_call_class = ( MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1d89aa011af2..2c00a3c6369d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import inspect import json import sys import time @@ -285,7 +286,7 @@ def __init__( def _get_tool_parser( self, tool_parser_name: str | None = None, enable_auto_tools: bool = False - ) -> Callable[[TokenizerLike], ToolParser] | None: + ) -> type[ToolParser] | None: """Get the tool parser based on the name.""" parser = None if not enable_auto_tools or tool_parser_name is None: @@ -1082,7 +1083,7 @@ async def _preprocess_chat( tool_dicts: list[dict[str, Any]] | None = None, documents: list[dict[str, str]] | None = None, chat_template_kwargs: dict[str, Any] | None = None, - tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + tool_parser: type[ToolParser] | None = None, add_special_tokens: bool = False, ) -> tuple[ list[ConversationMessage], @@ -1153,7 +1154,11 @@ async def _preprocess_chat( "or Responses API requests." ) raise NotImplementedError(msg) - request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore + request = self._create_tool_parser( + tool_parser, + tokenizer, + chat_template_kwargs=chat_template_kwargs, + ).adjust_request(request=request) if tokenizer is None: assert isinstance(request_prompt, str), ( @@ -1366,13 +1371,53 @@ def _get_data_parallel_rank(raw_request: Request | None) -> int | None: except ValueError: return None + @staticmethod + def _get_tool_parser_init_kwargs( + tool_parser_cls: type[ToolParser], + chat_template_kwargs: dict[str, Any] | None, + ) -> dict[str, Any]: + if not chat_template_kwargs: + return {} + + try: + init_signature = inspect.signature(tool_parser_cls.__init__) + except (TypeError, ValueError): + return {} + + if "chat_template_kwargs" in init_signature.parameters: + return {"chat_template_kwargs": chat_template_kwargs} + + if any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in init_signature.parameters.values() + ): + return {"chat_template_kwargs": chat_template_kwargs} + + return {} + + @classmethod + def _create_tool_parser( + cls, + tool_parser_cls: type[ToolParser], + tokenizer: TokenizerLike, + chat_template_kwargs: dict[str, Any] | None = None, + ) -> ToolParser: + return tool_parser_cls( + tokenizer, + **cls._get_tool_parser_init_kwargs( + tool_parser_cls, + chat_template_kwargs, + ), + ) + @staticmethod def _parse_tool_calls_from_content( request: ResponsesRequest | ChatCompletionRequest, tokenizer: TokenizerLike, enable_auto_tools: bool, - tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, + tool_parser_cls: type[ToolParser] | None, content: str | None = None, + chat_template_kwargs: dict[str, Any] | None = None, ) -> tuple[list[FunctionCall] | None, str | None]: function_calls = list[FunctionCall]() if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): @@ -1411,7 +1456,11 @@ def _parse_tool_calls_from_content( ): # Automatic Tool Call Parsing try: - tool_parser = tool_parser_cls(tokenizer) + tool_parser = OpenAIServing._create_tool_parser( + tool_parser_cls, + tokenizer, + chat_template_kwargs=chat_template_kwargs, + ) except RuntimeError as e: logger.exception("Error in tool parser creation.") raise e diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 81495a077754..83a67bd91943 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -855,6 +855,7 @@ def _make_response_output_items( content=content, enable_auto_tools=self.enable_auto_tools, tool_parser_cls=self.tool_parser, + chat_template_kwargs=request.chat_template_kwargs, ) if content: output_text = ResponseOutputText( diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 89e439dd53f5..0680d90b51e7 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -90,6 +90,10 @@ "minimax_tool_parser", "MinimaxToolParser", ), + "multi_format": ( + "multi_format_tool_parser", + "MultiFormatToolParser", + ), "mistral": ( "mistral_tool_parser", "MistralToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py new file mode 100644 index 000000000000..dcd8e8d428c1 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py @@ -0,0 +1,449 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import json +from collections.abc import Sequence +from typing import Any + +import regex as re + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike + +logger = init_logger(__name__) + + +class MultiFormatToolParser(ToolParser): + """Tool parser that dispatches on ``chat_template_kwargs['tool_format']``.""" + + _MINIMAX_START_TOKEN = "" + _MINIMAX_BLOCK_REGEX = re.compile( + r"(.*?)", + re.DOTALL, + ) + _MINIMAX_INVOKE_REGEX = re.compile( + r'(.*?)', + re.DOTALL, + ) + _MINIMAX_PARAMETER_REGEX = re.compile( + r'(.*?)', + re.DOTALL, + ) + + _GPTOSS_BLOCK_REGEX = re.compile( + r"\s*(?:assistant\s+)?to=functions\.(\S+?)" + r"(?:\s+json)?\s*\n(.*?)\n?\s*", + re.DOTALL, + ) + + _PYTHON_BLOCK_REGEX = re.compile( + r"(.*?)", + re.DOTALL, + ) + _GLM_BLOCK_REGEX = re.compile( + r"(.*?)", + re.DOTALL, + ) + _GLM_ARG_REGEX = re.compile( + r"(.*?)\s*(.*?)", + re.DOTALL, + ) + + def __init__( + self, + tokenizer: TokenizerLike, + chat_template_kwargs: dict[str, Any] | None = None, + ): + super().__init__(tokenizer) + + self.tool_format = str( + (chat_template_kwargs or {}).get("tool_format") or "default" + ) + self._delegate: ToolParser | None = None + + if self.tool_format == "default": + from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import ( + Hermes2ProToolParser, + ) + + self._delegate = Hermes2ProToolParser(tokenizer) + elif self.tool_format == "qwen3": + from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import ( + Qwen3XMLToolParser, + ) + + self._delegate = Qwen3XMLToolParser(tokenizer) + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if self._delegate is not None: + return self._delegate.adjust_request(request) + return super().adjust_request(request) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + if self._delegate is not None: + return self._delegate.extract_tool_calls(model_output, request) + + try: + if self.tool_format == "minimax": + return self._extract_minimax_tool_calls(model_output) + if self.tool_format == "dsv32": + return self._extract_dsv32_tool_calls(model_output) + if self.tool_format == "glm": + return self._extract_glm_tool_calls(model_output, request) + if self.tool_format == "gptoss": + return self._extract_gptoss_tool_calls(model_output) + if self.tool_format == "python": + return self._extract_python_tool_calls(model_output) + except Exception: + logger.exception( + "Error extracting tool calls for tool_format=%s.", + self.tool_format, + ) + + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if self._delegate is not None: + return self._delegate.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request, + ) + + return None + + @staticmethod + def _json_or_string(value: str) -> Any: + value = value.strip() + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + @staticmethod + def _prefix_content(model_output: str, first_tool_index: int | None) -> str | None: + if first_tool_index is None or first_tool_index <= 0: + return None + content = model_output[:first_tool_index] + return content if content.strip() else None + + @staticmethod + def _tool_call(function_name: str, arguments: dict[str, Any]) -> ToolCall: + return ToolCall( + type="function", + function=FunctionCall( + name=function_name, + arguments=json.dumps(arguments, ensure_ascii=False), + ), + ) + + def _extract_minimax_tool_calls( + self, + model_output: str, + ) -> ExtractedToolCallInformation: + if self._MINIMAX_START_TOKEN not in model_output: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls: list[ToolCall] = [] + for block in self._MINIMAX_BLOCK_REGEX.findall(model_output): + for function_name, invoke_body in self._MINIMAX_INVOKE_REGEX.findall(block): + arguments = { + param_name: self._json_or_string(param_value) + for param_name, _, param_value in ( + self._MINIMAX_PARAMETER_REGEX.findall(invoke_body) + ) + } + tool_calls.append(self._tool_call(function_name, arguments)) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=self._prefix_content( + model_output, + model_output.find(self._MINIMAX_START_TOKEN), + ), + ) + + def _extract_dsv32_tool_calls( + self, + model_output: str, + ) -> ExtractedToolCallInformation: + if self._MINIMAX_START_TOKEN not in model_output: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls: list[ToolCall] = [] + for block in self._MINIMAX_BLOCK_REGEX.findall(model_output): + for function_name, invoke_body in self._MINIMAX_INVOKE_REGEX.findall(block): + arguments: dict[str, Any] = {} + for ( + param_name, + string_flag, + param_value, + ) in self._MINIMAX_PARAMETER_REGEX.findall(invoke_body): + arguments[param_name] = ( + param_value + if string_flag == "true" + else self._json_or_string(param_value) + ) + tool_calls.append(self._tool_call(function_name, arguments)) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=self._prefix_content( + model_output, + model_output.find(self._MINIMAX_START_TOKEN), + ), + ) + + def _extract_gptoss_tool_calls( + self, + model_output: str, + ) -> ExtractedToolCallInformation: + matches = list(self._GPTOSS_BLOCK_REGEX.finditer(model_output)) + + if not matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls: list[ToolCall] = [] + for match in matches: + function_name = match.group(1) + arguments = json.loads(match.group(2).strip()) + tool_calls.append(self._tool_call(function_name, arguments)) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=self._prefix_content(model_output, matches[0].start()), + ) + + @staticmethod + def _deserialize_glm_value(value: str) -> Any: + value = value.strip() + try: + return json.loads(value) + except Exception: + pass + + try: + return ast.literal_eval(value) + except Exception: + pass + + return value + + @staticmethod + def _glm_value_is_string( + tool_name: str, + arg_name: str, + tools: list[ChatCompletionToolsParam] | None, + ) -> bool: + if tools is None: + return False + for tool in tools: + if tool.function.name != tool_name or tool.function.parameters is None: + continue + arg_type = ( + tool.function.parameters.get("properties", {}) + .get(arg_name, {}) + .get("type") + ) + return arg_type == "string" + return False + + def _extract_glm_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + matches = list(self._GLM_BLOCK_REGEX.finditer(model_output)) + if not matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls: list[ToolCall] = [] + for match in matches: + block = match.group(1) + first_arg_idx = block.find("") + if first_arg_idx == -1: + function_name = block.strip() + arguments: dict[str, Any] = {} + else: + function_name = block[:first_arg_idx].strip() + arg_block = block[first_arg_idx:] + arguments = {} + for key, value in self._GLM_ARG_REGEX.findall(arg_block): + arg_key = key.strip() + arg_value = value.strip() + if not self._glm_value_is_string( + function_name, arg_key, request.tools + ): + arg_value = self._deserialize_glm_value(arg_value) + arguments[arg_key] = arg_value + + if function_name: + tool_calls.append(self._tool_call(function_name, arguments)) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=self._prefix_content(model_output, matches[0].start()), + ) + + @staticmethod + def _get_python_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + if isinstance(val, ast.Name): + if val.id in {"true", "True"}: + return True + if val.id in {"false", "False"}: + return False + if val.id in {"null", "None"}: + return None + if isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise ValueError("Dict tool call arguments must have literal keys") + return { + k.value: MultiFormatToolParser._get_python_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + if isinstance(val, ast.List): + return [MultiFormatToolParser._get_python_value(v) for v in val.elts] + if isinstance(val, ast.Tuple): + return [MultiFormatToolParser._get_python_value(v) for v in val.elts] + if ( + isinstance(val, ast.UnaryOp) + and isinstance(val.op, (ast.USub, ast.UAdd)) + and isinstance(val.operand, ast.Constant) + and isinstance(val.operand.value, (int, float)) + ): + operand = val.operand.value + return -operand if isinstance(val.op, ast.USub) else operand + raise ValueError("Tool call arguments must be literals") + + @staticmethod + def _handle_python_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise ValueError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = MultiFormatToolParser._get_python_value( + keyword.value + ) + return ToolCall( + type="function", + function=FunctionCall( + name=function_name, + arguments=json.dumps(arguments, ensure_ascii=False), + ), + ) + + def _extract_python_tool_calls( + self, + model_output: str, + ) -> ExtractedToolCallInformation: + matches = self._PYTHON_BLOCK_REGEX.findall(model_output) + if not matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls: list[ToolCall] = [] + for block in matches: + module = ast.parse(block.strip()) + for statement in module.body: + if not isinstance(statement, ast.Expr) or not isinstance( + statement.value, + ast.Call, + ): + raise ValueError( + "Expected Python function call(s) inside tags." + ) + tool_calls.append(self._handle_python_tool(statement.value)) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=self._prefix_content( + model_output, + model_output.find(""), + ), + )