Skip to content

feat(tokenization): add encode_message to tokenize messages one by one #39507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

pco111
Copy link

@pco111 pco111 commented Jul 18, 2025

What does this PR do?
This PR introduces a new method, tokenizer.encode_message, to the base tokenizer class. This method allows for tokenizing a single chat message at a time while correctly handling the conversational context provided by conversation_history. This is particularly useful for token-by-token streaming applications where re-tokenizing the entire conversation history for each new token is inefficient.
The new method works by applying the chat template to the full conversation (history + new message) and then programmatically isolating the tokens that correspond to the new message. This ensures that all special tokens, roles, and formatting are applied correctly according to the model's chat template, maintaining consistency with apply_chat_template.

Fixes #39417
Before submitting
[x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
[x] Did you read the contributor guideline,
Pull Request section?
[x] Was this discussed/approved via a Github issue or the forum? Please add a link
to it if that's the case.
[x] Did you make sure to update the documentation with your changes? Here are the
documentation guidelines, and
here are tips on formatting docstrings.
[x] Did you write any new necessary tests?

Who can review?
@ArthurZucker @Rocketknight1

@ArthurZucker
Copy link
Collaborator

I like this! cc @Rocketknight1 if you can have a look!

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jul 21, 2025

Hi @pco111, this is a cool idea, but I'm not sure about some of the details! In particular, the interaction with add_generation_prompt is awkward. If we set that to True, then a common scenario is that the conversation_history will be tokenized like this, where the "generation prompt" is the last line:

<im_start>user
message<im_end>
<im_start>assistant

But in this case, encode_message() will treat <im_start>assistant as part of the history, and remove it from the encoded message, and then the encoded message will be incomplete. I'm not sure what the best solution is - maybe always set add_generation_prompt to False?

…arameter and add the corresponding error handling. Update the document to reflect this change and verify the error handling in the test.
@pco111
Copy link
Author

pco111 commented Jul 21, 2025

Hi @pco111, this is a cool idea, but I'm not sure about some of the details! In particular, the interaction with add_generation_prompt is awkward. If we set that to True, then a common scenario is that the conversation_history will be tokenized like this, where the "generation prompt" is the last line:

<im_start>user
message<im_end>
<im_start>assistant

But in this case, encode_message() will treat <im_start>assistant as part of the history, and remove it from the encoded message, and then the encoded message will be incomplete. I'm not sure what the best solution is - maybe always set add_generation_prompt to False?

Hi @Rocketknight1 ,

Thank you so much for your insightful feedback! You've pointed out a very important edge case with add_generation_prompt that I had overlooked.

Following your thoughts, I've opted for a clearer and more robust approach:

  1. Explicitly Disallowed add_generation_prompt: The encode_message method now raises a ValueError if add_generation_prompt is passed. This prevents any ambiguity.
  2. Updated Documentation: The docstring for encode_message now clearly states that it does not handle the generation prompt and advises users on how to add it separately if needed.
  3. Updated Tests: The tests have been updated to reflect this new design. There is now a test to ensure that the ValueError is raised correctly.

Thank you again for guiding me toward a better solution! I've pushed the new changes for your review.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some comments! Also, check the CI on Github - you may need to run make fixup to get the style tests to pass.

Comment on lines 1771 to 1772
if conversation_history is None:
conversation_history = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where conversation_history is None, presumably you just want to return the output of apply_chat_template() without changes?

@@ -1695,6 +1695,89 @@ def apply_chat_template(
else:
return rendered_chat

def _encode_message(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need a separate helper function! This can be folded into the main function to keep things simpler.

@@ -3253,7 +3336,7 @@ def pad(
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.

This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Tensor Cores" is correct, so we don't want this change.

@@ -375,3 +376,34 @@ def test_training_new_tokenizer_edge_cases(self):
tokenizer = PreTrainedTokenizerFast(tokenizer_object=_tokenizer)
toy_text_iterator = ("a" for _ in range(1000))
tokenizer.train_new_from_iterator(text_iterator=toy_text_iterator, length=1000, vocab_size=50)


class ChatTemplateTest(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some other chat template tests in existing test classes already, so this should probably go in one of those rather than making a new class!

pco111 added 3 commits July 22, 2025 14:25
… the empty dialogue history, and ensure that the chat template can be applied correctly when the dialogue history is empty. Update the document to reflect these changes.
…simplified, and the functional integrity of the `encode_message` method is ensured. Update the document to reflect these changes.
@pco111
Copy link
Author

pco111 commented Jul 22, 2025

Made some comments! Also, check the CI on Github - you may need to run make fixup to get the style tests to pass.

Hi @Rocketknight1,

Thank you for the detailed and helpful feedback! I've updated the PR according to all your suggestions:

  • The _encode_message helper has been folded into the main encode_message function.
  • An optimization has been added to handle empty conversation history directly.
  • The "Tensor Cores" typo has been corrected.
  • The new tests have been moved into the existing TokenizerUtilsTest class.

All local checks (make fixup and pytest) are passing. The code should be in much better shape now. Thanks again for your guidance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Option to tokenize messages one after the other
3 participants