Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if ns.multi_step_tool %}
{{- raise_exception('No user query found in messages.') }}
{%- endif %}
{%- for message in messages %}
{%- set content = render_content(message.content, true)|trim %}
{%- if message.role == "system" %}
Expand Down
173 changes: 135 additions & 38 deletions miles/utils/chat_template_utils/tito_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
token sequence, then merges them with the pretokenized prefix — handling
model-specific boundary tokens at the junction.

The default implementation uses a dummy-message diff: it tokenizes a
synthetic ``[dummy_user, dummy_assistant]`` base with and without the
appended messages, then takes the suffix difference as the incremental IDs.
Model-specific subclasses override ``merge_tokens`` to handle boundary
quirks at the junction.
The default implementation incrementally tokenizes appended non-assistant turns
with role-specific synthetic prefixes:

- contiguous ``tool`` runs use ``[dummy_system, dummy_assistant]``
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove user in tool to avoid boundary issue around think.

- each ``user`` message uses ``[dummy_system, dummy_user]``
- each ``system`` message uses ``[dummy_system]``

The appended suffix is processed left-to-right, then the generation prompt for
the next assistant turn is appended once at the end. Model-specific
subclasses only override ``merge_tokens`` for boundary quirks at the prefix
junction.
"""

from __future__ import annotations
Expand All @@ -20,7 +26,8 @@
from miles.utils.chat_template_utils.template import apply_chat_template, assert_messages_append_only_with_allowed_role
from miles.utils.chat_template_utils.token_seq_comparator import TokenSeqComparator

_DUMMY_USER: dict[str, Any] = {"role": "user", "content": "dummy"}
_DUMMY_SYSTEM: dict[str, Any] = {"role": "system", "content": "dummy system"}
_DUMMY_USER: dict[str, Any] = {"role": "user", "content": "dummy user"}


def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, Any]:
Expand All @@ -45,23 +52,12 @@ def _build_dummy_assistant(tool_responses: list[dict[str, Any]]) -> dict[str, An


# ---------------------------------------------------------------------------
# Base / default tokenizer (dummy-prefix diff)
# Base / default tokenizer
# ---------------------------------------------------------------------------


class TITOTokenizer:
"""Incremental tokenization and prefix merging using dummy-message diff.

A synthetic base ``[dummy_user, dummy_assistant]`` simulates the assistant
turn boundary so that the diff captures the correct turn-transition tokens:

1. ``tokens_without`` = tokenize(base, add_generation_prompt=False)
2. ``tokens_with`` = tokenize(base + appended, add_generation_prompt=True)
3. ``incremental_ids = tokens_with[len(tokens_without):]``

Subclasses override ``merge_tokens`` to handle model-specific boundary
token quirks.
"""
"""Incremental tokenization and prefix merging for appended non-assistant turns."""

max_trim_tokens: int = 0
trailing_token_ids: frozenset[int] = frozenset()
Expand All @@ -87,6 +83,105 @@ def create_comparator(self) -> TokenSeqComparator:
trim_trailing_ids=self.trailing_token_ids or None,
)

def _render_messages(
self,
messages: list[dict[str, Any]],
*,
add_generation_prompt: bool,
tools: list[dict[str, Any]] | None = None,
) -> str:
return apply_chat_template(
messages,
tokenizer=self.tokenizer,
tokenize=False,
add_generation_prompt=add_generation_prompt,
tools=tools,
**self.chat_template_kwargs,
)

def _encode_text(self, text: str) -> list[int]:
return self.tokenizer.encode(text, add_special_tokens=False)

def _split_appended_segments(self, appended_messages: list[dict[str, Any]]) -> list[list[dict[str, Any]]]:
segments: list[list[dict[str, Any]]] = []
i = 0
while i < len(appended_messages):
role = appended_messages[i]["role"]
# Many templates wrap a contiguous tool-response run as one logical
# block, so tool messages are diffed together instead of one-by-one.
if role == "tool":
j = i + 1
while j < len(appended_messages) and appended_messages[j]["role"] == "tool":
j += 1
segments.append(appended_messages[i:j])
i = j
continue
if role in {"user", "system"}:
segments.append([appended_messages[i]])
i += 1
continue
raise ValueError(f"unsupported appended role for TITO segmentation: {role}")

return segments

def _tokenize_rendered_suffix(
self,
base_messages: list[dict[str, Any]],
appended_messages: list[dict[str, Any]],
*,
tools: list[dict[str, Any]] | None = None,
add_generation_prompt: bool = False,
) -> list[int]:
"""Render *base_messages* and *base_messages + appended_messages*, return
tokens for the suffix.

