diff --git a/mlx_lm/tokenizer_utils.py b/mlx_lm/tokenizer_utils.py index c7e50fbe7..2c1e91996 100644 --- a/mlx_lm/tokenizer_utils.py +++ b/mlx_lm/tokenizer_utils.py @@ -549,6 +549,8 @@ def _infer_tool_parser(chat_template): """Attempt to auto-infer a tool parser from the chat template.""" if not isinstance(chat_template, str): return None + elif "|DSML|tool_calls" in chat_template: + return "deepseek_v4" elif "" in chat_template: return "minimax_m2" elif "<|tool_call>" in chat_template and "" in chat_template: diff --git a/mlx_lm/tool_parsers/deepseek_v4.py b/mlx_lm/tool_parsers/deepseek_v4.py new file mode 100644 index 000000000..f64847ed4 --- /dev/null +++ b/mlx_lm/tool_parsers/deepseek_v4.py @@ -0,0 +1,41 @@ +import json +from typing import Any + +import regex as re + +_dsml = "|DSML|" + +# Match from the DSML tag, not the official text sentinel's leading newlines. +# The server detects tool calls with token sequences, and DS4 can tokenize the +# same text as either "\n\n" or as part of the previous token, e.g. ".\n\n". +# Stop before ">" as well because the closing bracket can merge with "\n". +tool_call_start: str = f"<{_dsml}tool_calls" +tool_call_end: str = f"" + +_invoke_re = re.compile( + rf'<{re.escape(_dsml)}invoke name="([^"]+)">(.*?)', + re.DOTALL, +) +_param_re = re.compile( + rf'<{re.escape(_dsml)}parameter name="([^"]+)" string="(true|false)">(.*?)', + re.DOTALL, +) + + +def parse_tool_call(text: str, tools: Any = None): + """Parse one or more DSML tool calls from the text between tool_call_start/end.""" + calls = [] + for invoke in _invoke_re.finditer(text): + name = invoke.group(1) + body = invoke.group(2) + arguments = {} + for param in _param_re.finditer(body): + pname = param.group(1) + is_string = param.group(2) == "true" + value = param.group(3) + arguments[pname] = value if is_string else json.loads(value) + calls.append({"name": name, "arguments": arguments}) + + if not calls: + raise ValueError("No tool calls found in DSML block") + return calls[0] if len(calls) == 1 else calls diff --git a/tests/test_tool_parsing.py b/tests/test_tool_parsing.py index 52892b7ff..6a4fcb842 100644 --- a/tests/test_tool_parsing.py +++ b/tests/test_tool_parsing.py @@ -1,7 +1,7 @@ import unittest -from pathlib import Path from mlx_lm.tool_parsers import ( + deepseek_v4, function_gemma, gemma4, glm47, @@ -329,6 +329,40 @@ def test_minimax_m2(self): tool_calls = minimax_m2.parse_tool_call(test_case, None) self.assertEqual(expected, tool_calls) + def test_deepseek_v4(self): + self.assertEqual(deepseek_v4.tool_call_start, "<|DSML|tool_calls") + + # Single call with the closing tag bracket captured as tool text. + test_case = ( + '>\n<|DSML|invoke name="get_weather">\n' + '<|DSML|parameter name="location" string="true">Beijing\n' + '<|DSML|parameter name="num_results" string="false">5\n' + '\n' + ) + result = deepseek_v4.parse_tool_call(test_case, None) + self.assertEqual(result["name"], "get_weather") + self.assertEqual(result["arguments"]["location"], "Beijing") + self.assertEqual(result["arguments"]["num_results"], 5) + + # Multiple calls + test_case = ( + '\n<|DSML|invoke name="search">\n' + '<|DSML|parameter name="query" string="true">weather\n' + '\n' + '<|DSML|invoke name="read_file">\n' + '<|DSML|parameter name="path" string="true">/tmp/test.txt\n' + '\n' + ) + result = deepseek_v4.parse_tool_call(test_case, None) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertEqual( + result[0], {"name": "search", "arguments": {"query": "weather"}} + ) + self.assertEqual( + result[1], {"name": "read_file", "arguments": {"path": "/tmp/test.txt"}} + ) + if __name__ == "__main__": unittest.main()