diff --git a/tests/test_batched_engine_output_router.py b/tests/test_batched_engine_output_router.py new file mode 100644 index 00000000..50e59a69 --- /dev/null +++ b/tests/test_batched_engine_output_router.py @@ -0,0 +1,301 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for BatchedEngine token-level output routing.""" + +from collections.abc import AsyncIterator + +import pytest + +from vllm_mlx.engine.base import GenerationOutput +from vllm_mlx.engine.batched import BatchedEngine + + +class FakeTokenizer: + """Minimal tokenizer for OutputRouter detection and decoding.""" + + def __init__(self, vocab: dict[str, int]): + self._vocab = vocab + self._id_to_text = {v: k for k, v in vocab.items()} + + def get_vocab(self) -> dict[str, int]: + return self._vocab + + def decode(self, ids: list[int]) -> str: + return "".join(self._id_to_text.get(i, f"") for i in ids) + + +# Harmony vocab IDs mirror the GPT-OSS tokenizer subset used by OutputRouter. +HARMONY_VOCAB = { + "<|return|>": 200002, + "<|constrain|>": 200003, + "<|channel|>": 200005, + "<|start|>": 200006, + "<|end|>": 200007, + "<|message|>": 200008, + "<|call|>": 200012, + "<|endoftext|>": 200019, + "analysis": 35644, + "final": 17196, + "Reason": 2, + "ing": 3, + "Answer": 4, + "Fallback": 5, +} + +QWEN3_VOCAB = { + "": 248068, + "": 248069, + "Reason": 2, + "Answer": 4, +} + +GEMMA4_VOCAB = { + "": 0, + "": 1, + "": 2, + "<|tool>": 46, + "": 47, + "<|tool_call>": 48, + "": 49, + "<|tool_response>": 50, + "": 51, + '<|"|>': 52, + "<|channel>": 100, + "": 101, + "<|turn>": 105, + "": 106, + "thought": 45518, + "content": 3955, + "final": 10218, + "call": 6639, + ":": 236787, + "get": 828, + "_": 236779, + "weather": 19323, + "{": 236782, + "}": 236783, + "city": 13319, + "Tokyo": 89265, +} + + +def _make_engine(tokenizer: FakeTokenizer) -> BatchedEngine: + engine = BatchedEngine("fake-model") + engine._loaded = True + engine._tokenizer = tokenizer + engine._apply_chat_template = lambda *args, **kwargs: "prompt" + engine._compute_prefix_boundary = lambda *args, **kwargs: 0 + return engine + + +async def _collect( + outputs: AsyncIterator[GenerationOutput], +) -> list[GenerationOutput]: + return [output async for output in outputs] + + +@pytest.mark.asyncio +async def test_stream_chat_routes_supported_tokenizer_channels(): + """Supported tokenizers emit channel-tagged chunks and suppress controls.""" + engine = _make_engine(FakeTokenizer(HARMONY_VOCAB)) + + async def fake_stream_generate(**kwargs): + yield GenerationOutput( + text="", + new_text="<|channel|>analysis<|message|>Reason", + tokens=[200005, 35644, 200008, 2], + finished=False, + ) + yield GenerationOutput( + text="", + new_text="ing<|start|><|channel|>final<|message|>Answer", + tokens=[3, 200006, 200005, 17196, 200008, 4], + finished=True, + finish_reason="stop", + ) + + engine.stream_generate = fake_stream_generate + + outputs = await _collect( + engine.stream_chat(messages=[{"role": "user", "content": "hi"}]) + ) + + assert [(o.new_text, o.channel, o.finished) for o in outputs] == [ + ("Reason", "reasoning", False), + ("ing", "reasoning", False), + ("Answer", "content", True), + ] + assert all("<|channel|>" not in output.new_text for output in outputs) + assert all(output.logprobs is None for output in outputs) + + +@pytest.mark.asyncio +async def test_stream_chat_keeps_think_tag_tokenizers_on_legacy_path(): + """Think-tag routers are detected but not engine-enabled until validated.""" + engine = _make_engine(FakeTokenizer(QWEN3_VOCAB)) + + async def fake_stream_generate(**kwargs): + yield GenerationOutput( + text="", + new_text="ReasonAnswer", + tokens=[248068, 2, 248069, 4], + finished=True, + finish_reason="stop", + channel=None, + ) + + engine.stream_generate = fake_stream_generate + + outputs = await _collect( + engine.stream_chat(messages=[{"role": "user", "content": "hi"}]) + ) + + assert len(outputs) == 1 + assert outputs[0].new_text == "ReasonAnswer" + assert outputs[0].channel is None + + +@pytest.mark.asyncio +async def test_stream_chat_routes_tool_call_channel_on_finish(): + """Truncated tool calls are drained as tool_call channel output.""" + engine = _make_engine(FakeTokenizer(GEMMA4_VOCAB)) + + async def fake_stream_generate(**kwargs): + yield GenerationOutput( + text="", + new_text="<|tool_call>call:get_weather{city:Tokyo}", + tokens=[ + 48, + 6639, + 236787, + 828, + 236779, + 19323, + 236782, + 13319, + 236787, + 89265, + 236783, + ], + finished=True, + finish_reason="length", + ) + + engine.stream_generate = fake_stream_generate + + outputs = await _collect( + engine.stream_chat(messages=[{"role": "user", "content": "hi"}]) + ) + + assert [(o.channel, o.finished, o.finish_reason) for o in outputs] == [ + ("tool_call", True, "length") + ] + assert "get_weather" in outputs[0].new_text + assert "Tokyo" in outputs[0].new_text + assert outputs[0].logprobs is None + + +@pytest.mark.asyncio +async def test_stream_chat_uses_incremental_new_text_for_single_token_events(): + """Single-token routed chunks preserve scheduler detokenizer text.""" + vocab = { + **HARMONY_VOCAB, + "decoded-wrong": 6, + } + tokenizer = FakeTokenizer(vocab) + tokenizer._id_to_text[6] = "decoded-wrong" + engine = _make_engine(tokenizer) + + async def fake_stream_generate(**kwargs): + yield GenerationOutput( + text="", + new_text="decoded-right", + tokens=[6], + finished=True, + finish_reason="stop", + ) + + engine.stream_generate = fake_stream_generate + + outputs = await _collect( + engine.stream_chat(messages=[{"role": "user", "content": "hi"}]) + ) + + assert outputs[0].new_text == "decoded-right" + assert outputs[0].channel == "content" + + +@pytest.mark.asyncio +async def test_stream_chat_leaves_unsupported_tokenizer_on_legacy_path(): + """Unsupported tokenizers preserve raw chunks with channel=None.""" + engine = _make_engine(FakeTokenizer({"Hello": 1})) + + async def fake_stream_generate(**kwargs): + yield GenerationOutput( + text="Hello", + new_text="Hello", + tokens=[1], + finished=True, + finish_reason="stop", + channel=None, + ) + + engine.stream_generate = fake_stream_generate + + outputs = await _collect( + engine.stream_chat(messages=[{"role": "user", "content": "hi"}]) + ) + + assert len(outputs) == 1 + assert outputs[0].new_text == "Hello" + assert outputs[0].tokens == [1] + assert outputs[0].channel is None + assert outputs[0].finished is True + + +@pytest.mark.asyncio +async def test_stream_chat_falls_back_after_router_failure(): + """A mid-stream router failure disables routing for later chunks.""" + engine = _make_engine(FakeTokenizer(HARMONY_VOCAB)) + + class FailingRouter: + def feed(self, token_id): + raise RuntimeError("boom") + + async def fake_outputs(): + yield GenerationOutput( + text="", + new_text="Fallback", + tokens=[5], + finished=False, + channel=None, + ) + yield GenerationOutput( + text="", + new_text="Answer", + tokens=[4], + finished=True, + finish_reason="stop", + channel=None, + ) + + outputs = await _collect( + engine._stream_with_output_router(fake_outputs(), FailingRouter()) + ) + + assert [(o.new_text, o.channel, o.finished) for o in outputs] == [ + ("Fallback", None, False), + ("Answer", None, True), + ] + + +def test_create_output_router_catches_tokenizer_property_errors(): + """Tokenizer access failures fall back to legacy parsing.""" + + class BrokenTokenizerEngine(BatchedEngine): + @property + def tokenizer(self): + raise RuntimeError("not loaded") + + engine = BrokenTokenizerEngine("fake-model") + + assert engine._create_output_router() is None diff --git a/tests/test_output_router.py b/tests/test_output_router.py index 882c25d9..366a8ee7 100644 --- a/tests/test_output_router.py +++ b/tests/test_output_router.py @@ -257,6 +257,51 @@ def test_content_after_tool_call(self, router): assert event.channel == Channel.CONTENT +class TestFinalize: + """Test end-of-stream state draining.""" + + def test_finalize_emits_incomplete_tool_call(self, router): + """Unclosed TOOL_CALL state is drained as a TOOL_CALL event.""" + router.feed(48) # <|tool_call> + router.feed(6639) # call + router.feed(236787) # : + router.feed(828) # get + router.feed(236779) # _ + router.feed(19323) # weather + router.feed(236782) # { + router.feed(13319) # city + router.feed(236787) # : + router.feed(52) # <|"|> + router.feed(89265) # Tokyo + router.feed(52) # <|"|> + router.feed(236783) # } + + event = router.finalize() + assert event is not None + assert event.channel == Channel.TOOL_CALL + assert "get_weather" in event.text + assert "Tokyo" in event.text + assert router.state == RouterState.CONTENT + assert router._tool_tokens == [] + assert router.finalize() is None + + def test_finalize_drains_pending_harmony_analysis_message(self): + """Harmony analysis content is preserved if <|message|> is missing.""" + router = OutputRouter.from_tokenizer(HARMONY_TOKENIZER) + assert router is not None + + router.feed(200005) # <|channel|> + router.feed(35644) # analysis + router.feed(2) # Reason + router.feed(3) # ing + + event = router.finalize() + assert event is not None + assert event.channel == Channel.REASONING + assert event.text == "Reasoning" + assert router.state == RouterState.CONTENT + + class TestOrphanTokens: """Test handling of orphaned/leaked special tokens.""" @@ -611,6 +656,7 @@ def test_reset_clears_state(self, router): assert router._tool_tokens == [] assert router._pending_channel_style is None assert router._pending_message_channel is None + assert router._pending_control_tokens == [] def test_multiple_requests(self, router): """Router works correctly across multiple reset cycles.""" diff --git a/tests/test_postprocessor.py b/tests/test_postprocessor.py index caf3af61..3756c8c5 100644 --- a/tests/test_postprocessor.py +++ b/tests/test_postprocessor.py @@ -102,6 +102,19 @@ def test_reasoning_channel(self): assert events[0].type == "reasoning" assert events[0].reasoning == "thinking..." + def test_channel_bypasses_legacy_reasoning_parser(self): + parser = MagicMock() + cfg = _make_cfg(reasoning_parser=parser) + pp = StreamingPostProcessor(cfg) + pp.reset() + + events = pp.process_chunk(_make_output("Hello", channel="content")) + + assert len(events) == 1 + assert events[0].type == "content" + assert events[0].content == "Hello" + parser.extract_reasoning_streaming.assert_not_called() + class TestStreamingPostProcessorReasoning: """Tests for text-based reasoning parser integration.""" diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 4a38617c..eb32114d 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -14,10 +14,12 @@ import functools import logging from collections.abc import AsyncIterator +from dataclasses import replace from typing import Any from ..api.tool_calling import convert_tools_for_template from ..api.utils import clean_output_text, extract_multimodal_content, is_mllm_model +from ..output_router import Channel, OutputRouter from ..utils.chat_template import apply_chat_template as shared_apply_chat_template from .base import BaseEngine, GenerationOutput @@ -52,6 +54,20 @@ def _probe_mllm_cache_type(language_model: Any) -> str | None: return type(sample).__name__ +_CHANNEL_TO_STRING = { + Channel.CONTENT: "content", + Channel.REASONING: "reasoning", + Channel.TOOL_CALL: "tool_call", +} + +_OUTPUT_ROUTER_ALLOWLIST = {"gemma4", "harmony"} + + +def _channel_name(channel: Channel) -> str: + """Convert router channel enum values to GenerationOutput.channel strings.""" + return _CHANNEL_TO_STRING[channel] + + def _compute_metal_cache_limit(soft_limit_bytes: int) -> int: """Pick a Metal free-cache size that scales with the device's working set. @@ -761,6 +777,7 @@ async def stream_generate( yield GenerationOutput( text=clean_output_text(output.output_text), new_text=output.new_text, + tokens=output.new_token_ids, prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finished=output.finished, @@ -927,6 +944,152 @@ def _compute_prefix_boundary( except Exception: return 0 + def _create_output_router(self) -> OutputRouter | None: + """Create a per-request token router for supported tokenizer formats.""" + try: + tokenizer = self.tokenizer + if tokenizer is None: + return None + router = OutputRouter.from_tokenizer(tokenizer) + if router is None: + return None + if router.map.format_tag not in _OUTPUT_ROUTER_ALLOWLIST: + return None + return router + # Unsupported tokenizers are expected to fall through to the legacy + # parser path; construction failures indicate the same non-router path. + except Exception as e: + logger.debug("OutputRouter unavailable for this request: %s", e) + return None + + def _make_routed_output( + self, + source: GenerationOutput, + event, + *, + new_text: str | None = None, + finished: bool = False, + finish_reason: str | None = None, + ) -> GenerationOutput: + return GenerationOutput( + text=source.text, + new_text=event.text if new_text is None else new_text, + tokens=[event.token_id] if event.token_id is not None else [], + prompt_tokens=source.prompt_tokens, + completion_tokens=source.completion_tokens, + finished=finished, + finish_reason=finish_reason, + logprobs=None, + channel=_channel_name(event.channel), + ) + + def _routed_finish_sentinel(self, source: GenerationOutput) -> GenerationOutput: + return GenerationOutput( + text=source.text, + new_text="", + tokens=[], + prompt_tokens=source.prompt_tokens, + completion_tokens=source.completion_tokens, + finished=True, + finish_reason=source.finish_reason, + logprobs=source.logprobs, + channel=None, + ) + + def _finalize_output_router( + self, + router: OutputRouter, + source: GenerationOutput, + ) -> GenerationOutput | None: + try: + event = router.finalize() + except Exception as e: + # Unlike unavailable routers, mid-stream/finalize failures mean a + # selected router broke after consuming request bytes; warn loudly. + logger.warning("OutputRouter finalize failed; falling back: %s", e) + return None + if event is None: + return None + return self._make_routed_output( + source, + event, + finished=True, + finish_reason=source.finish_reason, + ) + + async def _stream_with_output_router( + self, + outputs: AsyncIterator[GenerationOutput], + router: OutputRouter | None, + ) -> AsyncIterator[GenerationOutput]: + """Attach semantic channels to streamed chat tokens when supported. + + This intentionally emits one GenerationOutput per routed token, even + when an upstream flush contains multiple tokens, so downstream + postprocessing sees clean channel boundaries. For the common + stream_interval=1 case, preserve the scheduler's incremental + detokenizer text instead of re-decoding the token in the router. + """ + if router is None: + async for output in outputs: + yield output + return + + async for output in outputs: + if router is None: + yield output + continue + + token_ids = output.tokens + if not token_ids: + yield output + continue + + routed_outputs: list[GenerationOutput] = [] + try: + for token_id in token_ids: + event = router.feed(token_id) + if event is None: + continue + event_text = output.new_text if len(token_ids) == 1 else event.text + routed_outputs.append( + self._make_routed_output( + output, + event, + new_text=event_text, + ) + ) + except Exception as e: + # Unlike unavailable routers, mid-stream failures mean a + # selected router broke after consuming request bytes; warn + # loudly and disable routing for the rest of this request. + logger.warning( + "OutputRouter failed; falling back to legacy parsers: %s", e + ) + router = None + yield output + continue + + if not routed_outputs: + if output.finished: + finalized = self._finalize_output_router(router, output) + yield finalized or self._routed_finish_sentinel(output) + continue + + if output.finished: + finalized = self._finalize_output_router(router, output) + if finalized is None: + routed_outputs[-1] = replace( + routed_outputs[-1], + finished=True, + finish_reason=output.finish_reason, + ) + else: + routed_outputs.append(finalized) + + for routed in routed_outputs: + yield routed + async def stream_chat( self, messages: list[dict[str, Any]], @@ -986,7 +1149,8 @@ async def stream_chat( if prefix_boundary > 0: kwargs["prefix_boundary"] = prefix_boundary - async for output in self.stream_generate( + router = self._create_output_router() + stream = self.stream_generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, @@ -994,7 +1158,8 @@ async def stream_chat( images=all_images if all_images else None, videos=all_videos if all_videos else None, **kwargs, - ): + ) + async for output in self._stream_with_output_router(stream, router): yield output def get_stats(self) -> dict[str, Any]: diff --git a/vllm_mlx/output_router.py b/vllm_mlx/output_router.py index cc358461..9a513bfb 100644 --- a/vllm_mlx/output_router.py +++ b/vllm_mlx/output_router.py @@ -59,6 +59,8 @@ class RouterEvent: class TokenMap: """Special token ID mappings for a model family.""" + format_tag: str = "" + # Channel control (Gemma 4 style) channel_start: int | None = None # <|channel> = 100 channel_end: int | None = None # = 101 @@ -126,6 +128,7 @@ def __init__(self, token_map: TokenMap, tokenizer: Any): self._tool_tokens: list[int] = [] # accumulated tool call token IDs self._pending_channel_style: str | None = None self._pending_message_channel: Channel | None = None + self._pending_control_tokens: list[int] = [] def reset(self): """Reset state for a new request.""" @@ -133,6 +136,7 @@ def reset(self): self._tool_tokens = [] self._pending_channel_style = None self._pending_message_channel = None + self._pending_control_tokens = [] def feed(self, token_id: int) -> RouterEvent | None: """ @@ -153,12 +157,14 @@ def feed(self, token_id: int) -> RouterEvent | None: self.state = RouterState.AWAITING_CHANNEL_TYPE self._pending_channel_style = "harmony" self._pending_message_channel = None + self._pending_control_tokens = [] return None if token_id == m.harmony_end or token_id == m.harmony_return: self.state = RouterState.CONTENT self._pending_channel_style = None self._pending_message_channel = None + self._pending_control_tokens = [] return None if token_id in (m.harmony_start, m.harmony_call, m.harmony_constrain): @@ -176,6 +182,7 @@ def feed(self, token_id: int) -> RouterEvent | None: # === Channel start: transition to AWAITING_CHANNEL_TYPE === if token_id == m.channel_start: self.state = RouterState.AWAITING_CHANNEL_TYPE + self._pending_control_tokens = [] return None # suppress <|channel> # === Channel type word: set state based on which channel === @@ -192,6 +199,7 @@ def feed(self, token_id: int) -> RouterEvent | None: self.state = RouterState.CONTENT self._pending_channel_style = None + self._pending_control_tokens = [] text = self.tokenizer.decode([token_id]) return RouterEvent(Channel.CONTENT, token_id, text) @@ -204,6 +212,7 @@ def feed(self, token_id: int) -> RouterEvent | None: else: # Unknown channel type — treat as content self.state = RouterState.CONTENT + self._pending_control_tokens = [] text = self.tokenizer.decode([token_id]) return RouterEvent(Channel.CONTENT, token_id, text) @@ -217,6 +226,9 @@ def feed(self, token_id: int) -> RouterEvent | None: ) self._pending_channel_style = None self._pending_message_channel = None + self._pending_control_tokens = [] + else: + self._pending_control_tokens.append(token_id) return None if token_id == m.harmony_message: @@ -264,6 +276,41 @@ def feed(self, token_id: int) -> RouterEvent | None: else: return RouterEvent(Channel.CONTENT, token_id, text) + def finalize(self) -> RouterEvent | None: + """Drain any buffered state at stream end. + + This is best-effort: complete channel transitions are still handled by + feed(), while finalize() preserves buffered tool calls or pending + Harmony pre-message text that would otherwise be dropped. + """ + if self.state == RouterState.TOOL_CALL and self._tool_tokens: + token_id = self._tool_tokens[-1] + text = self.tokenizer.decode(self._tool_tokens) + self.state = RouterState.CONTENT + self._tool_tokens = [] + return RouterEvent(Channel.TOOL_CALL, token_id, text) + + if self.state in ( + RouterState.AWAITING_CHANNEL_TYPE, + RouterState.AWAITING_MESSAGE, + ): + if self._pending_control_tokens: + channel = self._pending_message_channel or Channel.CONTENT + token_id = self._pending_control_tokens[-1] + text = self.tokenizer.decode(self._pending_control_tokens) + self.state = RouterState.CONTENT + self._pending_channel_style = None + self._pending_message_channel = None + self._pending_control_tokens = [] + if text.strip(): + return RouterEvent(channel, token_id, text) + else: + self.state = RouterState.CONTENT + self._pending_channel_style = None + self._pending_message_channel = None + + return None + def feed_sequence(self, token_ids: list[int]) -> dict[str, str]: """ Feed a complete token sequence and return separated channels. @@ -286,6 +333,15 @@ def feed_sequence(self, token_ids: list[int]) -> dict[str, str]: elif event.channel == Channel.TOOL_CALL: tool_calls.append(event.text) + event = self.finalize() + if event is not None: + if event.channel == Channel.CONTENT: + content += event.text + elif event.channel == Channel.REASONING: + reasoning += event.text + elif event.channel == Channel.TOOL_CALL: + tool_calls.append(event.text) + return { "content": content.strip() or None, "reasoning": reasoning.strip() or None, @@ -305,6 +361,7 @@ def from_tokenizer(cls, tokenizer: Any) -> "OutputRouter | None": # Gemma 4 detection: look for <|channel> and <|tool_call> if "<|channel>" in vocab and "<|tool_call>" in vocab: token_map = TokenMap( + format_tag="gemma4", channel_start=vocab.get("<|channel>"), channel_end=vocab.get(""), thought_word=vocab.get("thought"), @@ -335,6 +392,7 @@ def from_tokenizer(cls, tokenizer: Any) -> "OutputRouter | None": # GPT-OSS/Harmony detection: channel/message special tokens. if "<|channel|>" in vocab and "<|message|>" in vocab: token_map = TokenMap( + format_tag="harmony", harmony_channel=vocab.get("<|channel|>"), harmony_message=vocab.get("<|message|>"), harmony_start=vocab.get("<|start|>"), @@ -360,6 +418,7 @@ def from_tokenizer(cls, tokenizer: Any) -> "OutputRouter | None": # for Qwen3 (which has neither in its vocab). if "" in vocab and "" in vocab: token_map = TokenMap( + format_tag="think", think_start=vocab.get(""), think_end=vocab.get(""), bos=vocab.get("<|begin▁of▁sentence|>") or vocab.get(""), diff --git a/vllm_mlx/service/postprocessor.py b/vllm_mlx/service/postprocessor.py index 4d14dd31..2f442e92 100644 --- a/vllm_mlx/service/postprocessor.py +++ b/vllm_mlx/service/postprocessor.py @@ -212,16 +212,15 @@ def process_chunk(self, output: GenerationOutput) -> list[StreamEvent]: return [] # Step 1: Separate content from reasoning - if output.channel: + if output.channel is not None: return self._process_channel_routed(delta_text, output) - elif self.reasoning_parser and self.enable_thinking is not False: + if self.reasoning_parser and self.enable_thinking is not False: # When enable_thinking is explicitly False, the model is told to # skip thinking and answer directly. Bypass the reasoning parser # so its implicit-think heuristic doesn't reroute the answer to # reasoning_content. return self._process_with_reasoning(delta_text, output) - else: - return self._process_standard(delta_text, output) + return self._process_standard(delta_text, output) def _process_channel_routed( self, delta_text: str, output: GenerationOutput