diff --git a/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja b/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja index 7c06122223..07d0cdadbf 100644 --- a/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja +++ b/miles/utils/chat_template_utils/templates/qwen3.5_fixed.jinja @@ -75,14 +75,11 @@ {%- endif %} {%- endif %} {%- endfor %} -{%- if ns.multi_step_tool %} - {{- raise_exception('No user query found in messages.') }} -{%- endif %} {%- for message in messages %} {%- set content = render_content(message.content, true)|trim %} {%- if message.role == "system" %} {%- if not loop.first %} - {{- '<|im_start|>user\n' + content + '<|im_end|>\n' }} + {{- raise_exception('System message must be at the beginning.') }} {%- endif %} {%- elif message.role == "user" %} {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} @@ -151,4 +148,4 @@ {%- else %} {{- '\n' }} {%- endif %} -{%- endif %} \ No newline at end of file +{%- endif %} diff --git a/miles/utils/chat_template_utils/tito_tokenizer.py b/miles/utils/chat_template_utils/tito_tokenizer.py index ce02bec5f1..48a564314a 100644 --- a/miles/utils/chat_template_utils/tito_tokenizer.py +++ b/miles/utils/chat_template_utils/tito_tokenizer.py @@ -1,15 +1,20 @@ """TITO tokenizer — incremental tokenization for pretokenized prefix reuse. ``TITOTokenizer`` computes incremental token IDs for non-assistant messages -(tool responses, system injections) that follow the assistant's generated -token sequence, then merges them with the pretokenized prefix — handling -model-specific boundary tokens at the junction. - -The default implementation uses a dummy-message diff: it tokenizes a -synthetic ``[dummy_user, dummy_assistant]`` base with and without the -appended messages, then takes the suffix difference as the incremental IDs. -Model-specific subclasses override ``merge_tokens`` to handle boundary -quirks at the junction. +(tool responses, user follow-ups, system injections) that follow the +assistant's generated token sequence, then merges them with the pretokenized +prefix — handling model-specific boundary tokens at the junction. + +The default implementation incrementally tokenizes appended non-assistant turns +with role-specific synthetic prefixes: + +- contiguous ``tool`` runs use ``[dummy_system, dummy_assistant]`` +- each ``user`` or ``system`` message uses ``[dummy_system]`` + +The appended suffix is processed left-to-right, then the generation prompt for +the next assistant turn is appended once at the end. Model-specific +subclasses only override ``merge_tokens`` for boundary quirks at the prefix +junction. """ from __future__ import annotations @@ -20,7 +25,7 @@ from miles.utils.chat_template_utils.template import apply_chat_template, assert_messages_append_only_with_allowed_role from miles.utils.chat_template_utils.token_seq_comparator import TokenSeqComparator -_DUMMY_USER: dict[str, Any] = {"role": "user", "content": "dummy"} +_DUMMY_SYSTEM: dict[str, Any] = {"role": "system", "content": "dummy system"} def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]: @@ -45,23 +50,12 @@ def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, An # --------------------------------------------------------------------------- -# Base / default tokenizer (dummy-prefix diff) +# Base / default tokenizer # --------------------------------------------------------------------------- class TITOTokenizer: - """Incremental tokenization and prefix merging using dummy-message diff. - - A synthetic base ``[dummy_user, dummy_assistant]`` simulates the assistant - turn boundary so that the diff captures the correct turn-transition tokens: - - 1. ``tokens_without`` = tokenize(base, add_generation_prompt=False) - 2. ``tokens_with`` = tokenize(base + appended, add_generation_prompt=True) - 3. ``incremental_ids = tokens_with[len(tokens_without):]`` - - Subclasses override ``merge_tokens`` to handle model-specific boundary - token quirks. - """ + """Incremental tokenization and prefix merging for appended non-assistant turns.""" max_trim_tokens: int = 0 trailing_token_ids: frozenset[int] = frozenset() @@ -87,6 +81,96 @@ def create_comparator(self) -> TokenSeqComparator: trim_trailing_ids=self.trailing_token_ids or None, ) + def _render_messages( + self, + messages: list[dict[str, Any]], + *, + add_generation_prompt: bool, + tools: list[dict[str, Any]] | None = None, + ) -> str: + return apply_chat_template( + messages, + tokenizer=self.tokenizer, + tokenize=False, + add_generation_prompt=add_generation_prompt, + tools=tools, + **self.chat_template_kwargs, + ) + + def _encode_text(self, text: str) -> list[int]: + return self.tokenizer.encode(text, add_special_tokens=False) + + def _split_appended_segments(self, appended_messages: list[dict[str, Any]]) -> list[list[dict[str, Any]]]: + segments: list[list[dict[str, Any]]] = [] + i = 0 + while i < len(appended_messages): + role = appended_messages[i]["role"] + # Many templates wrap a contiguous tool-response run as one logical + # block, so tool messages are diffed together instead of one-by-one. + if role == "tool": + j = i + 1 + while j < len(appended_messages) and appended_messages[j]["role"] == "tool": + j += 1 + segments.append(appended_messages[i:j]) + i = j + continue + if role in {"user", "system"}: + segments.append([appended_messages[i]]) + i += 1 + continue + raise ValueError(f"unsupported appended role for TITO segmentation: {role}") + + return segments + + def _tokenize_rendered_suffix( + self, + base_messages: list[dict[str, Any]], + appended_messages: list[dict[str, Any]], + *, + tools: list[dict[str, Any]] | None = None, + add_generation_prompt: bool = False, + ) -> list[int]: + """Render *base_messages* and *base_messages + appended_messages*, return + tokens for the suffix. + + When *add_generation_prompt* is True and *appended_messages* is empty, + this computes the generation-prompt suffix (the assistant opener tokens). + """ + text_without = self._render_messages(base_messages, add_generation_prompt=False, tools=tools) + text_with = self._render_messages( + base_messages + appended_messages, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + if not text_with.startswith(text_without): + roles = [msg["role"] for msg in appended_messages] if appended_messages else ["generation_prompt"] + raise ValueError(f"rendered suffix diff failed for {roles}") + return self._encode_text(text_with[len(text_without) :]) + + def _tokenize_tool_segment( + self, + appended_messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + # No dummy user to avoid cut think issues. + return self._tokenize_rendered_suffix( + [_DUMMY_SYSTEM, _build_dummy_assistant(appended_messages)], + appended_messages, + tools=tools, + ) + + def _tokenize_user_and_system_segment( + self, + appended_message: dict[str, Any], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + # User/system single-message appends share one synthetic context. + return self._tokenize_rendered_suffix( + [_DUMMY_SYSTEM], + [appended_message], + tools=tools, + ) + def tokenize_additional_non_assistant( self, old_messages: list[dict[str, Any]], @@ -96,9 +180,9 @@ def tokenize_additional_non_assistant( """Compute incremental token IDs for non-assistant messages appended after the pretokenized prefix. - Only handles tool responses, system injections, etc. — never an - assistant message. Validates that *new_messages* is an append-only - extension of *old_messages* via + Handles tool responses, user, and system messages — + never an assistant message. Validates that *new_messages* is an + append-only extension of *old_messages* via ``assert_messages_append_only_with_allowed_role``. Args: @@ -114,29 +198,29 @@ def tokenize_additional_non_assistant( """ assert_messages_append_only_with_allowed_role(old_messages, new_messages, self.allowed_append_roles) appended_messages = new_messages[len(old_messages) :] - - dummy_assistant = _build_dummy_assistant(appended_messages) - base_messages = [_DUMMY_USER, dummy_assistant] - - tokens_without = apply_chat_template( - base_messages, - tokenizer=self.tokenizer, - tokenize=True, - add_generation_prompt=False, + incremental: list[int] = [] + + # Incremental non-assistant content is assembled segment-by-segment + # using the smallest synthetic context that preserves each role's + # boundary tokens. + for segment in self._split_appended_segments(appended_messages): + role = segment[0]["role"] + if role == "tool": + incremental.extend(self._tokenize_tool_segment(segment, tools)) + elif role == "user" or role == "system": + incremental.extend(self._tokenize_user_and_system_segment(segment[0], tools)) + else: + raise ValueError(f"unsupported appended role for TITO tokenization: {role}") + + # The next assistant opener depends on the full post-append history, so + # it is derived from the real messages once and appended only at the end. + return incremental + self._tokenize_rendered_suffix( + new_messages, + [], tools=tools, - **self.chat_template_kwargs, - ) - tokens_with = apply_chat_template( - base_messages + list(appended_messages), - tokenizer=self.tokenizer, - tokenize=True, add_generation_prompt=True, - tools=tools, - **self.chat_template_kwargs, ) - return list(tokens_with[len(tokens_without) :]) - def merge_tokens( self, old_messages: list[dict[str, Any]], diff --git a/tests/fast/router/test_session_pretokenized_e2e.py b/tests/fast/router/test_session_pretokenized_e2e.py index e5b7ed0419..68917a68f1 100644 --- a/tests/fast/router/test_session_pretokenized_e2e.py +++ b/tests/fast/router/test_session_pretokenized_e2e.py @@ -65,7 +65,10 @@ class ModelTemplateConfig: "Qwen/Qwen3-4B-Thinking-2507", try_get_fixed_chat_template("Qwen/Qwen3-4B-Thinking-2507"), ), - "qwen3.5-native": ModelTemplateConfig("Qwen/Qwen3.5-0.8B", None), + "qwen3.5-fixed": ModelTemplateConfig( + "Qwen/Qwen3.5-0.8B", + try_get_fixed_chat_template("Qwen/Qwen3.5-0.8B"), + ), "qwen3-next-instruct-native": ModelTemplateConfig("Qwen/Qwen3-Next-80B-A3B-Instruct", None), "qwen3-next-thinking-fixed": ModelTemplateConfig( "Qwen/Qwen3-Next-80B-A3B-Thinking", diff --git a/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py b/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py index 78177ee63b..d37142a6c0 100644 --- a/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py +++ b/tests/fast/utils/chat_template_utils/test_pretokenized_chat.py @@ -13,7 +13,11 @@ from miles.utils.chat_template_utils.autofix import try_get_fixed_chat_template from miles.utils.chat_template_utils.template import load_hf_chat_template -from miles.utils.test_utils.chat_template_verify import assert_pretokenized_equals_standard, simulate_pretokenized_path +from miles.utils.test_utils.chat_template_verify import ( + assert_pretokenized_equals_standard, + simulate_pretokenized_path, + verify_append_only, +) from miles.utils.test_utils.mock_trajectories import ( MultiTurnTrajectory, MultiUserTurnThinkingTrajectory, @@ -70,21 +74,15 @@ def _load_fixed(hf_id: str) -> str: ) -def _to_pytest_params(cases, include_tools=True): +def _to_pytest_params(cases): """Convert (name, cls, n, tools) tuples to pytest.param list.""" - params = [] - for name, cls, n, tools in cases: - if include_tools: - params.append(pytest.param(cls, n, tools, id=name)) - else: - params.append(pytest.param(cls, n, tools, id=name)) - return params + return [pytest.param(cls, n, tools, id=name) for name, cls, n, tools in cases] -_STANDARD_CASES = _to_pytest_params(STANDARD_CASES) -_THINKING_CASES = _to_pytest_params(THINKING_CASES) -_INTERMEDIATE_SYSTEM_CASES = _to_pytest_params(INTERMEDIATE_SYSTEM_CASES) -_INTERMEDIATE_SYSTEM_THINKING_CASES = _to_pytest_params(INTERMEDIATE_SYSTEM_THINKING_CASES) +_STANDARD_PARAMS = _to_pytest_params(STANDARD_CASES) +_THINKING_PARAMS = _to_pytest_params(THINKING_CASES) +_INTERMEDIATE_SYSTEM_PARAMS = _to_pytest_params(INTERMEDIATE_SYSTEM_CASES) +_INTERMEDIATE_SYSTEM_THINKING_PARAMS = _to_pytest_params(INTERMEDIATE_SYSTEM_THINKING_CASES) # (chat_template, trajectory_cls, pretokenize_n) — original templates that break prefix invariant _MISMATCH_CASES = [ @@ -101,11 +99,53 @@ def _to_pytest_params(cases, include_tools=True): ), ] -# Template parametrization lists -all_template_ids = list(ALL_TEMPLATES.keys()) -all_template_values = list(ALL_TEMPLATES.values()) -thinking_template_ids = list(TEMPLATES_WITH_THINKING.keys()) -thinking_template_values = list(TEMPLATES_WITH_THINKING.values()) + +def _template_params(templates: dict[str, str]) -> list: + """Convert a {name: template_str} dict to a list of pytest.param(template_str, id=name).""" + return [pytest.param(v, id=k) for k, v in templates.items()] + + +# Intermediate-system compatibility: only qwen3.5_fixed is known to reject them. +# test_intermediate_system_probe_matrix locks this set against drift. +_INTERMEDIATE_SYSTEM_FORBIDDEN = {"qwen3.5_fixed"} +_INTERMEDIATE_SYSTEM_TEMPLATES = {k: v for k, v in ALL_TEMPLATES.items() if k not in _INTERMEDIATE_SYSTEM_FORBIDDEN} +_INTERMEDIATE_SYSTEM_THINKING_TEMPLATES = { + k: v for k, v in TEMPLATES_WITH_THINKING.items() if k not in _INTERMEDIATE_SYSTEM_FORBIDDEN +} + + +def _collect_intermediate_system_failures(template_id: str, chat_template: str) -> list[str]: + failures: list[str] = [] + for case_name, traj_cls, n, tools in INTERMEDIATE_SYSTEM_CASES: + result = verify_append_only(chat_template, deepcopy(traj_cls.MESSAGES), n, tools=tools, case_name=case_name) + if not result.passed: + failures.append(f"{case_name}: {result.error}") + + if template_id in TEMPLATES_WITH_THINKING: + for enable in (True, False): + suffix = "thinking_on" if enable else "thinking_off" + for case_name, traj_cls, n, tools in INTERMEDIATE_SYSTEM_THINKING_CASES: + full_case_name = f"{case_name}[{suffix}]" + result = verify_append_only( + chat_template, + deepcopy(traj_cls.MESSAGES), + n, + tools=tools, + case_name=full_case_name, + enable_thinking=enable, + ) + if not result.passed: + failures.append(f"{full_case_name}: {result.error}") + + return failures + + +def _format_failure_map(failure_map: dict[str, list[str]]) -> str: + lines: list[str] = [] + for template_id in sorted(failure_map): + lines.append(f"{template_id}:") + lines.extend(f" - {item}" for item in failure_map[template_id]) + return "\n".join(lines) # =========================================================================== @@ -113,8 +153,8 @@ def _to_pytest_params(cases, include_tools=True): # =========================================================================== -@pytest.mark.parametrize("chat_template", all_template_values, ids=all_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _STANDARD_CASES) +@pytest.mark.parametrize("chat_template", _template_params(ALL_TEMPLATES)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _STANDARD_PARAMS) def test_pretokenized_equals_standard(chat_template, trajectory_cls, pretokenize_n, tools): """Pretokenized incremental path produces same text as standard full render.""" assert_pretokenized_equals_standard( @@ -130,8 +170,8 @@ def test_pretokenized_equals_standard(chat_template, trajectory_cls, pretokenize # =========================================================================== -@pytest.mark.parametrize("chat_template", thinking_template_values, ids=thinking_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _THINKING_CASES) +@pytest.mark.parametrize("chat_template", _template_params(TEMPLATES_WITH_THINKING)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _THINKING_PARAMS) @pytest.mark.parametrize("enable_thinking", [True, False], ids=["thinking_on", "thinking_off"]) def test_pretokenized_thinking(chat_template, trajectory_cls, pretokenize_n, tools, enable_thinking): """Thinking-capable templates work with pretokenized path and enable_thinking flag.""" @@ -149,10 +189,29 @@ def test_pretokenized_thinking(chat_template, trajectory_cls, pretokenize_n, too # =========================================================================== -@pytest.mark.parametrize("chat_template", all_template_values, ids=all_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_CASES) +def test_intermediate_system_probe_matrix(): + """Probe ALL_TEMPLATES and lock the allow/forbid intermediate-system matrix.""" + failure_map: dict[str, list[str]] = {} + for template_id, chat_template in ALL_TEMPLATES.items(): + failures = _collect_intermediate_system_failures(template_id, chat_template) + if failures: + failure_map[template_id] = failures + + detected_forbidden = set(failure_map.keys()) + assert detected_forbidden == _INTERMEDIATE_SYSTEM_FORBIDDEN, ( + f"Intermediate-system forbidden set changed.\n" + f"expected={sorted(_INTERMEDIATE_SYSTEM_FORBIDDEN)}\n" + f"detected={sorted(detected_forbidden)}\n" + f"{_format_failure_map(failure_map)}" + ) + qwen35_failures = failure_map.get("qwen3.5_fixed", []) + assert any("System message must be at the beginning." in failure for failure in qwen35_failures), qwen35_failures + + +@pytest.mark.parametrize("chat_template", _template_params(_INTERMEDIATE_SYSTEM_TEMPLATES)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_PARAMS) def test_pretokenized_intermediate_system(chat_template, trajectory_cls, pretokenize_n, tools): - """All templates support intermediate system messages (converted to user role in fixed templates).""" + """Templates in the allowlist support intermediate system messages.""" assert_pretokenized_equals_standard( chat_template=chat_template, messages=deepcopy(trajectory_cls.MESSAGES), @@ -161,13 +220,13 @@ def test_pretokenized_intermediate_system(chat_template, trajectory_cls, pretoke ) -@pytest.mark.parametrize("chat_template", thinking_template_values, ids=thinking_template_ids) -@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_THINKING_CASES) +@pytest.mark.parametrize("chat_template", _template_params(_INTERMEDIATE_SYSTEM_THINKING_TEMPLATES)) +@pytest.mark.parametrize("trajectory_cls,pretokenize_n,tools", _INTERMEDIATE_SYSTEM_THINKING_PARAMS) @pytest.mark.parametrize("enable_thinking", [True, False], ids=["thinking_on", "thinking_off"]) def test_pretokenized_intermediate_system_thinking( chat_template, trajectory_cls, pretokenize_n, tools, enable_thinking ): - """Thinking templates support intermediate system messages with thinking.""" + """Thinking templates in the allowlist support intermediate system messages.""" assert_pretokenized_equals_standard( chat_template=chat_template, messages=deepcopy(trajectory_cls.MESSAGES), @@ -204,7 +263,7 @@ def test_original_template_prefix_mismatch(chat_template, trajectory_cls, pretok _CROSS_USER_THINKING_N = last_user_index(MultiUserTurnThinkingTrajectory.MESSAGES) -@pytest.mark.parametrize("chat_template", thinking_template_values, ids=thinking_template_ids) +@pytest.mark.parametrize("chat_template", _template_params(TEMPLATES_WITH_THINKING)) @pytest.mark.parametrize("enable_thinking", [True, False], ids=["thinking_on", "thinking_off"]) def test_cross_user_turn_thinking_prefix_mismatch(chat_template, enable_thinking): """Thinking templates compress reasoning_content from earlier user turns, breaking prefix invariant.""" diff --git a/tests/fast/utils/chat_template_utils/test_template.py b/tests/fast/utils/chat_template_utils/test_template.py index 39ac412954..5ba907f2a9 100644 --- a/tests/fast/utils/chat_template_utils/test_template.py +++ b/tests/fast/utils/chat_template_utils/test_template.py @@ -119,6 +119,10 @@ def tokenizer(request) -> AutoTokenizer: # Trajectory / kwargs definitions # --------------------------------------------------------------------------- +_NO_INTERMEDIATE_SYSTEM_MODELS = { + "Qwen/Qwen3.5-4B", +} + _STANDARD_CASES = [ pytest.param(SingleToolTrajectory, {}, id="single_tool"), pytest.param(MultiTurnTrajectory, {}, id="multi_turn"), @@ -129,7 +133,7 @@ def tokenizer(request) -> AutoTokenizer: pytest.param(MultiTurnNoToolTrajectory, {}, id="multi_turn_no_tool"), ] -# Trajectories with intermediate system messages (Qwen3.5 uses fixed template). +# Trajectories with intermediate system messages. _INTERMEDIATE_SYSTEM_CASES = [ pytest.param(RetrySystemTrajectory, {}, id="retry_system"), pytest.param(IntermediateSystemTrajectory, {}, id="intermediate_system"), @@ -181,6 +185,8 @@ def test_standard(self, tokenizer, traj_cls, kwargs): @pytest.mark.parametrize("traj_cls, kwargs", _INTERMEDIATE_SYSTEM_CASES) def test_intermediate_system(self, tokenizer, traj_cls, kwargs): + if tokenizer.name_or_path in _NO_INTERMEDIATE_SYSTEM_MODELS: + pytest.skip(f"{tokenizer.name_or_path} intentionally forbids intermediate system messages") _assert_aligned(tokenizer, traj_cls, kwargs) @pytest.mark.parametrize("traj_cls, kwargs", _THINKING_CASES) @@ -189,6 +195,8 @@ def test_thinking(self, tokenizer, traj_cls, kwargs): @pytest.mark.parametrize("traj_cls, kwargs", _INTERMEDIATE_SYSTEM_THINKING_CASES) def test_intermediate_system_thinking(self, tokenizer, traj_cls, kwargs): + if tokenizer.name_or_path in _NO_INTERMEDIATE_SYSTEM_MODELS: + pytest.skip(f"{tokenizer.name_or_path} intentionally forbids intermediate system messages") _assert_aligned(tokenizer, traj_cls, kwargs) def test_json_string_arguments(self, tokenizer): diff --git a/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py b/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py index 1321260525..4172a089f6 100644 --- a/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py +++ b/tests/fast/utils/chat_template_utils/test_tito_tokenizer.py @@ -26,16 +26,18 @@ - Default: plain concatenation (no boundary handling). TestTokenizeAdditional - Behavioral tests for tokenize_additional_non_assistant — the dummy-prefix - diff that computes incremental token IDs for appended non-assistant messages. + Behavioral tests for tokenize_additional_non_assistant — the role-segmented + synthetic-prefix diff that computes incremental token IDs for appended + non-assistant messages. ``test_produces_nonempty_incremental`` is parametrized over: _TOOL_TRAJECTORIES (trajectory classes) × _TITO_MODELS (qwen3, glm47) Split points are auto-detected by _find_tito_splits from message structure, so adding a trajectory to _TOOL_TRAJECTORIES automatically extends coverage. - Remaining tests verify append-only validation (reject prefix mutation, - fewer messages, or forbidden roles like assistant). + Remaining tests cover segmentation logic, generation-prompt timing, + reasoning-content shape, merge structure preservation, and append-only + validation (reject prefix mutation, fewer messages, or forbidden roles). TestFactory get_tito_tokenizer factory: string/enum dispatch, invalid input handling. @@ -43,16 +45,21 @@ from __future__ import annotations +from pathlib import Path + import pytest from transformers import AutoTokenizer +from miles.utils.chat_template_utils import MismatchType, apply_chat_template, try_get_fixed_chat_template from miles.utils.chat_template_utils.tito_tokenizer import ( GLM47TITOTokenizer, Qwen3TITOTokenizer, TITOTokenizer, TITOTokenizerType, + _build_dummy_assistant, get_tito_tokenizer, ) +from miles.utils.processing_utils import load_tokenizer from miles.utils.test_utils.mock_trajectories import ( IntermediateSystemTrajectory, LongChainTrajectory, @@ -68,13 +75,19 @@ # Tokenizer cache # --------------------------------------------------------------------------- -_TOK_CACHE: dict[str, AutoTokenizer] = {} +_TOK_CACHE: dict[tuple[str, str | None], AutoTokenizer] = {} def _get_tokenizer(model_id: str) -> AutoTokenizer: - if model_id not in _TOK_CACHE: - _TOK_CACHE[model_id] = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - return _TOK_CACHE[model_id] + chat_template_path = try_get_fixed_chat_template(model_id) + cache_key = (model_id, chat_template_path) + if cache_key not in _TOK_CACHE: + _TOK_CACHE[cache_key] = load_tokenizer( + model_id, + chat_template_path=chat_template_path, + trust_remote_code=True, + ) + return _TOK_CACHE[cache_key] # --------------------------------------------------------------------------- @@ -91,32 +104,28 @@ def _get_tokenizer(model_id: str) -> AutoTokenizer: } -# TODO: "user" is intentionally excluded — the dummy-prefix diff in -# tokenize_additional_non_assistant assumes appended messages don't change how -# earlier turns render, which breaks for user messages on context-sensitive -# templates (e.g. Qwen3's last_query_index). Only tool and system are safe. -_TOOL_AND_SYSTEM = ["tool", "system"] +_ALLOWED_APPEND_ROLES = ["tool", "user", "system"] @pytest.fixture(params=list(_TITO_MODELS.keys())) def tito(request) -> TITOTokenizer: model_id, cls = _TITO_MODELS[request.param] - return cls(_get_tokenizer(model_id), allowed_append_roles=_TOOL_AND_SYSTEM) + return cls(_get_tokenizer(model_id), allowed_append_roles=_ALLOWED_APPEND_ROLES) @pytest.fixture def qwen3_tito() -> Qwen3TITOTokenizer: - return Qwen3TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_TOOL_AND_SYSTEM) + return Qwen3TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_ALLOWED_APPEND_ROLES) @pytest.fixture def glm47_tito() -> GLM47TITOTokenizer: - return GLM47TITOTokenizer(_get_tokenizer("zai-org/GLM-4.7-Flash"), allowed_append_roles=_TOOL_AND_SYSTEM) + return GLM47TITOTokenizer(_get_tokenizer("zai-org/GLM-4.7-Flash"), allowed_append_roles=_ALLOWED_APPEND_ROLES) @pytest.fixture def default_tito() -> TITOTokenizer: - return TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_TOOL_AND_SYSTEM) + return TITOTokenizer(_get_tokenizer("Qwen/Qwen3-4B"), allowed_append_roles=_ALLOWED_APPEND_ROLES) # --------------------------------------------------------------------------- @@ -156,8 +165,8 @@ def _split_at(traj_cls, pos: int): """Split trajectory at *pos* into ``(old_msgs, new_msgs, tools)``. ``old_msgs = messages[:pos]`` — the pretokenized prefix (ends with assistant turn). - ``new_msgs`` extends through all subsequent non-assistant messages (tool/system), - stopping before the next assistant turn. + ``new_msgs`` extends through all subsequent non-assistant messages + (tool/user/system), stopping before the next assistant turn. """ msgs = traj_cls.MESSAGES end = pos @@ -281,7 +290,7 @@ def test_empty_prefix(self, qwen3_tito: Qwen3TITOTokenizer): # --------------------------------------------------------------------------- -# TestTokenizeAdditional — incremental tokenization via dummy-prefix diff +# TestTokenizeAdditional — incremental tokenization via role-segmented synthetic diff # # test_produces_nonempty_incremental is the scalable core: parametrized over # _TRAJ_CASES (trajectories × split points) × tito fixture (models). @@ -306,6 +315,108 @@ def test_produces_nonempty_incremental(self, tito: TITOTokenizer, traj_cls, pos) incremental = tito.tokenize_additional_non_assistant(old_msgs, new_msgs, tools) assert len(incremental) > 0 + def test_contiguous_tool_segment_is_tokenized_together(self, qwen3_tito: Qwen3TITOTokenizer): + old_msgs, new_msgs, tools = _split_at(MultiToolSingleTurnTrajectory, 3) + appended = new_msgs[len(old_msgs) :] + + segments = qwen3_tito._split_appended_segments(appended) + assert len(segments) == 1 + assert [msg["role"] for msg in segments[0]] == ["tool", "tool"] + + incremental = qwen3_tito.tokenize_additional_non_assistant(old_msgs, new_msgs, tools) + decoded = qwen3_tito.tokenizer.decode(incremental) + assert MultiToolSingleTurnTrajectory.MESSAGES[3]["content"] in decoded + assert MultiToolSingleTurnTrajectory.MESSAGES[4]["content"] in decoded + + def test_user_and_system_segments_are_singletons(self, default_tito: TITOTokenizer): + appended = [ + {"role": "system", "content": "Use JSON."}, + {"role": "user", "content": "Hello"}, + {"role": "tool", "tool_call_id": "call_1", "content": '{"ok": true}'}, + {"role": "tool", "tool_call_id": "call_2", "content": '{"ok": false}'}, + {"role": "user", "content": "Try again"}, + ] + + segments = default_tito._split_appended_segments(appended) + assert [[msg["role"] for msg in segment] for segment in segments] == [ + ["system"], + ["user"], + ["tool", "tool"], + ["user"], + ] + + def test_generation_prompt_is_appended_once_for_full_suffix(self, qwen3_tito: Qwen3TITOTokenizer): + old_msgs = list(SingleToolThinkingTrajectory.MESSAGES[:3]) + new_msgs = old_msgs + [ + SingleToolThinkingTrajectory.MESSAGES[3], + {"role": "user", "content": "Now check Shanghai too."}, + ] + tools = SingleToolThinkingTrajectory.TOOLS + + incremental = qwen3_tito.tokenize_additional_non_assistant(old_msgs, new_msgs, tools) + decoded = qwen3_tito.tokenizer.decode(incremental) + assert decoded.count(qwen3_tito._assistant_start_str) == 1 + assert decoded.endswith( + qwen3_tito.tokenizer.decode( + qwen3_tito._tokenize_rendered_suffix(new_msgs, [], tools=tools, add_generation_prompt=True) + ) + ) + + def test_qwen3_tool_dummy_assistant_preserves_reasoning_shape(self): + thinking_template_path = ( + Path(__file__).resolve().parents[4] + / "miles/utils/chat_template_utils/templates/qwen3_thinking_2507_and_next_fixed.jinja" + ) + thinking_tito = Qwen3TITOTokenizer( + load_tokenizer( + "Qwen/Qwen3-4B-Instruct-2507", + chat_template_path=str(thinking_template_path), + trust_remote_code=True, + ), + allowed_append_roles=_ALLOWED_APPEND_ROLES, + ) + tool_messages = [SingleToolThinkingTrajectory.MESSAGES[3]] + dummy_assistant = _build_dummy_assistant(tool_messages) + rendered = thinking_tito._render_messages( + [{"role": "system", "content": "dummy system"}, dummy_assistant], + add_generation_prompt=False, + tools=SingleToolThinkingTrajectory.TOOLS, + ) + + assert dummy_assistant["reasoning_content"] == " " + assert rendered.endswith( + '<|im_start|>assistant\n\n{"name": "dummy_func", "arguments": {}}\n<|im_end|>\n' + ) + + @pytest.mark.parametrize( + "traj_cls, pos", + [ + pytest.param(SingleToolTrajectory, 3, id="single-tool"), + pytest.param(RetrySystemTrajectory, 3, id="tool-plus-system"), + pytest.param(IntermediateSystemTrajectory, 3, id="intermediate-system"), + ], + ) + def test_qwen3_merge_preserves_non_assistant_structure(self, qwen3_tito: Qwen3TITOTokenizer, traj_cls, pos): + """Merged tokens may differ in assistant text, but not in tool/system structure.""" + old_msgs, new_msgs, tools = _split_at(traj_cls, pos) + pretokenized = apply_chat_template( + old_msgs, + tokenizer=qwen3_tito.tokenizer, + tokenize=True, + add_generation_prompt=False, + tools=tools, + ) + merged = qwen3_tito.merge_tokens(old_msgs, new_msgs, pretokenized, tools) + expected = apply_chat_template( + new_msgs, + tokenizer=qwen3_tito.tokenizer, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + mismatches = qwen3_tito.create_comparator().compare_sequences(expected, merged) + assert all(m.type == MismatchType.ASSISTANT_TEXT for m in mismatches) + # -- Append-only validation (assert_messages_append_only_with_allowed_role is called internally) -- def test_rejects_prefix_mutation(self, qwen3_tito: Qwen3TITOTokenizer): diff --git a/tests/fast/utils/chat_template_utils/test_tito_tokenizer_model_matrix.py b/tests/fast/utils/chat_template_utils/test_tito_tokenizer_model_matrix.py new file mode 100644 index 0000000000..e846e8285b --- /dev/null +++ b/tests/fast/utils/chat_template_utils/test_tito_tokenizer_model_matrix.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass + +import pytest +from transformers import AutoTokenizer + +from miles.utils.chat_template_utils import MismatchType, apply_chat_template, try_get_fixed_chat_template +from miles.utils.chat_template_utils.tito_tokenizer import TITOTokenizer, TITOTokenizerType, get_tito_tokenizer +from miles.utils.processing_utils import load_tokenizer +from miles.utils.test_utils.mock_trajectories import ( + MultiUserTurnThinkingTrajectory, + SimpleNoToolTrajectory, + SingleToolThinkingTrajectory, + SingleToolTrajectory, +) + +TOOL_CALL_TEST_MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen3-0.6B", + "Qwen/Qwen3.5-0.8B", + "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-Coder-30B-A3B-Instruct", + # "meta-llama/Llama-3.2-1B-Instruct", # Skipped: gated repo, requires HF_TOKEN in CI + "zai-org/GLM-4.7-Flash", + "mistralai/Mistral-7B-Instruct-v0.3", + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "MiniMaxAI/MiniMax-M2.5", + "internlm/internlm3-8b-instruct", + "THUDM/glm-4-9b-chat", + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2.5", + "XiaomiMiMo/MiMo-7B-RL", + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", +] + +# Models excluded from TITO testing due to known template incompatibilities. +# Filtered out of parametrized test cases below. +_TITO_EXCLUDED_MODELS: dict[str, str] = { + "Qwen/Qwen3.5-0.8B": ( + "The qwen3.5 fixed template rejects non-first system messages with " + "'System message must be at the beginning'. TITO's synthetic bases " + "place system first, so this exclusion may be removable — needs testing." + ), + "deepseek-ai/DeepSeek-V3": ( + "TITO tokenizes each tool segment independently via _tokenize_tool_segment, " + "which causes DeepSeek-V3's template to emit extra " + "<|tool_outputs_begin|>/<|tool_outputs_end|> wrappers that differ from " + "full-conversation rendering." + ), +} +_TITO_TEST_MODELS = [m for m in TOOL_CALL_TEST_MODELS if m not in _TITO_EXCLUDED_MODELS] + +_ALLOWED_APPEND_ROLES = ["tool", "user", "system"] +_TOK_CACHE: dict[tuple[str, str | None], AutoTokenizer] = {} +_ASSISTANT_START_BY_MODEL: dict[str, str] = { + "Qwen/Qwen2.5-0.5B-Instruct": "<|im_start|>assistant\n", + "mistralai/Mistral-7B-Instruct-v0.3": "[/INST]", + "deepseek-ai/DeepSeek-V3": "<|Assistant|>", + "stepfun-ai/step3": "<|BOT|>assistant\n", + "MiniMaxAI/MiniMax-M2": "]~b]ai\n", + "MiniMaxAI/MiniMax-M2.5": "]~b]ai\n", + "internlm/internlm3-8b-instruct": "<|im_start|>assistant\n", + "THUDM/glm-4-9b-chat": "<|assistant|>", + "moonshotai/Kimi-K2-Instruct": "<|im_assistant|>assistant<|im_middle|>", + "moonshotai/Kimi-K2.5": "<|im_assistant|>assistant<|im_middle|>", + "XiaomiMiMo/MiMo-7B-RL": "<|im_start|>assistant\n", + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "<|im_start|>assistant\n", +} +_NO_SYSTEM_APPEND_MODELS = { + "deepseek-ai/DeepSeek-V3", + "stepfun-ai/step3", + "MiniMaxAI/MiniMax-M2", + "MiniMaxAI/MiniMax-M2.5", +} +_CONTENT_WHITESPACE_AGNOSTIC_MODELS = { + "stepfun-ai/step3", +} + + +@dataclass(frozen=True) +class AppendCase: + name: str + old_messages: list[dict] + appended_messages: list[dict] + tools: list[dict] | None + required_contents: tuple[str, ...] = () + + +_APPEND_CASES = [ + AppendCase( + name="single_tool", + old_messages=deepcopy(SingleToolTrajectory.MESSAGES[:3]), + appended_messages=deepcopy([SingleToolTrajectory.MESSAGES[3]]), + tools=deepcopy(SingleToolTrajectory.TOOLS), + required_contents=(SingleToolTrajectory.MESSAGES[3]["content"],), + ), + AppendCase( + name="single_user", + old_messages=deepcopy(MultiUserTurnThinkingTrajectory.MESSAGES[:5]), + appended_messages=deepcopy([MultiUserTurnThinkingTrajectory.MESSAGES[5]]), + tools=deepcopy(MultiUserTurnThinkingTrajectory.TOOLS), + required_contents=(MultiUserTurnThinkingTrajectory.MESSAGES[5]["content"],), + ), + AppendCase( + name="single_system", + old_messages=deepcopy(SimpleNoToolTrajectory.MESSAGES), + appended_messages=[{"role": "system", "content": "Please answer in one short sentence."}], + tools=None, + required_contents=("Please answer in one short sentence.",), + ), + AppendCase( + name="alternating_user_tool", + old_messages=deepcopy(SingleToolThinkingTrajectory.MESSAGES[:3]), + appended_messages=[ + deepcopy(SingleToolThinkingTrajectory.MESSAGES[3]), + {"role": "user", "content": "Now check Shanghai too."}, + { + "role": "tool", + "tool_call_id": "call_followup_1", + "content": '{"temperature": 30, "condition": "cloudy"}', + }, + {"role": "user", "content": "And tell me the date as well."}, + ], + tools=deepcopy(SingleToolThinkingTrajectory.TOOLS), + required_contents=( + SingleToolThinkingTrajectory.MESSAGES[3]["content"], + "Now check Shanghai too.", + '{"temperature": 30, "condition": "cloudy"}', + "And tell me the date as well.", + ), + ), +] + +_ALL_PARAMS = [ + pytest.param(model_name, case, id=f"{case.name}-{model_name}") + for model_name in _TITO_TEST_MODELS + for case in _APPEND_CASES + if not (case.name == "single_system" and model_name in _NO_SYSTEM_APPEND_MODELS) +] + + +def _resolve_tito_type(model_name: str) -> TITOTokenizerType: + lowered = model_name.lower() + if "qwen3" in lowered: + return TITOTokenizerType.QWEN3 + if "glm-4.7" in lowered: + return TITOTokenizerType.GLM47 + return TITOTokenizerType.DEFAULT + + +def _get_tokenizer(model_name: str) -> AutoTokenizer: + chat_template_path = try_get_fixed_chat_template(model_name) + cache_key = (model_name, chat_template_path) + if cache_key not in _TOK_CACHE: + _TOK_CACHE[cache_key] = load_tokenizer( + model_name, + chat_template_path=chat_template_path, + trust_remote_code=True, + ) + return _TOK_CACHE[cache_key] + + +def _get_tito(model_name: str, tokenizer: AutoTokenizer) -> TITOTokenizer: + tokenizer_type = _resolve_tito_type(model_name) + kwargs = { + "tokenizer_type": tokenizer_type, + "allowed_append_roles": _ALLOWED_APPEND_ROLES, + } + if tokenizer_type == TITOTokenizerType.DEFAULT: + kwargs["assistant_start_str"] = _ASSISTANT_START_BY_MODEL[model_name] + return get_tito_tokenizer(tokenizer, **kwargs) + + +def _render_ids( + tokenizer: AutoTokenizer, messages: list[dict], tools: list[dict] | None, *, add_generation_prompt: bool +) -> list[int]: + return apply_chat_template( + messages, + tokenizer=tokenizer, + tokenize=True, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + +def _assert_only_assistant_mismatches(tito: TITOTokenizer, expected: list[int], merged: list[int]) -> None: + mismatches = tito.create_comparator().compare_sequences(expected, merged) + bad = [m for m in mismatches if m.type != MismatchType.ASSISTANT_TEXT] + assert not bad, [m.to_dict() for m in bad] + + +def _assert_contents_in_order( + incremental_text: str, required_contents: tuple[str, ...], *, model_name: str, case_name: str +) -> None: + if model_name in _CONTENT_WHITESPACE_AGNOSTIC_MODELS: + incremental_text = "".join(incremental_text.split()) + required_contents = tuple("".join(content.split()) for content in required_contents) + cursor = 0 + for content in required_contents: + found = incremental_text.find(content, cursor) + assert found >= 0, f"{model_name=} {case_name=} missing ordered content {content!r}" + cursor = found + len(content) + + +def _run_case(model_name: str, case: AppendCase) -> tuple[TITOTokenizer, list[int], list[int], str]: + tokenizer = _get_tokenizer(model_name) + tito = _get_tito(model_name, tokenizer) + old_messages = deepcopy(case.old_messages) + new_messages = old_messages + deepcopy(case.appended_messages) + try: + expected = _render_ids(tokenizer, new_messages, case.tools, add_generation_prompt=True) + pretokenized = _render_ids(tokenizer, old_messages, case.tools, add_generation_prompt=False) + except Exception as exc: + pytest.skip(f"{model_name} cannot render case {case.name}: {type(exc).__name__}: {exc}") + merged = tito.merge_tokens(old_messages, new_messages, pretokenized, case.tools) + incremental_text = tokenizer.decode(tito.tokenize_additional_non_assistant(old_messages, new_messages, case.tools)) + return tito, merged, expected, incremental_text + + +@pytest.mark.parametrize(("model_name", "case"), _ALL_PARAMS) +def test_appended_non_assistant_content_preserved(model_name: str, case: AppendCase): + tito, merged, expected, incremental_text = _run_case(model_name, case) + _assert_only_assistant_mismatches(tito, expected, merged) + _assert_contents_in_order(incremental_text, case.required_contents, model_name=model_name, case_name=case.name)