Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,8 @@ def test_parse_response(self, model_name):
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(processing_class, response)
tokenizer = processing_class.tokenizer if self.is_vlm else processing_class
parsed = parse_response(tokenizer, response)
assert parsed == expected

def test_parse_response_with_reasoning_content(self, model_name):
Expand Down Expand Up @@ -746,7 +747,8 @@ def test_parse_response_with_reasoning_content(self, model_name):
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(processing_class, response)
tokenizer = processing_class.tokenizer if self.is_vlm else processing_class
parsed = parse_response(tokenizer, response)
assert parsed == expected

def test_parse_response_tool_call(self, model_name):
Expand All @@ -772,7 +774,8 @@ def test_parse_response_tool_call(self, model_name):
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(processing_class, response)
tokenizer = processing_class.tokenizer if self.is_vlm else processing_class
parsed = parse_response(tokenizer, response)
assert parsed == expected

def test_parse_response_tool_call_with_content(self, model_name):
Expand All @@ -797,7 +800,8 @@ def test_parse_response_tool_call_with_content(self, model_name):
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(processing_class, response)
tokenizer = processing_class.tokenizer if self.is_vlm else processing_class
parsed = parse_response(tokenizer, response)
assert parsed == expected

def test_parse_response_tool_call_without_arguments(self, model_name):
Expand All @@ -823,7 +827,8 @@ def test_parse_response_tool_call_without_arguments(self, model_name):
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(processing_class, response)
tokenizer = processing_class.tokenizer if self.is_vlm else processing_class
parsed = parse_response(tokenizer, response)
assert parsed == expected

def test_parse_response_multiple_tool_calls(self, model_name):
Expand Down Expand Up @@ -858,7 +863,8 @@ def test_parse_response_multiple_tool_calls(self, model_name):
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(processing_class, response)
tokenizer = processing_class.tokenizer if self.is_vlm else processing_class
parsed = parse_response(tokenizer, response)
assert parsed == expected

def test_parse_response_malformed_tool_call(self, model_name):
Expand Down
10 changes: 3 additions & 7 deletions trl/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None:
tool_call["arguments"] = {}


def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, ids: list[int]) -> dict:
def parse_response(tokenizer: PreTrainedTokenizerBase, ids: list[int]) -> dict:
r"""
Parse a token sequence into structured response dictionaries with fallback handling.

Expand All @@ -695,11 +695,9 @@ def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, i
Also removes incorrectly appended EOS tokens from tool call content when present, and validates tool_calls to
ensure all required fields exist.

For VLM processors, automatically uses the inner tokenizer for parsing.

Args:
processing_class (`PreTrainedTokenizerBase` or VLM processor):
Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer).
Comment thread
cursor[bot] marked this conversation as resolved.
tokenizer (`PreTrainedTokenizerBase`):
Tokenizer with a `parse_response()` method.
ids (`list[int]`):
List of token sequences.

Expand All @@ -720,8 +718,6 @@ def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, i
{'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]}
```
"""
# VLM processors don't have parse_response directly; use the inner tokenizer
tokenizer = getattr(processing_class, "tokenizer", processing_class)
try:
parsed = tokenizer.parse_response(ids)
# Hotfix: remove incorrectly appended EOS token from tool calls
Expand Down
Loading