When *add_generation_prompt* is True and *appended_messages* is empty,
this computes the generation-prompt suffix (the assistant opener tokens).
"""
text_without = self._render_messages(base_messages, add_generation_prompt=False, tools=tools)
text_with = self._render_messages(
base_messages + appended_messages,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
if not text_with.startswith(text_without):
roles = [msg["role"] for msg in appended_messages] if appended_messages else ["generation_prompt"]
raise ValueError(f"rendered suffix diff failed for {roles}")
return self._encode_text(text_with[len(text_without) :])

def _tokenize_tool_segment(
self,
appended_messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> list[int]:
return self._tokenize_rendered_suffix(
[_DUMMY_SYSTEM, _build_dummy_assistant(appended_messages)],
appended_messages,
tools=tools,
)

def _tokenize_user_segment(
self,
appended_message: dict[str, Any],
tools: list[dict[str, Any]] | None = None,
) -> list[int]:
return self._tokenize_rendered_suffix(
[_DUMMY_SYSTEM, _DUMMY_USER],
[appended_message],
tools=tools,
)

def _tokenize_system_segment(
self,
appended_message: dict[str, Any],
tools: list[dict[str, Any]] | None = None,
) -> list[int]:
return self._tokenize_rendered_suffix(
[_DUMMY_SYSTEM],
[appended_message],
tools=tools,
)

def tokenize_additional_non_assistant(
self,
old_messages: list[dict[str, Any]],
Expand Down Expand Up @@ -114,29 +209,31 @@ def tokenize_additional_non_assistant(
"""
assert_messages_append_only_with_allowed_role(old_messages, new_messages, self.allowed_append_roles)
appended_messages = new_messages[len(old_messages) :]

dummy_assistant = _build_dummy_assistant(appended_messages)
base_messages = [_DUMMY_USER, dummy_assistant]

tokens_without = apply_chat_template(
base_messages,
tokenizer=self.tokenizer,
tokenize=True,
add_generation_prompt=False,
incremental: list[int] = []

# Incremental non-assistant content is assembled segment-by-segment
# using the smallest synthetic context that preserves each role's
# boundary tokens.
for segment in self._split_appended_segments(appended_messages):
role = segment[0]["role"]
if role == "tool":
incremental.extend(self._tokenize_tool_segment(segment, tools))
elif role == "user":
incremental.extend(self._tokenize_user_segment(segment[0], tools))
elif role == "system":
incremental.extend(self._tokenize_system_segment(segment[0], tools))
else:
raise ValueError(f"unsupported appended role for TITO tokenization: {role}")

# The next assistant opener depends on the full post-append history, so
# it is derived from the real messages once and appended only at the end.
return incremental + self._tokenize_rendered_suffix(
new_messages,
[],
tools=tools,
**self.chat_template_kwargs,
)
tokens_with = apply_chat_template(
base_messages + list(appended_messages),
tokenizer=self.tokenizer,
tokenize=True,
add_generation_prompt=True,
tools=tools,
**self.chat_template_kwargs,
)

return list(tokens_with[len(tokens_without) :])

def merge_tokens(
self,
old_messages: list[dict[str, Any]],
Expand Down
5 changes: 4 additions & 1 deletion tests/fast/router/test_session_pretokenized_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ class ModelTemplateConfig:
"Qwen/Qwen3-4B-Thinking-2507",
try_get_fixed_chat_template("Qwen/Qwen3-4B-Thinking-2507"),
),
"qwen3.5-native": ModelTemplateConfig("Qwen/Qwen3.5-0.8B", None),
"qwen3.5-fixed": ModelTemplateConfig(
"Qwen/Qwen3.5-0.8B",
try_get_fixed_chat_template("Qwen/Qwen3.5-0.8B"),
),
"qwen3-next-instruct-native": ModelTemplateConfig("Qwen/Qwen3-Next-80B-A3B-Instruct", None),
"qwen3-next-thinking-fixed": ModelTemplateConfig(
"Qwen/Qwen3-Next-80B-A3B-Thinking",
Expand Down
Loading
Loading