Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/generate/tokenizer_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,61 @@ def test_default_tokenizer(self):
self.assertEqual(tokenizer.tokenizer_type, 'sentencepiece')
self.assertIsNotNone(tokenizer._tokenizer_type, adapter.TokenizerType.SP)

def test_tokenize_hf(self):
model = 'meta-llama/Meta-Llama-3-8B-Instruct'
tokenizer = adapter.Tokenizer(
tokenizer_type='huggingface', tokenizer_path=model
)

text = 'hello world'
tokens = tokenizer.tokenize(text, add_eos=True)

# Check BOS
self.assertEqual(tokens[0], tokenizer.bos_id())
# Check EOS
self.assertEqual(tokens[-1], tokenizer.eos_id())

# Check no double BOS (assuming "hello" doesn't tokenize to BOS)
self.assertNotEqual(tokens[1], tokenizer.bos_id())

# Test add_bos=False
tokens_no_bos = tokenizer.tokenize(text, add_bos=False, add_eos=True)
self.assertNotEqual(tokens_no_bos[0], tokenizer.bos_id())
self.assertEqual(tokens_no_bos[-1], tokenizer.eos_id())

# Decode back
decoded = tokenizer.decode(tokens)
# decoded might contain special tokens depending on decode implementation
# TokenizerAdapter.decode calls tokenizer.decode. HF decode usually skips special tokens by default?
# No, skip_special_tokens defaults to False in HF decode?
# Actually TokenizerAdapter.decode just calls self._tokenizer.decode(ids, **kwargs).
# We didn't pass kwargs.

# Let's just check length or content roughly.
self.assertTrue(len(tokens) > 2)

def test_special_eos_token(self):
model = 'meta-llama/Meta-Llama-3-8B-Instruct'
# Use a token that definitely exists and is different from default EOS
special_eos = 'world'
tokenizer = adapter.Tokenizer(
tokenizer_type='huggingface',
tokenizer_path=model,
special_eos_token=special_eos,
)

text = 'hello'
tokens = tokenizer.tokenize(text, add_eos=True)

# Check EOS is the special token
self.assertEqual(tokens[-1], tokenizer.eos_id())

# Verify the ID matches what we expect from the tokenizer directly
hf_tokenizer = tokenizer.tokenizer
expected_id = hf_tokenizer.convert_tokens_to_ids(special_eos)
self.assertEqual(tokenizer.eos_id(), expected_id)
self.assertNotEqual(tokenizer.eos_id(), hf_tokenizer.eos_token_id)


if __name__ == '__main__':
absltest.main()
45 changes: 36 additions & 9 deletions tunix/generate/tokenizer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from etils import epath
import numpy as np

import sentencepiece as spm


Expand Down Expand Up @@ -219,6 +218,7 @@ def __init__(
add_bos: bool | None = True,
add_eos: bool | None = True,
hf_access_token: str | None = None,
special_eos_token: str | None = None,
):

self.tokenizer_type = tokenizer_type
Expand Down Expand Up @@ -248,11 +248,26 @@ def __init__(
raise ValueError(f'Unsupported tokenizer_type: {tokenizer_type}')
super().__init__(tokenizer)

self._special_eos_id = None
if special_eos_token:
if tokenizer_type == 'huggingface':
self._special_eos_id = tokenizer.convert_tokens_to_ids(
special_eos_token
)
elif tokenizer_type == 'sentencepiece':
self._special_eos_id = tokenizer.PieceToId(special_eos_token)

def eos_id(self) -> int:
if self._special_eos_id is not None:
return self._special_eos_id
return super().eos_id()

def tokenize(
self,
example: str,
prefix: str = '',
suffix: str = '',
add_bos: bool = True,
add_eos: bool = True,
) -> np.ndarray:
"""The tokenization function.
Expand All @@ -261,22 +276,34 @@ def tokenize(
example: Input string to tokenize.
prefix: Prefix to add to the input string.
suffix: Suffix to add to the input string.
add_bos: If True, add a "beginning of sentence" token at the start of the
output sequence.
add_eos: If True, add an "end of sentence" token at the end of the output
sequence.

Returns:
Tokens corresponding to the input string.
"""
int_list = []
if self.bos_id():
int_list.append(self.bos_id())
if self.tokenizer_type == 'huggingface':
int_list.extend(
self.encode(prefix + example + suffix, add_special_tokens=False)
)
bos_id = self.bos_id()
if add_bos and bos_id is not None:
int_list.append(bos_id)

text = prefix + example + suffix
if self._tokenizer_type == TokenizerType.HF:
int_list.extend(self.encode(text, add_special_tokens=False))
else:
# sentencepiece
int_list.extend(self.tokenizer.EncodeAsIds(prefix + example + suffix))
int_list.extend(self.encode(text))

if bos_id is not None:
if add_bos:
# Deduplicate BOS tokens if the tokenizer added one and we added one.
int_list = self.dedup_bos_ids(int_list)
else:
# Remove BOS token if the tokenizer added one.
while int_list and int_list[0] == bos_id:
int_list.pop(0)

if add_eos:
int_list.append(self.eos_id())
return np.array(int_list, dtype=np.int32)