Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions trl/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None:
tool_call["arguments"] = {}


def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids: list[int]) -> dict:
def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict:
r"""
Parse a token sequence into structured response dictionaries with fallback handling.

Expand All @@ -637,11 +637,9 @@ def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids:
Also removes incorrectly appended EOS tokens from tool call content when present, and validates tool_calls to
ensure all required fields exist.

For VLM processors, automatically uses the inner tokenizer for parsing.

Args:
processing_class (`PreTrainedTokenizer` or VLM processor):
Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer).
Comment thread
cursor[bot] marked this conversation as resolved.
tokenizer (`PreTrainedTokenizer`):
Tokenizer with a `parse_response()` method.
ids (`list[int]`):
List of token sequences.

Expand All @@ -662,8 +660,6 @@ def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids:
{'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]}
```
"""
# VLM processors don't have parse_response directly; use the inner tokenizer
tokenizer = getattr(processing_class, "tokenizer", processing_class)
try:
parsed = tokenizer.parse_response(ids)
# Hotfix: remove incorrectly appended EOS token from tool calls
Expand Down
7 changes: 3 additions & 4 deletions trl/experimental/dppo/dppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,8 @@ async def _run_async_tools(async_coros):
completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx]

# Decode post-tool completions
post_tool_completions = [
parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids
]
tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class
post_tool_completions = [parse_response(tokenizer, ids) if ids else {} for ids in post_tool_ids]

for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
Expand Down Expand Up @@ -674,7 +673,7 @@ def _generate(self, prompts: list):
and hasattr(tokenizer, "response_schema") # attribute not set by default for now
and tokenizer.response_schema is not None # only works if the tokenizer has a schema
):
completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids]
completions = [[parse_response(tokenizer, ids)] for ids in completion_ids]
else:
contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in contents]
Expand Down
9 changes: 4 additions & 5 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,10 +1657,9 @@ async def _run_async_tools(async_coros):
pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool
completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx]

# Decode post-tool completions.
post_tool_completions = [
parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids
]
# Decode post-tool completions
tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class
post_tool_completions = [parse_response(tokenizer, ids) if ids else {} for ids in post_tool_ids]

# Add post-tool completions to the existing completions
for idx in range(len(idxs_with_tool)):
Expand Down Expand Up @@ -1716,7 +1715,7 @@ def _generate(self, prompts: list):
and hasattr(tokenizer, "response_schema") # attribute not set by default for now
and tokenizer.response_schema is not None # only works if the tokenizer has a schema
):
completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids]
completions = [[parse_response(tokenizer, ids)] for ids in completion_ids]
else:
contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in contents]
Expand Down
Loading