diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 79c38ba1ea..33b15d6d0f 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -716,7 +716,8 @@ def test_parse_response(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_with_reasoning_content(self, model_name): @@ -746,7 +747,8 @@ def test_parse_response_with_reasoning_content(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_tool_call(self, model_name): @@ -772,7 +774,8 @@ def test_parse_response_tool_call(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_tool_call_with_content(self, model_name): @@ -797,7 +800,8 @@ def test_parse_response_tool_call_with_content(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_tool_call_without_arguments(self, model_name): @@ -823,7 +827,8 @@ def test_parse_response_tool_call_without_arguments(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_multiple_tool_calls(self, model_name): @@ -858,7 +863,8 @@ def test_parse_response_multiple_tool_calls(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_malformed_tool_call(self, model_name): diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index d6079b838b..5b63e43e17 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -685,7 +685,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, ids: list[int]) -> dict: +def parse_response(tokenizer: PreTrainedTokenizerBase, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -695,11 +695,9 @@ def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, i Also removes incorrectly appended EOS tokens from tool call content when present, and validates tool_calls to ensure all required fields exist. - For VLM processors, automatically uses the inner tokenizer for parsing. - Args: - processing_class (`PreTrainedTokenizerBase` or VLM processor): - Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer). + tokenizer (`PreTrainedTokenizerBase`): + Tokenizer with a `parse_response()` method. ids (`list[int]`): List of token sequences. @@ -720,8 +718,6 @@ def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, i {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ - # VLM processors don't have parse_response directly; use the inner tokenizer - tokenizer = getattr(processing_class, "tokenizer", processing_class) try: parsed = tokenizer.parse_response(ids) # Hotfix: remove incorrectly appended EOS token from tool calls